package be.ac.vub.bsb.cooccurrence.core;

import be.ac.ulb.bigre.pathwayinference.core.core.PathwayinferenceConstants;
import be.ac.ulb.bigre.pathwayinference.core.util.Groups;
import be.ac.vub.bsb.cooccurrence.cmd.OptionNames;
import be.ac.vub.bsb.cooccurrence.measures.Matrix;
import be.ac.vub.bsb.cooccurrence.measures.MatrixToolsProvider;
import be.ac.vub.bsb.cooccurrence.measures.NaNTreatmentProvider;
import be.ac.vub.bsb.cooccurrence.resampling.CrossValidator;
import be.ac.vub.bsb.cooccurrence.resampling.IBootstrapper;
import be.ac.vub.bsb.cooccurrence.util.ArrayTools;
import be.ac.vub.bsb.cooccurrence.util.MatrixMetadataGroupManager;
import be.ac.vub.bsb.cooccurrence.util.RConnectionProvider;
import cern.colt.matrix.DoubleMatrix1D;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.cli.HelpFormatter;
import org.rosuda.REngine.REXPMismatchException;
import org.rosuda.REngine.REngineException;
import org.rosuda.REngine.RList;
import org.rosuda.REngine.Rserve.RserveException;

/* JADX WARN: Classes with same name are omitted:
  input_file:be/ac/vub/bsb/cooccurrence/core/CooccurrenceFromModelNetworkBuilder.class
 */
/* loaded from: input_file:lib/be_ac_vub_bsb_cooccurrence.jar:be/ac/vub/bsb/cooccurrence/core/CooccurrenceFromModelNetworkBuilder.class */
public class CooccurrenceFromModelNetworkBuilder extends CooccurrenceNetworkBuilder {
    public static String COEFFICIENT_ATTRIBUTE = "coefficient";
    public static String MODELQUALITY_ATTRIBUTE = "model_quality";
    public static String GROUP_ATTRIBUTE = "group";
    public static String SIMPLE_LINEAR = "y ~ x";
    public static String SIMPLE_LINEAR_NO_INTERCEPT = "y ~ x - 1";
    public static String COMPOSED_LINEAR = "y ~ I(mean(y)+x)";
    public static String[] SUPPORTED_PREDEFINED_MODELS = {SIMPLE_LINEAR, SIMPLE_LINEAR_NO_INTERCEPT, COMPOSED_LINEAR};
    public static String LINEAR_FORMULAR_TYPE = "linear";
    public static String AKAIKE = "AIC";
    public static String R_SQUARE = "R2";
    public static String R_SQUARE_ADJUSTED = "R2_adj";
    public static String[] SUPPORTED_SCORES = {AKAIKE, R_SQUARE, R_SQUARE_ADJUSTED};
    public static String GAUSSIAN_FAMILY = "gaussian";
    public static String BINOMIAL_FAMILY = "binomial";
    public static String POISSON_FAMILY = IBootstrapper.POISSON_PROBAB;
    public static String GAMMA_FAMILY = "Gamma";
    public static String LAPLACE_FAMILY = "Laplace()";
    public static String[] SUPPORTED_FAMILIES = {GAUSSIAN_FAMILY, BINOMIAL_FAMILY, POISSON_FAMILY, GAMMA_FAMILY, LAPLACE_FAMILY};
    public static String GLMBOOST_R = "glmboost";
    public static String GBM_R = "gbm";
    public static String GLM_R = "glm";
    public static String LM_R = "lm";
    public static String[] R_PACKAGES = {LM_R, GLM_R, GBM_R, GLMBOOST_R};
    public static String DEFAULT_FORMULA = SIMPLE_LINEAR;
    public static String DEFAULT_FORMULA_TYPE = LINEAR_FORMULAR_TYPE;
    public static String DEFAULT_FAMILY = GAUSSIAN_FAMILY;
    public static String DEFAULT_SCORE_TYPE = AKAIKE;
    public static String DEFAULT_R_PACKAGE = GLMBOOST_R;
    public static int DEFAULT_BOOST_ITERATIONS = 100;
    public static double DEFAULT_SPEARMAN_THRESHOLD = 0.05d;
    private String _formula = DEFAULT_FORMULA;
    private String _formularType = DEFAULT_FORMULA_TYPE;
    private String _family = DEFAULT_FAMILY;
    private String _groupAttribute = "";
    private String _rFunction = DEFAULT_R_PACKAGE;
    private int _boostIterations = DEFAULT_BOOST_ITERATIONS;
    private boolean _spearmanFilter = false;
    private double _spearmanFilterThreshold = DEFAULT_SPEARMAN_THRESHOLD;
    private int _crossvalidateFold = 0;
    private boolean _discardZeroRSquare = false;
    private boolean _displayGroupEdgesSeparately = false;
    private boolean _subtractMeanFromResponse = false;
    private boolean _allAgainstAll = false;
    private Map<String, Double> _rowNameVsCoeffiCurrentMap = new HashMap();
    private Map<String, String> _familyVsGLMName = new HashMap();
    private Map<String, String> _familyVsGLMBoostName = new HashMap();
    private Map<String, String> _familyVsGBMName = new HashMap();
    private Set<String> _spearmanFilteredEdges = new HashSet();
    private static final String INTERCEPT = "Intercept";
    private static final boolean R_SQUARE_AS_SUMOFSQUARES = true;

    public CooccurrenceFromModelNetworkBuilder() {
        super.setCooccurrenceMethod(CooccurrenceNetworkBuilder.MODELLING);
        super.setMatrix(new Matrix());
        super.initCooccurrenceNetwork();
        fillFamilyMaps();
        setReturnType(DEFAULT_SCORE_TYPE);
    }

