package animo.fitting.levenbergmarquardt;

import animo.core.analyser.LevelResult;
import animo.core.graph.Graph;
import animo.fitting.levenbergmarquardt.LevenbergMarquardtFitter;
import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Vector;
import org.ejml.data.D1Matrix64F;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;
import org.ejml.ops.SpecializedOps;

/* loaded from: input_file:animo/fitting/levenbergmarquardt/LevenbergMarquardt.class */
public class LevenbergMarquardt {
    private DenseMatrix64F minimumCostParameters;
    private Function func;
    private double initialCost;
    private double finalCost;
    private LevenbergMarquardtFitter.LMSwingWorker swingWorker;
    public double DELTA = 1.0E-8d;
    public double MIN_COST = 0.001d;
    private double minCost = Double.NaN;
    private double initialLambda = 1.0d;
    private DenseMatrix64F temp0 = new DenseMatrix64F(1, 1);
    private DenseMatrix64F temp1 = new DenseMatrix64F(1, 1);
    private DenseMatrix64F tempDH = new DenseMatrix64F(1, 1);
    private DenseMatrix64F jacobian = new DenseMatrix64F(1, 1);
    private DenseMatrix64F param = new DenseMatrix64F(1, 1);
    private DenseMatrix64F d = new DenseMatrix64F(1, 1);
    private DenseMatrix64F H = new DenseMatrix64F(1, 1);
    private DenseMatrix64F negDelta = new DenseMatrix64F(1, 1);
    private DenseMatrix64F tempParam = new DenseMatrix64F(1, 1);
    private DenseMatrix64F A = new DenseMatrix64F(1, 1);

    /* loaded from: input_file:animo/fitting/levenbergmarquardt/LevenbergMarquardt$Function.class */
    public interface Function {
        void compute(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, DenseMatrix64F denseMatrix64F3);
    }

    public LevenbergMarquardt(Function function) {
        this.func = function;
    }

    public double getInitialCost() {
        return this.initialCost;
    }

    public double getFinalCost() {
        return this.finalCost;
    }

    public DenseMatrix64F getParameters() {
        return this.param;
    }

    public void setSwingWorker(LevenbergMarquardtFitter.LMSwingWorker lMSwingWorker) {
        this.swingWorker = lMSwingWorker;
    }

    public boolean optimize(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, DenseMatrix64F denseMatrix64F3) {
        configure(denseMatrix64F, denseMatrix64F2, denseMatrix64F3);
        double cost = cost(this.param, denseMatrix64F2, denseMatrix64F3);
        this.initialCost = cost;
        this.minCost = cost;
        this.minimumCostParameters = new DenseMatrix64F(denseMatrix64F);
        return adjustParam(denseMatrix64F2, denseMatrix64F3, this.initialCost);
    }

    private boolean adjustParam(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, double d) {
        double d2;
        double d3 = this.initialLambda;
        double d4 = 1000.0d;
        int i = 0;
        while (true) {
            if (i >= 20 && d4 >= 1.0E-6d) {
                break;
            }
            computeDandH(this.param, denseMatrix64F, denseMatrix64F2);
            boolean z = false;
            for (int i2 = 0; i2 < 5; i2++) {
                computeA(this.A, this.H, d3);
                if (!CommonOps.solve(this.A, this.d, this.negDelta)) {
                    return false;
                }
                CommonOps.subtract(this.param, this.negDelta, this.tempParam);
                double cost = cost(this.tempParam, denseMatrix64F, denseMatrix64F2);
                if (cost < this.minCost) {
                    this.minCost = cost;
                    this.minimumCostParameters.set((D1Matrix64F) this.tempParam);
                }
                if (this.swingWorker != null) {
                    int round = (int) Math.round(100.0d - ((100.0d * cost) / this.initialCost));
                    if (round < 0) {
                        round = 0;
                    }
                    if (round > 100) {
                        round = 100;
                    }
                    this.swingWorker.setProgresso(new Integer(round));
                    if (this.swingWorker.getMustTerminate()) {
                        this.finalCost = this.minCost;
                        this.param.set((D1Matrix64F) this.minimumCostParameters);
                        return false;
                    }
                }
                if (cost < this.MIN_COST) {
                    this.finalCost = cost;
                    this.param.set((D1Matrix64F) this.tempParam);
                    return true;
                }
                if (cost < d) {
                    z = true;
                    this.param.set((D1Matrix64F) this.tempParam);
                    d4 = d - cost;
                    d = cost;
                    d2 = d3 / 10.0d;
                } else {
                    d2 = d3 * 10.0d;
                }
                d3 = d2;
            }
            if (!z) {
                break;
            }
            i++;
        }
        this.finalCost = d;
        return true;
    }

