package org.cytoscape.cyni.internal.imputationAlgorithms.BPCAFillAlgorithm;

/* loaded from: input_file:org/cytoscape/cyni/internal/imputationAlgorithms/BPCAFillAlgorithm/BPCAFillAlgorithm.class */
class BPCAFillAlgorithm {
    BPCAUnit u;
    int numData;
    int iDim;
    int fDim;
    double[][] yest;
    private boolean[][] missingIndex;
    private int[] numberOfMissingIndex;
    private double[][] Rxo;
    private double[][] invRxo;
    private double[][] comp_mat;
    private double[][] matT;
    private double[] dy;
    private double[] x;
    private double[] ex;
    private double trS;
    private double logdetRxo;
    int verbose = 5;
    int epoch = 0;
    int maxEpoch = 500;
    boolean isConverged = false;
    boolean isFinished = false;
    double FEATURE_DELETION_THRESHOLD = 1.0E-8d;

    public BPCAFillAlgorithm(double[][] dArr, MissingValueHandler missingValueHandler) {
        this.numData = dArr.length;
        this.iDim = dArr[0].length;
        this.fDim = this.iDim - 1;
        this.yest = new double[this.numData][this.iDim];
        this.missingIndex = new boolean[this.numData][this.iDim];
        this.numberOfMissingIndex = new int[this.numData];
        missingValueHandler.isMissing(this.numData, this.iDim, dArr, this.missingIndex, this.numberOfMissingIndex);
        this.u = new BPCAUnit(0, this.iDim, this.fDim);
        this.Rxo = new double[this.fDim][this.fDim];
        this.invRxo = new double[this.fDim][this.fDim];
        this.matT = new double[this.iDim][this.fDim];
        this.comp_mat = new double[this.fDim][this.fDim];
        this.dy = new double[this.iDim];
        this.x = new double[this.fDim];
        this.ex = new double[this.fDim];
        MatrixUtils.init(this.iDim);
        for (int i = this.iDim - 1; i >= 0; i--) {
            int i2 = 0;
            double d = 0.0d;
            for (int i3 = this.numData - 1; i3 >= 0; i3--) {
                if (!this.missingIndex[i3][i]) {
                    i2++;
                    d += dArr[i3][i];
                }
            }
            if (i2 == 0) {
                this.u.mu[i] = 0.0d;
            } else {
                this.u.mu[i] = d / i2;
            }
        }
        for (int i4 = this.iDim - 1; i4 >= 0; i4--) {
            for (int i5 = this.numData - 1; i5 >= 0; i5--) {
                if (this.missingIndex[i5][i4]) {
                    this.yest[i5][i4] = this.u.mu[i4];
                } else {
                    this.yest[i5][i4] = dArr[i5][i4];
                }
            }
        }
        this.u.gamma = this.numData;
        this.u.galpha0 = 1.0E-10d;
        this.u.balpha0 = 1.0d;
        this.u.gtau0 = 1.0E-10d;
        this.u.btau0 = 1.0d;
        this.u.gmu0 = 0.001d;
        this.u.min_tau = 1.0E-10d;
        this.u.max_tau = 1.0E10d;
        this.u.isVB = true;
        initParameterByPCA();
    }

    public double[][] getMatrix() {
        return this.yest;
    }

    public double getTau() {
        return this.u.tau;
    }

    public BPCAUnit getModel() {
        return this.u;
    }

    private void initParameterByPCA() {
        int[] iArr = new int[this.iDim];
        double[] dArr = new double[this.iDim];
        double[] dArr2 = new double[this.iDim];
        double[][] dArr3 = new double[this.iDim][this.iDim];
        double[] dArr4 = new double[this.numData];
        for (int i = this.numData - 1; i >= 0; i--) {
            dArr4[i] = 1.0d;
        }
        double[][] dArr5 = new double[this.iDim][this.iDim];
        MatrixUtils.cov(this.numData, this.iDim, this.yest, this.u.mu, dArr4, dArr5);
        double d = 0.0d;
        for (int i2 = 0; i2 < this.iDim; i2++) {
            d += dArr5[i2][i2];
        }
        if (this.fDim > 0) {
            MatrixUtils.svdcmp(this.iDim, this.iDim, dArr5, dArr, dArr3);
            dArr2 = MatrixUtils.sortDecend(dArr, iArr);
        }
        for (int i3 = 0; i3 < this.fDim; i3++) {
            for (int i4 = 0; i4 < this.iDim; i4++) {
                this.u.W[i4][i3] = Math.sqrt(dArr2[i3]) * dArr3[i4][iArr[i3]];
            }
        }
        this.u.tau = 0.0d;
        for (int i5 = this.fDim - 1; i5 >= 0; i5--) {
            d -= dArr2[i5];
        }
        this.u.tau = 1.0d / d;
        calcAlpha();
        for (int i6 = 0; i6 < this.fDim; i6++) {
            this.u.invDw[i6][i6] = (1.0d * this.numData) / this.iDim;
        }
        calcInvRx();
    }