    public CooccurrenceFromModelNetworkBuilder(Matrix matrix) {
        super.setCooccurrenceMethod(CooccurrenceNetworkBuilder.MODELLING);
        super.setMatrix(matrix);
        super.initCooccurrenceNetwork();
        fillFamilyMaps();
        setReturnType(DEFAULT_SCORE_TYPE);
    }

    public CooccurrenceFromModelNetworkBuilder(String str, boolean z) {
        super.setCooccurrenceMethod(CooccurrenceNetworkBuilder.MODELLING);
        Matrix matrix = new Matrix();
        matrix.readMatrix(str, z);
        super.setMatrix(matrix);
        super.initCooccurrenceNetwork();
        fillFamilyMaps();
        setReturnType(DEFAULT_SCORE_TYPE);
    }

    private void fillFamilyMaps() {
        this._familyVsGLMBoostName.put(GAUSSIAN_FAMILY, "Gaussian()");
        this._familyVsGLMBoostName.put(BINOMIAL_FAMILY, "Binomial()");
        this._familyVsGLMBoostName.put(POISSON_FAMILY, "Poisson()");
        this._familyVsGLMBoostName.put(LAPLACE_FAMILY, LAPLACE_FAMILY);
        this._familyVsGLMName.put(GAUSSIAN_FAMILY, GAUSSIAN_FAMILY);
        this._familyVsGLMName.put(BINOMIAL_FAMILY, BINOMIAL_FAMILY);
        this._familyVsGLMName.put(POISSON_FAMILY, POISSON_FAMILY);
        this._familyVsGLMName.put(GAMMA_FAMILY, GAMMA_FAMILY);
        this._familyVsGLMName.put(GAUSSIAN_FAMILY, GAUSSIAN_FAMILY);
        this._familyVsGBMName.put(BINOMIAL_FAMILY, "bernoulli");
        this._familyVsGBMName.put(GAUSSIAN_FAMILY, "gaussian");
        this._familyVsGBMName.put(POISSON_FAMILY, IBootstrapper.POISSON_PROBAB);
        this._familyVsGBMName.put(LAPLACE_FAMILY, "laplace");
    }

    private void parseRListIntoMap(RList rList, List<String> list) throws REXPMismatchException {
        this._rowNameVsCoeffiCurrentMap = new HashMap();
        for (int i = 0; i < rList.size(); i++) {
            String keyAt = rList.keyAt(i);
            for (String str : list) {
                if (keyAt.contains(str)) {
                    this._rowNameVsCoeffiCurrentMap.put(str, Double.valueOf(rList.at(i).asDouble()));
                }
                if (keyAt.contains(INTERCEPT)) {
                    this._rowNameVsCoeffiCurrentMap.put(INTERCEPT, Double.valueOf(rList.at(i).asDouble()));
                }
            }
        }
    }

    private double computeRSquare() throws RserveException, REXPMismatchException {
        double asDouble;
        super.getRConnection().voidEval("SStot=sum((y-mean(y))^2)");
        super.getRConnection().voidEval("SSerr=sum((y-y.predicted)^2)");
        super.getRConnection().voidEval("R2=1-(SSerr/SStot)");
        this._logger.debug("R^2=" + super.getRConnection().eval("R2").asDouble());
        if (getScoreType().equals(R_SQUARE)) {
            asDouble = super.getRConnection().eval("R2").asDouble();
        } else {
            super.getRConnection().voidEval("p=length(coef(out))-1");
            super.getRConnection().voidEval("n=dim(data)[[1]]");
            asDouble = super.getRConnection().eval("1-(1-R2)*((n-1)/(n-p-1))").asDouble();
            if (asDouble < 0.0d) {
                asDouble = 0.0d;
                this._logger.warn("Adjusted R^2 was smaller than zero and has been set to zero!");
            }
        }
        return asDouble;
    }

    private double doGBMViaDismo(List<String> list) throws REngineException, REXPMismatchException {
        double d = Double.NaN;
        if (!RConnectionProvider.DISMO_LOADED) {
            RConnectionProvider.LOAD_DISMO = true;
            RConnectionProvider.loadDismo();
        }
        super.getRConnection().assign(OptionNames.errorDistribution, getFamily());
        super.getRConnection().voidEval("tempmat=as.matrix(data)");
        super.getRConnection().voidEval("rownames=colnames(data)");
        super.getRConnection().voidEval("colnames(tempmat)=1:ncol(tempmat)");
        super.getRConnection().voidEval("data=as.data.frame(tempmat)");
        super.getRConnection().voidEval("out<-gbm.step(data = data, gbm.x=1:(ncol(tempmat)), gbm.y=ncol(tempmat), family = family)");
        parseRListIntoMap(super.getRConnection().eval("as.list(summary(out,plotit=F))").asList(), list);
        if (getScoreType().equals(AKAIKE)) {
            this._logger.error("Akaike score not supported for gbm via dismo!");
        } else if (getScoreType().equals(R_SQUARE) || getScoreType().equals(R_SQUARE_ADJUSTED)) {
            super.getRConnection().voidEval("y.predicted = out$fitted");
            d = computeRSquare();
        } else {
            this._logger.error("Score " + getScoreType() + " not supported. Supported scores are " + ArrayTools.stringArrayToString(SUPPORTED_SCORES, ", "));
        }
        return d;
    }

