package edu.princeton.safe.internal;

import cern.jet.stat.Probability;
import edu.princeton.safe.AnnotationProvider;
import edu.princeton.safe.DistanceMetric;
import edu.princeton.safe.GroupingMethod;
import edu.princeton.safe.NeighborhoodScoringMethod;
import edu.princeton.safe.NetworkProvider;
import edu.princeton.safe.ProgressReporter;
import edu.princeton.safe.RestrictionMethod;
import edu.princeton.safe.internal.scoring.RandomizedMemberScoringMethod;
import edu.princeton.safe.io.DomainConsumer;
import edu.princeton.safe.model.Neighborhood;
import java.util.Arrays;
import java.util.List;
import java.util.OptionalDouble;
import java.util.OptionalInt;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.apache.commons.math3.distribution.HypergeometricDistribution;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.commons.math3.random.Well44497b;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;

/* loaded from: input_file:safe-core-1.0.0-beta6.jar:edu/princeton/safe/internal/ParallelSafe.class */
public class ParallelSafe {
    NetworkProvider networkProvider;
    AnnotationProvider annotationProvider;
    DistanceMetric distanceMetric;
    RestrictionMethod restrictionMethod;
    GroupingMethod groupingMethod;
    BackgroundMethod backgroundMethod;
    ProgressReporter progressReporter;
    boolean isDistanceThresholdAbsolute;
    double distanceThreshold;
    int empiricalIterations;

