package jsat.datatransform.visualization;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransform;
import jsat.distributions.Normal;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.vectorcollection.VPTree;
import jsat.linear.vectorcollection.VPTreeMV;
import jsat.math.FastMath;
import jsat.math.FunctionBase;
import jsat.math.optimization.stochastic.Adam;
import jsat.math.rootfinding.Zeroin;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.AtomicDouble;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.XORWOW;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/datatransform/visualization/TSNE.class */
public class TSNE implements VisualizationTransform {
    private double alpha = 4.0d;
    private double exageratedPortion = 0.25d;
    private DistanceMetric dm = new EuclideanDistance();
    private int T = 1000;
    private double perplexity = 30.0d;
    private double theta = 0.5d;
    private int s = 2;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/datatransform/visualization/TSNE$Quadtree.class */
    public class Quadtree {
        public Node root = new Node();

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:JSAT-0.0.7.jar:jsat/datatransform/visualization/TSNE$Quadtree$Node.class */
        public class Node implements Iterable<Node> {
            public int indx;
            public double x_mass;
            public double y_mass;
            public int N_cell;
            public double minX;
            public double maxX;
            public double minY;
            public double maxY;
            public Node NW;
            public Node NE;
            public Node SE;
            public Node SW;

            public Node() {
                this.indx = -1;
                this.N_cell = 0;
                this.y_mass = 0.0d;
                this.x_mass = 0.0d;
                this.SW = null;
                this.SE = null;
                this.NE = null;
                this.NW = null;
            }

            public Node(Quadtree quadtree, double d, double d2, double d3, double d4) {
                this();
                this.minX = d;
                this.maxX = d2;
                this.minY = d3;
                this.maxY = d4;
            }

            public boolean contains(int i, double[] dArr) {
                double d = dArr[i * 2];
                double d2 = dArr[(i * 2) + 1];
                return this.minX <= d && d < this.maxX && this.minY <= d2 && d2 < this.maxY;
            }

            public void insert(int i, int i2, double[] dArr) {
                this.x_mass += dArr[i2 * 2];
                this.y_mass += dArr[(i2 * 2) + 1];
                this.N_cell += i;
                if (this.NW == null && this.indx < 0) {
                    this.indx = i2;
                    return;
                }
                if (this.indx < 0 || Math.abs(dArr[this.indx * 2] - dArr[i2 * 2]) >= 1.0E-13d || Math.abs(dArr[(this.indx * 2) + 1] - dArr[(i2 * 2) + 1]) >= 1.0E-13d) {
                    if (this.NW == null) {
                        double d = (this.maxX - this.minX) / 2.0d;
                        double d2 = (this.maxY - this.minY) / 2.0d;
                        this.NW = new Node(Quadtree.this, this.minX, this.minX + d, this.minY + d2, this.maxY);
                        this.NE = new Node(Quadtree.this, this.minX + d, this.maxX, this.minY + d2, this.maxY);
                        this.SW = new Node(Quadtree.this, this.minX, this.minX + d, this.minY, this.minY + d2);
                        this.SE = new Node(Quadtree.this, this.minX + d, this.maxX, this.minY, this.minY + d2);
                        Iterator<Node> it = iterator();
                        while (true) {
                            if (!it.hasNext()) {
                                break;
                            }
                            Node next = it.next();
                            if (next.contains(this.indx, dArr)) {
                                next.insert(this.N_cell, this.indx, dArr);
                                break;
                            }
                        }
                        this.indx = -1;
                    }
                    Iterator<Node> it2 = iterator();
                    while (it2.hasNext()) {
                        Node next2 = it2.next();
                        if (next2.contains(i2, dArr)) {
                            next2.insert(i, i2, dArr);
                            return;
                        }
                    }
                }
            }

            public double diagLen() {
                double d = this.maxX - this.minX;
                double d2 = this.maxY - this.minY;
                return Math.sqrt((d * d) + (d2 * d2));
            }

            @Override // java.lang.Iterable
            public Iterator<Node> iterator() {
                return this.NW == null ? Collections.emptyIterator() : Arrays.asList(this.NW, this.NE, this.SW, this.SE).iterator();
            }
        }