    private double doGBM(List<String> list) throws REngineException, REXPMismatchException {
        String str;
        double d = Double.NaN;
        if (!RConnectionProvider.GBM_LOADED) {
            RConnectionProvider.LOAD_GBM = true;
            RConnectionProvider.loadGBM();
        }
        if (!getFormula().equals(SIMPLE_LINEAR)) {
            this._logger.error("GBM interface currently only supports formulas of the type " + SIMPLE_LINEAR + "!");
        }
        String str2 = "y~";
        if (super.getRConnection().eval("binary").asInteger() == 0) {
            Iterator<String> it = list.iterator();
            while (it.hasNext()) {
                str2 = String.valueOf(str2) + "x." + it.next() + "+";
            }
            str = str2.substring(0, str2.length() - 1);
        } else {
            str = String.valueOf(str2) + "x";
        }
        super.getRConnection().assign("form", str);
        super.getRConnection().voidEval("formula=formula(form)");
        super.getRConnection().assign(OptionNames.errorDistribution, getFamily());
        super.getRConnection().voidEval("out<-gbm(formula, data = data, distribution = family, cv.folds=" + getCrossvalidateFold() + ")");
        super.getRConnection().voidEval("relevances=summary(out,plotit=F)");
        super.getRConnection().voidEval("m=as.matrix(relevances)");
        super.getRConnection().voidEval("values=as.numeric(m[,2])");
        super.getRConnection().voidEval("names=m[,1]");
        super.getRConnection().voidEval("names(values)=names");
        parseRListIntoMap(super.getRConnection().eval("as.list(values)").asList(), list);
        if (getScoreType().equals(AKAIKE)) {
            this._logger.error("Akaike score not supported for gbm!");
        } else if (getScoreType().equals(R_SQUARE) || getScoreType().equals(R_SQUARE_ADJUSTED)) {
            super.getRConnection().voidEval("y.predicted = predict.gbm(out, data, n.trees=out$n.trees, type=\"response\")");
            d = computeRSquare();
        } else {
            this._logger.error("Score " + getScoreType() + " not supported. Supported scores are " + ArrayTools.stringArrayToString(SUPPORTED_SCORES, ", "));
        }
        return d;
    }

    private double doGLMBoost(List<String> list) throws REngineException, REXPMismatchException {
        double d = Double.NaN;
        if (!RConnectionProvider.MBOOST_LOADED) {
            RConnectionProvider.LOAD_MBOOST = true;
            RConnectionProvider.loadMboost();
        }
        super.getRConnection().voidEval("formula=formula(form)");
        super.getRConnection().voidEval("out<-glmboost(formula, data = data,control = boost_control(mstop=" + getBoostIterations() + "), family = " + getFamily() + ")");
        parseRListIntoMap(super.getRConnection().eval("as.list(coef(out))").asList(), list);
        if (getScoreType().equals(AKAIKE)) {
            d = 1.0d / super.getRConnection().eval("AIC(out,method=classical)").asDouble();
        } else if (getScoreType().equals(R_SQUARE) || getScoreType().equals(R_SQUARE_ADJUSTED)) {
            super.getRConnection().voidEval("y.predicted = predict(out,type=\"response\")");
            d = computeRSquare();
        } else {
            this._logger.error("Score " + getScoreType() + " not supported. Supported scores are " + ArrayTools.stringArrayToString(SUPPORTED_SCORES, ", "));
        }
        return d;
    }

    private double doGML(List<String> list) throws REngineException, REXPMismatchException {
        double d = Double.NaN;
        super.getRConnection().voidEval("out<-glm(form,data=data,family=" + getFamily() + ")");
        parseRListIntoMap(super.getRConnection().eval("as.list(coef(out))").asList(), list);
        if (getScoreType().equals(AKAIKE)) {
            d = 1.0d / super.getRConnection().eval("out$aic").asDouble();
        } else if (getScoreType().equals(R_SQUARE) || getScoreType().equals(R_SQUARE_ADJUSTED)) {
            super.getRConnection().voidEval("y.predicted = predict(out,type=\"response\")");
            d = computeRSquare();
        } else {
            this._logger.error("Score " + getScoreType() + " not supported. Supported scores are " + ArrayTools.stringArrayToString(SUPPORTED_SCORES, ", "));
        }
        return d;
    }

    private double doLM(List<String> list) throws REngineException, REXPMismatchException {
        double d = Double.NaN;
        super.getRConnection().voidEval("out<-lm(form,data=data)");
        parseRListIntoMap(super.getRConnection().eval("as.list(coef(out))").asList(), list);
        if (getScoreType().equals(AKAIKE)) {
            this._logger.error("The Akaike criterion is not supported for R function " + LM_R + "!");
        } else if (getScoreType().equals(R_SQUARE) || getScoreType().equals(R_SQUARE_ADJUSTED)) {
            super.getRConnection().voidEval("y.predicted = predict(out,type=\"response\")");
            d = computeRSquare();
        } else {
            this._logger.error("Score " + getScoreType() + " not supported. Supported scores are " + ArrayTools.stringArrayToString(SUPPORTED_SCORES, ", "));
        }
        return d;
    }

    private double doLinearFit(DoubleMatrix1D doubleMatrix1D, DoubleMatrix1D doubleMatrix1D2, String str) {
        double d = Double.NaN;
        ArrayList arrayList = new ArrayList();
        arrayList.add("x");
        try {
            try {
                try {
                    if (!super.isRConnectionSet()) {
                        super.setInternalRConnection(RConnectionProvider.getInstance());
                    }
                    super.getRConnection().assign("form", getFormula());
                    super.getRConnection().assign("x", doubleMatrix1D.toArray());
                    super.getRConnection().assign("y", doubleMatrix1D2.toArray());
                    if (isSubtractMeanFromResponse().booleanValue()) {
                        super.getRConnection().voidEval("y=y-mean(y)");
                    }
                    super.getRConnection().assign("xname", str);
                    super.getRConnection().voidEval("binary=1");
                    super.getRConnection().voidEval("data=data.frame(x=x,y=y)");
                    if (getRFunction().equals(GLMBOOST_R)) {
                        d = doGLMBoost(arrayList);
                    } else if (getRFunction().equals(GLM_R)) {
                        d = doGML(arrayList);
                    } else if (getRFunction().equals(LM_R)) {
                        d = doLM(arrayList);
                    } else if (getRFunction().equals(GBM_R)) {
                        d = doGBM(arrayList);
                    } else {
                        this._logger.error("Selected R function (" + getRFunction() + ") not supported! Available R functions are: " + ArrayTools.stringArrayToString(R_PACKAGES, ", "));
                    }
                    if (!this._rowNameVsCoeffiCurrentMap.isEmpty()) {
                        HashMap hashMap = new HashMap();
                        hashMap.put(str, this._rowNameVsCoeffiCurrentMap.get("x"));
                        this._rowNameVsCoeffiCurrentMap = hashMap;
                    }
                    if (!super.isRConnectionSet()) {
                        super.getRConnection().close();
                    }
                } catch (RserveException e) {
                    e.printStackTrace();
                    super.getRConnection().close();
                    if (!super.isRConnectionSet()) {
                        super.getRConnection().close();
                    }
                }
            } catch (REXPMismatchException e2) {
                e2.printStackTrace();
                super.getRConnection().close();
                if (!super.isRConnectionSet()) {
                    super.getRConnection().close();
                }
            } catch (REngineException e3) {
                e3.printStackTrace();
                super.getRConnection().close();
                if (!super.isRConnectionSet()) {
                    super.getRConnection().close();
                }
            }
            return d;
        } catch (Throwable th) {
            if (!super.isRConnectionSet()) {
                super.getRConnection().close();
            }
            throw th;
        }
    }