    public ParallelSafe(NetworkProvider networkProvider, AnnotationProvider annotationProvider, DistanceMetric distanceMetric, BackgroundMethod backgroundMethod, RestrictionMethod restrictionMethod, GroupingMethod groupingMethod, boolean z, double d, int i, ProgressReporter progressReporter) {
        this.networkProvider = networkProvider;
        this.annotationProvider = annotationProvider;
        this.distanceMetric = distanceMetric;
        this.backgroundMethod = backgroundMethod;
        this.restrictionMethod = restrictionMethod;
        this.groupingMethod = groupingMethod;
        this.progressReporter = progressReporter;
        this.isDistanceThresholdAbsolute = z;
        this.distanceThreshold = d;
        this.empiricalIterations = i;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void computeUnimodality(DefaultEnrichmentLandscape defaultEnrichmentLandscape, DefaultCompositeMap defaultCompositeMap, RestrictionMethod restrictionMethod, ProgressReporter progressReporter) {
        if (restrictionMethod != null) {
            restrictionMethod.applyRestriction(defaultEnrichmentLandscape, defaultCompositeMap, progressReporter);
        } else {
            int attributeCount = defaultEnrichmentLandscape.getAnnotationProvider().getAttributeCount();
            IntStream.range(0, defaultCompositeMap.isTop.length).forEach(i -> {
                IntStream.range(0, attributeCount).forEach(i -> {
                    defaultCompositeMap.setTop(i, i, true);
                });
            });
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void computeDomains(DefaultEnrichmentLandscape defaultEnrichmentLandscape, DefaultCompositeMap defaultCompositeMap, int i, ProgressReporter progressReporter) {
        computeDomains(defaultEnrichmentLandscape, defaultCompositeMap, 0, i, progressReporter);
        if (defaultEnrichmentLandscape.getAnnotationProvider().isBinary()) {
            return;
        }
        computeDomains(defaultEnrichmentLandscape, defaultCompositeMap, 1, i, progressReporter);
    }

    static void computeDomains(DefaultEnrichmentLandscape defaultEnrichmentLandscape, DefaultCompositeMap defaultCompositeMap, int i, int i2, ProgressReporter progressReporter) {
        switch (i) {
            case 0:
                progressReporter.setStatus("Computing highest most significant domains...", new Object[0]);
                break;
            case 1:
                progressReporter.setStatus("Computing lowest most significant domains...", new Object[0]);
                break;
            default:
                throw new RuntimeException();
        }
        AnnotationProvider annotationProvider = defaultEnrichmentLandscape.annotationProvider;
        SignificancePredicate significancePredicate = Neighborhood.getSignificancePredicate(i, annotationProvider.getAttributeCount());
        ScoringFunction scoringFunction = Neighborhood.getScoringFunction(i);
        List<DefaultDomain> list = defaultCompositeMap.domainsByType[i];
        if (list == null) {
            return;
        }
        defaultEnrichmentLandscape.neighborhoods.stream().forEach(defaultNeighborhood -> {
            computeDomains(defaultCompositeMap, i, scoringFunction, significancePredicate, list, defaultNeighborhood);
        });
        if (i2 > 0) {
            minimizeDomains(defaultEnrichmentLandscape, defaultCompositeMap, i, i2, progressReporter);
        }
        DomainLabeller.assignLabels(annotationProvider, list);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void computeDomains(DefaultCompositeMap defaultCompositeMap, int i, ScoringFunction scoringFunction, SignificancePredicate significancePredicate, List<DefaultDomain> list, DefaultNeighborhood defaultNeighborhood) {
        int asInt;
        int size = list.size();
        int[] iArr = new int[size];
        double[] dArr = new double[size];
        IntStream.range(0, size).forEach(i2 -> {
            ((DefaultDomain) list.get(i2)).forEachAttribute(i2 -> {
                if (significancePredicate.test(defaultNeighborhood, i2)) {
                    iArr[i2] = iArr[i2] + 1;
                }
                dArr[i2] = dArr[i2] + scoringFunction.get(defaultNeighborhood, i2);
            });
        });
        OptionalInt max = Arrays.stream(iArr).max();
        if (max.isPresent() && (asInt = max.getAsInt()) > 0) {
            int[] array = IntStream.range(0, size).filter(i3 -> {
                return iArr[i3] == asInt;
            }).toArray();
            OptionalDouble max2 = Arrays.stream(array).mapToDouble(i4 -> {
                return dArr[i4];
            }).max();
            if (max2.isPresent()) {
                double asDouble = max2.getAsDouble();
                OptionalInt findFirst = Arrays.stream(array).filter(i5 -> {
                    return dArr[i5] == asDouble;
                }).findFirst();
                if (findFirst.isPresent()) {
                    DefaultDomain defaultDomain = list.get(findFirst.getAsInt());
                    OptionalDouble max3 = StreamSupport.stream(defaultDomain.attributeIndexes.spliterator(), false).mapToDouble(intCursor -> {
                        return scoringFunction.get(defaultNeighborhood, intCursor.value);
                    }).max();
                    int nodeIndex = defaultNeighborhood.getNodeIndex();
                    defaultCompositeMap.maximumEnrichment[i][nodeIndex] = max3.getAsDouble();
                    defaultCompositeMap.topDomain[i][nodeIndex] = defaultDomain;
                }
            }
        }
    }

    static void minimizeDomains(DefaultEnrichmentLandscape defaultEnrichmentLandscape, DefaultCompositeMap defaultCompositeMap, int i, int i2, ProgressReporter progressReporter) {
        double enrichmentThreshold = Neighborhood.getEnrichmentThreshold(defaultEnrichmentLandscape.getAnnotationProvider().getAttributeCount());
        SignificancePredicate significancePredicate = (neighborhood, i3) -> {
            return defaultCompositeMap.maximumEnrichment[i][neighborhood.getNodeIndex()] > enrichmentThreshold;
        };
        List<DefaultDomain> list = defaultCompositeMap.domainsByType[i];
        int size = list.size();
        IntStream.range(0, size).forEach(i4 -> {
            ((DefaultDomain) list.get(i4)).index = i4;
        });
        int[] iArr = new int[size];
        defaultEnrichmentLandscape.neighborhoods.stream().forEach(defaultNeighborhood -> {
            if (significancePredicate.test(defaultNeighborhood, 0)) {
                int i5 = defaultCompositeMap.topDomain[i][defaultNeighborhood.nodeIndex].index;
                iArr[i5] = iArr[i5] + 1;
            }
        });
        List<DefaultDomain> list2 = (List) list.stream().filter(defaultDomain -> {
            return iArr[defaultDomain.index] >= i2;
        }).collect(Collectors.toList());
        IntStream.range(0, size).forEach(i5 -> {
            ((DefaultDomain) list.get(i5)).index = -1;
        });
        progressReporter.setStatus("Total domains: %d", Integer.valueOf(list.size()));
        progressReporter.setStatus("Total domains (after filtering): %d", Integer.valueOf(list2.size()));
        IntStream.range(0, list2.size()).forEach(i6 -> {
            ((DefaultDomain) list2.get(i6)).index = i6;
        });
        defaultCompositeMap.domainsByType[i] = list2;
        progressReporter.setStatus("Nodes with domain: %d", Long.valueOf(Arrays.stream(defaultCompositeMap.topDomain[i]).filter(defaultDomain2 -> {
            return defaultDomain2 != null;
        }).count()));
        DefaultDomain[] defaultDomainArr = defaultCompositeMap.topDomain[i];
        IntStream.range(0, defaultDomainArr.length).forEach(i7 -> {
            DefaultDomain defaultDomain3 = defaultDomainArr[i7];
            if (defaultDomain3 == null || defaultDomain3.index != -1) {
                return;
            }
            defaultDomainArr[i7] = null;
        });
        progressReporter.setStatus("Nodes with domain (after filtering): %d", Long.valueOf(Arrays.stream(defaultCompositeMap.topDomain[i]).filter(defaultDomain3 -> {
            return defaultDomain3 != null;
        }).count()));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void computeGroups(DefaultEnrichmentLandscape defaultEnrichmentLandscape, DefaultCompositeMap defaultCompositeMap, GroupingMethod groupingMethod, ProgressReporter progressReporter) {
        AnnotationProvider annotationProvider = defaultEnrichmentLandscape.getAnnotationProvider();
        DomainConsumer consumer = defaultCompositeMap.getConsumer();
        groupingMethod.group(defaultEnrichmentLandscape, defaultCompositeMap, 0, consumer, progressReporter);
        if (annotationProvider.isBinary()) {
            return;
        }
        groupingMethod.group(defaultEnrichmentLandscape, defaultCompositeMap, 1, consumer, progressReporter);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void computeEnrichment(NetworkProvider networkProvider, AnnotationProvider annotationProvider, BackgroundMethod backgroundMethod, int i, int i2, ProgressReporter progressReporter, DefaultEnrichmentLandscape defaultEnrichmentLandscape) {
        if (annotationProvider.isBinary()) {
            computeBinaryEnrichment(networkProvider, annotationProvider, progressReporter, defaultEnrichmentLandscape.neighborhoods, backgroundMethod);
        } else {
            computeQuantitativeEnrichment(networkProvider, annotationProvider, new RandomizedMemberScoringMethod(annotationProvider, new Well44497b(i2), i, networkProvider.getNodeCount()), progressReporter, defaultEnrichmentLandscape.neighborhoods);
        }
    }

    static void computeQuantitativeEnrichment(NetworkProvider networkProvider, AnnotationProvider annotationProvider, NeighborhoodScoringMethod neighborhoodScoringMethod, ProgressReporter progressReporter, List<? extends Neighborhood> list) {
        Stream<? extends Neighborhood> stream = list.stream();
        if (progressReporter.supportsParallel()) {
            stream = (Stream) stream.parallel();
        }
        progressReporter.startNeighborhoodScore(networkProvider, annotationProvider);
        stream.forEach(neighborhood -> {
            int nodeIndex = neighborhood.getNodeIndex();
            for (int i = 0; i < annotationProvider.getAttributeCount(); i++) {
                int i2 = i;
                double[] dArr = {CMAESOptimizer.DEFAULT_STOPFITNESS};
                neighborhood.forEachMemberIndex(i3 -> {
                    double value = annotationProvider.getValue(i3, i2);
                    if (Double.isNaN(value)) {
                        return;
                    }
                    dArr[0] = dArr[0] + value;
                });
                double[] computeRandomizedScores = neighborhoodScoringMethod.computeRandomizedScores(neighborhood, i);
                SummaryStatistics summaryStatistics = new SummaryStatistics();
                for (int i4 = 0; i4 < computeRandomizedScores.length; i4++) {
                    if (!Double.isNaN(computeRandomizedScores[i4])) {
                        summaryStatistics.addValue(computeRandomizedScores[i4]);
                    }
                }
                double normal = 1.0d - Probability.normal(summaryStatistics.getMean(), summaryStatistics.getVariance(), dArr[0]);
                neighborhood.setPValue(i, normal);
                progressReporter.neighborhoodScore(nodeIndex, i, Neighborhood.computeEnrichmentScore(normal));
            }
            progressReporter.finishNeighborhood(nodeIndex);
        });
        progressReporter.finishNeighborhoodScore();
    }

    static void computeBinaryEnrichment(NetworkProvider networkProvider, AnnotationProvider annotationProvider, ProgressReporter progressReporter, List<? extends Neighborhood> list, BackgroundMethod backgroundMethod) {
        int annotationNodeCount;
        IntIntFunction intIntFunction;
        switch (backgroundMethod) {
            case Network:
                annotationNodeCount = annotationProvider.getNetworkNodeCount();
                intIntFunction = i -> {
                    return annotationProvider.getNetworkNodeCountForAttribute(i);
                };
                break;
            case Annotation:
                annotationNodeCount = annotationProvider.getAnnotationNodeCount();
                intIntFunction = i2 -> {
                    return annotationProvider.getAnnotationNodeCountForAttribute(i2);
                };
                break;
            default:
                throw new RuntimeException("Unexpected background method");
        }
        Stream<? extends Neighborhood> stream = list.stream();
        if (progressReporter.supportsParallel()) {
            stream = (Stream) stream.parallel();
        }
        progressReporter.startNeighborhoodScore(networkProvider, annotationProvider);
        IntIntFunction intIntFunction2 = intIntFunction;
        int i3 = annotationNodeCount;
        stream.forEach(neighborhood -> {
            int nodeIndex = neighborhood.getNodeIndex();
            int memberCount = neighborhood.getMemberCount();
            for (int i4 = 0; i4 < annotationProvider.getAttributeCount(); i4++) {
                double cumulativeProbability = 1.0d - new HypergeometricDistribution(null, i3, intIntFunction2.apply(i4), memberCount).cumulativeProbability(neighborhood.getMemberCountForAttribute(i4, annotationProvider) - 1);
                if (Double.isFinite(cumulativeProbability) && cumulativeProbability < CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    cumulativeProbability = 0.0d;
                }
                neighborhood.setPValue(i4, cumulativeProbability);
                progressReporter.neighborhoodScore(nodeIndex, i4, Neighborhood.computeEnrichmentScore(cumulativeProbability));
            }
            progressReporter.finishNeighborhood(nodeIndex);
        });
        progressReporter.finishNeighborhoodScore();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void computeNeighborhoods(DefaultEnrichmentLandscape defaultEnrichmentLandscape, NetworkProvider networkProvider, AnnotationProvider annotationProvider) {
        defaultEnrichmentLandscape.neighborhoods.stream().forEach(defaultNeighborhood -> {
            defaultNeighborhood.applyDistanceThreshold(defaultEnrichmentLandscape.maximumDistanceThreshold);
        });
    }

    static double computeMaximumDistanceThreshold(List<? extends DefaultNeighborhood> list, double d) {
        return Util.percentile(list.stream().flatMapToDouble(defaultNeighborhood -> {
            return defaultNeighborhood.streamDistances();
        }).toArray(), d);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void computeDistances(NetworkProvider networkProvider, AnnotationProvider annotationProvider, DistanceMetric distanceMetric, boolean z, double d, DefaultEnrichmentLandscape defaultEnrichmentLandscape) {
        if (defaultEnrichmentLandscape.neighborhoods != null) {
            return;
        }
        int nodeCount = networkProvider.getNodeCount();
        int attributeCount = annotationProvider.getAttributeCount();
        defaultEnrichmentLandscape.neighborhoods = distanceMetric.computeDistances(networkProvider, annotationProvider.isBinary() ? i -> {
            return new SparseNeighborhood(i, attributeCount);
        } : i2 -> {
            return new DenseNeighborhood(i2, nodeCount, attributeCount);
        });
        if (z) {
            defaultEnrichmentLandscape.maximumDistanceThreshold = d;
        } else {
            defaultEnrichmentLandscape.maximumDistanceThreshold = computeMaximumDistanceThreshold(defaultEnrichmentLandscape.neighborhoods, d);
        }
    }
}
