package org.reactome.factorgraph;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.math3.random.RandomDataGenerator;
import org.apache.log4j.Logger;
import org.junit.Test;
import org.reactome.r3.util.FileUtility;

/* loaded from: input_file:caBIGR3-minimal-3.0.jar:org/reactome/factorgraph/GibbsSampling.class */
public class GibbsSampling extends AbstractInferencer {
    private static final Logger logger = Logger.getLogger(GibbsSampling.class);
    private int burnin;
    private Map<Variable, Factor> varToMBFactor;
    private Map<Variable, Set<Variable>> varToMbVars;
    private final int DEFAULT_MAX_ITERATION = 500;
    private int restart = 1;
    private RandomDataGenerator randomizer = new RandomDataGenerator();

    public GibbsSampling() {
        this.maxIteration = 500;
    }

    public void setRandomGenerator(RandomDataGenerator randomDataGenerator) {
        this.randomizer = randomDataGenerator;
    }

    public RandomDataGenerator getRandomGenerator() {
        return this.randomizer;
    }

    public void setBurnin(int i) {
        this.burnin = i;
    }

    public int getBurnin() {
        return this.burnin;
    }

    public void setRestart(int i) {
        this.restart = i;
    }

    public int getRestart() {
        return this.restart;
    }

    @Override // org.reactome.factorgraph.AbstractInferencer, org.reactome.factorgraph.Inferencer
    public synchronized void runInference() throws InferenceCannotConvergeException {
        super.runInference();
        resetCache();
        truncateContinuousFactors();
        attachObservation();
        ArrayList arrayList = new ArrayList();
        this.iteration = 0;
        for (int i = 0; i < this.restart; i++) {
            Observation<Integer> burn = burn(initializeAssignment());
            ArrayList arrayList2 = new ArrayList();
            sample(burn, arrayList2);
            arrayList.add(arrayList2);
        }
        double calculateConvergence = calculateConvergence(arrayList);
        if (this.debug) {
            logger.info("Measurement of Gibbs Sampling: " + calculateConvergence);
        }
        calculateMarginals(arrayList);
        detachObservation();
        addBackContinuosFactors();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void resetCache() {
        if (this.varToMBFactor == null) {
            this.varToMBFactor = new HashMap();
        } else {
            this.varToMBFactor.clear();
        }
        if (this.varToMbVars == null) {
            this.varToMbVars = new HashMap();
        } else {
            this.varToMbVars.clear();
        }
    }

    private double calculateConvergence(List<List<Observation<Integer>>> list) {
        ArrayList arrayList = new ArrayList();
        Variable variable = null;
        for (List<Observation<Integer>> list2 : list) {
            ArrayList arrayList2 = new ArrayList();
            int i = 0;
            for (int i2 = 0; i2 < list2.size(); i2++) {
                Observation<Integer> observation = list2.get(i2);
                if (variable == null) {
                    variable = observation.getVariableToAssignment().keySet().iterator().next();
                }
                if (observation.getVariableToAssignment().get(variable).intValue() == 0) {
                    i++;
                }
                if ((i2 + 1) % 100 == 0) {
                    arrayList2.add(Double.valueOf(i / (i2 + 1)));
                }
            }
            arrayList.add(arrayList2);
        }
        ArrayList arrayList3 = new ArrayList();
        double d = 0.0d;
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            double d2 = 0.0d;
            Iterator it2 = ((List) it.next()).iterator();
            while (it2.hasNext()) {
                d2 += ((Double) it2.next()).doubleValue();
            }
            arrayList3.add(Double.valueOf(d2 / r0.size()));
            d += ((Double) arrayList3.get(arrayList3.size() - 1)).doubleValue();
        }
        double size = d / arrayList.size();
        double d3 = 0.0d;
        for (int i3 = 0; i3 < arrayList3.size(); i3++) {
            d3 += (((Double) arrayList3.get(i3)).doubleValue() - size) * (((Double) arrayList3.get(i3)).doubleValue() - size);
        }
        int size2 = ((List) arrayList.get(0)).size();
        double d4 = d3 * (size2 / (arrayList.size() == 1 ? 2 : r0 - 1));
        double d5 = 0.0d;
        for (int i4 = 0; i4 < arrayList.size(); i4++) {
            List<Double> list3 = (List) arrayList.get(i4);
            double doubleValue = ((Double) arrayList3.get(i4)).doubleValue();
            for (Double d6 : list3) {
                d5 += (d6.doubleValue() - doubleValue) * (d6.doubleValue() - doubleValue);
            }
        }
        double size3 = d5 / (arrayList.size() * (size2 - 1));
        return Math.sqrt(((((size2 - 1) / size2) * size3) + (d4 / size2)) / size3);
    }