    private double doLinearFit(Matrix matrix, DoubleMatrix1D doubleMatrix1D) {
        double d = Double.NaN;
        double[] dArr = {matrix.getMatrix().rows(), matrix.getMatrix().columns()};
        ArrayList arrayList = new ArrayList();
        for (String str : matrix.getRowNames()) {
            try {
                arrayList.add(str);
            } catch (Throwable th) {
                if (!super.isRConnectionSet()) {
                    super.getRConnection().close();
                }
                throw th;
            }
        }
        try {
            if (!super.isRConnectionSet()) {
                super.setInternalRConnection(RConnectionProvider.getInstance());
            }
            super.getRConnection().assign("form", getFormula());
            super.getRConnection().assign("vec", matrix.toDoubleVector());
            super.getRConnection().assign("dim", dArr);
            super.getRConnection().assign("rownames", matrix.getRowNames());
            super.getRConnection().voidEval("mat<-matrix(data=vec, nrow=dim[1], ncol=dim[2], byrow=T)");
            super.getRConnection().voidEval("rownames(mat)=rownames");
            super.getRConnection().voidEval("binary=0");
            super.getRConnection().voidEval("x<-t(mat)");
            super.getRConnection().assign("y", doubleMatrix1D.toArray());
            if (isSubtractMeanFromResponse().booleanValue()) {
                super.getRConnection().voidEval("y=y-mean(y)");
            }
            super.getRConnection().voidEval("data=data.frame(x=x,y=y)");
            if (getRFunction().equals(GLMBOOST_R)) {
                d = doGLMBoost(arrayList);
            } else if (getRFunction().equals(GLM_R)) {
                d = doGML(arrayList);
            } else if (getRFunction().equals(LM_R)) {
                d = doLM(arrayList);
            } else if (getRFunction().equals(GBM_R)) {
                d = doGBM(arrayList);
            } else {
                this._logger.error("Selected R function (" + getRFunction() + ") not supported! Available R functions are: " + ArrayTools.stringArrayToString(R_PACKAGES, ", "));
            }
            if (!super.isRConnectionSet()) {
                super.getRConnection().close();
            }
        } catch (REXPMismatchException e) {
            e.printStackTrace();
            super.getRConnection().close();
            if (!super.isRConnectionSet()) {
                super.getRConnection().close();
            }
        } catch (RserveException e2) {
            e2.printStackTrace();
            super.getRConnection().close();
            if (!super.isRConnectionSet()) {
                super.getRConnection().close();
            }
        } catch (REngineException e3) {
            e3.printStackTrace();
            super.getRConnection().close();
            if (!super.isRConnectionSet()) {
                super.getRConnection().close();
            }
        }
        return d;
    }

    private double myPredict(Matrix matrix) {
        double d = Double.NaN;
        String str = String.valueOf("data") + "=data.frame(";
        String str2 = "y.predicted=";
        Set<String> arrayToSet = ArrayTools.arrayToSet(matrix.getRowNames());
        try {
            try {
                try {
                    if (!super.isRConnectionSet()) {
                        super.setInternalRConnection(RConnectionProvider.getInstance());
                    }
                    for (String str3 : this._rowNameVsCoeffiCurrentMap.keySet()) {
                        if (arrayToSet.contains(str3) || str3.equals(INTERCEPT)) {
                            str2 = str3.equals(INTERCEPT) ? String.valueOf(str2) + this._rowNameVsCoeffiCurrentMap.get(str3) + "+" : String.valueOf(str2) + this._rowNameVsCoeffiCurrentMap.get(str3) + "*data$" + str3 + "+";
                            super.getRConnection().assign(str3, matrix.getMatrix().viewRow(matrix.getIndexOfRowName(str3)).toArray());
                            str = String.valueOf(str) + str3 + "=" + str3 + ",";
                        } else {
                            this._logger.warn("Factor " + str3 + " not contained in test data set!");
                        }
                    }
                    if (str2.endsWith("+")) {
                        str2 = str2.substring(0, str2.length() - 1);
                    }
                    if (str.endsWith(",")) {
                        str = str.substring(0, str.length() - 1);
                    }
                    String str4 = String.valueOf(str) + ")";
                    this._logger.info("Predicted response with formula: " + str2);
                    getRConnection().assign("y", matrix.getMatrix().viewRow(matrix.getIndexOfRowName("y")).toArray());
                    getRConnection().voidEval(str2);
                    if (getScoreType().equals(R_SQUARE) || getScoreType().equals(R_SQUARE_ADJUSTED)) {
                        d = computeRSquare();
                    } else if (getScoreType().equals(AKAIKE)) {
                        this._logger.error("Score type " + AKAIKE + " cannot be used in combination with cross-validation.");
                    } else {
                        this._logger.error("Score " + getScoreType() + " not supported. Supported scores are " + ArrayTools.stringArrayToString(SUPPORTED_SCORES, ", "));
                    }
                    if (!super.isRConnectionSet()) {
                        super.getRConnection().close();
                    }
                } catch (REXPMismatchException e) {
                    e.printStackTrace();
                    if (!super.isRConnectionSet()) {
                        super.getRConnection().close();
                    }
                }
            } catch (RserveException e2) {
                e2.printStackTrace();
                super.getRConnection().close();
                if (!super.isRConnectionSet()) {
                    super.getRConnection().close();
                }
            } catch (REngineException e3) {
                e3.printStackTrace();
                super.getRConnection().close();
                if (!super.isRConnectionSet()) {
                    super.getRConnection().close();
                }
            }
            return d;
        } catch (Throwable th) {
            if (!super.isRConnectionSet()) {
                super.getRConnection().close();
            }
            throw th;
        }
    }

