package org.reactome.factorgraph;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.log4j.Logger;

/* loaded from: input_file:caBIGR3-minimal-2.0.jar:org/reactome/factorgraph/RandomRestartEM.class */
public class RandomRestartEM extends ExpectationMaximization {
    private final Logger logger = Logger.getLogger(RandomRestartEM.class);
    private int restart = 10;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:caBIGR3-minimal-2.0.jar:org/reactome/factorgraph/RandomRestartEM$LearnResult.class */
    public class LearnResult {
        private Map<EMFactor, double[]> factorToValues;
        private double logLikelihood;

        private LearnResult() {
        }

        /* synthetic */ LearnResult(RandomRestartEM randomRestartEM, LearnResult learnResult) {
            this();
        }
    }

    public void setRestart(int i) {
        if (i <= 0) {
            return;
        }
        this.restart = i;
    }

    public int getRestart() {
        return this.restart;
    }

    @Override // org.reactome.factorgraph.ExpectationMaximization
    public double learn(FactorGraph factorGraph, List<EMFactor> list) throws InferenceCannotConvergeException {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.restart; i++) {
            this.logger.info("Random parameters:");
            for (EMFactor eMFactor : list) {
                eMFactor.randomFactorValues();
                this.logger.info(String.valueOf(eMFactor.getName()) + ": " + Arrays.toString(eMFactor.getValues()));
            }
            double learn = super.learn(factorGraph, list);
            this.logger.info("Parameters learned:");
            for (EMFactor eMFactor2 : list) {
                this.logger.info(String.valueOf(eMFactor2.getName()) + ": " + Arrays.toString(eMFactor2.getValues()));
            }
            storeLearnResult(learn, list, arrayList);
        }
        double d = Double.NEGATIVE_INFINITY;
        LearnResult learnResult = null;
        for (LearnResult learnResult2 : arrayList) {
            if (learnResult2.logLikelihood > d) {
                learnResult = learnResult2;
                d = learnResult2.logLikelihood;
            }
        }
        for (EMFactor eMFactor3 : list) {
            eMFactor3.setValues((double[]) learnResult.factorToValues.get(eMFactor3));
        }
        return d;
    }

    private void storeLearnResult(double d, List<EMFactor> list, List<LearnResult> list2) {
        HashMap hashMap = new HashMap();
        for (EMFactor eMFactor : list) {
            double[] values = eMFactor.getValues();
            hashMap.put(eMFactor, Arrays.copyOf(values, values.length));
        }
        LearnResult learnResult = new LearnResult(this, null);
        learnResult.factorToValues = hashMap;
        learnResult.logLikelihood = d;
        list2.add(learnResult);
    }
}
