package org.reactome.factorgraph;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:caBIGR3-minimal-2.0.jar:org/reactome/factorgraph/AbstractInferencer.class */
public abstract class AbstractInferencer implements Inferencer {
    protected FactorGraph factorGraph;
    protected Map<Variable, Integer> observation;
    private List<Factor> observationFactors;
    protected double tolerance = 1.0E-6d;
    protected int maxIteration = 10000;
    protected int iteration;
    protected double maxDiff;
    protected boolean debug;

    @Override // org.reactome.factorgraph.Inferencer
    public void setFactorGraph(FactorGraph factorGraph) {
        this.factorGraph = factorGraph;
    }

    @Override // org.reactome.factorgraph.Inferencer
    public FactorGraph getFactorGraph() {
        return this.factorGraph;
    }

    @Override // org.reactome.factorgraph.Inferencer
    public void setObservation(Map<Variable, Integer> map) {
        this.observation = map;
    }

    @Override // org.reactome.factorgraph.Inferencer
    public Map<Variable, Integer> getObservation() {
        return this.observation;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    public void setTolerance(double d) {
        this.tolerance = d;
    }

    public int getMaxIteration() {
        return this.maxIteration;
    }

    public void setMaxIteration(int i) {
        this.maxIteration = i;
    }

    public int getIteration() {
        return this.iteration;
    }

    public boolean getDebug() {
        return this.debug;
    }

    public void setDebug(boolean z) {
        this.debug = z;
    }

    public double getMaxDiff() {
        return this.maxDiff;
    }

    @Override // org.reactome.factorgraph.Inferencer
    public abstract void runInference() throws InferenceCannotConvergeException;

    @Override // org.reactome.factorgraph.Inferencer
    public double calculateLogZ() {
        double d = 0.0d;
        for (Variable variable : this.factorGraph.getVariables()) {
            double d2 = 0.0d;
            for (double d3 : variable.getBelief()) {
                if (d3 != 0.0d) {
                    d2 += d3 * Math.log(d3);
                }
            }
            d += (variable.getFactors().size() - 1) * d2;
            if (Double.isNaN(d)) {
                throw new IllegalStateException("NaN encountered in variable: " + variable + " with belief " + variable.getBelief());
            }
        }
        for (Factor factor : this.factorGraph.getFactors()) {
            double[] belief = factor.getBelief();
            double[] values = factor.getValues();
            for (int i = 0; i < belief.length; i++) {
                if (belief[i] != 0.0d && values[i] != 0.0d) {
                    d += belief[i] * (Math.log(values[i]) - Math.log(belief[i]));
                    if (Double.isNaN(d)) {
                        throw new IllegalStateException("NaN encountered in factor: " + factor + " with belief " + belief[i]);
                    }
                }
            }
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void attachObservation() {
        if (this.observation == null || this.observation.size() == 0) {
            return;
        }
        if (this.observationFactors == null) {
            this.observationFactors = new ArrayList();
        } else {
            this.observationFactors.clear();
        }
        Set<Variable> variables = this.factorGraph.getVariables();
        for (Variable variable : this.observation.keySet()) {
            if (variables.contains(variable)) {
                Factor factor = new Factor();
                ArrayList arrayList = new ArrayList();
                arrayList.add(variable);
                factor.setVariables(arrayList);
                variable.addFactor(factor);
                double[] dArr = new double[variable.getStates()];
                dArr[this.observation.get(variable).intValue()] = 1.0d;
                factor.setValues(dArr);
                this.factorGraph.addFactor(factor);
                this.observationFactors.add(factor);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void detachObservation() {
        if (this.observation == null || this.observation.size() == 0) {
            return;
        }
        this.factorGraph.getFactors().removeAll(this.observationFactors);
        for (Factor factor : this.observationFactors) {
            Iterator<Variable> it = factor.getVariables().iterator();
            while (it.hasNext()) {
                it.next().removeFactor(factor);
            }
            if (factor.getOutEdges() != null) {
                for (Edge edge : factor.getOutEdges()) {
                    Variable variable = (Variable) edge.getToNode();
                    variable.removeInEdge(edge);
                    variable.removeOutEdge(factor);
                }
            }
        }
        this.observationFactors.clear();
    }
}
