package org.reactome.factorgraph;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.log4j.Logger;

/* loaded from: input_file:caBIGR3-minimal-2.0.jar:org/reactome/factorgraph/ExpectationMaximization.class */
public class ExpectationMaximization {
    private static final Logger logger = Logger.getLogger(ExpectationMaximization.class);
    private List<Observation> evidences;
    private boolean debug;
    private int maxIteration = 50;
    private double tolerance = 0.001d;
    private Inferencer inferencer = new LoopyBeliefPropagation();

    public void setDebug(boolean z) {
        this.debug = z;
    }

    public boolean getDebug() {
        return this.debug;
    }

    public void setInferenser(Inferencer inferencer) {
        this.inferencer = inferencer;
    }

    public Inferencer getInferencer() {
        return this.inferencer;
    }

    public List<Observation> getEvidences() {
        return this.evidences;
    }

    public void setEvidences(List<Observation> list) {
        this.evidences = list;
    }

    public int getMaxIteration() {
        return this.maxIteration;
    }

    public void setMaxIteration(int i) {
        this.maxIteration = i;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    public void setTolerance(double d) {
        this.tolerance = d;
    }

    public double learn(FactorGraph factorGraph, List<EMFactor> list) throws InferenceCannotConvergeException {
        if (list == null || list.size() == 0) {
            throw new IllegalArgumentException("No learning factors has been provided.");
        }
        if (this.evidences == null || this.evidences.size() == 0) {
            throw new IllegalStateException("No evidences have been assigned for the EM learning.");
        }
        this.inferencer.setFactorGraph(factorGraph);
        int i = 0;
        double d = 0.0d;
        double d2 = Double.MAX_VALUE;
        while (i < this.maxIteration && d2 > this.tolerance) {
            double expect = expect(factorGraph, list);
            maximize(factorGraph, list);
            if (i > 0) {
                double d3 = expect - d;
                if (d3 < 0.0d) {
                    logger.warn("During learning loglikelihood decrease by: " + d3);
                }
                d2 = Math.abs(d3 / d);
            }
            d = expect;
            i++;
            if (this.debug) {
                logger.info("Loglikelihood for iteration " + i + ": " + expect + " and difference: " + (i > 1 ? Double.valueOf(d2) : "N/A"));
                logger.info("Learned parameters:");
                for (EMFactor eMFactor : list) {
                    logger.info(String.valueOf(eMFactor.getName()) + ": " + Arrays.toString(eMFactor.getValues()));
                }
            }
        }
        return d;
    }

    private double expect(FactorGraph factorGraph, List<EMFactor> list) throws InferenceCannotConvergeException {
        Iterator<EMFactor> it = list.iterator();
        while (it.hasNext()) {
            it.next().initCounts();
        }
        this.inferencer.setObservation(null);
        this.inferencer.runInference();
        double calculateLogZ = this.inferencer.calculateLogZ();
        if (this.debug) {
            logger.info("LogZ: " + calculateLogZ);
        }
        double d = 0.0d;
        Iterator<Observation> it2 = this.evidences.iterator();
        while (it2.hasNext()) {
            this.inferencer.setObservation(it2.next().getVariableToAssignment());
            this.inferencer.runInference();
            Iterator<EMFactor> it3 = list.iterator();
            while (it3.hasNext()) {
                expect(it3.next());
            }
            d += this.inferencer.calculateLogZ() - calculateLogZ;
        }
        return d;
    }

    protected void expect(EMFactor eMFactor) {
        eMFactor.updateCounts();
    }

    private void maximize(FactorGraph factorGraph, List<EMFactor> list) {
        Iterator<EMFactor> it = list.iterator();
        while (it.hasNext()) {
            it.next().updateFactorValues();
        }
    }
}
