package org.genemania.engine.core.integration.calculators;

import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.Vector;
import org.apache.log4j.Logger;
import org.genemania.engine.Constants;
import org.genemania.engine.cache.DataCache;
import org.genemania.engine.core.MatrixUtils;
import org.genemania.engine.core.integration.CombineNetworksOnly;
import org.genemania.engine.core.integration.FeatureList;
import org.genemania.engine.core.integration.Solver;
import org.genemania.engine.core.integration.attribute.QueryEnrichedAttributeScorer;
import org.genemania.engine.core.integration.gram.BasicGramBuilder;
import org.genemania.engine.core.integration.gram.GramEditor;
import org.genemania.engine.exception.WeightingFailedException;
import org.genemania.exception.ApplicationException;
import org.genemania.util.ProgressReporter;

/* loaded from: input_file:org/genemania/engine/core/integration/calculators/BranchSpecificCalculator.class */
public class BranchSpecificCalculator extends AbstractNetworkWeightCalculator {
    Constants.CombiningMethod method;
    public static final String PARAM_KEY_FORMAT = "%s-%s";
    private static Logger logger = Logger.getLogger(BranchSpecificCalculator.class);
    private static int MIN_QUERY_GENES_PER_ATTRIBUTE = 1;

    public BranchSpecificCalculator(String str, DataCache dataCache, Collection<Collection<Long>> collection, Collection<Long> collection2, long j, Vector vector, int i, Constants.CombiningMethod combiningMethod, ProgressReporter progressReporter) throws ApplicationException {
        super(str, dataCache, collection, collection2, j, vector, i, progressReporter);
        this.method = combiningMethod;
    }

    @Override // org.genemania.engine.core.integration.INetworkWeightCalculator
    public void process() throws ApplicationException {
        this.progress.setStatus(Constants.PROGRESS_WEIGHTING_MESSAGE);
        this.progress.setProgress(1);
        computeNewResult(queryHasUserNetworks());
    }

    void computeNewResult(boolean z) throws ApplicationException {
        DenseMatrix ktK = getKtK(z);
        DenseMatrix ktT = getKtT(this.method.toString(), z);
        FeatureList features = this.cache.getKtKFeatures(this.namespace, this.organismId).getFeatures();
        FeatureList buildFeatureList = buildFeatureList(new QueryEnrichedAttributeScorer(this.cache, this.label, MIN_QUERY_GENES_PER_ATTRIBUTE), false);
        buildFeatureList.addBias();
        FeatureList intersect = intersect(features, buildFeatureList);
        FeatureList featureList = setdiff(features, buildFeatureList);
        if (intersect.get(0).getType() != Constants.NetworkType.BIAS) {
            throw new ApplicationException("internal error: bias must be first column");
        }
        DenseMatrix RemoveNetworkKtK = GramEditor.RemoveNetworkKtK(ktK, features, intersect);
        DenseMatrix RemoveNetworkKtT = GramEditor.RemoveNetworkKtT(ktT, features, intersect);
        if (featureList.size() > 0) {
            logger.debug(String.format("need to update gram for %d features", Integer.valueOf(featureList.size())));
            BasicGramBuilder basicGramBuilder = new BasicGramBuilder(this.cache, this.namespace, this.organismId, this.progress);
            RemoveNetworkKtK = basicGramBuilder.updateBasicKtK(RemoveNetworkKtK, intersect, featureList, this.progress);
            RemoveNetworkKtT = basicGramBuilder.updateKtT(RemoveNetworkKtT, intersect, featureList, this.cache.getCoAnnotationSet(this.organismId, this.method.toString()), this.progress);
            intersect.addAll(featureList);
        }
        scaleKtK(RemoveNetworkKtK, this.method.toString());
        try {
            this.weights = Solver.solve(RemoveNetworkKtK, MatrixUtils.extractColumnToVector(RemoveNetworkKtT, 0), intersect, this.progress);
        } catch (WeightingFailedException e) {
            logger.error("weighting calculation failed, falling back to average: " + e.getMessage());
            this.weights = AverageByNetworkCalculator.average(intersect);
        }
        this.progress.setStatus(Constants.PROGRESS_COMBINING_MESSAGE);
        this.progress.setProgress(2);
        this.combinedMatrix = CombineNetworksOnly.combine(this.weights, this.namespace, this.organismId, this.cache, this.progress);
    }

    private FeatureList intersect(FeatureList featureList, FeatureList featureList2) {
        HashSet hashSet = new HashSet();
        hashSet.addAll(featureList);
        HashSet hashSet2 = new HashSet();
        hashSet2.addAll(featureList2);
        hashSet.retainAll(hashSet2);
        FeatureList featureList3 = new FeatureList();
        featureList3.addAll(hashSet);
        Collections.sort(featureList3);
        return featureList3;
    }

    private FeatureList setdiff(FeatureList featureList, FeatureList featureList2) {
        HashSet hashSet = new HashSet();
        hashSet.addAll(featureList2);
        hashSet.removeAll(intersect(featureList, featureList2));
        FeatureList featureList3 = new FeatureList();
        featureList3.addAll(hashSet);
        Collections.sort(featureList3);
        return featureList3;
    }

    @Override // org.genemania.engine.core.integration.calculators.AbstractNetworkWeightCalculator, org.genemania.engine.core.integration.INetworkWeightCalculator
    public String getParameterKey() throws ApplicationException {
        if (this.attributeGroupIds == null || this.attributeGroupIds.size() <= 0) {
            return String.format("%s-%s", this.method.toString(), formattedNetworkList(this.networkIds));
        }
        throw new ApplicationException("not cacheable");
    }
}
