package jsat.linear.distancemetrics;

import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.linear.MatrixStatistics;
import jsat.linear.Vec;
import jsat.linear.VecOps;
import jsat.math.FunctionBase;
import jsat.math.MathTricks;
import jsat.regression.RegressionDataSet;
import jsat.utils.DoubleList;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.ParallelUtils;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/linear/distancemetrics/NormalizedEuclideanDistance.class */
public class NormalizedEuclideanDistance extends TrainableDistanceMetric {
    private static final long serialVersionUID = 210109457671623688L;
    private Vec invStndDevs;

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public <V extends Vec> void train(List<V> list) {
        this.invStndDevs = MatrixStatistics.covarianceDiag(MatrixStatistics.meanVector(list), list);
        this.invStndDevs.applyFunction(MathTricks.sqrdFunc);
        this.invStndDevs.applyFunction(MathTricks.invsFunc);
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public <V extends Vec> void train(List<V> list, ExecutorService executorService) {
        train(list);
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public void train(DataSet dataSet) {
        this.invStndDevs = dataSet.getColumnMeanVariance()[1];
        this.invStndDevs.applyFunction(MathTricks.sqrdFunc);
        this.invStndDevs.applyFunction(MathTricks.invsFunc);
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public void train(DataSet dataSet, ExecutorService executorService) {
        train(dataSet);
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public void train(ClassificationDataSet classificationDataSet) {
        train((DataSet) classificationDataSet);
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public void train(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        train(classificationDataSet);
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public boolean supportsClassificationTraining() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public void train(RegressionDataSet regressionDataSet) {
        train((DataSet) regressionDataSet);
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public void train(RegressionDataSet regressionDataSet, ExecutorService executorService) {
        train(regressionDataSet);
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public boolean supportsRegressionTraining() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public boolean needsTraining() {
        return this.invStndDevs == null;
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    /* renamed from: clone */
    public NormalizedEuclideanDistance mo651clone() {
        NormalizedEuclideanDistance normalizedEuclideanDistance = new NormalizedEuclideanDistance();
        if (this.invStndDevs != null) {
            normalizedEuclideanDistance.invStndDevs = this.invStndDevs.mo524clone();
        }
        return normalizedEuclideanDistance;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public double dist(Vec vec, Vec vec2) {
        return Math.sqrt(VecOps.accumulateSum(this.invStndDevs, vec, vec2, new FunctionBase() { // from class: jsat.linear.distancemetrics.NormalizedEuclideanDistance.1
            private static final long serialVersionUID = 3190953661114076430L;

            @Override // jsat.math.Function
            public double f(Vec vec3) {
                return Math.pow(vec3.get(0), 2.0d);
            }
        }));
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public boolean isSymmetric() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public boolean isSubadditive() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public boolean isIndiscemible() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public double metricBound() {
        return Double.POSITIVE_INFINITY;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public boolean supportsAcceleration() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public List<Double> getAccelerationCache(List<? extends Vec> list) {
        DoubleList doubleList = new DoubleList(list.size());
        for (Vec vec : list) {
            doubleList.add(VecOps.weightedDot(this.invStndDevs, vec, vec));
        }
        return doubleList;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public List<Double> getAccelerationCache(final List<? extends Vec> list, ExecutorService executorService) {
        if (executorService == null || (executorService instanceof FakeExecutor)) {
            return getAccelerationCache(list);
        }
        final double[] dArr = new double[list.size()];
        int min = Math.min(SystemInfo.LogicalCores, list.size());
        final CountDownLatch countDownLatch = new CountDownLatch(min);
        for (int i = 0; i < min; i++) {
            final int startBlock = ParallelUtils.getStartBlock(dArr.length, i, min);
            final int endBlock = ParallelUtils.getEndBlock(dArr.length, i, min);
            executorService.submit(new Runnable() { // from class: jsat.linear.distancemetrics.NormalizedEuclideanDistance.2
                @Override // java.lang.Runnable
                public void run() {
                    for (int i2 = startBlock; i2 < endBlock; i2++) {
                        dArr[i2] = VecOps.weightedDot(NormalizedEuclideanDistance.this.invStndDevs, (Vec) list.get(i2), (Vec) list.get(i2));
                    }
                    countDownLatch.countDown();
                }
            });
        }
        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            Logger.getLogger(NormalizedEuclideanDistance.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
        return DoubleList.view(dArr, dArr.length);
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public double dist(int i, int i2, List<? extends Vec> list, List<Double> list2) {
        return list2 == null ? dist(list.get(i), list.get(i2)) : Math.sqrt((list2.get(i).doubleValue() + list2.get(i2).doubleValue()) - (2.0d * VecOps.weightedDot(this.invStndDevs, list.get(i), list.get(i2))));
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public double dist(int i, Vec vec, List<? extends Vec> list, List<Double> list2) {
        return list2 == null ? dist(list.get(i), vec) : Math.sqrt((list2.get(i).doubleValue() + VecOps.weightedDot(this.invStndDevs, vec, vec)) - (2.0d * VecOps.weightedDot(this.invStndDevs, list.get(i), vec)));
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public List<Double> getQueryInfo(Vec vec) {
        DoubleList doubleList = new DoubleList(1);
        doubleList.add(VecOps.weightedDot(this.invStndDevs, vec, vec));
        return doubleList;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public double dist(int i, Vec vec, List<Double> list, List<? extends Vec> list2, List<Double> list3) {
        return list3 == null ? dist(list2.get(i), vec) : Math.sqrt((list3.get(i).doubleValue() + list.get(0).doubleValue()) - (2.0d * VecOps.weightedDot(this.invStndDevs, list2.get(i), vec)));
    }
}
