package org.reactome.factorgraph;

import cern.colt.matrix.impl.AbstractFormatter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.random.EmpiricalDistribution;
import org.apache.log4j.PropertyConfigurator;
import org.jfree.chart.axis.Axis;
import org.junit.Test;
import org.reactome.factorgraph.CLGFactor;
import org.reactome.factorgraph.ContinuousVariable;
import org.reactome.factorgraph.common.PGMConfiguration;
import org.reactome.r3.util.FileUtility;

/* loaded from: input_file:caBIGR3-minimal-3.0.jar:org/reactome/factorgraph/LBPTester.class */
public class LBPTester {
    private LoopyBeliefPropagation lbp;

    public LBPTester() {
        PropertyConfigurator.configure("resources/log4j.properties");
        this.lbp = new LoopyBeliefPropagation();
    }

    @Test
    public void testMaxProductInference() throws InferenceCannotConvergeException {
        PropertyConfigurator.configure("resources/log4j.properties");
        FactorGraph factorGraph = new FactorGraph();
        Variable variable = new Variable(2);
        variable.setName("A");
        Variable variable2 = new Variable(2);
        variable2.setName("B");
        Factor factor = new Factor();
        ArrayList arrayList = new ArrayList();
        arrayList.add(variable);
        factor.setVariables(arrayList);
        factor.setValues(new double[]{0.4d, 0.6d});
        factorGraph.addFactor(factor);
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(variable);
        arrayList2.add(variable2);
        Factor factor2 = new Factor();
        factor2.setVariables(arrayList2);
        factor2.setValues(new double[]{0.1d, 0.55d, 0.9d, 0.45d});
        factorGraph.addFactor(factor2);
        factorGraph.validatVariables();
        System.out.println("Example 13.1:");
        testMaxProd(factorGraph);
        factor2.setValues(new double[]{0.3d, 0.4d, 0.3d, 0.0d});
        FactorGraph factorGraph2 = new FactorGraph();
        factorGraph2.addFactor(factor2);
        factorGraph2.validatVariables();
        System.out.println("\nExample 13.10:");
        testMaxProd(factorGraph2);
        FactorGraph factorGraph3 = new FactorGraph();
        Factor factor3 = new Factor();
        ArrayList arrayList3 = new ArrayList();
        arrayList3.add(variable);
        arrayList3.add(variable2);
        factor3.setVariables(arrayList3);
        factor3.setValues(parseValues("1, 2, 2, 1"));
        factorGraph3.addFactor(factor3);
        Variable variable3 = new Variable(2);
        variable3.setName("C");
        Factor factor4 = new Factor();
        ArrayList arrayList4 = new ArrayList();
        arrayList4.add(variable2);
        arrayList4.add(variable3);
        factor4.setVariables(arrayList4);
        factor4.setValues(parseValues("1, 2, 2, 1"));
        factorGraph3.addFactor(factor4);
        Factor factor5 = new Factor();
        ArrayList arrayList5 = new ArrayList();
        arrayList5.add(variable);
        arrayList5.add(variable3);
        factor5.setVariables(arrayList5);
        factor5.setValues(parseValues("1, 2, 2, 1"));
        factorGraph3.addFactor(factor5);
        factorGraph3.validatVariables();
        System.out.println("\nExample 13.11:");
        testMaxProd(factorGraph3);
    }

    private double[] parseValues(String str) {
        String[] split = str.split(", ");
        double[] dArr = new double[split.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = new Double(split[i]).doubleValue();
        }
        return dArr;
    }