    protected void configure(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, DenseMatrix64F denseMatrix64F3) {
        if (denseMatrix64F3.getNumRows() != denseMatrix64F2.getNumRows()) {
            throw new IllegalArgumentException("Different vector lengths");
        }
        if (denseMatrix64F3.getNumCols() != 1 || denseMatrix64F2.getNumCols() != 1) {
            throw new IllegalArgumentException("Inputs must be a column vector");
        }
        int numElements = denseMatrix64F.getNumElements();
        int numRows = denseMatrix64F3.getNumRows();
        if (this.param.getNumElements() != denseMatrix64F.getNumElements()) {
            this.param.reshape(numElements, 1, false);
            this.d.reshape(numElements, 1, false);
            this.H.reshape(numElements, numElements, false);
            this.negDelta.reshape(numElements, 1, false);
            this.tempParam.reshape(numElements, 1, false);
            this.A.reshape(numElements, numElements, false);
        }
        this.param.set((D1Matrix64F) denseMatrix64F);
        this.temp0.reshape(numRows, 1, false);
        this.temp1.reshape(numRows, 1, false);
        this.tempDH.reshape(numRows, 1, false);
        this.jacobian.reshape(numElements, numRows, false);
    }

    private void computeDandH(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, DenseMatrix64F denseMatrix64F3) {
        this.func.compute(denseMatrix64F, denseMatrix64F2, this.tempDH);
        CommonOps.subtractEquals(this.tempDH, denseMatrix64F3);
        computeNumericalJacobian(denseMatrix64F, denseMatrix64F2, this.jacobian);
        int numElements = denseMatrix64F.getNumElements();
        int numElements2 = denseMatrix64F2.getNumElements();
        for (int i = 0; i < numElements; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < numElements2; i2++) {
                d += this.tempDH.get(i2, 0) * this.jacobian.get(i, i2);
            }
            this.d.set(i, 0, d / numElements2);
        }
        CommonOps.multTransB(this.jacobian, this.jacobian, this.H);
        CommonOps.scale(1.0d / numElements2, this.H);
    }

    private void computeA(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, double d) {
        int numElements = this.param.getNumElements();
        denseMatrix64F.set((D1Matrix64F) denseMatrix64F2);
        for (int i = 0; i < numElements; i++) {
            denseMatrix64F.set(i, i, denseMatrix64F.get(i, i) + d);
        }
    }

    private double cost(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, DenseMatrix64F denseMatrix64F3) {
        this.func.compute(denseMatrix64F, denseMatrix64F2, this.temp0);
        double diffNormF = SpecializedOps.diffNormF(this.temp0, denseMatrix64F3);
        return (diffNormF * diffNormF) / denseMatrix64F2.numRows;
    }

    protected void computeNumericalJacobian(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, DenseMatrix64F denseMatrix64F3) {
        double d = 1.0d / this.DELTA;
        this.func.compute(denseMatrix64F, denseMatrix64F2, this.temp0);
        for (int i = 0; i < denseMatrix64F.numRows; i++) {
            double[] dArr = denseMatrix64F.data;
            int i2 = i;
            dArr[i2] = dArr[i2] + this.DELTA;
            this.func.compute(denseMatrix64F, denseMatrix64F2, this.temp1);
            CommonOps.add(d, this.temp1, -d, this.temp0, this.temp1);
            System.arraycopy(this.temp1.data, 0, denseMatrix64F3.data, i * denseMatrix64F2.numRows, denseMatrix64F2.numRows);
            double[] dArr2 = denseMatrix64F.data;
            int i3 = i;
            dArr2[i3] = dArr2[i3] - this.DELTA;
        }
    }

    public static DenseMatrix64F readCSVtoMatrix(String str, Collection<String> collection, double d) throws IOException {
        return levelResultToMatrix(Graph.readCSVtoLevelResult(str, collection, d));
    }

    public static DenseMatrix64F levelResultToMatrix(LevelResult levelResult) {
        return levelResultToMatrix(levelResult, 1.0d);
    }

    public static DenseMatrix64F levelResultToMatrix(LevelResult levelResult, double d) {
        return levelResultToMatrix(levelResult, 1.0d, Collections.emptyList());
    }

    public static DenseMatrix64F levelResultToMatrix(LevelResult levelResult, double d, List<Double> list) {
        Vector vector = new Vector();
        vector.addAll(levelResult.getReactantIds());
        List<Double> timeIndices = !list.isEmpty() ? list : levelResult.getTimeIndices();
        double[][] dArr = new double[vector.size() * timeIndices.size()][1];
        int i = 0;
        Iterator<Double> it = timeIndices.iterator();
        while (it.hasNext()) {
            double doubleValue = it.next().doubleValue();
            Iterator it2 = vector.iterator();
            while (it2.hasNext()) {
                int i2 = i;
                i++;
                dArr[i2][0] = levelResult.getConcentration((String) it2.next(), doubleValue / d);
            }
        }
        return new DenseMatrix64F(dArr);
    }

    public static void printMatrix(DenseMatrix64F denseMatrix64F) {
        for (int i = 0; i < denseMatrix64F.getNumRows(); i++) {
            System.out.print("[ ");
            for (int i2 = 0; i2 < denseMatrix64F.getNumCols() - 1; i2++) {
                System.out.print(String.valueOf(denseMatrix64F.get(i, i2)) + ", ");
            }
            System.out.println(String.valueOf(denseMatrix64F.get(i, denseMatrix64F.getNumCols() - 1)) + " ]");
        }
    }
}
