package juicebox.tools.utils.norm;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.stream.DoubleStream;
import juicebox.data.ContactRecord;
import juicebox.data.basics.ListOfDoubleArrays;
import juicebox.data.basics.ListOfFloatArrays;
import juicebox.data.basics.ListOfIntArrays;
import juicebox.data.iterator.IteratorContainer;
import juicebox.windowui.NormalizationType;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.jfree.chart.axis.Axis;

/* loaded from: input_file:juicebox/tools/utils/norm/NormalizationCalculations.class */
public class NormalizationCalculations {
    private final long matrixSize;
    private boolean isEnoughMemory;
    private final IteratorContainer ic;

    public NormalizationCalculations(IteratorContainer iteratorContainer) {
        this.isEnoughMemory = false;
        this.ic = iteratorContainer;
        this.matrixSize = iteratorContainer.getMatrixSize();
        this.isEnoughMemory = iteratorContainer.getIsThereEnoughMemoryForNormCalculation();
    }

    private static ListOfDoubleArrays sparseMultiplyFromContactRecords(ListOfIntArrays listOfIntArrays, Iterator<ContactRecord> it, ListOfDoubleArrays listOfDoubleArrays) {
        ListOfDoubleArrays listOfDoubleArrays2 = new ListOfDoubleArrays(listOfDoubleArrays.getLength());
        while (it.hasNext()) {
            ContactRecord next = it.next();
            int binX = next.getBinX();
            int binY = next.getBinY();
            float counts = next.getCounts();
            int i = listOfIntArrays.get(binX);
            int i2 = listOfIntArrays.get(binY);
            if (i != -1 && i2 != -1) {
                listOfDoubleArrays2.addTo(i, listOfDoubleArrays.get(i2) * counts);
                if (i != i2) {
                    listOfDoubleArrays2.addTo(i2, listOfDoubleArrays.get(i) * counts);
                }
            }
        }
        return listOfDoubleArrays2;
    }