    private void testMaxProd(FactorGraph factorGraph) throws InferenceCannotConvergeException {
        this.lbp.setFactorGraph(factorGraph);
        this.lbp.setInferenceType(InferenceType.MAX_PRODUCT);
        this.lbp.runInference();
        System.out.println("In the probability space:");
        Map<Variable, Integer> findMaximum = this.lbp.findMaximum();
        for (Variable variable : findMaximum.keySet()) {
            System.out.println(String.valueOf(variable.getName()) + ": " + findMaximum.get(variable));
        }
        this.lbp.setUseLogSpace(true);
        this.lbp.runInference();
        System.out.println("\nIn the log space:");
        Map<Variable, Integer> findMaximum2 = this.lbp.findMaximum();
        for (Variable variable2 : findMaximum2.keySet()) {
            System.out.println(String.valueOf(variable2.getName()) + ": " + findMaximum2.get(variable2));
        }
        System.out.println("LogLikelihood: " + factorGraph.getLogLikelihood(findMaximum2));
        System.out.println("\nRun SUM_PRODUCT:");
        this.lbp.setFactorGraph(factorGraph);
        this.lbp.setInferenceType(InferenceType.SUM_PRODUCT);
        this.lbp.runInference();
        outputBelief(factorGraph);
    }

    @Test
    public void testEmpiricalDistribution() {
        NormalDistribution normalDistribution = new NormalDistribution(0.0d, 1.0d);
        double[] dArr = new double[EmpiricalDistribution.DEFAULT_BIN_COUNT];
        for (int i = 0; i < 1000; i++) {
            dArr[i] = normalDistribution.sample();
        }
        EmpiricalDistribution empiricalDistribution = new EmpiricalDistribution();
        empiricalDistribution.load(dArr);
        System.out.println(AbstractFormatter.DEFAULT_ROW_SEPARATOR);
        Arrays.sort(dArr);
        for (int i2 : new int[]{(int) (EmpiricalDistribution.DEFAULT_BIN_COUNT * 0.05d), (int) (EmpiricalDistribution.DEFAULT_BIN_COUNT * 0.45d), (int) (EmpiricalDistribution.DEFAULT_BIN_COUNT * 0.5d), (int) (EmpiricalDistribution.DEFAULT_BIN_COUNT * 0.55d), (int) ((EmpiricalDistribution.DEFAULT_BIN_COUNT * 1.0d) - 1.0d)}) {
            double d = dArr[i2];
            System.out.println(String.valueOf(d) + ": " + empiricalDistribution.cumulativeProbability(d));
        }
    }