    private double crossValidate(Matrix matrix, DoubleMatrix1D doubleMatrix1D, double d) {
        new Matrix();
        new Matrix();
        Matrix shuffleColumns = MatrixToolsProvider.shuffleColumns(MatrixToolsProvider.addRowToMatrix(doubleMatrix1D, "y", matrix));
        CrossValidator crossValidator = new CrossValidator();
        crossValidator.setMatrix(shuffleColumns);
        crossValidator.setFold(getCrossvalidateFold().intValue());
        crossValidator.setSubdataSizeGivenFoldNumber();
        crossValidator.computeSubdataPartitioning();
        this._logger.info("Number of cv iterations: " + crossValidator.getIterationNumber());
        ArrayList arrayList = new ArrayList();
        arrayList.add("y");
        for (int i = 0; i < crossValidator.getIterationNumber(); i++) {
            this._logger.info("Iteration number: " + i);
            crossValidator.crossValidate();
            Matrix resampledMatrix = crossValidator.getResampledMatrix();
            if (resampledMatrix.getMatrix().columns() < 2) {
                this._logger.error("Cross-validation training data sub-matrix has less than 2 columns!");
            }
            Matrix validationData = crossValidator.getValidationData();
            if (validationData.getMatrix().columns() < 2) {
                this._logger.error("Cross-validation test data sub-matrix has less than 2 columns!");
            }
            DoubleMatrix1D viewRow = resampledMatrix.getMatrix().viewRow(resampledMatrix.getIndexOfRowName("y"));
            Matrix submatrixWithoutRows = MatrixToolsProvider.getSubmatrixWithoutRows(resampledMatrix, arrayList);
            this._logger.info("Column number in train matrix: " + submatrixWithoutRows.getMatrix().columns());
            this._logger.info("Column number in test matrix: " + validationData.getMatrix().columns());
            if (!getFormularType().equals(LINEAR_FORMULAR_TYPE)) {
                this._logger.fatal("For the moment, only linear models are supported.");
                throw new IllegalArgumentException("For the moment, only linear models are supported.");
            }
            doLinearFit(submatrixWithoutRows, viewRow);
            double myPredict = myPredict(validationData);
            this._logger.info("score: " + myPredict);
            if ((!isDiscardZeroRSquare().booleanValue() || myPredict != 0.0d || !getScoreType().equals(R_SQUARE)) && !getScoreType().equals(R_SQUARE_ADJUSTED) && myPredict < d) {
                d = myPredict;
            }
        }
        return d;
    }

    private void addEdgeAttributes(String str, double d, double d2, String str2) {
        if (getCooccurrenceNetwork().getGraph().hasArc(str)) {
            if (getCooccurrenceNetwork().hasDataAnnotation(str, MODELQUALITY_ATTRIBUTE)) {
                this._logger.warn("Edge " + str + " already added before!");
            } else {
                getCooccurrenceNetwork().getDatas().get(0).put(str, MODELQUALITY_ATTRIBUTE, Double.valueOf(d));
            }
            if (getCooccurrenceNetwork().hasDataAnnotation(str, COEFFICIENT_ATTRIBUTE)) {
                this._logger.warn("Edge " + str + " already added before!");
            } else {
                getCooccurrenceNetwork().getDatas().get(0).put(str, COEFFICIENT_ATTRIBUTE, Double.valueOf(d2));
            }
            if (getCooccurrenceNetwork().hasDataAnnotation(str, GROUP_ATTRIBUTE)) {
                return;
            }
            if (str2.isEmpty()) {
                this._logger.warn("Edge " + str + " already added before!");
            } else {
                getCooccurrenceNetwork().getDatas().get(0).put(str, GROUP_ATTRIBUTE, str2);
            }
        }
    }

    private boolean buildEdges(double d, String str, String str2) {
        this._logger.debug("Score " + d);
        if (d > this._maxValue) {
            this._maxValue = d;
        }
        boolean z = d > getLowerThreshold().doubleValue();
        if (z) {
            String str3 = "";
            double d2 = 0.0d;
            for (String str4 : this._rowNameVsCoeffiCurrentMap.keySet()) {
                if (!str4.contains(INTERCEPT)) {
                    double doubleValue = this._rowNameVsCoeffiCurrentMap.get(str4).doubleValue();
                    d2 += Math.abs(doubleValue);
                    String str5 = doubleValue >= 0.0d ? "copresence" : CooccurrenceConstants.MUTUAL_EXCLUSION;
                    if (isDisplayGroupEdgesSeparately().booleanValue()) {
                        addSpeciesInteraction(str4, str, str5, Double.valueOf(d));
                        addEdgeAttributes(String.valueOf(str4) + "->" + str, d, doubleValue, str2);
                    } else {
                        str3 = str5.equals("copresence") ? String.valueOf(str3) + CooccurrenceConstants.AND + str4 : String.valueOf(str3) + "_AND_neg" + str4;
                    }
                }
            }
            if (!isDisplayGroupEdgesSeparately().booleanValue()) {
                if (str3.startsWith(CooccurrenceConstants.AND)) {
                    str3 = str3.replaceFirst(CooccurrenceConstants.AND, "");
                }
                addSpeciesInteraction(str3, str, "->", Double.valueOf(d));
                addEdgeAttributes(String.valueOf(str3) + "->" + str, d, d2, str2);
            }
        }
        return z;
    }