        public Quadtree(double[] dArr) {
            Node node = this.root;
            this.root.minY = Double.POSITIVE_INFINITY;
            node.minX = Double.POSITIVE_INFINITY;
            Node node2 = this.root;
            this.root.maxY = Double.NEGATIVE_INFINITY;
            node2.maxX = Double.NEGATIVE_INFINITY;
            for (int i = 0; i < dArr.length / 2; i++) {
                double d = dArr[i * 2];
                double d2 = dArr[(i * 2) + 1];
                this.root.minX = Math.min(this.root.minX, d);
                this.root.maxX = Math.max(this.root.maxX, d);
                this.root.minY = Math.min(this.root.minY, d2);
                this.root.maxY = Math.max(this.root.maxY, d2);
            }
            this.root.maxX = Math.nextUp(this.root.maxX);
            this.root.maxY = Math.nextUp(this.root.maxY);
            for (int i2 = 0; i2 < dArr.length / 2; i2++) {
                this.root.insert(1, i2, dArr);
            }
        }
    }

    public void setAlpha(double d) {
        if (d <= 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("alpha must be positive, not " + d);
        }
        this.alpha = d;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setPerplexity(double d) {
        if (d <= 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("perplexity must be positive, not " + d);
        }
        this.perplexity = d;
    }

    public double getPerplexity() {
        return this.perplexity;
    }

    public void setIterations(int i) {
        if (i <= 1) {
            throw new IllegalArgumentException("number of iterations must be positive, not " + i);
        }
        this.T = i;
    }

    public int getIterations() {
        return this.T;
    }

    @Override // jsat.datatransform.visualization.VisualizationTransform
    public <Type extends DataSet> Type transform(DataSet<Type> dataSet) {
        return (Type) transform(dataSet, new FakeExecutor());
    }

    @Override // jsat.datatransform.visualization.VisualizationTransform
    public <Type extends DataSet> Type transform(DataSet<Type> dataSet, ExecutorService executorService) {
        XORWOW xorwow = new XORWOW();
        final int sampleSize = dataSet.getSampleSize();
        final int min = (int) Math.min(Math.floor(3.0d * this.perplexity), sampleSize - 1);
        final double[][] dArr = new double[sampleSize][min];
        final int[][] iArr = new int[sampleSize][min];
        computeP(dataSet, executorService, xorwow, min, iArr, dArr, this.dm, this.perplexity);
        final double[] sample = new Normal(0.0d, 1.0E-4d).sample(sampleSize * this.s, xorwow);
        final double[] dArr2 = new double[sample.length];
        DenseVector denseVec = DenseVector.toDenseVec(sample);
        DenseVector denseVec2 = DenseVector.toDenseVec(dArr2);
        Adam adam = new Adam();
        adam.setup(sample.length);
        for (int i = 0; i < this.T; i++) {
            final int i2 = i;
            Arrays.fill(dArr2, 0.0d);
            final Quadtree quadtree = new Quadtree(sample);
            final AtomicDouble atomicDouble = new AtomicDouble(0.0d);
            final CountDownLatch countDownLatch = new CountDownLatch(SystemInfo.LogicalCores);
            for (int i3 = 0; i3 < SystemInfo.LogicalCores; i3++) {
                final int i4 = i3;
                executorService.submit(new Runnable() { // from class: jsat.datatransform.visualization.TSNE.1
                    @Override // java.lang.Runnable
                    public void run() {
                        double[] dArr3 = new double[TSNE.this.s];
                        double d = 0.0d;
                        int i5 = i4;
                        while (true) {
                            int i6 = i5;
                            if (i6 >= sampleSize) {
                                atomicDouble.addAndGet(d);
                                countDownLatch.countDown();
                                return;
                            }
                            Arrays.fill(dArr3, 0.0d);
                            d += TSNE.this.computeF_rep(quadtree.root, i6, sample, dArr3);
                            for (int i7 = 0; i7 < TSNE.this.s; i7++) {
                                TSNE.inc_z_ij(dArr3[i7], i6, i7, dArr2, TSNE.this.s);
                            }
                            i5 = i6 + SystemInfo.LogicalCores;
                        }
                    }
                });
            }
            try {
                countDownLatch.await();
            } catch (InterruptedException e) {
                Logger.getLogger(TSNE.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
            }
            double d = 4.0d / (atomicDouble.get() + 1.0E-13d);
            for (int i5 = 0; i5 < sample.length; i5++) {
                int i6 = i5;
                dArr2[i6] = dArr2[i6] * d;
            }
            final CountDownLatch countDownLatch2 = new CountDownLatch(SystemInfo.LogicalCores);
            for (int i7 = 0; i7 < SystemInfo.LogicalCores; i7++) {
                final int i8 = i7;
                executorService.submit(new Runnable() { // from class: jsat.datatransform.visualization.TSNE.2
                    @Override // java.lang.Runnable
                    public void run() {
                        int startBlock = ParallelUtils.getStartBlock(sampleSize, i8, SystemInfo.LogicalCores);
                        int endBlock = ParallelUtils.getEndBlock(sampleSize, i8, SystemInfo.LogicalCores);
                        for (int i9 = startBlock; i9 < endBlock; i9++) {
                            for (int i10 = 0; i10 < min; i10++) {
                                int i11 = iArr[i9][i10];
                                if (i9 != i11) {
                                    double d2 = dArr[i9][i10];
                                    if (i2 < TSNE.this.T * TSNE.this.exageratedPortion) {
                                        d2 *= TSNE.this.alpha;
                                    }
                                    double q_ijZ = d2 * TSNE.q_ijZ(i9, i11, sample, TSNE.this.s) * 4.0d;
                                    for (int i12 = 0; i12 < TSNE.this.s; i12++) {
                                        TSNE.inc_z_ij(q_ijZ * (TSNE.z_ij(i9, i12, sample, TSNE.this.s) - TSNE.z_ij(i11, i12, sample, TSNE.this.s)), i9, i12, dArr2, TSNE.this.s);
                                    }
                                }
                            }
                        }
                        countDownLatch2.countDown();
                    }
                });
            }
            try {
                countDownLatch2.await();
            } catch (InterruptedException e2) {
                Logger.getLogger(TSNE.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e2);
            }
            adam.update(denseVec, denseVec2, 200.0d);
        }
        DataSet<Type> shallowClone2 = dataSet.shallowClone2();
        final IdentityHashMap identityHashMap = new IdentityHashMap(sampleSize);
        for (int i9 = 0; i9 < sampleSize; i9++) {
            identityHashMap.put(dataSet.getDataPoint(i9), Integer.valueOf(i9));
        }
        shallowClone2.applyTransform(new DataTransform() { // from class: jsat.datatransform.visualization.TSNE.3
            @Override // jsat.datatransform.DataTransform
            public DataPoint transform(DataPoint dataPoint) {
                int intValue = ((Integer) identityHashMap.get(dataPoint)).intValue();
                DenseVector denseVector = new DenseVector(TSNE.this.s);
                for (int i10 = 0; i10 < TSNE.this.s; i10++) {
                    denseVector.set(i10, sample[(intValue * 2) + i10]);
                }
                return new DataPoint(denseVector, dataPoint.getCategoricalValues(), dataPoint.getCategoricalData(), dataPoint.getWeight());
            }

            @Override // jsat.datatransform.DataTransform
            public void fit(DataSet dataSet2) {
            }

            @Override // jsat.datatransform.DataTransform
            public DataTransform clone() {
                return this;
            }
        });
        return shallowClone2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static void computeP(DataSet dataSet, ExecutorService executorService, Random random, final int i, final int[][] iArr, final double[][] dArr, final DistanceMetric distanceMetric, final double d) {
        final List<Vec> dataVectors = dataSet.getDataVectors();
        final List<Double> accelerationCache = distanceMetric.getAccelerationCache(dataVectors, executorService);
        final int size = dataVectors.size();
        final VPTreeMV vPTreeMV = new VPTreeMV(dataVectors, distanceMetric, VPTree.VPSelection.Random, random, 2, 1, executorService);
        final ArrayList arrayList = new ArrayList(size);
        for (int i2 = 0; i2 < size; i2++) {
            arrayList.add(null);
        }
        final IdentityHashMap identityHashMap = new IdentityHashMap(size);
        for (int i3 = 0; i3 < size; i3++) {
            identityHashMap.put(dataVectors.get(i3), Integer.valueOf(i3));
        }
        final CountDownLatch countDownLatch = new CountDownLatch(SystemInfo.LogicalCores);
        for (int i4 = 0; i4 < SystemInfo.LogicalCores; i4++) {
            final int i5 = i4;
            executorService.submit(new Runnable() { // from class: jsat.datatransform.visualization.TSNE.4
                @Override // java.lang.Runnable
                public void run() {
                    int i6 = i5;
                    while (true) {
                        int i7 = i6;
                        if (i7 >= size) {
                            countDownLatch.countDown();
                            return;
                        }
                        List<? extends VecPaired<V, Double>> search = vPTreeMV.search((Vec) dataVectors.get(i7), i + 1);
                        arrayList.set(i7, search);
                        for (int i8 = 1; i8 < search.size(); i8++) {
                            iArr[i7][i8 - 1] = ((Integer) identityHashMap.get(((VecPaired) search.get(i8)).getVector())).intValue();
                        }
                        i6 = i7 + SystemInfo.LogicalCores;
                    }
                }
            });
        }
        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            Logger.getLogger(TSNE.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
        final double[] dArr2 = new double[size];
        final AtomicDouble atomicDouble = new AtomicDouble(Double.POSITIVE_INFINITY);
        final AtomicDouble atomicDouble2 = new AtomicDouble(0.0d);
        for (int i6 = 0; i6 < size; i6++) {
            List list = (List) arrayList.get(i6);
            double doubleValue = ((Double) ((VecPaired) list.get(1)).getPair()).doubleValue();
            double doubleValue2 = ((Double) ((VecPaired) list.get(Math.min(i, list.size() - 1))).getPair()).doubleValue();
            atomicDouble.set(Math.min(atomicDouble.get(), Math.max(doubleValue, 1.0E-9d)));
            atomicDouble2.set(Math.max(atomicDouble2.get(), doubleValue2));
        }
        final CountDownLatch countDownLatch2 = new CountDownLatch(SystemInfo.LogicalCores);
        for (int i7 = 0; i7 < SystemInfo.LogicalCores; i7++) {
            final int i8 = i7;
            executorService.submit(new Runnable() { // from class: jsat.datatransform.visualization.TSNE.5
                @Override // java.lang.Runnable
                public void run() {
                    boolean z;
                    int i9 = i8;
                    while (true) {
                        final int i10 = i9;
                        if (i10 >= size) {
                            countDownLatch2.countDown();
                            return;
                        }
                        do {
                            z = false;
                            try {
                                dArr2[i10] = Zeroin.root(0.01d, 100, atomicDouble.get(), atomicDouble2.get(), 0, new FunctionBase() { // from class: jsat.datatransform.visualization.TSNE.5.1
                                    @Override // jsat.math.Function
                                    public double f(Vec vec) {
                                        return TSNE.perp(i10, iArr, vec.get(0), arrayList, dataVectors, accelerationCache, distanceMetric) - d;
                                    }
                                }, new double[0]);
                            } catch (ArithmeticException e2) {
                                if (atomicDouble2.get() >= 8.988465674311579E307d) {
                                    dArr2[i10] = 1.0E100d;
                                } else {
                                    z = true;
                                    atomicDouble.set(Math.max(atomicDouble.get() / 2.0d, 1.0E-6d));
                                    atomicDouble2.set(Math.min(atomicDouble2.get() * 2.0d, 8.988465674311579E307d));
                                }
                            }
                        } while (z);
                        i9 = i10 + SystemInfo.LogicalCores;
                    }
                }
            });
        }
        try {
            countDownLatch2.await();
        } catch (InterruptedException e2) {
            Logger.getLogger(TSNE.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e2);
        }
        final CountDownLatch countDownLatch3 = new CountDownLatch(SystemInfo.LogicalCores);
        for (int i9 = 0; i9 < SystemInfo.LogicalCores; i9++) {
            final int i10 = i9;
            executorService.submit(new Runnable() { // from class: jsat.datatransform.visualization.TSNE.6
                @Override // java.lang.Runnable
                public void run() {
                    int i11 = i10;
                    while (true) {
                        int i12 = i11;
                        if (i12 >= size) {
                            countDownLatch3.countDown();
                            return;
                        }
                        for (int i13 = 0; i13 < i; i13++) {
                            int i14 = iArr[i12][i13];
                            dArr[i12][i13] = TSNE.p_ij(i12, i14, dArr2[i12], dArr2[i14], arrayList, dataVectors, accelerationCache, distanceMetric);
                        }
                        i11 = i12 + SystemInfo.LogicalCores;
                    }
                }
            });
        }
        try {
            countDownLatch3.await();
        } catch (InterruptedException e3) {
            Logger.getLogger(TSNE.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e3);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double computeF_rep(Quadtree.Node node, int i, double[] dArr, double[] dArr2) {
        if (node == null || node.N_cell == 0 || node.indx == i) {
            return 0.0d;
        }
        double d = dArr[i * 2];
        double d2 = dArr[(i * 2) + 1];
        double max = Math.max(node.maxX - node.minX, node.maxY - node.minY);
        double d3 = max * max;
        double d4 = node.x_mass / node.N_cell;
        double d5 = node.y_mass / node.N_cell;
        double d6 = ((d4 - d) * (d4 - d)) + ((d5 - d2) * (d5 - d2));
        if (node.NW != null && d3 >= this.theta * d6) {
            double d7 = 0.0d;
            Iterator<Quadtree.Node> it = node.iterator();
            while (it.hasNext()) {
                d7 += computeF_rep(it.next(), i, dArr, dArr2);
            }
            return d7;
        }
        if (node.indx == i) {
            return 0.0d;
        }
        double d8 = 1.0d / (1.0d + d6);
        double d9 = (-node.N_cell) * d8 * d8;
        dArr2[0] = dArr2[0] + (d9 * (d - d4));
        dArr2[1] = dArr2[1] + (d9 * (d2 - d5));
        return d8 * node.N_cell;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void inc_z_ij(double d, int i, int i2, double[] dArr, int i3) {
        int i4 = (i * i3) + i2;
        dArr[i4] = dArr[i4] + d;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double z_ij(int i, int i2, double[] dArr, int i3) {
        return dArr[(i * i3) + i2];
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double q_ijZ(int i, int i2, double[] dArr, int i3) {
        double d = 1.0d;
        for (int i4 = 0; i4 < i3; i4++) {
            double z_ij = z_ij(i, i4, dArr, i3) - z_ij(i2, i4, dArr, i3);
            d += z_ij * z_ij;
        }
        return 1.0d / d;
    }

    private static double p_j_i(int i, int i2, double d, List<List<? extends VecPaired<Vec, Double>>> list, List<Vec> list2, List<Double> list3, DistanceMetric distanceMetric) {
        if (i2 == i) {
            return 0.0d;
        }
        Vec vector = list.get(i).get(0).getVector();
        double d2 = 1.0d / (2.0d * (d * d));
        double d3 = 0.0d;
        double d4 = 0.0d;
        boolean z = false;
        List<? extends VecPaired<Vec, Double>> list4 = list.get(i2);
        for (int i3 = 1; i3 < list4.size(); i3++) {
            VecPaired<Vec, Double> vecPaired = list4.get(i3);
            double doubleValue = vecPaired.getPair().doubleValue();
            d4 += FastMath.exp((-(doubleValue * doubleValue)) * d2);
            if (vecPaired.getVector() == vector) {
                z = true;
                d3 = FastMath.exp((-(doubleValue * doubleValue)) * d2);
            }
        }
        if (!z) {
            double dist = distanceMetric.dist(i2, i, list2, list3);
            d3 = FastMath.exp((-(dist * dist)) * d2);
        }
        return d3 / (d4 + 1.0E-9d);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double p_ij(int i, int i2, double d, double d2, List<List<? extends VecPaired<Vec, Double>>> list, List<Vec> list2, List<Double> list3, DistanceMetric distanceMetric) {
        return (p_j_i(i2, i, d, list, list2, list3, distanceMetric) + p_j_i(i, i2, d2, list, list2, list3, distanceMetric)) / (2 * list.size());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double perp(int i, int[][] iArr, double d, List<List<? extends VecPaired<Vec, Double>>> list, List<Vec> list2, List<Double> list3, DistanceMetric distanceMetric) {
        double d2 = 0.0d;
        for (int i2 = 0; i2 < iArr[i].length; i2++) {
            double p_j_i = p_j_i(iArr[i][i2], i, d, list, list2, list3, distanceMetric);
            if (p_j_i > 0.0d) {
                d2 += p_j_i * FastMath.log2(p_j_i);
            }
        }
        return FastMath.pow2(d2 * (-1.0d));
    }

    @Override // jsat.datatransform.visualization.VisualizationTransform
    public int getTargetDimension() {
        return 2;
    }

    @Override // jsat.datatransform.visualization.VisualizationTransform
    public boolean setTargetDimension(int i) {
        return i == 2;
    }
}