    public boolean doStep() {
        this.epoch++;
        preEStep();
        for (int i = 0; i < this.numData; i++) {
            if (this.numberOfMissingIndex[i] == 0) {
                eStepWithoutMiss(this.yest[i]);
            } else {
                eStepWithMiss(this.yest[i], this.missingIndex[i]);
            }
        }
        postEStep();
        doMStep();
        deleteDeadFeatures();
        return this.isFinished;
    }

    public void deleteDeadFeatures() {
        this.fDim = this.u.fDim;
        double d = 10.0d;
        int i = 0;
        for (int i2 = this.fDim - 1; i2 >= 0; i2--) {
            if (this.u.diagWTW[i2] < d) {
                d = this.u.diagWTW[i2];
                i = i2;
            }
        }
        if (d < this.FEATURE_DELETION_THRESHOLD) {
            this.u.deleteFeature(i);
        }
        this.fDim = this.u.fDim;
    }

    private void preEStep() {
        MatrixUtils.mulScalar(this.iDim, this.fDim, 0.0d, this.matT);
        this.trS = 0.0d;
        calcInvRx();
    }

    private void eStepWithoutMiss(double[] dArr) {
        for (int i = this.iDim - 1; i >= 0; i--) {
            this.dy[i] = dArr[i] - this.u.mu[i];
        }
        MatrixUtils.mul(this.iDim, this.fDim, this.dy, this.u.W, this.ex);
        MatrixUtils.mulScalar(this.fDim, this.u.tau, this.ex);
        MatrixUtils.mul(this.fDim, this.fDim, this.ex, this.u.invRx, this.x);
        for (int i2 = this.iDim - 1; i2 >= 0; i2--) {
            this.trS += this.dy[i2] * this.dy[i2];
            for (int i3 = this.fDim - 1; i3 >= 0; i3--) {
                double[] dArr2 = this.matT[i2];
                int i4 = i3;
                dArr2[i4] = dArr2[i4] + (this.dy[i2] * this.x[i3]);
            }
        }
    }

    private void eStepWithMiss(double[] dArr, boolean[] zArr) {
        for (int i = this.iDim - 1; i >= 0; i--) {
            if (zArr[i]) {
                this.dy[i] = 0.0d;
            } else {
                this.dy[i] = dArr[i] - this.u.mu[i];
            }
        }
        this.logdetRxo = calcInvRxo(zArr);
        MatrixUtils.mul(this.iDim, this.fDim, this.dy, this.u.W, this.ex);
        MatrixUtils.mulScalar(this.fDim, this.u.tau, this.ex);
        MatrixUtils.mul(this.fDim, this.fDim, this.ex, this.invRxo, this.x);
        for (int i2 = this.iDim - 1; i2 >= 0; i2--) {
            if (zArr[i2]) {
                dArr[i2] = this.u.mu[i2];
                for (int i3 = this.fDim - 1; i3 >= 0; i3--) {
                    int i4 = i2;
                    dArr[i4] = dArr[i4] + (this.u.W[i2][i3] * this.x[i3]);
                }
                this.dy[i2] = dArr[i2] - this.u.mu[i2];
            }
        }
        double d = 1.0d / this.u.tau;
        double d2 = 0.0d;
        for (int i5 = this.iDim - 1; i5 >= 0; i5--) {
            d2 += this.dy[i5] * this.dy[i5];
            if (zArr[i5]) {
                d2 += d;
            }
            for (int i6 = this.fDim - 1; i6 >= 0; i6--) {
                double[] dArr2 = this.matT[i5];
                int i7 = i6;
                dArr2[i7] = dArr2[i7] + (this.dy[i5] * this.x[i6]);
                if (zArr[i5]) {
                    for (int i8 = this.fDim - 1; i8 >= 0; i8--) {
                        double[] dArr3 = this.matT[i5];
                        int i9 = i8;
                        dArr3[i9] = dArr3[i9] + (this.u.W[i5][i6] * this.invRxo[i8][i6]);
                        d2 += this.u.W[i5][i8] * this.u.W[i5][i6] * this.invRxo[i8][i6];
                    }
                }
            }
        }
        this.trS += d2;
    }

    private void postEStep() {
        this.trS /= this.numData;
        MatrixUtils.mulScalar(this.iDim, this.fDim, 1.0d / this.numData, this.matT);
    }

