package oliver.neuralnetwork;

import bi.kevin.Layer;
import java.io.File;
import java.io.Serializable;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import oliver.logic.impl.NeuralNetworkSettings;
import oliver.logic.impl.TrainedNeuralNetwork;
import oliver.map.Heatmap;
import oliver.statistics.BasicStats;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;

/* loaded from: input_file:oliver/neuralnetwork/NeuralNetworkBuilder.class */
public abstract class NeuralNetworkBuilder {
    protected final NeuralNetworkSettings settings;
    private final Map<String, List<Integer>> rowsToExcludeByCause;
    protected final List<Integer> allRowsToExclude = new ArrayList();
    public double[][] masterDataset = (double[][]) null;
    public double[][] testingDataset = (double[][]) null;
    public double[][] trainingDataset = (double[][]) null;
    public double[][] masterLabels = (double[][]) null;
    public double[][] testingLabels = (double[][]) null;
    public double[][] trainingLabels = (double[][]) null;
    public String[] inputColumnHeaders = null;
    public String[] labelColumnHeaders = null;
    protected final File outputDir;

    /* loaded from: input_file:oliver/neuralnetwork/NeuralNetworkBuilder$NnFile.class */
    public enum NnFile {
        TrainingDataH5("train.h5", "Training data file (.h5)"),
        TestingDataH5("test.h5", "Testing data file (.h5)"),
        MasterDataH5("master.h5", "master data file (.h5)"),
        TrainingLoc("training_files", "File specifying the location of the training data h5 file"),
        TestingLoc("testing_files", "File specifying the location of the testing data h5 file"),
        MasterLoc("master_files", "File specifying the location of the master data h5 file"),
        TrainingModelPrototxt("train_net.prototxt", "Neural net model for training/testing"),
        PredictModelPrototxt("predict_net.prototxt", "Nerual net model for analyze/predict"),
        SolverFile("solver", "solver file");

        public final String filename;
        public final String description;

        NnFile(String str, String str2) {
            this.filename = str;
            this.description = str2;
        }
    }

    public NeuralNetworkBuilder(NeuralNetworkSettings neuralNetworkSettings, File file) throws Exception {
        this.settings = neuralNetworkSettings;
        this.outputDir = file;
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(neuralNetworkSettings.listSelectedInputs());
        arrayList.add(neuralNetworkSettings.getSelectedLabel());
        this.rowsToExcludeByCause = initRowsToExclude(arrayList, neuralNetworkSettings.getHeatmap());
        Iterator<List<Integer>> it = this.rowsToExcludeByCause.values().iterator();
        while (it.hasNext()) {
            for (Integer num : it.next()) {
                if (!this.allRowsToExclude.contains(num)) {
                    this.allRowsToExclude.add(num);
                }
            }
        }
        if (neuralNetworkSettings.listSelectedInputs().isEmpty()) {
            throw new Exception("You must select at least one input");
        }
    }

    public abstract TrainedNeuralNetwork buildAndTrainNeuralNetwork() throws Exception;

