package edu.wisc.game.math;

import edu.wisc.game.formatter.Fmter;
import edu.wisc.game.sql.Main;
import edu.wisc.game.sql.MlcEntry;
import edu.wisc.game.tools.MwByHuman;
import edu.wisc.game.util.ImportCSV;
import edu.wisc.game.util.Util;
import java.io.File;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Vector;
import javax.persistence.EntityManager;
import javax.persistence.Query;

/* loaded from: input_file:edu/wisc/game/math/MannWhitneyComparison.class */
public class MannWhitneyComparison {
    final Mode mode;
    private static Fmter plainFm = new Fmter();

    /* loaded from: input_file:edu/wisc/game/math/MannWhitneyComparison$Mode.class */
    public enum Mode {
        CMP_RULES,
        CMP_RULES_HUMAN,
        CMP_ALGOS
    }

    public MannWhitneyComparison(Mode mode) {
        this.mode = mode;
    }

    private Query mkQuery(EntityManager entityManager, String str, String str2) {
        Query createQuery;
        if (this.mode == Mode.CMP_ALGOS) {
            createQuery = entityManager.createQuery("select m from MlcEntry m where m.ruleSetName=:x");
            createQuery.setParameter("x", str2);
        } else {
            if (this.mode != Mode.CMP_RULES) {
                throw new IllegalArgumentException("Wrong mode for querying MlcEntry: " + this.mode);
            }
            createQuery = entityManager.createQuery("select m from MlcEntry m where m.nickname=:x");
            createQuery.setParameter("x", str);
        }
        return createQuery;
    }