    private Observation<Integer> initializeAssignment() {
        Observation<Integer> observation = new Observation<>();
        HashMap hashMap = new HashMap();
        for (Variable variable : this.factorGraph.getVariables()) {
            hashMap.put(variable, Integer.valueOf(this.randomizer.nextInt(0, variable.getStates() - 1)));
        }
        observation.setVariableToAssignment(hashMap);
        return observation;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Observation<Integer> burn(Observation<Integer> observation) {
        for (int i = 0; i < this.burnin; i++) {
            observation = sampleOnce(observation);
        }
        return observation;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void sample(Observation<Integer> observation, List<Observation<Integer>> list) {
        for (int i = 0; i < this.maxIteration; i++) {
            observation = sampleOnce(observation);
            observation.setName("Sample" + this.iteration);
            list.add(observation);
            this.iteration++;
        }
    }

    @Override // org.reactome.factorgraph.AbstractInferencer
    public void setMaxIteration(int i) {
        this.maxIteration = i;
    }

    private void calculateMarginals(List<List<Observation<Integer>>> list) {
        ArrayList arrayList = new ArrayList();
        Iterator<List<Observation<Integer>>> it = list.iterator();
        while (it.hasNext()) {
            arrayList.addAll(it.next());
        }
        Map<Variable, int[]> hashMap = new HashMap<>();
        Map<Factor, int[]> hashMap2 = new HashMap<>();
        Iterator it2 = arrayList.iterator();
        while (it2.hasNext()) {
            Map<Variable, Integer> variableToAssignment = ((Observation) it2.next()).getVariableToAssignment();
            countVariables(hashMap, variableToAssignment);
            countFactors(hashMap2, variableToAssignment);
        }
        for (Variable variable : hashMap.keySet()) {
            boolean z = false;
            double[] dArr = new double[variable.getStates()];
            int[] iArr = hashMap.get(variable);
            for (int i = 0; i < iArr.length; i++) {
                if (iArr[i] == 0) {
                    dArr[i] = 1.0d / arrayList.size();
                    z = true;
                } else {
                    dArr[i] = iArr[i] / arrayList.size();
                }
            }
            if (z) {
                variable.normalize(dArr, false, false);
            }
            variable.setBelief(dArr);
        }
        for (Factor factor : hashMap2.keySet()) {
            boolean z2 = false;
            int[] iArr2 = hashMap2.get(factor);
            double[] dArr2 = new double[iArr2.length];
            for (int i2 = 0; i2 < iArr2.length; i2++) {
                if (iArr2[i2] == 0) {
                    dArr2[i2] = 1.0d / arrayList.size();
                    z2 = true;
                } else {
                    dArr2[i2] = iArr2[i2] / arrayList.size();
                }
            }
            if (z2) {
                factor.normalize(dArr2, false, false);
            }
            factor.setBelief(dArr2);
        }
    }

    private void countFactors(Map<Factor, int[]> map, Map<Variable, Integer> map2) {
        HashMap hashMap = new HashMap();
        for (Factor factor : this.factorGraph.getFactors()) {
            int[] iArr = map.get(factor);
            if (iArr == null) {
                iArr = new int[factor.getValues().length];
                map.put(factor, iArr);
            }
            hashMap.clear();
            for (Variable variable : factor.getVariables()) {
                hashMap.put(variable, map2.get(variable));
            }
            int indexForAssignment = factor.getIndexForAssignment(hashMap);
            int[] iArr2 = iArr;
            iArr2[indexForAssignment] = iArr2[indexForAssignment] + 1;
        }
    }

    private void countVariables(Map<Variable, int[]> map, Map<Variable, Integer> map2) {
        for (Variable variable : map2.keySet()) {
            int[] iArr = map.get(variable);
            if (iArr == null) {
                iArr = new int[variable.getStates()];
                map.put(variable, iArr);
            }
            int intValue = map2.get(variable).intValue();
            int[] iArr2 = iArr;
            iArr2[intValue] = iArr2[intValue] + 1;
        }
    }

    private Observation<Integer> sampleOnce(Observation<Integer> observation) {
        Observation<Integer> observation2 = new Observation<>();
        HashMap hashMap = new HashMap(observation.getVariableToAssignment());
        observation2.setVariableToAssignment(hashMap);
        for (Variable variable : this.factorGraph.getVariables()) {
            hashMap.put(variable, Integer.valueOf(sampleOnce(calculateProbabilities(variable, hashMap))));
        }
        return observation2;
    }

    private int sampleOnce(double[] dArr) {
        double nextUniform = this.randomizer.nextUniform(0.0d, 1.0d, true);
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += dArr[i];
            if (nextUniform <= d) {
                return i;
            }
        }
        throw new IllegalStateException("Cannot find a legal state: " + Arrays.toString(dArr));
    }

    private double[] calculateProbabilities(Variable variable, Map<Variable, Integer> map) {
        double[] probabilities = getProbabilities(variable, map);
        if (probabilities != null) {
            return probabilities;
        }
        double[] dArr = new double[variable.getStates()];
        HashMap hashMap = new HashMap();
        for (int i = 0; i < variable.getStates(); i++) {
            double d = 0.0d;
            for (Factor factor : variable.getFactors()) {
                hashMap.clear();
                for (Variable variable2 : factor.getVariables()) {
                    hashMap.put(variable2, map.get(variable2));
                }
                hashMap.put(variable, Integer.valueOf(i));
                d += Math.log(factor.getValue(hashMap).doubleValue());
            }
            dArr[i] = Math.exp(d);
        }
        normalize(dArr);
        cacheProbabilities(dArr, variable, map);
        return dArr;
    }

    private void cacheProbabilities(double[] dArr, Variable variable, Map<Variable, Integer> map) {
        Set<Variable> mBVars = getMBVars(variable);
        if (mBVars.size() > 6) {
            return;
        }
        Factor factor = this.varToMBFactor.get(variable);
        if (factor == null) {
            factor = new Factor();
            factor.setVariables(new ArrayList(mBVars));
            double[] values = factor.getValues();
            for (int i = 0; i < values.length; i++) {
                values[i] = -1.0d;
            }
            this.varToMBFactor.put(variable, factor);
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            map.put(variable, Integer.valueOf(i2));
            factor.setValue(Double.valueOf(dArr[i2]), map);
        }
    }

    private double[] getProbabilities(Variable variable, Map<Variable, Integer> map) {
        Factor factor = this.varToMBFactor.get(variable);
        if (factor == null) {
            return null;
        }
        double[] dArr = new double[variable.getStates()];
        for (int i = 0; i < variable.getStates(); i++) {
            map.put(variable, Integer.valueOf(i));
            double doubleValue = factor.getValue(map).doubleValue();
            if (doubleValue < 0.0d) {
                return null;
            }
            dArr[i] = doubleValue;
        }
        return dArr;
    }

    private Set<Variable> getMBVars(Variable variable) {
        Set<Variable> set = this.varToMbVars.get(variable);
        if (set != null) {
            return set;
        }
        HashSet hashSet = new HashSet();
        Iterator<Factor> it = variable.getFactors().iterator();
        while (it.hasNext()) {
            hashSet.addAll(it.next().getVariables());
        }
        this.varToMbVars.put(variable, hashSet);
        return hashSet;
    }

    private void normalize(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        for (int i = 0; i < dArr.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] / d;
        }
    }