    @Test
    public void testRunEmpiricalInference() throws InferenceCannotConvergeException {
        FileUtility.initializeLogging();
        HashSet hashSet = new HashSet();
        Factor createSimpleABFactor = createSimpleABFactor();
        hashSet.add(createSimpleABFactor);
        Variable variable = getVariable(createSimpleABFactor, "A");
        Variable variable2 = getVariable(createSimpleABFactor, "B");
        ContinuousVariable continuousVariable = new ContinuousVariable();
        continuousVariable.setName("X");
        continuousVariable.setDistributionType(ContinuousVariable.DistributionType.TWO_SIDED);
        EmpiricalFactor empiricalFactor = new EmpiricalFactor();
        empiricalFactor.setDiscreteVariable(variable2);
        empiricalFactor.setContinuousVariable(continuousVariable);
        hashSet.add(empiricalFactor);
        FactorGraph factorGraph = new FactorGraph();
        factorGraph.setFactors(hashSet);
        factorGraph.validatVariables();
        NormalDistribution normalDistribution = new NormalDistribution(0.0d, 1.0d);
        double[] dArr = new double[EmpiricalDistribution.DEFAULT_BIN_COUNT];
        for (int i = 0; i < 1000; i++) {
            dArr[i] = normalDistribution.sample();
        }
        EmpiricalDistribution empiricalDistribution = new EmpiricalDistribution();
        empiricalDistribution.load(dArr);
        LoopyBeliefPropagation loopyBeliefPropagation = new LoopyBeliefPropagation();
        loopyBeliefPropagation.setFactorGraph(factorGraph);
        Arrays.sort(dArr);
        int[] iArr = {0, (int) (EmpiricalDistribution.DEFAULT_BIN_COUNT * 0.05d), (int) (EmpiricalDistribution.DEFAULT_BIN_COUNT * 0.25d), (int) (EmpiricalDistribution.DEFAULT_BIN_COUNT * 0.5d), (int) (EmpiricalDistribution.DEFAULT_BIN_COUNT * 0.75d), (int) ((EmpiricalDistribution.DEFAULT_BIN_COUNT * 1.0d) - 1.0d)};
        for (int i2 = 0; i2 < 3; i2++) {
            loopyBeliefPropagation.clearObservation();
            loopyBeliefPropagation.runInference();
            double[] belief = variable.getBelief();
            double[] belief2 = variable2.getBelief();
            System.out.println("Belief for A:");
            for (int i3 = 0; i3 < belief.length; i3++) {
                System.out.println("State " + i3 + ": " + belief[i3]);
            }
            System.out.println("Belief for B:");
            for (int i4 = 0; i4 < belief2.length; i4++) {
                System.out.println("State " + i4 + ": " + belief2[i4]);
            }
            VariableAssignment variableAssignment = new VariableAssignment();
            variableAssignment.setVariable(continuousVariable);
            variableAssignment.setDistribution(empiricalDistribution);
            Observation observation = new Observation();
            observation.addAssignment(variableAssignment);
            loopyBeliefPropagation.setObservation(observation);
            for (int i5 : iArr) {
                double d = dArr[i5];
                variableAssignment.setAssignment(Double.valueOf(d));
                loopyBeliefPropagation.runInference();
                System.out.println("X = " + d);
                System.out.println("Belief for A:");
                for (int i6 = 0; i6 < belief.length; i6++) {
                    System.out.println("State " + i6 + ": " + belief[i6]);
                }
                System.out.println("Belief for B:");
                for (int i7 = 0; i7 < belief2.length; i7++) {
                    System.out.println("State " + i7 + ": " + belief2[i7]);
                }
                System.out.println("Continuous probability: " + empiricalDistribution.cumulativeProbability(d));
            }
            System.out.println();
        }
    }

    private Factor createSimpleABFactor() {
        Variable variable = new Variable(2);
        variable.setName("A");
        Variable variable2 = new Variable(2);
        variable2.setName("B");
        Factor factor = new Factor();
        ArrayList arrayList = new ArrayList();
        arrayList.add(variable);
        arrayList.add(variable2);
        factor.setVariables(arrayList);
        factor.setValues(new double[]{0.8d, 0.6d, 0.2d, 0.4d});
        return factor;
    }

    private Variable getVariable(Factor factor, String str) {
        for (Variable variable : factor.getVariables()) {
            if (variable.getName().equals(str)) {
                return variable;
            }
        }
        return null;
    }