    private String getKey(MlcEntry mlcEntry) {
        return this.mode == Mode.CMP_ALGOS ? mlcEntry.getNickname() : mlcEntry.getRuleSetName();
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v34, types: [edu.wisc.game.math.Comparandum[], edu.wisc.game.math.Comparandum[][]] */
    public Comparandum[][] mkMlcComparanda(String str, String str2) {
        int intValue;
        EntityManager entityManager = null;
        try {
            entityManager = Main.getNewEM();
            List<MlcEntry> resultList = mkQuery(entityManager, str, str2).getResultList();
            Vector vector = new Vector();
            HashMap hashMap = new HashMap();
            Vector vector2 = new Vector();
            int i = 0;
            Iterator it = resultList.iterator();
            while (it.hasNext()) {
                String key = getKey((MlcEntry) it.next());
                boolean z = hashMap.get(key) == null;
                if (z) {
                    intValue = i;
                    i++;
                } else {
                    intValue = ((Integer) hashMap.get(key)).intValue();
                }
                int i2 = intValue;
                if (z) {
                    hashMap.put(key, Integer.valueOf(i2));
                    vector.add(key);
                    vector2.add(1);
                } else {
                    vector2.set(i2, Integer.valueOf(((Integer) vector2.get(i2)).intValue() + 1));
                }
            }
            MlcEntry[] mlcEntryArr = new MlcEntry[i];
            for (int i3 = 0; i3 < i; i3++) {
                mlcEntryArr[i3] = new MlcEntry[((Integer) vector2.get(i3)).intValue()];
            }
            int[] iArr = new int[i];
            for (MlcEntry mlcEntry : resultList) {
                int intValue2 = ((Integer) hashMap.get(getKey(mlcEntry))).intValue();
                MlcEntry[] mlcEntryArr2 = mlcEntryArr[intValue2];
                int i4 = iArr[intValue2];
                iArr[intValue2] = i4 + 1;
                mlcEntryArr2[i4] = mlcEntry;
            }
            Vector vector3 = new Vector();
            Vector vector4 = new Vector();
            for (int i5 = 0; i5 < i; i5++) {
                boolean z2 = false;
                for (int i6 = 0; i6 < mlcEntryArr[i5].length; i6++) {
                    z2 = z2 || !mlcEntryArr[i5][i6].getLearned();
                }
                (z2 ? vector4 : vector3).add(new Comparandum(getKey(mlcEntryArr[i5][0]), !z2, mlcEntryArr[i5]));
            }
            Comparandum[] comparandumArr = new Comparandum[0];
            ?? r0 = {(Comparandum[]) vector3.toArray(comparandumArr), (Comparandum[]) vector4.toArray(comparandumArr)};
            if (entityManager != null) {
                try {
                    entityManager.close();
                } catch (Exception e) {
                }
            }
            return r0;
        } catch (Throwable th) {
            if (entityManager != null) {
                try {
                    entityManager.close();
                } catch (Exception e2) {
                }
            }
            throw th;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v106, types: [java.lang.String[], java.lang.String[][]] */
    public String doCompare(String str, String str2, Comparandum[][] comparandumArr, Fmter fmter, File[] fileArr) {
        String str3;
        String[] strArr;
        String str4 = this.mode == Mode.CMP_ALGOS ? str : this.mode == Mode.CMP_RULES ? str2 : "";
        String str5 = this.mode == Mode.CMP_ALGOS ? str2 : this.mode == Mode.CMP_RULES ? str : "";
        String str6 = this.mode == Mode.CMP_ALGOS ? "Results comparison on rule set " : this.mode == Mode.CMP_RULES ? "Comparing rule sets with respect to algo " : "Comparing rule sets with respect to human performance";
        String str7 = str6 + fmter.tt(str5);
        String str8 = str6 + str5;
        try {
            Comparandum[] comparandumArr2 = comparandumArr[0];
            Comparandum[] comparandumArr3 = comparandumArr[1];
            String str9 = "" + fmter.h1(str7);
            double[][] rawMatrix = MannWhitney.rawMatrix(Comparandum.asArray(comparandumArr2));
            double[][] ratioMatrix = MannWhitney.ratioMatrix(rawMatrix);
            double[] dArr = MannWhitney.topEigenVector(ratioMatrix);
            for (int i = 0; i < dArr.length; i++) {
                comparandumArr2[i].setEv(dArr[i]);
            }
            Vector vector = new Vector();
            for (int i2 = 0; i2 < comparandumArr2.length; i2++) {
                vector.add(Integer.valueOf(i2));
            }
            vector.sort((num, num2) -> {
                return (int) Math.signum(dArr[num2.intValue()] - dArr[num.intValue()]);
            });
            Vector<String> vector2 = new Vector<>();
            Vector<String> vector3 = new Vector<>();
            int size = vector.size() + 1;
            String[][] strArr2 = {new String[size], new String[size]};
            for (int i3 = 0; i3 < size; i3++) {
                strArr2[0][i3] = new String[size];
                strArr2[1][i3] = new String[size];
            }
            String[] strArr3 = strArr2[0][0];
            strArr2[1][0][0] = "#key";
            strArr3[0] = "#key";
            for (int i4 = 0; i4 < vector.size(); i4++) {
                int intValue = ((Integer) vector.get(i4)).intValue();
                String str10 = comparandumArr2[intValue].key;
                boolean equals = str10.equals(str4);
                Vector<String> vector4 = new Vector<>();
                Vector<String> vector5 = new Vector<>();
                strArr2[1][0][i4 + 1] = str10;
                strArr2[0][0][i4 + 1] = str10;
                String[] strArr4 = strArr2[0][i4 + 1];
                strArr2[1][i4 + 1][0] = str10;
                strArr4[0] = str10;
                for (int i5 = 0; i5 < vector.size(); i5++) {
                    double[] dArr2 = {rawMatrix[intValue][((Integer) vector.get(i5)).intValue()], ratioMatrix[intValue][((Integer) vector.get(i5)).intValue()]};
                    strArr2[0][i4 + 1][i5 + 1] = "" + dArr2[0];
                    strArr2[1][i4 + 1][i5 + 1] = "" + dArr2[1];
                    vector4.add("" + dArr2[0]);
                    vector5.add(fmter.sprintf("%8.4f", Double.valueOf(dArr2[1])));
                }
                if (equals) {
                    eachStrong(fmter, vector4);
                }
                if (equals) {
                    eachStrong(fmter, vector5);
                }
                vector2.add(fmter.rowTh(str10, "align='right'", vector4));
                vector3.add(fmter.rowTh(str10, "align='right'", vector5));
            }
            String str11 = (((str9 + fmter.h3("Raw M-W matrix")) + fmter.table("", vector2)) + fmter.h3("M-W ratio matrix")) + fmter.table("", vector3);
            if (fileArr != null) {
                for (int i6 = 0; i6 < 2; i6++) {
                    ImportCSV.escapeAndwriteToFile(strArr2[i6], fileArr[i6]);
                }
            }
            String str12 = str11 + fmter.h3(this.mode == Mode.CMP_ALGOS ? "Comparison of algorithms" : "Comparison of rule sets");
            String str13 = this.mode == Mode.CMP_ALGOS ? "#Algo nickname" : "#Rule set name";
            Vector<String> vector6 = new Vector<>();
            ?? r0 = new String[size];
            if (this.mode == Mode.CMP_ALGOS || this.mode == Mode.CMP_RULES) {
                strArr = new String[]{str13, "Learned? (learned/not learned)", "EV score", "Runs", "Avg episodes till learned", "Avg errors till learned", "Avg moves till learned", "Avg error rate"};
                String[] strArr5 = new String[9];
                strArr5[0] = str13;
                strArr5[1] = "Learned";
                strArr5[2] = "Not learned";
                strArr5[3] = "EV score";
                strArr5[4] = "Runs";
                strArr5[5] = "Avg episodes till learned";
                strArr5[6] = "Avg errors till learned";
                strArr5[7] = "Avg moves till learned";
                strArr5[8] = "Avg error rate";
                r0[0] = strArr5;
            } else {
                boolean z = vector.size() > 0 && comparandumArr2[((Integer) vector.get(0)).intValue()].useMDagger;
                String str14 = z ? "m<sub>&dagger;</sub>" : "m<sub>*</sub>";
                strArr = new String[]{str13, "Learned/not learned", "EV score", "Avg " + str14 + " (learners/all)", "min-median-max " + str14 + fmter.brHtml() + "(learners)", "Harmonic mean " + str14 + fmter.brHtml() + "(learners/all)", "Avg error rate"};
                String str15 = z ? "m!" : "m*";
                String[] strArr6 = new String[12];
                strArr6[0] = str13;
                strArr6[1] = "Learned";
                strArr6[2] = "Not learned";
                strArr6[3] = "EV score";
                strArr6[4] = "Avg " + str15 + " on learners";
                strArr6[5] = "Avg " + str15 + " on all";
                strArr6[6] = "min " + str15;
                strArr6[7] = "med " + str15;
                strArr6[8] = "max " + str15;
                strArr6[9] = "Harmonic mean " + str15 + " on learners";
                strArr6[10] = "Harmonic mean " + str15 + " on all";
                strArr6[11] = "Avg error rate";
                r0[0] = strArr6;
            }
            String str16 = "";
            for (String str17 : strArr) {
                str16 = str16 + fmter.th(str17);
            }
            vector6.add(str16);
            for (int i7 = 0; i7 < vector.size(); i7++) {
                int intValue2 = ((Integer) vector.get(i7)).intValue();
                Comparandum comparandum = comparandumArr2[intValue2];
                String str18 = comparandum.key;
                double d = dArr[intValue2];
                boolean equals2 = str18.equals(str4);
                String[] strArr7 = new String[0];
                if (comparandum.mlc != null) {
                    MlcEntry[] mlcEntryArr = comparandum.mlc;
                    int length = mlcEntryArr.length;
                    double d2 = 0.0d;
                    double d3 = 0.0d;
                    double d4 = 0.0d;
                    for (MlcEntry mlcEntry : mlcEntryArr) {
                        d2 += mlcEntry.getErrorsUntilLearned();
                        d4 += mlcEntry.getEpisodesUntilLearned();
                        d3 += mlcEntry.getMovesUntilLearned();
                    }
                    double d5 = d2 / length;
                    double d6 = d4 / length;
                    double d7 = d3 / length;
                    String[] strArr8 = {"Learned (" + length + "/0)", fmter.sprintf("%6.4g", Double.valueOf(d)), "" + length, fmter.sprintf("%5.2f", Double.valueOf(d6)), fmter.sprintf("%6.2f", Double.valueOf(d5)), fmter.sprintf("%6.2f", Double.valueOf(d7)), ""};
                    String[] strArr9 = new String[9];
                    strArr9[0] = str18;
                    strArr9[1] = "" + length;
                    strArr9[2] = "0";
                    strArr9[3] = "" + d;
                    strArr9[4] = "" + length;
                    strArr9[5] = "" + d6;
                    strArr9[6] = "" + d5;
                    strArr9[7] = "" + d7;
                    strArr9[8] = "";
                    r0[i7 + 1] = strArr9;
                    if (equals2) {
                        eachStrong(fmter, strArr8);
                    }
                    vector6.add(fmter.rowTh(str18, "align='right'", strArr8));
                } else {
                    if (comparandum.humanSer == null) {
                        throw new IllegalArgumentException();
                    }
                    int i8 = 0;
                    double d8 = 0.0d;
                    double d9 = 0.0d;
                    int i9 = 0;
                    int i10 = 0;
                    double d10 = 0.0d;
                    double d11 = 0.0d;
                    for (MwByHuman.MwSeries mwSeries : comparandum.humanSer) {
                        double m = comparandum.getM(mwSeries);
                        if (mwSeries.getLearned()) {
                            i8++;
                            d9 += m;
                            d11 += 1.0d / m;
                        }
                        d8 += m;
                        d10 += 1.0d / m;
                        i9 += mwSeries.getTotalMoves();
                        i10 += mwSeries.getTotalErrors();
                    }
                    double[] dArr3 = new double[i8];
                    int i11 = 0;
                    for (MwByHuman.MwSeries mwSeries2 : comparandum.humanSer) {
                        if (mwSeries2.getLearned()) {
                            int i12 = i11;
                            i11++;
                            dArr3[i12] = comparandum.getM(mwSeries2);
                        }
                    }
                    double[] minMedMax = minMedMax(dArr3);
                    double length2 = comparandum.humanSer.length;
                    double d12 = d8 / length2;
                    double d13 = d9 / i8;
                    double d14 = length2 / d10;
                    double d15 = i8 / d11;
                    double d16 = i10 / i9;
                    String formatHumanKey = formatHumanKey(fmter, str18);
                    String[] strArr10 = new String[7];
                    strArr10[0] = formatHumanKey;
                    strArr10[1] = "" + i8 + "/" + (comparandum.humanSer.length - i8);
                    strArr10[2] = fmter.sprintf("%6.4g", Double.valueOf(d));
                    strArr10[3] = fmter.sprintf("%6.2f", Double.valueOf(d13)) + "/" + fmter.sprintf("%6.2f", Double.valueOf(d12));
                    strArr10[4] = dArr3.length > 0 ? ((int) minMedMax[0]) + " - " + fmter.sprintf("%6.1f", Double.valueOf(minMedMax[1])) + " - " + ((int) minMedMax[2]) : "";
                    strArr10[5] = fmter.sprintf("%6.2f", Double.valueOf(d15)) + "/" + fmter.sprintf("%6.2f", Double.valueOf(d14));
                    strArr10[6] = fmter.sprintf("%4.2f", Double.valueOf(d16));
                    int i13 = i7 + 1;
                    String[] strArr11 = new String[12];
                    strArr11[0] = formatHumanKey;
                    strArr11[1] = "" + i8;
                    strArr11[2] = "" + (comparandum.humanSer.length - i8);
                    strArr11[3] = "" + d;
                    strArr11[4] = "" + d13;
                    strArr11[5] = "" + d12;
                    strArr11[6] = dArr3.length > 0 ? "" + minMedMax[0] : "";
                    strArr11[7] = dArr3.length > 0 ? "" + minMedMax[1] : "";
                    strArr11[8] = dArr3.length > 0 ? "" + minMedMax[2] : "";
                    strArr11[9] = "" + d15;
                    strArr11[10] = "" + d14;
                    strArr11[11] = "" + d16;
                    r0[i13] = strArr11;
                    if (equals2) {
                        eachStrong(fmter, strArr10);
                    }
                    vector6.add(fmter.rowExtra("align='right'", strArr10));
                }
            }
            if (fileArr != null) {
                ImportCSV.escapeAndwriteToFile(r0, fileArr[2]);
            }
            vector.clear();
            Vector vector7 = new Vector();
            Vector vector8 = new Vector();
            for (Comparandum comparandum2 : comparandumArr3) {
                String str19 = comparandum2.key;
                boolean equals3 = str19.equals(str4);
                MlcEntry[] mlcEntryArr2 = comparandum2.mlc;
                int length3 = mlcEntryArr2.length;
                double d17 = 0.0d;
                double d18 = 0.0d;
                int i14 = 0;
                for (MlcEntry mlcEntry2 : mlcEntryArr2) {
                    d18 += r0.getTotalErrors();
                    d17 += r0.getTotalMoves();
                    if (mlcEntry2.getLearned()) {
                        i14++;
                    }
                }
                double d19 = d18 / d17;
                String str20 = i14 == 0 ? "Not learned" : "Sometimes learned";
                vector.add(Integer.valueOf(vector7.size()));
                String[] strArr12 = {str20 + " (" + i14 + "/" + (length3 - i14) + ")", "", "", "", "", "", fmter.sprintf("%4.3f", Double.valueOf(d19))};
                if (equals3) {
                    eachStrong(fmter, strArr12);
                }
                vector7.add(fmter.rowTh(str19, "align='right'", strArr12));
                vector8.add(Double.valueOf(d19));
            }
            vector.sort((num3, num4) -> {
                return (int) Math.signum(((Double) vector8.get(num3.intValue())).doubleValue() - ((Double) vector8.get(num4.intValue())).doubleValue());
            });
            Iterator it = vector.iterator();
            while (it.hasNext()) {
                vector6.add((String) vector7.get(((Integer) it.next()).intValue()));
            }
            str3 = str12 + fmter.table("border=\"1\"", vector6);
        } catch (Exception e) {
            str8 = "Error";
            str3 = fmter.para(e.toString()) + fmter.para(fmter.wrap("small", "Details:" + Util.stackToString(e)));
            System.err.println("" + e);
            e.printStackTrace(System.err);
        }
        return fmter.html(str8, str3);
    }

    static double[] minMedMax(double[] dArr) {
        if (dArr.length == 0) {
            return new double[]{Double.NaN, Double.NaN, Double.NaN};
        }
        Arrays.sort(dArr);
        int length = dArr.length / 2;
        return new double[]{dArr[0], dArr.length % 2 == 1 ? dArr[length] : (dArr[length - 1] + dArr[length]) / 2.0d, dArr[dArr.length - 1]};
    }

    private static void usage() {
        usage(null);
    }

    private static void usage(String str) {
        System.err.println("Usage:");
        System.err.println(" MannWhitneyComparison -mode CMP_ALGOS -rule ruleName [-csvOut dir]");
        System.err.println(" MannWhitneyComparison -mode CMP_RULES -algo ruleName [-csvOut dir]");
        if (str != null) {
            System.err.println(str + "\n");
        }
        System.exit(1);
    }

    public static void main(String[] strArr) {
        Mode mode = Mode.CMP_RULES;
        String str = null;
        String str2 = null;
        File file = null;
        int i = 0;
        while (i < strArr.length) {
            String str3 = strArr[i];
            if (i + 1 < strArr.length && str3.equals("-mode")) {
                i++;
                mode = (Mode) Enum.valueOf(Mode.class, strArr[i].toUpperCase());
            } else if (i + 1 < strArr.length && str3.equals("-algo")) {
                i++;
                str = strArr[i];
            } else if (i + 1 < strArr.length && str3.equals("-rule")) {
                i++;
                str2 = strArr[i];
            } else if (i + 1 < strArr.length && str3.equals("-csvOut")) {
                i++;
                file = new File(strArr[i]);
            } else if (str3.startsWith("-")) {
                usage("Unknown option: " + str3);
            } else {
                usage("Don't know what to do with the argument: " + str3);
            }
            i++;
        }
        if (mode == Mode.CMP_RULES) {
            if (str == null) {
                usage("In the mode " + mode + ", must supply -algo");
            }
        } else if (mode != Mode.CMP_ALGOS) {
            usage("Mode not supported: " + mode);
        } else if (str2 == null) {
            usage("In the mode " + mode + ", must supply -rule");
        }
        MannWhitneyComparison mannWhitneyComparison = new MannWhitneyComparison(mode);
        System.out.println(mannWhitneyComparison.doCompare(str, str2, mannWhitneyComparison.mkMlcComparanda(str, str2), plainFm, expandCsvOutDir(file)));
    }

    private static void eachStrong(Fmter fmter, String[] strArr) {
        for (int i = 0; i < strArr.length; i++) {
            strArr[i] = fmter.strong(strArr[i]);
        }
    }

    private static void eachStrong(Fmter fmter, Vector<String> vector) {
        for (int i = 0; i < vector.size(); i++) {
            vector.set(i, fmter.strong(vector.get(i)));
        }
    }

    private static String formatHumanKey(Fmter fmter, String str) {
        String[] split = str.split(":");
        split[split.length - 1] = fmter.strong(split[split.length - 1]);
        return String.join(":", split);
    }

    public static File[] expandCsvOutDir(File file) {
        File[] fileArr = null;
        if (file != null) {
            file.mkdirs();
            fileArr = new File[]{new File(file, "raw-wm.csv"), new File(file, "ratio-wm.csv"), new File(file, "ranking.csv")};
        }
        return fileArr;
    }
}