    private void doMStep() {
        for (int i = this.fDim - 1; i >= 0; i--) {
            for (int i2 = this.fDim - 1; i2 >= 0; i2--) {
                double d = 0.0d;
                for (int i3 = this.iDim - 1; i3 >= 0; i3--) {
                    d += this.matT[i3][i] * this.u.W[i3][i2];
                }
                this.u.invDw[i][i2] = d * this.u.tau;
            }
            double[] dArr = this.u.invDw[i];
            int i4 = i;
            dArr[i4] = dArr[i4] + 1.0d;
        }
        MatrixUtils.mul(this.fDim, this.fDim, this.fDim, this.u.invDw, this.u.invRx, this.comp_mat);
        if (this.u.isVB) {
            for (int i5 = this.fDim - 1; i5 >= 0; i5--) {
                double[] dArr2 = this.comp_mat[i5];
                int i6 = i5;
                dArr2[i6] = dArr2[i6] + (this.u.alpha[i5] / this.numData);
            }
        }
        MatrixUtils.inverse(this.fDim, this.comp_mat, this.u.invDw);
        MatrixUtils.symmetrize(this.u.invDw);
        MatrixUtils.mul(this.iDim, this.fDim, this.fDim, this.matT, this.u.invDw, this.u.W);
        if (this.u.isVB) {
            this.u.tau = (this.iDim + ((2.0d * this.u.gtau0) / this.u.gamma)) / ((this.trS - MatrixUtils.matrixInnerProduct(this.iDim, this.fDim, this.u.W, this.matT)) + (((MatrixUtils.innerProduct(this.iDim, this.u.mu, this.u.mu) * this.u.gmu0) + ((2.0d * this.u.gtau0) / this.u.btau0)) / this.u.gamma));
            double d2 = ((this.iDim * this.u.gamma) / 2.0d) + this.u.gtau0;
            this.u.lntau = SpecialFunctions.digamma(d2) - Math.log(d2);
        } else {
            this.u.tau = this.iDim / (this.trS - MatrixUtils.matrixInnerProduct(this.iDim, this.fDim, this.u.W, this.matT));
            this.u.lntau = 0.0d;
        }
        this.u.tau = Math.min(Math.max(this.u.tau, this.u.min_tau), this.u.max_tau);
        this.u.lntau += Math.log(this.u.tau);
        calcAlpha();
    }

    private void calcInvRx() {
        double d = this.iDim / this.u.gamma;
        for (int i = this.fDim - 1; i >= 0; i--) {
            for (int i2 = this.fDim - 1; i2 >= i; i2--) {
                double d2 = 0.0d;
                for (int i3 = this.iDim - 1; i3 >= 0; i3--) {
                    d2 += this.u.W[i3][i] * this.u.W[i3][i2];
                }
                if (i == i2) {
                    this.u.diagWTW[i] = d2;
                }
                double d3 = d2 * this.u.tau;
                if (this.u.isVB) {
                    d3 += d * this.u.invDw[i][i2];
                }
                this.u.Rx[i][i2] = d3;
                this.u.Rx[i2][i] = d3;
                this.comp_mat[i][i2] = d3;
                this.comp_mat[i2][i] = d3;
            }
            double[] dArr = this.u.Rx[i];
            int i4 = i;
            dArr[i4] = dArr[i4] + 1.0d;
            double[] dArr2 = this.comp_mat[i];
            int i5 = i;
            dArr2[i5] = dArr2[i5] + 1.0d;
        }
        this.u.logdetRx = MatrixUtils.logDetWithInverse(this.fDim, this.comp_mat, this.u.invRx);
    }

    private void calcAlpha() {
        double d = this.iDim / this.u.gamma;
        for (int i = 0; i < this.fDim; i++) {
            double d2 = 0.0d;
            for (int i2 = 0; i2 < this.iDim; i2++) {
                d2 += this.u.W[i2][i] * this.u.W[i2][i];
            }
            this.u.alpha[i] = ((2.0d * this.u.galpha0) + this.iDim) / (((this.u.tau * d2) + (d * this.u.invDw[i][i])) + (this.u.galpha0 / this.u.balpha0));
        }
    }

    private double calcInvRxo(boolean[] zArr) {
        for (int i = this.fDim - 1; i >= 0; i--) {
            for (int i2 = this.fDim - 1; i2 >= i; i2--) {
                double d = 0.0d;
                for (int i3 = this.iDim - 1; i3 >= 0; i3--) {
                    if (zArr[i3]) {
                        d += this.u.W[i3][i] * this.u.W[i3][i2];
                    }
                }
                double d2 = this.u.Rx[i][i2] - (d * this.u.tau);
                this.Rxo[i][i2] = d2;
                this.Rxo[i2][i] = d2;
                this.comp_mat[i][i2] = d2;
                this.comp_mat[i2][i] = d2;
            }
        }
        return MatrixUtils.logDetWithInverse(this.fDim, this.comp_mat, this.invRxo);
    }
}