    @Test
    public void testRunGaussianInference() throws InferenceCannotConvergeException {
        FileUtility.initializeLogging();
        HashSet hashSet = new HashSet();
        Factor createSimpleABFactor = createSimpleABFactor();
        hashSet.add(createSimpleABFactor);
        Variable variable = getVariable(createSimpleABFactor, "A");
        Variable variable2 = getVariable(createSimpleABFactor, "B");
        HashMap hashMap = new HashMap();
        hashMap.put(variable, 1);
        hashMap.put(variable2, 0);
        System.out.println("Index for " + hashMap + ": " + createSimpleABFactor.getIndexForAssignment(hashMap));
        ContinuousVariable continuousVariable = new ContinuousVariable();
        continuousVariable.setName("X");
        CLGFactor cLGFactor = new CLGFactor();
        cLGFactor.setDiscreteVariable(variable2);
        cLGFactor.setContinuousVariable(continuousVariable);
        ArrayList arrayList = new ArrayList();
        CLGFactor.CLGFactorDistribution cLGFactorDistribution = new CLGFactor.CLGFactorDistribution((Integer) 0, 0.5d, new NormalDistribution(0.0d, 1.0d));
        CLGFactor.CLGFactorDistribution cLGFactorDistribution2 = new CLGFactor.CLGFactorDistribution((Integer) 1, 0.5d, new NormalDistribution(2.0d, 1.0d));
        arrayList.add(cLGFactorDistribution);
        arrayList.add(cLGFactorDistribution2);
        cLGFactor.setDistributions(arrayList);
        hashSet.add(cLGFactor);
        FactorGraph factorGraph = new FactorGraph();
        factorGraph.setFactors(hashSet);
        factorGraph.validatVariables();
        System.out.println("Total variables: " + factorGraph.getVariables().size());
        System.out.println("Total factors: " + factorGraph.getFactors().size());
        for (Factor factor : factorGraph.getFactors()) {
            System.out.println(factor + ": " + factor.getVariables().size());
        }
        System.out.println();
        LoopyBeliefPropagation loopyBeliefPropagation = new LoopyBeliefPropagation();
        loopyBeliefPropagation.setDebug(true);
        loopyBeliefPropagation.setFactorGraph(factorGraph);
        HashMap hashMap2 = new HashMap();
        float[] fArr = {Axis.DEFAULT_TICK_MARK_INSIDE_LENGTH, 2.0f};
        float[] fArr2 = {2.0f};
        for (int i = 0; i < 1; i++) {
            loopyBeliefPropagation.clearObservation();
            loopyBeliefPropagation.runInference();
            double[] belief = variable.getBelief();
            double[] belief2 = variable2.getBelief();
            System.out.println("Belief for A:");
            for (int i2 = 0; i2 < belief.length; i2++) {
                System.out.println("State " + i2 + ": " + belief[i2]);
            }
            System.out.println("Belief for B:");
            for (int i3 = 0; i3 < belief2.length; i3++) {
                System.out.println("State " + i3 + ": " + belief2[i3]);
            }
            for (float f : fArr2) {
                hashMap2.put(continuousVariable, Float.valueOf(f));
                loopyBeliefPropagation.setObservation(hashMap2);
                loopyBeliefPropagation.runInference();
                System.out.println("X = " + f);
                System.out.println("Belief for A:");
                for (int i4 = 0; i4 < belief.length; i4++) {
                    System.out.println("State " + i4 + ": " + belief[i4]);
                }
                System.out.println("Belief for B:");
                for (int i5 = 0; i5 < belief2.length; i5++) {
                    System.out.println("State " + i5 + ": " + belief2[i5]);
                }
                System.out.println("Continuous probabilities:");
                for (int i6 = 0; i6 < arrayList.size(); i6++) {
                    System.out.println(arrayList.get(i6).getDensity(f));
                }
            }
            System.out.println();
        }
    }