    /* JADX INFO: Access modifiers changed from: protected */
    public void prepareDataset() throws Exception {
        this.trainingDataset = (double[][]) null;
        this.testingDataset = (double[][]) null;
        this.trainingLabels = (double[][]) null;
        this.testingLabels = (double[][]) null;
        this.inputColumnHeaders = null;
        TreeMap treeMap = new TreeMap();
        Iterator<String> it = this.settings.listSelectedInputs().iterator();
        while (it.hasNext()) {
            treeMap.putAll(getDataColumnsForInputLabel(it.next()));
        }
        String selectedLabel = this.settings.getSelectedLabel();
        Map<String, double[]> dataColumnsForInputLabel = getDataColumnsForInputLabel(selectedLabel);
        if (dataColumnsForInputLabel.size() > 1) {
            throw new Exception(MessageFormat.format("Cannot use \"{0}\" for labels because it is catagorical. Currently only scalar-valued labels are supported.", selectedLabel));
        }
        this.inputColumnHeaders = new String[treeMap.size()];
        int i = 0;
        Iterator it2 = treeMap.keySet().iterator();
        while (it2.hasNext()) {
            int i2 = i;
            i++;
            this.inputColumnHeaders[i2] = (String) it2.next();
        }
        this.labelColumnHeaders = new String[dataColumnsForInputLabel.size()];
        int i3 = 0;
        Iterator<String> it3 = dataColumnsForInputLabel.keySet().iterator();
        while (it3.hasNext()) {
            int i4 = i3;
            i3++;
            this.labelColumnHeaders[i4] = it3.next();
        }
        ArrayList arrayList = new ArrayList();
        for (String str : this.inputColumnHeaders) {
            arrayList.add(treeMap.get(str));
        }
        ArrayList arrayList2 = new ArrayList();
        for (String str2 : this.labelColumnHeaders) {
            arrayList2.add(dataColumnsForInputLabel.get(str2));
        }
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        ArrayList arrayList5 = new ArrayList();
        int length = ((double[]) arrayList.get(0)).length;
        for (int i5 = 0; i5 < length; i5++) {
            if (!this.allRowsToExclude.contains(Integer.valueOf(i5))) {
                arrayList3.add(Integer.valueOf(i5));
                arrayList4.add(Integer.valueOf(i5));
            }
        }
        int networkParamValue = (int) (length * this.settings.getNetworkParamValue(NeuralNetworkSettings.NetworkParam.Percent_Training));
        while (arrayList5.size() < networkParamValue && !arrayList4.isEmpty()) {
            arrayList5.add(popRandom(arrayList4));
        }
        this.masterDataset = buildDataSubset(arrayList, arrayList3);
        this.trainingDataset = buildDataSubset(arrayList, arrayList5);
        this.testingDataset = buildDataSubset(arrayList, arrayList4);
        this.masterLabels = toSingleColumn(buildDataSubset(arrayList2, arrayList3));
        this.trainingLabels = toSingleColumn(buildDataSubset(arrayList2, arrayList5));
        this.testingLabels = toSingleColumn(buildDataSubset(arrayList2, arrayList4));
    }