    private double doMultiRegression(Matrix matrix, String str, Set<String> set) {
        double d = Double.NaN;
        if (matrix.isEmpty()) {
            this._logger.warn("Submatrix is empty!");
        } else {
            NaNTreatmentProvider.getInstance().setMatrix(matrix);
            Matrix treatMissingValuesInMatrix = NaNTreatmentProvider.getInstance().treatMissingValuesInMatrix();
            if (NaNTreatmentProvider.getInstance().proceed()) {
                DoubleMatrix1D viewRow = treatMissingValuesInMatrix.getMatrix().viewRow(treatMissingValuesInMatrix.getIndexOfRowName(str));
                ArrayList arrayList = new ArrayList();
                arrayList.add(str);
                Matrix submatrixWithoutRows = MatrixToolsProvider.getSubmatrixWithoutRows(treatMissingValuesInMatrix, arrayList);
                if (isSpearmanFilter().booleanValue()) {
                    ArrayList arrayList2 = new ArrayList();
                    for (int i = 0; i < submatrixWithoutRows.getMatrix().rows(); i++) {
                        if (this._spearmanFilteredEdges.contains(String.valueOf(submatrixWithoutRows.getRowName(i)) + "->" + str)) {
                            this._logger.debug("Spearman filter removes interaction " + submatrixWithoutRows.getRowName(i) + HelpFormatter.DEFAULT_OPT_PREFIX + str);
                        } else if (MatrixToolsProvider.getSpearmanUsingJSC(submatrixWithoutRows.getMatrix().viewRow(i), viewRow, true, true) < getSpearmanFilterThreshold().doubleValue()) {
                            arrayList2.add(submatrixWithoutRows.getRowName(i));
                        } else {
                            this._spearmanFilteredEdges.add(String.valueOf(submatrixWithoutRows.getRowName(i)) + "->" + str);
                            this._logger.debug("Spearman filter removes interaction " + submatrixWithoutRows.getRowName(i) + HelpFormatter.DEFAULT_OPT_PREFIX + str);
                        }
                    }
                    submatrixWithoutRows = MatrixToolsProvider.getSubmatrixWithoutRows(submatrixWithoutRows, arrayList2);
                }
                if (submatrixWithoutRows.getMatrix().rows() <= 0) {
                    this._logger.info("Submatrix " + set.toString() + " has no rows left!");
                } else {
                    if (!getFormularType().equals(LINEAR_FORMULAR_TYPE)) {
                        this._logger.fatal("For the moment, only linear models are supported.");
                        throw new IllegalArgumentException("For the moment, only linear models are supported.");
                    }
                    d = doLinearFit(submatrixWithoutRows, viewRow);
                    if (getCrossvalidateFold().intValue() > 0 && !getRFunction().equals(GBM_R)) {
                        d = crossValidate(submatrixWithoutRows, viewRow, d);
                    }
                }
            } else {
                this._logger.warn("Submatrix " + set.toString() + " versus response " + str + " has too many missing values and is skipped.");
            }
        }
        return d;
    }