    @Test
    public void testCompetitiveReactions() throws Exception {
        FileUtility.initializeLogging();
        HashMap hashMap = new HashMap();
        for (String str : new String[]{"A", "B", "C", "D", "E", "F", "G"}) {
            Variable variable = new Variable(3);
            variable.setName(str);
            hashMap.put(variable.getName(), variable);
        }
        double[] dArr = {0.99d, 0.005d, 0.005d, 0.005d, 0.99d, 0.005d, 0.005d, 0.005d, 0.99d};
        FactorGraph factorGraph = new FactorGraph();
        factorGraph.addFactor(new Factor((Variable) hashMap.get("A"), (Variable) hashMap.get("B"), dArr));
        factorGraph.addFactor(new Factor((Variable) hashMap.get("B"), (Variable) hashMap.get("C"), dArr));
        factorGraph.addFactor(new Factor((Variable) hashMap.get("A"), (Variable) hashMap.get("D"), dArr));
        factorGraph.addFactor(new Factor((Variable) hashMap.get("D"), (Variable) hashMap.get("E"), dArr));
        factorGraph.addFactor(new Factor((Variable) hashMap.get("D"), (Variable) hashMap.get("F"), dArr));
        factorGraph.addFactor(new Factor((Variable) hashMap.get("A"), (Variable) hashMap.get("G"), dArr));
        factorGraph.validatVariables();
        factorGraph.setIdsInFactors();
        LoopyBeliefPropagation loopyBeliefPropagation = new LoopyBeliefPropagation();
        loopyBeliefPropagation.setFactorGraph(factorGraph);
        loopyBeliefPropagation.runInference();
        System.out.println("Prior:");
        outputBelief(factorGraph);
        System.out.println("\nIf D = 2:");
        Observation observation = new Observation();
        observation.addAssignment((Variable) hashMap.get("D"), 2);
        loopyBeliefPropagation.setObservation(observation);
        loopyBeliefPropagation.runInference();
        outputBelief(factorGraph);
        System.out.println("\nIf F = 2:");
        Observation observation2 = new Observation();
        observation2.addAssignment((Variable) hashMap.get("F"), 2);
        loopyBeliefPropagation.setObservation(observation2);
        loopyBeliefPropagation.runInference();
        outputBelief(factorGraph);
        System.out.println("\nIf D = 2, F = 2:");
        Observation observation3 = new Observation();
        observation3.addAssignment((Variable) hashMap.get("F"), 2);
        observation3.addAssignment((Variable) hashMap.get("D"), 2);
        loopyBeliefPropagation.setObservation(observation3);
        loopyBeliefPropagation.runInference();
        outputBelief(factorGraph);
        System.out.println("\nIf F = 2, G = 1:");
        Observation observation4 = new Observation();
        observation4.addAssignment((Variable) hashMap.get("F"), 2);
        observation4.addAssignment((Variable) hashMap.get("G"), 1);
        loopyBeliefPropagation.setObservation(observation4);
        loopyBeliefPropagation.runInference();
        outputBelief(factorGraph);
    }

    private Factor createFactor(Variable variable, Variable variable2, double[] dArr) {
        Factor factor = new Factor();
        ArrayList arrayList = new ArrayList();
        arrayList.add(variable);
        arrayList.add(variable2);
        factor.setVariables(arrayList);
        factor.setValues(dArr);
        return factor;
    }

    @Test
    public void testRunCompetitiveModel() throws Exception {
        FileUtility.initializeLogging();
        FactorGraph createCompetitiveFB = TestUtilities.createCompetitiveFB();
        createCompetitiveFB.exportFG(System.out);
        System.out.println("\nPrior:");
        performLBP(createCompetitiveFB, null);
        Variable variable = TestUtilities.getVariable(createCompetitiveFB, "A");
        Variable variable2 = TestUtilities.getVariable(createCompetitiveFB, "B");
        Variable variable3 = TestUtilities.getVariable(createCompetitiveFB, "C");
        HashMap hashMap = new HashMap();
        hashMap.put(variable2, 0);
        System.out.println("\nB = 0");
        performLBP(createCompetitiveFB, hashMap);
        hashMap.clear();
        hashMap.put(variable2, 1);
        System.out.println("\nB = 1");
        performLBP(createCompetitiveFB, hashMap);
        hashMap.clear();
        hashMap.put(variable2, 2);
        System.out.println("\nB = 2");
        performLBP(createCompetitiveFB, hashMap);
        hashMap.clear();
        hashMap.put(variable2, 0);
        hashMap.put(variable, 1);
        System.out.println("\nA = 1, B = 0");
        performLBP(createCompetitiveFB, hashMap);
        hashMap.clear();
        hashMap.put(variable2, 1);
        hashMap.put(variable, 1);
        System.out.println("\nA = 1, B = 1");
        performLBP(createCompetitiveFB, hashMap);
        hashMap.clear();
        hashMap.put(variable2, 2);
        hashMap.put(variable, 1);
        System.out.println("\nA = 1, B = 2");
        performLBP(createCompetitiveFB, hashMap);
        hashMap.clear();
        hashMap.put(variable2, 2);
        hashMap.put(variable, 2);
        System.out.println("\nA = 2, B = 2");
        performLBP(createCompetitiveFB, hashMap);
        hashMap.clear();
        hashMap.put(variable2, 1);
        hashMap.put(variable, 1);
        hashMap.put(variable3, 1);
        System.out.println("\nA = 1, B = 1, C = 1:");
        performLBP(createCompetitiveFB, hashMap);
        hashMap.clear();
        hashMap.put(variable2, 2);
        hashMap.put(variable, 1);
        hashMap.put(variable3, 1);
        System.out.println("\nA = 1, B = 2, C = 1:");
        performLBP(createCompetitiveFB, hashMap);
    }