    @Test
    public void testInitializeAssignment() {
        setFactorGraph(TestUtilities.createSimpleFG());
        Map<Variable, Integer> variableToAssignment = initializeAssignment().getVariableToAssignment();
        for (Variable variable : variableToAssignment.keySet()) {
            System.out.println(variable + ": " + variableToAssignment.get(variable));
        }
    }

    @Test
    public void testInference() throws Exception {
        FileUtility.initializeLogging();
        FactorGraph createSimpleFG = TestUtilities.createSimpleFG();
        System.out.println("FG is a tree: " + createSimpleFG.isTree());
        Variable variable = TestUtilities.getVariable(createSimpleFG, "protein");
        HashMap hashMap = new HashMap();
        hashMap.put(variable, 2);
        setObservation(hashMap);
        System.out.println("Inference using Gibbs sampling:");
        setDebug(true);
        setFactorGraph(createSimpleFG);
        setBurnin(10000);
        setMaxIteration(10000);
        setRestart(5);
        long currentTimeMillis = System.currentTimeMillis();
        runInference();
        System.out.println("Time: " + (System.currentTimeMillis() - currentTimeMillis));
        for (Variable variable2 : createSimpleFG.getVariables()) {
            System.out.println(variable2 + ": " + Arrays.toString(variable2.getBelief()));
        }
        System.out.println("LogZ: " + calculateLogZ());
        System.out.println("\nInfernece using LBP:");
        LoopyBeliefPropagation loopyBeliefPropagation = new LoopyBeliefPropagation();
        loopyBeliefPropagation.setFactorGraph(createSimpleFG);
        loopyBeliefPropagation.setObservation(hashMap);
        loopyBeliefPropagation.setDebug(true);
        long currentTimeMillis2 = System.currentTimeMillis();
        loopyBeliefPropagation.runInference();
        System.out.println("Time: " + (System.currentTimeMillis() - currentTimeMillis2));
        for (Variable variable3 : createSimpleFG.getVariables()) {
            System.out.println(variable3 + ": " + Arrays.toString(variable3.getBelief()));
        }
        System.out.println("LogZ: " + calculateLogZ());
    }
}