    @Override // be.ac.vub.bsb.cooccurrence.core.CooccurrenceNetworkBuilder
    public void buildNetwork() {
        if (getLowerThreshold().isNaN() && getUpperThreshold().isNaN()) {
            throw new IllegalArgumentException("You need to set a co-occurrence threshold!");
        }
        if (!getUpperThreshold().isNaN() && !getLowerThreshold().isNaN()) {
            this._logger.warn("Both thresholds are set, but only one is required! The value of the upper threshold (" + getUpperThreshold() + ") is selected as threshold on the number of co-occurrences.");
        } else if (getLowerThreshold().isNaN()) {
            setLowerThreshold(getUpperThreshold());
        } else if (getUpperThreshold().isNaN()) {
            setUpperThreshold(getLowerThreshold());
        }
        if (getRFunction().equals(GLM_R)) {
            if (this._familyVsGLMName.containsKey(getFamily())) {
                setFamily(this._familyVsGLMName.get(getFamily()));
            } else {
                this._logger.warn("Family " + getFamily() + " not among the known families for R function " + getRFunction() + "!");
            }
        } else if (getRFunction().equals(GLMBOOST_R)) {
            if (this._familyVsGLMBoostName.containsKey(getFamily())) {
                setFamily(this._familyVsGLMBoostName.get(getFamily()));
            } else {
                this._logger.warn("Family " + getFamily() + " not among the known families for R function " + getRFunction() + "!");
            }
        } else if (getRFunction().equals(GBM_R)) {
            if (this._familyVsGBMName.containsKey(getFamily())) {
                setFamily(this._familyVsGBMName.get(getFamily()));
            } else {
                this._logger.warn("Family " + getFamily() + " not among the known families for R function " + getRFunction() + "!");
            }
        }
        if (isAllAgainstAll() || !getGroupAttribute().isEmpty()) {
            setAsymmetricModel(true);
        } else {
            setAsymmetricModel(false);
        }
        this._spearmanFilteredEdges = new HashSet();
        checkThresholds();
        super.initCooccurrenceNetwork();
        getMatrix().setIndicesAsColNames();
        if (!getGroupAttribute().isEmpty()) {
            setNetworkDirected(true);
        }
        Groups groups = new Groups();
        if (!getGroupAttribute().isEmpty()) {
            MatrixMetadataGroupManager matrixMetadataGroupManager = new MatrixMetadataGroupManager(getMatrix(), getGroupAttribute());
            matrixMetadataGroupManager.assembleGroups();
            groups = matrixMetadataGroupManager.getGroups();
        }
        new HashSet();
        String str = "";
        new Matrix();
        double d = Double.NaN;
        new Matrix();
        int i = 0;
        int rows = getMatrix().getMatrix().rows() - 1;
        int i2 = isAllAgainstAll() ? 0 : 1;
        if (isNetworkDirected() && !isAllAgainstAll()) {
            rows = getMatrix().getMatrix().rows() - 1;
            i = rows - 1;
            i2 = 0;
        }
        for (int i3 = i2; i3 <= rows; i3++) {
            String rowName = getMatrix().getRowName(i3);
            this._logger.info("Going through row with index " + i3 + " and name " + rowName);
            if (isAllAgainstAll()) {
                HashSet hashSet = new HashSet();
                Matrix submatrixWithoutRows = MatrixToolsProvider.getSubmatrixWithoutRows(getMatrix(), hashSet);
                hashSet.addAll(ArrayTools.arrayToSet(submatrixWithoutRows.getRowNames()));
                double doMultiRegression = doMultiRegression(submatrixWithoutRows, rowName, hashSet);
                if (!Double.isNaN(doMultiRegression)) {
                    buildEdges(doMultiRegression, rowName, rowName);
                }
            } else {
                if (!isNetworkDirected()) {
                    i = i3 - 1;
                }
                for (int i4 = 0; i4 <= i; i4++) {
                    if (i3 != i4) {
                        double d2 = Double.NaN;
                        boolean z = false;
                        String rowName2 = getMatrix().getRowName(i4);
                        NaNTreatmentProvider.getInstance().setMatrix(getMatrix());
                        NaNTreatmentProvider.getInstance().setXRow(i3);
                        NaNTreatmentProvider.getInstance().setYRow(i4);
                        DoubleMatrix1D treatedXRow = NaNTreatmentProvider.getInstance().getTreatedXRow();
                        DoubleMatrix1D treatedYRow = NaNTreatmentProvider.getInstance().getTreatedYRow();
                        if (NaNTreatmentProvider.getInstance().proceed()) {
                            if (getGroupAttribute().isEmpty()) {
                                if (!getFormularType().equals(LINEAR_FORMULAR_TYPE)) {
                                    this._logger.fatal("For the moment, only linear models are supported.");
                                    throw new IllegalArgumentException("For the moment, only linear models are supported.");
                                }
                                if (isSpearmanFilter().booleanValue()) {
                                    d = MatrixToolsProvider.getSpearmanValueUsingR(treatedXRow, treatedYRow, true, false);
                                }
                                if (d < getSpearmanFilterThreshold().doubleValue() || !isSpearmanFilter().booleanValue()) {
                                    d2 = doLinearFit(treatedXRow, treatedYRow, rowName);
                                    Matrix matrix = new Matrix(1, treatedYRow.size());
                                    matrix.setRowName(0, rowName);
                                    matrix.setRow(0, treatedXRow.toArray());
                                    if (getCrossvalidateFold().intValue() > 0 && !getRFunction().equals(GBM_R)) {
                                        d2 = crossValidate(matrix, treatedYRow, d2);
                                    }
                                } else {
                                    z = true;
                                }
                            } else if (getMatrix().hasRowMetaAnnotation(i3, getGroupAttribute())) {
                                str = groups.getGroupsOfMember(rowName).iterator().next();
                                HashSet<String> membersOfGroup = groups.getMembersOfGroup(str);
                                membersOfGroup.add(rowName2);
                                d2 = doMultiRegression(MatrixToolsProvider.getSubMatrix(getMatrix(), membersOfGroup), rowName2, membersOfGroup);
                                if (Double.isNaN(d2)) {
                                    z = true;
                                }
                            } else {
                                this._logger.warn("Row " + getMatrix().getRowName(i3) + " does not contain a value for metadata attribute " + getGroupAttribute() + "!");
                                z = true;
                            }
                        }
                        if (!z) {
                            buildEdges(d2, rowName2, str);
                        }
                    }
                }
            }
        }
        if (getScoreType().equals(R_SQUARE) || getScoreType().equals(R_SQUARE_ADJUSTED)) {
            this._maxValue = 1.0d;
        }
        this._logger.info("Maximum weight: " + getMaxValue());
        getCooccurrenceNetwork().getDatas().get(0).put(getCooccurrenceNetwork().getGraph().getIdentifier(), "Comment", toString());
    }

    public double getMaxValue() {
        return this._maxValue;
    }

    private void setAsymmetricModel(Boolean bool) {
        setNetworkDirected(bool.booleanValue());
    }

    public Boolean getAsymmetricModel() {
        return Boolean.valueOf(isNetworkDirected());
    }

    public void setAllAgainstAll(boolean z) {
        this._allAgainstAll = z;
    }

    public boolean isAllAgainstAll() {
        return this._allAgainstAll;
    }

    public void setFormula(String str) {
        this._formula = str;
    }

    public String getFormula() {
        return this._formula;
    }

    public void setFormularType(String str) {
        this._formularType = str;
    }

    public String getFormularType() {
        return this._formularType;
    }

    public void setSubtractMeanFromResponse(Boolean bool) {
        this._subtractMeanFromResponse = bool.booleanValue();
    }

    public Boolean isSubtractMeanFromResponse() {
        return Boolean.valueOf(this._subtractMeanFromResponse);
    }

    public void setFamily(String str) {
        this._family = str;
    }

    public String getFamily() {
        return this._family;
    }

    public void setScoreType(String str) {
        setReturnType(str);
    }

    public String getScoreType() {
        return getReturnType();
    }

