package org.ujmp.core.doublematrix.calculation.general.statistical;

import java.util.Arrays;
import java.util.HashMap;
import org.ujmp.core.Matrix;
import org.ujmp.core.benchmark.BenchmarkConfig;
import org.ujmp.core.doublematrix.DenseDoubleMatrix2D;
import org.ujmp.core.doublematrix.DoubleMatrix2D;
import org.ujmp.core.doublematrix.calculation.AbstractDoubleCalculation;
import org.ujmp.core.enums.ValueType;
import org.ujmp.core.intmatrix.IntMatrix2D;
import org.ujmp.core.intmatrix.impl.DefaultDenseIntMatrix2D;
import org.ujmp.core.util.MathUtil;

/* loaded from: input_file:WEB-INF/lib/ujmp-core-0.3.0.jar:org/ujmp/core/doublematrix/calculation/general/statistical/MutualInformation.class */
public class MutualInformation extends AbstractDoubleCalculation {
    private static final long serialVersionUID = -4891250637894943873L;

    public MutualInformation(Matrix matrix) {
        super(matrix);
    }

    @Override // org.ujmp.core.doublematrix.calculation.DoubleCalculation
    public double getDouble(long... jArr) {
        return calculate(jArr[0], jArr[1], getSource());
    }

    @Override // org.ujmp.core.calculation.AbstractCalculation, org.ujmp.core.calculation.Calculation
    public long[] getSize() {
        return new long[]{getSource().getColumnCount(), getSource().getColumnCount()};
    }

    public static final double calculate(long j, long j2, Matrix matrix) {
        double rowCount = matrix.getRowCount();
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        HashMap hashMap3 = new HashMap();
        for (int i = 0; i < matrix.getRowCount(); i++) {
            double asDouble = matrix.getAsDouble(i, j);
            double asDouble2 = matrix.getAsDouble(i, j2);
            Double d = (Double) hashMap.get(Double.valueOf(asDouble));
            hashMap.put(Double.valueOf(asDouble), Double.valueOf(Double.valueOf(d == null ? BenchmarkConfig.NOTAVAILABLE : d.doubleValue()).doubleValue() + 1.0d));
            Double d2 = (Double) hashMap2.get(Double.valueOf(asDouble2));
            hashMap2.put(Double.valueOf(asDouble2), Double.valueOf(Double.valueOf(d2 == null ? BenchmarkConfig.NOTAVAILABLE : d2.doubleValue()).doubleValue() + 1.0d));
            Double d3 = (Double) hashMap3.get(asDouble + "," + asDouble2);
            hashMap3.put(asDouble + "," + asDouble2, Double.valueOf(Double.valueOf(d3 == null ? BenchmarkConfig.NOTAVAILABLE : d3.doubleValue()).doubleValue() + 1.0d));
        }
        for (Double d4 : hashMap.keySet()) {
            hashMap.put(d4, Double.valueOf(((Double) hashMap.get(d4)).doubleValue() / rowCount));
        }
        for (Double d5 : hashMap2.keySet()) {
            hashMap2.put(d5, Double.valueOf(((Double) hashMap2.get(d5)).doubleValue() / rowCount));
        }
        for (String str : hashMap3.keySet()) {
            hashMap3.put(str, Double.valueOf(((Double) hashMap3.get(str)).doubleValue() / rowCount));
        }
        double d6 = 0.0d;
        for (Double d7 : hashMap.keySet()) {
            double doubleValue = ((Double) hashMap.get(d7)).doubleValue();
            for (Double d8 : hashMap2.keySet()) {
                double doubleValue2 = ((Double) hashMap2.get(d8)).doubleValue();
                Double d9 = (Double) hashMap3.get(d7 + "," + d8);
                if (d9 != null) {
                    d6 += d9.doubleValue() * MathUtil.log2(d9.doubleValue() / (doubleValue * doubleValue2));
                }
            }
        }
        return d6;
    }

    public static DoubleMatrix2D calcNew(Matrix matrix) {
        return calcNew(matrix.convert(ValueType.INT));
    }

    public static DoubleMatrix2D calcNew(IntMatrix2D intMatrix2D) {
        DefaultDenseIntMatrix2D defaultDenseIntMatrix2D = (DefaultDenseIntMatrix2D) intMatrix2D;
        long columnCount = intMatrix2D.getColumnCount();
        int rowCount = (int) intMatrix2D.getRowCount();
        DoubleMatrix2D doubleMatrix2D = (DoubleMatrix2D) DenseDoubleMatrix2D.Factory.zeros(columnCount, columnCount);
        int[] iArr = new int[(int) columnCount];
        Arrays.fill(iArr, ((int) intMatrix2D.getMaxValue()) + 1);
        for (int i = 0; i < columnCount; i++) {
            for (int i2 = 0; i2 <= i; i2++) {
                double d = 0.0d;
                double[][] dArr = new double[iArr[i]][iArr[i2]];
                double[] dArr2 = new double[iArr[i]];
                double[] dArr3 = new double[iArr[i2]];
                for (int rowCount2 = ((int) intMatrix2D.getRowCount()) - 1; rowCount2 >= 0; rowCount2--) {
                    int i3 = defaultDenseIntMatrix2D.getInt(rowCount2, i);
                    int i4 = defaultDenseIntMatrix2D.getInt(rowCount2, i2);
                    dArr2[i3] = dArr2[i3] + 1.0d;
                    dArr3[i4] = dArr3[i4] + 1.0d;
                    double[] dArr4 = dArr[i3];
                    dArr4[i4] = dArr4[i4] + 1.0d;
                }
                double[] dArr5 = new double[iArr[i]];
                double[] dArr6 = new double[iArr[i2]];
                double log = Math.log(2.0d);
                for (int i5 = iArr[i2] - 1; i5 >= 0; i5--) {
                    int i6 = i5;
                    dArr3[i6] = dArr3[i6] / rowCount;
                    if (dArr3[i5] != BenchmarkConfig.NOTAVAILABLE) {
                        dArr6[i5] = Math.log(dArr3[i5]);
                    }
                }
                for (int i7 = iArr[i] - 1; i7 >= 0; i7--) {
                    int i8 = i7;
                    dArr2[i8] = dArr2[i8] / rowCount;
                    if (dArr2[i7] != BenchmarkConfig.NOTAVAILABLE) {
                        dArr5[i7] = Math.log(dArr2[i7]);
                    }
                    for (int i9 = iArr[i2] - 1; i9 >= 0; i9--) {
                        double[] dArr7 = dArr[i7];
                        int i10 = i9;
                        dArr7[i10] = dArr7[i10] / rowCount;
                        if (dArr2[i7] != BenchmarkConfig.NOTAVAILABLE && dArr3[i9] != BenchmarkConfig.NOTAVAILABLE && dArr[i7][i9] != BenchmarkConfig.NOTAVAILABLE) {
                            d += (dArr[i7][i9] * ((Math.log(dArr[i7][i9]) - dArr5[i7]) - dArr6[i9])) / log;
                        }
                    }
                }
                double d2 = d < BenchmarkConfig.NOTAVAILABLE ? BenchmarkConfig.NOTAVAILABLE : d;
                doubleMatrix2D.setDouble(d2, i, i2);
                doubleMatrix2D.setDouble(d2, i2, i);
            }
        }
        return doubleMatrix2D;
    }
}