    @Test
    public void testRunFeedbackLoop() throws Exception {
        FileUtility.initializeLogging();
        FactorGraph createFeedbackLoopFG = TestUtilities.createFeedbackLoopFG();
        createFeedbackLoopFG.exportFG(System.out);
        System.out.println("\nPrior:");
        performLBP(createFeedbackLoopFG, null);
        Variable variable = TestUtilities.getVariable(createFeedbackLoopFG, "A");
        TestUtilities.getVariable(createFeedbackLoopFG, "B");
        Variable variable2 = TestUtilities.getVariable(createFeedbackLoopFG, "C");
        HashMap hashMap = new HashMap();
        hashMap.put(variable, 2);
        System.out.println("\nA = 2");
        performLBP(createFeedbackLoopFG, hashMap);
        hashMap.clear();
        hashMap.put(variable, 0);
        System.out.println("\nA == 0");
        performLBP(createFeedbackLoopFG, hashMap);
        hashMap.clear();
        hashMap.put(variable, 1);
        System.out.println("\nA == 1");
        performLBP(createFeedbackLoopFG, hashMap);
        hashMap.clear();
        hashMap.put(variable2, 0);
        System.out.println("\nC == 0");
        performLBP(createFeedbackLoopFG, hashMap);
        hashMap.clear();
        hashMap.put(variable2, 1);
        System.out.println("\nC == 1");
        performLBP(createFeedbackLoopFG, hashMap);
        hashMap.clear();
        hashMap.put(variable2, 2);
        System.out.println("\nC == 2");
        performLBP(createFeedbackLoopFG, hashMap);
    }

    @Test
    public void testRunInference() throws InferenceCannotConvergeException {
        FileUtility.initializeLogging();
        FactorGraph createSimpleFG = TestUtilities.createSimpleFG();
        Variable variable = TestUtilities.getVariable(createSimpleFG, PGMConfiguration.mRNA);
        HashMap hashMap = new HashMap();
        hashMap.put(variable, 2);
        performLBP(createSimpleFG, hashMap);
        this.lbp.setUseLogSpace(true);
        this.lbp.runInference();
        System.out.println("\n\nIn the log space:");
        System.out.println("iteration: " + this.lbp.getIteration());
        System.out.println("maxDiff: " + this.lbp.getMaxDiff());
        outputBelief(createSimpleFG);
    }

    private void performLBP(FactorGraph factorGraph, Map<Variable, Integer> map) throws InferenceCannotConvergeException {
        this.lbp.setObservation(map);
        this.lbp.setFactorGraph(factorGraph);
        this.lbp.setDebug(true);
        this.lbp.runInference();
        System.out.println("In the probability space:");
        System.out.println("iteration: " + this.lbp.getIteration());
        System.out.println("maxDiff: " + this.lbp.getMaxDiff());
        outputBelief(factorGraph);
        System.out.println("LogZ: " + this.lbp.calculateLogZ());
    }

    private void outputBelief(FactorGraph factorGraph) {
        StringBuilder sb = new StringBuilder();
        ArrayList<Variable> arrayList = new ArrayList(factorGraph.getVariables());
        Collections.sort(arrayList, new Comparator<Variable>() { // from class: org.reactome.factorgraph.LBPTester.1
            @Override // java.util.Comparator
            public int compare(Variable variable, Variable variable2) {
                return variable.getName().compareTo(variable2.getName());
            }
        });
        for (Variable variable : arrayList) {
            for (double d : variable.getBelief()) {
                sb.append(d).append("\t");
            }
            System.out.println(String.valueOf(variable.getName()) + ": " + sb.toString());
            sb.setLength(0);
        }
    }
}