    private Map<String, double[]> getDataColumnsForInputLabel(String str) {
        HashMap hashMap = new HashMap();
        Serializable[] inputValues = this.settings.getInputValues(str);
        if (isSeriesNumeric(inputValues)) {
            hashMap.put(str, toNormalized(toDoubleArr(inputValues)));
        } else {
            ArrayList<Serializable> arrayList = new ArrayList();
            for (Serializable serializable : inputValues) {
                if (!arrayList.contains(serializable)) {
                    arrayList.add(serializable);
                }
            }
            for (Serializable serializable2 : arrayList) {
                hashMap.put(str + ":" + serializable2, toBinary(inputValues, serializable2));
            }
        }
        return hashMap;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Layer[] buildLayersForAnalyzeAndPredict() {
        int countHiddenLayers = this.settings.countHiddenLayers();
        int i = countHiddenLayers + 2;
        Layer[] layerArr = new Layer[i];
        layerArr[0] = new Layer(4, 0, NnFile.MasterLoc.filename, 1);
        for (int i2 = 0; i2 < countHiddenLayers; i2++) {
            layerArr[1 + i2] = new Layer(getEc2nnLayerTypeInteger(this.settings.getHiddenLayerNeuronType(i2)), this.settings.getHiddenLayerNeuronCount(i2));
        }
        layerArr[i - 1] = new Layer(2, this.labelColumnHeaders.length);
        return layerArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Layer[] buildLayersForTraining() {
        int countHiddenLayers = this.settings.countHiddenLayers();
        int i = countHiddenLayers + 3;
        Layer[] layerArr = new Layer[i];
        layerArr[0] = new Layer(4, 0, NnFile.TrainingLoc.filename, 1, 100);
        layerArr[1] = new Layer(4, 0, NnFile.TestingLoc.filename, 2, (int) this.settings.getNetworkParamValue(NeuralNetworkSettings.NetworkParam.Testing_Batch_Size));
        for (int i2 = 0; i2 < countHiddenLayers; i2++) {
            layerArr[2 + i2] = new Layer(getEc2nnLayerTypeInteger(this.settings.getHiddenLayerNeuronType(i2)), this.settings.getHiddenLayerNeuronCount(i2));
        }
        layerArr[i - 1] = new Layer(2, this.labelColumnHeaders.length);
        return layerArr;
    }

    public static double[][] buildDataSubset(List<double[]> list, List<Integer> list2) {
        int size = list2.size();
        int size2 = list.size();
        double[][] dArr = new double[size][size2];
        for (int i = 0; i < size2; i++) {
            for (int i2 = 0; i2 < size; i2++) {
                dArr[i2][i] = list.get(i)[list2.get(i2).intValue()];
            }
        }
        return dArr;
    }

    private int getEc2nnLayerTypeInteger(NeuralNetworkSettings.NeuronType neuronType) {
        switch (neuronType) {
            case Sigmoid:
                return 1;
            case TanH:
                return 2;
            case ReLU:
                return 3;
            default:
                throw new Error(MessageFormat.format("no EC2NN layer type integer specified for NeuronType \"{0}\"", neuronType.name()));
        }
    }

    public Map<String, List<Integer>> getRowsToExcludeByCause() {
        HashMap hashMap = new HashMap();
        for (String str : this.rowsToExcludeByCause.keySet()) {
            hashMap.put(str, new ArrayList(this.rowsToExcludeByCause.get(str)));
        }
        return hashMap;
    }

    public List<Integer> getAllRowstoExclude() {
        return new ArrayList(this.allRowsToExclude);
    }

    protected static boolean isSeriesNumeric(Serializable[] serializableArr) {
        for (Serializable serializable : serializableArr) {
            if (!(serializable instanceof Double)) {
                return false;
            }
        }
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static Map<String, List<Integer>> initRowsToExclude(List<String> list, Heatmap heatmap) {
        HashMap hashMap = new HashMap();
        for (String str : list) {
            Serializable[] extraColumnValues = heatmap.getExtraColumnValues(str);
            if (isSeriesNumeric(extraColumnValues)) {
                for (int i = 0; i < extraColumnValues.length; i++) {
                    if (Double.isNaN(((Double) extraColumnValues[i]).doubleValue())) {
                        if (!hashMap.containsKey(str)) {
                            hashMap.put(str, new ArrayList());
                        }
                        ((List) hashMap.get(str)).add(Integer.valueOf(i));
                    }
                }
            }
        }
        return hashMap;
    }

    private Integer popRandom(List<Integer> list) {
        return list.remove((int) (Math.random() * list.size()));
    }

    private double[][] toSingleColumn(double[][] dArr) {
        int length = dArr[0].length;
        int length2 = dArr.length;
        double[][] dArr2 = new double[length2][1];
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length2; i2++) {
                double[] dArr3 = dArr2[i2];
                dArr3[0] = dArr3[0] + (dArr[i2][i] * (i + 1));
            }
        }
        return dArr2;
    }

    private double[] toBinary(Serializable[] serializableArr, Serializable serializable) {
        int length = serializableArr.length;
        double[] dArr = new double[length];
        for (int i = 0; i < length; i++) {
            if ((serializableArr[i] == null && serializable == null) || (serializableArr[i] != null && serializableArr[i].equals(serializable))) {
                dArr[i] = 1.0d;
            }
        }
        return dArr;
    }

    private double[] toNormalized(double[] dArr) {
        int length = dArr.length;
        double mean = BasicStats.getMean(dArr);
        double std = BasicStats.getStd(dArr, mean);
        double[] dArr2 = new double[length];
        for (int i = 0; i < length; i++) {
            dArr2[i] = dArr[i] == mean ? CMAESOptimizer.DEFAULT_STOPFITNESS : (dArr[i] - mean) / std;
        }
        return dArr2;
    }

    private double[] toDoubleArr(Serializable[] serializableArr) {
        int length = serializableArr.length;
        double[] dArr = new double[length];
        for (int i = 0; i < length; i++) {
            dArr[i] = ((Double) serializableArr[i]).doubleValue();
        }
        return dArr;
    }
}
