package be.ac.vub.bsb.cooccurrence.analysis;

import be.ac.vub.bsb.cooccurrence.cmd.OptionNames;
import be.ac.vub.bsb.cooccurrence.measures.Matrix;
import be.ac.vub.bsb.cooccurrence.measures.MatrixToolsProvider;
import be.ac.vub.bsb.cooccurrence.resampling.CrossValidator;
import be.ac.vub.bsb.cooccurrence.util.ArrayTools;
import be.ac.vub.bsb.cooccurrence.util.IRConnectionManager;
import be.ac.vub.bsb.cooccurrence.util.PlotTools;
import be.ac.vub.bsb.cooccurrence.util.RConnectionProvider;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import org.rosuda.REngine.REXPMismatchException;
import org.rosuda.REngine.REngineException;
import org.rosuda.REngine.Rserve.RConnection;

/* JADX WARN: Classes with same name are omitted:
  input_file:be/ac/vub/bsb/cooccurrence/analysis/MultivariateRegression.class
 */
/* loaded from: input_file:lib/be_ac_vub_bsb_cooccurrence.jar:be/ac/vub/bsb/cooccurrence/analysis/MultivariateRegression.class */
public class MultivariateRegression implements IRConnectionManager {
    public static String BINOMIAL = "binomial";
    public static String GAUSSIAN = "gaussian";
    public static String POISSON = "possion";
    public static String CV_RSS = "RSS";
    public static String CV_AKAIKE = "Akaike";
    public static String DEFAULT_CV_SCORE = CV_RSS;
    private DoubleMatrix1D _adjustedRSquares;
    protected DoubleMatrix1D _goodnessOfFitCVIterations;
    private RConnection _rConnection;
    private Matrix _response = new Matrix();
    private Matrix _factors = new Matrix();
    private String _family = "";
    private int _crossValidationFold = 0;
    private Matrix _residuals = new Matrix();
    private Matrix _coefficients = new Matrix();
    private double _residualSumOfSquares = Double.NaN;
    private double _aic = Double.NaN;
    private String _cvScore = DEFAULT_CV_SCORE;
    protected CrossValidator _crossValidator = new CrossValidator();
    protected List<Matrix> _BCVIterations = new ArrayList();
    protected Matrix _trainMatrix = new Matrix();
    protected Matrix _testMatrix = new Matrix();
    private boolean _rConnectionSet = false;