    private ListOfDoubleArrays computeKRNormVector(ListOfIntArrays listOfIntArrays, double d, ListOfDoubleArrays listOfDoubleArrays, double d2) {
        ListOfDoubleArrays listOfDoubleArrays2 = new ListOfDoubleArrays(listOfDoubleArrays.getLength(), 1.0d);
        double d3 = 0.1d;
        double pow = Math.pow(d, 2.0d);
        ListOfDoubleArrays sparseMultiplyFromContactRecords = sparseMultiplyFromContactRecords(listOfIntArrays, getIterator(), listOfDoubleArrays);
        ListOfDoubleArrays listOfDoubleArrays3 = new ListOfDoubleArrays(sparseMultiplyFromContactRecords.getLength());
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= sparseMultiplyFromContactRecords.getLength()) {
                break;
            }
            sparseMultiplyFromContactRecords.multiplyBy(j2, listOfDoubleArrays.get(j2));
            listOfDoubleArrays3.set(j2, 1.0d - sparseMultiplyFromContactRecords.get(j2));
            j = j2 + 1;
        }
        double d4 = 0.0d;
        for (double[] dArr : listOfDoubleArrays3.getValues()) {
            for (double d5 : dArr) {
                d4 += d5 * d5;
            }
        }
        double d6 = d4;
        double d7 = d6;
        int i = 0;
        int i2 = 0;
        while (d6 > pow && i2 < 100) {
            int i3 = 0;
            ListOfDoubleArrays deepClone = listOfDoubleArrays2.deepClone();
            ListOfDoubleArrays listOfDoubleArrays4 = new ListOfDoubleArrays(listOfDoubleArrays2.getLength());
            ListOfDoubleArrays listOfDoubleArrays5 = new ListOfDoubleArrays(listOfDoubleArrays2.getLength());
            ListOfDoubleArrays listOfDoubleArrays6 = new ListOfDoubleArrays(listOfDoubleArrays2.getLength());
            ListOfDoubleArrays listOfDoubleArrays7 = new ListOfDoubleArrays(listOfDoubleArrays2.getLength());
            double d8 = d4;
            double max = Math.max(Math.pow(d3, 2.0d) * d6, pow);
            while (true) {
                if (d4 <= max) {
                    break;
                }
                i3++;
                if (i3 != 1) {
                    listOfDoubleArrays6.multiplyEverythingBy(d4 / d8);
                    long j3 = 0;
                    while (true) {
                        long j4 = j3;
                        if (j4 >= listOfDoubleArrays6.getLength()) {
                            break;
                        }
                        listOfDoubleArrays6.addTo(j4, listOfDoubleArrays5.get(j4));
                        j3 = j4 + 1;
                    }
                } else {
                    d4 = 0.0d;
                    long j5 = 0;
                    while (true) {
                        long j6 = j5;
                        if (j6 >= listOfDoubleArrays5.getLength()) {
                            break;
                        }
                        double d9 = listOfDoubleArrays3.get(j6);
                        double d10 = d9 / sparseMultiplyFromContactRecords.get(j6);
                        listOfDoubleArrays5.set(j6, d10);
                        d4 += d9 * d10;
                        j5 = j6 + 1;
                    }
                    listOfDoubleArrays6 = listOfDoubleArrays5.deepClone();
                }
                ListOfDoubleArrays listOfDoubleArrays8 = new ListOfDoubleArrays(listOfDoubleArrays2.getLength());
                long j7 = 0;
                while (true) {
                    long j8 = j7;
                    if (j8 >= listOfDoubleArrays8.getLength()) {
                        break;
                    }
                    listOfDoubleArrays8.set(j8, listOfDoubleArrays.get(j8) * listOfDoubleArrays6.get(j8));
                    j7 = j8 + 1;
                }
                ListOfDoubleArrays sparseMultiplyFromContactRecords2 = sparseMultiplyFromContactRecords(listOfIntArrays, getIterator(), listOfDoubleArrays8);
                double d11 = 0.0d;
                long j9 = 0;
                while (true) {
                    long j10 = j9;
                    if (j10 >= sparseMultiplyFromContactRecords2.getLength()) {
                        break;
                    }
                    double d12 = listOfDoubleArrays6.get(j10);
                    double d13 = (listOfDoubleArrays.get(j10) * sparseMultiplyFromContactRecords2.get(j10)) + (sparseMultiplyFromContactRecords.get(j10) * d12);
                    listOfDoubleArrays7.set(j10, d13);
                    d11 += d12 * d13;
                    j9 = j10 + 1;
                }
                double d14 = d4 / d11;
                double d15 = Double.MAX_VALUE;
                long j11 = 0;
                while (true) {
                    long j12 = j11;
                    if (j12 >= listOfDoubleArrays6.getLength()) {
                        break;
                    }
                    double d16 = deepClone.get(j12) + (d14 * listOfDoubleArrays6.get(j12));
                    listOfDoubleArrays4.set(j12, d16);
                    if (d16 < d15) {
                        d15 = d16;
                    }
                    j11 = j12 + 1;
                }
                if (d15 > d2) {
                    d8 = d4;
                    d4 = 0.0d;
                    deepClone = listOfDoubleArrays4.deepClone();
                    long j13 = 0;
                    while (true) {
                        long j14 = j13;
                        if (j14 < deepClone.getLength()) {
                            listOfDoubleArrays3.addTo(j14, (-d14) * listOfDoubleArrays7.get(j14));
                            double d17 = listOfDoubleArrays3.get(j14);
                            listOfDoubleArrays5.set(j14, d17 / sparseMultiplyFromContactRecords.get(j14));
                            d4 += d17 * listOfDoubleArrays5.get(j14);
                            j13 = j14 + 1;
                        }
                    }
                } else if (d2 != 0.0d) {
                    double d18 = Double.MAX_VALUE;
                    long j15 = 0;
                    while (true) {
                        long j16 = j15;
                        if (j16 >= listOfDoubleArrays4.getLength()) {
                            break;
                        }
                        double d19 = listOfDoubleArrays6.get(j16);
                        if (d14 * d19 < 0.0d) {
                            double d20 = deepClone.get(j16);
                            if ((d2 - d20) / (d14 * d19) < d18) {
                                d18 = (d2 - d20) / (d14 * d19);
                            }
                        }
                        j15 = j16 + 1;
                    }
                    long j17 = 0;
                    while (true) {
                        long j18 = j17;
                        if (j18 >= deepClone.getLength()) {
                            break;
                        }
                        deepClone.addTo(j18, d18 * d14 * listOfDoubleArrays6.get(j18));
                        j17 = j18 + 1;
                    }
                }
            }
            long j19 = 0;
            while (true) {
                long j20 = j19;
                if (j20 >= listOfDoubleArrays.getLength()) {
                    break;
                }
                listOfDoubleArrays.multiplyBy(j20, deepClone.get(j20));
                j19 = j20 + 1;
            }
            sparseMultiplyFromContactRecords = sparseMultiplyFromContactRecords(listOfIntArrays, getIterator(), listOfDoubleArrays);
            d4 = 0.0d;
            long j21 = 0;
            while (true) {
                long j22 = j21;
                if (j22 >= sparseMultiplyFromContactRecords.getLength()) {
                    break;
                }
                sparseMultiplyFromContactRecords.multiplyBy(j22, listOfDoubleArrays.get(j22));
                double d21 = 1.0d - sparseMultiplyFromContactRecords.get(j22);
                listOfDoubleArrays3.set(j22, d21);
                d4 += d21 * d21;
                j21 = j22 + 1;
            }
            if (Math.abs(d4 - d6) < 1.0E-6d || Double.isInfinite(d4)) {
                i2++;
            }
            d6 = d4;
            i = i + i3 + 1;
            double d22 = d6 / d7;
            d7 = d6;
            double sqrt = Math.sqrt(d6);
            double d23 = d3;
            double d24 = 0.9d * d22;
            if (0.9d * Math.pow(d23, 2.0d) > 0.1d) {
                d24 = Math.max(d24, 0.9d * Math.pow(d23, 2.0d));
            }
            d3 = Math.max(Math.min(d24, 0.1d), (0.5d * d) / sqrt);
        }
        if (i2 >= 100) {
            return null;
        }
        return listOfDoubleArrays;
    }

    private Iterator<ContactRecord> getIterator() {
        return this.ic.getNewContactRecordIterator();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public boolean isEnoughMemory() {
        return this.isEnoughMemory;
    }

    public ListOfFloatArrays getNorm(NormalizationType normalizationType) {
        ListOfFloatArrays computeMMBA;
        if (normalizationType.usesKR()) {
            computeMMBA = computeKR();
        } else if (normalizationType.usesVC()) {
            computeMMBA = computeVC();
        } else {
            if (!normalizationType.usesSCALE()) {
                if (normalizationType.isNONE()) {
                    return new ListOfFloatArrays(this.matrixSize, 1.0f);
                }
                System.err.println("Not supported for normalization " + normalizationType);
                return null;
            }
            computeMMBA = computeMMBA();
        }
        if (computeMMBA != null && computeMMBA.getLength() > 0) {
            computeMMBA.multiplyEverythingBy(getSumFactor(computeMMBA));
        }
        return computeMMBA;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ListOfFloatArrays computeVC() {
        ListOfFloatArrays listOfFloatArrays = new ListOfFloatArrays(this.matrixSize, Axis.DEFAULT_TICK_MARK_INSIDE_LENGTH);
        Iterator<ContactRecord> iterator = getIterator();
        while (iterator.hasNext()) {
            ContactRecord next = iterator.next();
            int binX = next.getBinX();
            int binY = next.getBinY();
            float counts = next.getCounts();
            listOfFloatArrays.addTo(binX, counts);
            if (binX != binY) {
                listOfFloatArrays.addTo(binY, counts);
            }
        }
        return listOfFloatArrays;
    }

    public double getSumFactor(ListOfFloatArrays listOfFloatArrays) {
        double[] normMatrixSumFactor = getNormMatrixSumFactor(listOfFloatArrays);
        return Math.sqrt(normMatrixSumFactor[0] / normMatrixSumFactor[1]);
    }

    public double[] getNormMatrixSumFactor(ListOfFloatArrays listOfFloatArrays) {
        double d = 0.0d;
        double d2 = 0.0d;
        Iterator<ContactRecord> iterator = getIterator();
        while (iterator.hasNext()) {
            ContactRecord next = iterator.next();
            int binX = next.getBinX();
            int binY = next.getBinY();
            float counts = next.getCounts();
            double d3 = listOfFloatArrays.get(binX);
            double d4 = listOfFloatArrays.get(binY);
            if (!Double.isNaN(d3) && !Double.isNaN(d4) && d3 > 0.0d && d4 > 0.0d) {
                if (binX == binY) {
                    d2 += counts / (d3 * d4);
                    d += counts;
                } else {
                    d2 += (2.0f * counts) / (d3 * d4);
                    d += 2.0f * counts;
                }
            }
        }
        return new double[]{d2, d};
    }

    public int getNumberOfValidEntriesInVector(double[] dArr) {
        int i = 0;
        for (double d : dArr) {
            if (!Double.isNaN(d) && d > 0.0d) {
                i++;
            }
        }
        return i;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ListOfFloatArrays computeKR() {
        boolean z = true;
        ListOfIntArrays offset = getOffset(0.0d);
        ListOfFloatArrays listOfFloatArrays = null;
        int i = 1;
        while (z && i <= 6) {
            long j = 0;
            for (int[] iArr : offset.getValues()) {
                for (int i2 : iArr) {
                    if (i2 != -1) {
                        j++;
                    }
                }
            }
            ListOfDoubleArrays computeKRNormVector = computeKRNormVector(offset, 1.0E-6d, new ListOfDoubleArrays(j, 1.0d), 0.1d);
            z = false;
            int i3 = 0;
            if (computeKRNormVector == null || i == 5) {
                z = true;
                offset = i < 5 ? getOffset(i) : getOffset(10.0d);
            } else {
                listOfFloatArrays = new ListOfFloatArrays(this.matrixSize);
                int i4 = 0;
                for (int[] iArr2 : offset.getValues()) {
                    for (int i5 : iArr2) {
                        if (i5 == -1) {
                            int i6 = i4;
                            i4++;
                            listOfFloatArrays.set(i6, Float.NaN);
                        } else {
                            int i7 = i4;
                            i4++;
                            listOfFloatArrays.set(i7, (float) (1.0d / computeKRNormVector.get(i5)));
                        }
                    }
                }
                double sumFactor = getSumFactor(listOfFloatArrays);
                int i8 = 0;
                long j2 = 0;
                while (true) {
                    long j3 = j2;
                    if (j3 < listOfFloatArrays.getLength()) {
                        if (listOfFloatArrays.get(j3) * sumFactor < 0.01d) {
                            offset.set(j3, -1);
                            i3++;
                            z = true;
                        } else if (offset.get(j3) != -1) {
                            int i9 = i8;
                            i8++;
                            offset.set(j3, i9);
                        }
                        j2 = j3 + 1;
                    }
                }
            }
            i++;
            System.gc();
        }
        if (i > 6 && z) {
            listOfFloatArrays = new ListOfFloatArrays(this.matrixSize, Float.NaN);
        }
        return listOfFloatArrays;
    }

    private ListOfIntArrays getOffset(double d) {
        ListOfDoubleArrays listOfDoubleArrays = new ListOfDoubleArrays(this.matrixSize, 0.0d);
        Iterator<ContactRecord> iterator = getIterator();
        while (iterator.hasNext()) {
            ContactRecord next = iterator.next();
            int binX = next.getBinX();
            int binY = next.getBinY();
            float counts = next.getCounts();
            listOfDoubleArrays.addTo(binX, counts);
            if (binX != binY) {
                listOfDoubleArrays.addTo(binY, counts);
            }
        }
        double d2 = 0.0d;
        if (d > 0.0d) {
            DescriptiveStatistics descriptiveStatistics = new DescriptiveStatistics();
            listOfDoubleArrays.getValues().forEach(dArr -> {
                DoubleStream filter = Arrays.stream(dArr).filter(d3 -> {
                    return d3 != 0.0d;
                });
                descriptiveStatistics.getClass();
                filter.forEach(descriptiveStatistics::addValue);
            });
            d2 = descriptiveStatistics.getPercentile(d);
        }
        ListOfIntArrays listOfIntArrays = new ListOfIntArrays(listOfDoubleArrays.getLength());
        int i = 0;
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= listOfDoubleArrays.getLength()) {
                return listOfIntArrays;
            }
            if (listOfDoubleArrays.get(j2) <= d2) {
                listOfIntArrays.set(j2, -1);
            } else {
                int i2 = i;
                i++;
                listOfIntArrays.set(j2, i2);
            }
            j = j2 + 1;
        }
    }

    public ListOfFloatArrays computeMMBA() {
        return ZeroScale.mmbaScaleToVector(this.ic, new ListOfFloatArrays(this.matrixSize, 1.0f));
    }

    public int getRandomWithExclusion(Random random, int i, List<Integer> list) {
        int i2 = 0;
        try {
            i2 = random.nextInt(i - list.size());
        } catch (Exception e) {
            System.err.println(i + " " + list.size());
            e.printStackTrace();
        }
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            if (i2 < it.next().intValue()) {
                break;
            }
            i2++;
        }
        return i2;
    }
}