    public void setGroupAttribute(String str) {
        this._groupAttribute = str;
    }

    public String getGroupAttribute() {
        return this._groupAttribute;
    }

    public void setRFunction(String str) {
        this._rFunction = str;
    }

    public String getRFunction() {
        return this._rFunction;
    }

    public void setBoostIterations(Integer num) {
        this._boostIterations = num.intValue();
    }

    public Integer getBoostIterations() {
        return Integer.valueOf(this._boostIterations);
    }

    public void setSpearmanFilter(Boolean bool) {
        this._spearmanFilter = bool.booleanValue();
    }

    public Boolean isSpearmanFilter() {
        return Boolean.valueOf(this._spearmanFilter);
    }

    public void setSpearmanFilterThreshold(Double d) {
        this._spearmanFilterThreshold = d.doubleValue();
    }

    public Double getSpearmanFilterThreshold() {
        return Double.valueOf(this._spearmanFilterThreshold);
    }

    public void setCrossvalidateFold(Integer num) {
        this._crossvalidateFold = num.intValue();
    }

    public Integer getCrossvalidateFold() {
        return Integer.valueOf(this._crossvalidateFold);
    }

    public void setDiscardZeroRSquare(Boolean bool) {
        this._discardZeroRSquare = bool.booleanValue();
    }

    public Boolean isDiscardZeroRSquare() {
        return Boolean.valueOf(this._discardZeroRSquare);
    }

    public void setDisplayGroupEdgesSeparately(Boolean bool) {
        this._displayGroupEdgesSeparately = bool.booleanValue();
    }

    public Boolean isDisplayGroupEdgesSeparately() {
        return Boolean.valueOf(this._displayGroupEdgesSeparately);
    }

    @Override // be.ac.vub.bsb.cooccurrence.core.CooccurrenceNetworkBuilder, be.ac.vub.bsb.cooccurrence.core.NetworkInferenceAlgorithm, be.ac.vub.bsb.cooccurrence.util.IMethod
    public List<String> getParameters() {
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(super.getParameters());
        arrayList.add(PathwayinferenceConstants.FORMULA);
        arrayList.add("FormulaType");
        arrayList.add("SubtractMeanFromResponse");
        arrayList.add("ErrorDistribution");
        arrayList.add("GroupAttribute");
        arrayList.add("RFunction");
        arrayList.add("BoostIterations");
        arrayList.add("SpearmanFilter");
        arrayList.add("SpearmanFilterThreshold");
        arrayList.add("CrossvalidateFold");
        arrayList.add("DiscardZeroRSquare");
        arrayList.add("DisplayGroupEdgesSeparately");
        return arrayList;
    }

    @Override // be.ac.vub.bsb.cooccurrence.core.CooccurrenceNetworkBuilder, be.ac.vub.bsb.cooccurrence.core.NetworkInferenceAlgorithm, be.ac.vub.bsb.cooccurrence.util.IMethod
    public String toString() {
        String str = String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(super.toString()) + "# " + PathwayinferenceConstants.FORMULA + "=" + getFormula() + "\n") + "# Formula is asymmetric=" + getAsymmetricModel() + "\n") + "# Type of formula=" + getFormularType() + "\n") + "# Family=" + getFamily() + "\n") + "# Mean of response is subtracted from response=" + isSubtractMeanFromResponse() + "\n") + "# Score type=" + getScoreType() + "\n") + "# Group attribute=" + getGroupAttribute() + "\n") + "# Group edges displayed separately=" + isDisplayGroupEdgesSeparately() + "\n") + "# R function called=" + getRFunction() + "\n") + "# Boost iterations (for " + GLMBOOST_R + " only)=" + getBoostIterations() + "\n") + "# Spearman filter=" + isSpearmanFilter() + "\n") + "# Spearman filter maximal allowed p-value=" + getSpearmanFilterThreshold() + "\n";
        if (!this._spearmanFilteredEdges.isEmpty()) {
            str = String.valueOf(str) + "# Number of interactions removed by Spearman filter=" + this._spearmanFilteredEdges.size() + "\n";
        }
        return String.valueOf(String.valueOf(str) + "# Cross validation fold (0 = no cross-validation)=" + getCrossvalidateFold() + "\n") + "# Zero R^2 models discarded during cross-validation=" + isDiscardZeroRSquare() + "\n";
    }

    public static void main(String[] strArr) {
        Matrix matrix = new Matrix();
        matrix.readMatrix("/Users/karoline/Documents/Documents_Karoline/Publications/Review_on_microbial_interactions/ExampleChafffron/Input/presenceAbsenceChaffron-D15-Frows4-Fcols3.txt", false);
        try {
            CooccurrenceFromModelNetworkBuilder cooccurrenceFromModelNetworkBuilder = new CooccurrenceFromModelNetworkBuilder(matrix);
            cooccurrenceFromModelNetworkBuilder.setRConnection(RConnectionProvider.getInstance());
            cooccurrenceFromModelNetworkBuilder.setFamily(BINOMIAL_FAMILY);
            cooccurrenceFromModelNetworkBuilder.setRFunction(GBM_R);
            cooccurrenceFromModelNetworkBuilder.setFormula(SIMPLE_LINEAR);
            cooccurrenceFromModelNetworkBuilder.setAllAgainstAll(true);
            cooccurrenceFromModelNetworkBuilder.setScoreType(R_SQUARE_ADJUSTED);
            cooccurrenceFromModelNetworkBuilder.setLowerThreshold(Double.valueOf(0.9d));
            cooccurrenceFromModelNetworkBuilder.setSpearmanFilter(true);
            cooccurrenceFromModelNetworkBuilder.buildNetwork();
            System.out.println(cooccurrenceFromModelNetworkBuilder.toString());
        } catch (RserveException e) {
            System.out.println(e.toString());
        }
    }
}