    /* JADX INFO: Access modifiers changed from: protected */
    public void init() {
        if (getFactors().getMatrix().rows() != getResponse().getMatrix().rows()) {
            throw new IllegalArgumentException("The factor matrix and the response matrix need to have the same number of rows!");
        }
        setResiduals(new Matrix(getResponse().getMatrix().rows(), getResponse().getMatrix().columns()));
        System.out.println("Sample number: " + getFactors().getMatrix().rows());
        System.out.println("Carrying out regression for " + getFactors().getMatrix().columns() + " factors and " + getResponse().getMatrix().columns() + " responses.");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void initCV() {
        if (getCrossValidationFold() > 0) {
            Matrix shuffleColumns = MatrixToolsProvider.shuffleColumns(MatrixToolsProvider.mergeMatricesColumnWise(getFactors(), getResponse()));
            this._crossValidator = new CrossValidator();
            this._crossValidator.setMatrix(MatrixToolsProvider.getTransposedMatrix(shuffleColumns));
            this._crossValidator.setFold(getCrossValidationFold());
            this._crossValidator.setSubdataSizeGivenFoldNumber();
            this._crossValidator.computeSubdataPartitioning();
            System.out.println("Number of cv iterations: " + this._crossValidator.getIterationNumber());
            this._goodnessOfFitCVIterations = new DenseDoubleMatrix1D(this._crossValidator.getIterationNumber());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<Matrix> getFactorsAndResponseFromTrainMatrix() {
        Matrix subMatrixWithoutColNames;
        Matrix subMatrixWithoutColNames2;
        ArrayList arrayList = new ArrayList();
        new Matrix();
        new Matrix();
        this._trainMatrix = new Matrix();
        this._testMatrix = new Matrix();
        if (getCrossValidationFold() == 0) {
            subMatrixWithoutColNames = getFactors();
            subMatrixWithoutColNames2 = getResponse();
        } else {
            this._crossValidator.crossValidate();
            this._trainMatrix = this._crossValidator.getResampledMatrix();
            if (this._trainMatrix.getMatrix().columns() < 2) {
                System.err.println("Cross-validation training data sub-matrix has less than 2 columns!");
            }
            this._testMatrix = this._crossValidator.getValidationData();
            if (this._testMatrix.getMatrix().columns() < 2) {
                System.err.println("Cross-validation test data sub-matrix has less than 2 columns!");
            }
            this._trainMatrix = MatrixToolsProvider.getTransposedMatrix(this._trainMatrix);
            System.out.println("Number of rows in train matrix: " + this._trainMatrix.getMatrix().rows());
            System.out.println("Number of columns in train matrix: " + this._trainMatrix.getMatrix().columns());
            HashSet hashSet = new HashSet();
            for (String str : getResponse().getColNames()) {
                hashSet.add(str);
            }
            HashSet hashSet2 = new HashSet();
            for (String str2 : getFactors().getColNames()) {
                hashSet2.add(str2);
            }
            System.out.println("Assembling factor matrix...");
            subMatrixWithoutColNames = MatrixToolsProvider.getSubMatrixWithoutColNames(this._trainMatrix, hashSet);
            System.out.println("Number of rows in factors extracted from train matrix: " + subMatrixWithoutColNames.getMatrix().rows());
            System.out.println("Number of columns in factors extracted from train matrix: " + subMatrixWithoutColNames.getMatrix().columns());
            subMatrixWithoutColNames2 = MatrixToolsProvider.getSubMatrixWithoutColNames(this._trainMatrix, hashSet2);
            System.out.println("Number of rows in responses extracted from train matrix: " + subMatrixWithoutColNames2.getMatrix().rows());
            System.out.println("Number of columns in responses extracted from train matrix: " + subMatrixWithoutColNames2.getMatrix().columns());
            System.out.println("factor names: " + ArrayTools.arrayToString(subMatrixWithoutColNames.getColNames(), ", "));
            System.out.println("response names: " + ArrayTools.arrayToString(subMatrixWithoutColNames2.getColNames(), ", "));
        }
        arrayList.add(subMatrixWithoutColNames);
        arrayList.add(subMatrixWithoutColNames2);
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<Matrix> getFactorsAndResponseFromTestMatrix() {
        ArrayList arrayList = new ArrayList();
        new Matrix();
        new Matrix();
        this._testMatrix = MatrixToolsProvider.getTransposedMatrix(this._testMatrix);
        HashSet hashSet = new HashSet();
        hashSet.addAll(ArrayTools.arrayToList(getResponse().getColNames()));
        HashSet hashSet2 = new HashSet();
        hashSet2.addAll(ArrayTools.arrayToList(getFactors().getColNames()));
        Matrix subMatrixWithoutColNames = MatrixToolsProvider.getSubMatrixWithoutColNames(this._testMatrix, hashSet);
        System.out.println("Number of rows in factors extracted from test matrix: " + subMatrixWithoutColNames.getMatrix().rows());
        System.out.println("Number of columns in factors extracted from test matrix: " + subMatrixWithoutColNames.getMatrix().columns());
        Matrix subMatrixWithoutColNames2 = MatrixToolsProvider.getSubMatrixWithoutColNames(this._testMatrix, hashSet2);
        System.out.println("Number of rows in responses extracted from test matrix: " + subMatrixWithoutColNames2.getMatrix().rows());
        System.out.println("Number of columns in responses extracted from test matrix: " + subMatrixWithoutColNames2.getMatrix().columns());
        arrayList.add(subMatrixWithoutColNames);
        arrayList.add(subMatrixWithoutColNames2);
        return arrayList;
    }

    public void doMultivariateRegression() {
        init();
        setCoefficients(new Matrix(getFactors().getMatrix().columns() + 1, getResponse().getMatrix().columns()));
        setAdjustedRSquares(new DenseDoubleMatrix1D(getResponse().getMatrix().columns()));
        try {
            try {
                try {
                    int i = 1;
                    if (getCrossValidationFold() > 0) {
                        initCV();
                        i = this._crossValidator.getIterationNumber();
                    }
                    for (int i2 = 0; i2 < i; i2++) {
                        if (getCrossValidationFold() > 0) {
                            System.out.println("Iteration number: " + i2);
                        }
                        if (!isRConnectionSet()) {
                            setInternalRConnection(RConnectionProvider.getInstance());
                        }
                        List<Matrix> factorsAndResponseFromTrainMatrix = getFactorsAndResponseFromTrainMatrix();
                        setFactors(factorsAndResponseFromTrainMatrix.get(0));
                        setResponse(factorsAndResponseFromTrainMatrix.get(1));
                        PlotTools.transferMatrixToR(getFactors(), "X", this._rConnection);
                        PlotTools.transferMatrixToR(getResponse(), "Y", this._rConnection);
                        getRConnection().voidEval("data=as.data.frame(X,Y)");
                        if (getFamily().isEmpty()) {
                            getRConnection().voidEval("out = lm(formula=Y~X,data=data)");
                        } else {
                            getRConnection().assign(OptionNames.errorDistribution, getFamily());
                            getRConnection().voidEval("out = glm(formula=Y~X,data=data, family=family)");
                        }
                        getResiduals().setMatrix(getRConnection().eval("residuals(out)").asDoubleMatrix());
                        getCoefficients().setMatrix(getRConnection().eval("coef(out)").asDoubleMatrix());
                        setAic(getRConnection().eval("AIC(out)").asDouble());
                        if (this._family.isEmpty()) {
                            for (int i3 = 0; i3 < getResponse().getMatrix().columns(); i3++) {
                                getAdjustedRSquares().set(i3, getRConnection().eval("summary(out)[[(" + i3 + "+1)]][9]").asDouble());
                            }
                        }
                        if (getCrossValidationFold() > 0) {
                            List<Matrix> factorsAndResponseFromTestMatrix = getFactorsAndResponseFromTestMatrix();
                            setFactors(factorsAndResponseFromTestMatrix.get(0));
                            setResponse(factorsAndResponseFromTestMatrix.get(1));
                            PlotTools.transferMatrixToR(getFactors(), "X", this._rConnection);
                            PlotTools.transferMatrixToR(getResponse(), "Y", this._rConnection);
                            getRConnection().voidEval("B=coef(out)");
                            getRConnection().voidEval("Y.pred=B[1,]+X%*%B[2:(ncol(X)+1),]");
                            getResiduals().setMatrix(getRConnection().eval("Y-Y.pred").asDoubleMatrix());
                            setResidualSumOfSquares(getRConnection().eval("sum((Y-Y.pred)^2)").asDouble());
                        }
                    }
                    if (isRConnectionSet()) {
                        return;
                    }
                    getRConnection().close();
                } catch (REXPMismatchException e) {
                    e.printStackTrace();
                    if (isRConnectionSet()) {
                        getRConnection().close();
                    }
                    if (isRConnectionSet()) {
                        return;
                    }
                    getRConnection().close();
                }
            } catch (REngineException e2) {
                e2.printStackTrace();
                if (isRConnectionSet()) {
                    getRConnection().close();
                }
                if (isRConnectionSet()) {
                    return;
                }
                getRConnection().close();
            }
        } catch (Throwable th) {
            if (!isRConnectionSet()) {
                getRConnection().close();
            }
            throw th;
        }
    }

    public Matrix getResponse() {
        return this._response;
    }

    public void setResponse(Matrix matrix) {
        this._response = matrix;
    }

    public String getFamily() {
        return this._family;
    }

    public void setFamily(String str) {
        this._family = str;
    }

    public Matrix getFactors() {
        return this._factors;
    }

    public void setFactors(Matrix matrix) {
        this._factors = matrix;
    }

    public int getCrossValidationFold() {
        return this._crossValidationFold;
    }

    public void setCrossValidationFold(int i) {
        this._crossValidationFold = i;
    }

    public String getCvScore() {
        return this._cvScore;
    }

    public void setCvScore(String str) {
        this._cvScore = str;
    }

    public Matrix getResiduals() {
        return this._residuals;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setResiduals(Matrix matrix) {
        this._residuals = matrix;
    }

    public Matrix getCoefficients() {
        return this._coefficients;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setCoefficients(Matrix matrix) {
        this._coefficients = matrix;
    }

    public double getAic() {
        return this._aic;
    }

    protected void setAic(double d) {
        this._aic = d;
    }

    public double getResidualSumOfSquares() {
        return this._residualSumOfSquares;
    }

    public void setResidualSumOfSquares(double d) {
        this._residualSumOfSquares = d;
    }

    public DoubleMatrix1D getAdjustedRSquares() {
        return this._adjustedRSquares;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setAdjustedRSquares(DoubleMatrix1D doubleMatrix1D) {
        this._adjustedRSquares = doubleMatrix1D;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setInternalRConnection(RConnection rConnection) {
        this._rConnection = rConnection;
    }

    @Override // be.ac.vub.bsb.cooccurrence.util.IRConnectionManager
    public void setRConnection(RConnection rConnection) {
        this._rConnection = rConnection;
        this._rConnectionSet = true;
    }

    @Override // be.ac.vub.bsb.cooccurrence.util.IRConnectionManager
    public RConnection getRConnection() {
        return this._rConnection;
    }

    @Override // be.ac.vub.bsb.cooccurrence.util.IRConnectionManager
    public boolean isRConnectionSet() {
        return this._rConnectionSet;
    }

    public static void main(String[] strArr) {
        Matrix matrix = new Matrix();
        matrix.readMatrix("/Users/karoline/Documents/Documents_Karoline/BSB_Lab/Results/TARA-bio-abio/Input/final_TARA_metadata_matrix.txt", false);
        Matrix matrix2 = new Matrix();
        matrix2.readMatrix("/Users/karoline/Documents/Documents_Karoline/BSB_Lab/Results/TARA-bio-abio/Input/final_TARA_prokaryotes_genera_count_matrix.txt", false);
        Matrix transposedMatrix = MatrixToolsProvider.getTransposedMatrix(matrix2);
        MultivariateRegression multivariateRegression = new MultivariateRegression();
        multivariateRegression.setFactors(matrix);
        multivariateRegression.setResponse(transposedMatrix);
        multivariateRegression.doMultivariateRegression();
    }
}
