package oliver.logic.impl;

import java.io.File;
import java.io.Serializable;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.stream.Collectors;
import javax.swing.JComponent;
import javax.swing.JTextField;
import oliver.logic.Logic;
import oliver.logic.LogicalBranchingInterface;
import oliver.map.Heatmap;
import oliver.neuralnetwork.RemoteNeuralNetworkBuilder;
import oliver.ui.components.LogSliderWithTextField;
import oliver.ui.components.SliderWithTextField;

/* loaded from: input_file:oliver/logic/impl/NeuralNetworkSettings.class */
public class NeuralNetworkSettings extends Logic {
    private final Heatmap heatmap;
    private String labelSelected;
    private final TreeMap<String, Boolean> inputsSelected = new TreeMap<>();
    private final List<HiddenLayerSettings> hiddenLayers = new ArrayList();
    private boolean useSimpleNetwork = true;
    private final TreeMap<NetworkParam, Double> networkParamValues = new TreeMap<>();
    private boolean useLocalForTraining = false;
    private String hostUrl = "ec2-54-152-208-18.compute-1.amazonaws.com";
    private File hostKey = null;

    /* loaded from: input_file:oliver/logic/impl/NeuralNetworkSettings$HiddenLayerSettings.class */
    public class HiddenLayerSettings {
        public int neuronCount;
        public NeuronType neuronType;

        public HiddenLayerSettings() {
        }
    }

    /* loaded from: input_file:oliver/logic/impl/NeuralNetworkSettings$NetworkParam.class */
    public enum NetworkParam {
        Training_Batch_Size(100.0d, JTextField.class, true),
        Testing_Batch_Size(100.0d, JTextField.class, true),
        Snapshot(50.0d, JTextField.class, true),
        Training_Iterations(100.0d, JTextField.class, true),
        Test_Interval(50.0d, JTextField.class, true),
        Test_Iterations(50.0d, JTextField.class, true),
        Learning_Rate(0.01d, LogSliderWithTextField.class, false),
        Weight_Decay(0.005d, LogSliderWithTextField.class, false),
        Percent_Training(0.5d, SliderWithTextField.class, false),
        Gamma(0.01d, SliderWithTextField.class, false),
        Power(0.75d, SliderWithTextField.class, false),
        Momentum(0.9d, SliderWithTextField.class, false);

        public final boolean availableInSimpleMode;
        private final double defaultValue;
        public final Class<? extends JComponent> uiControlClass;

        NetworkParam(double d, Class cls, boolean z) {
            this.defaultValue = d;
            this.uiControlClass = cls;
            this.availableInSimpleMode = z;
        }
    }

    /* loaded from: input_file:oliver/logic/impl/NeuralNetworkSettings$NeuronType.class */
    public enum NeuronType {
        TanH,
        Sigmoid,
        ReLU
    }

    /* loaded from: input_file:oliver/logic/impl/NeuralNetworkSettings$ResultType.class */
    public enum ResultType {
        Classification,
        Regression
    }

    public NeuralNetworkSettings(Heatmap heatmap, List<String> list) {
        if (list.size() < 2) {
            throw new Error("Cannot have less than two inputs");
        }
        this.heatmap = heatmap;
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            this.inputsSelected.put(it.next(), false);
        }
        this.labelSelected = this.inputsSelected.firstKey();
        for (NetworkParam networkParam : NetworkParam.values()) {
            this.networkParamValues.put(networkParam, Double.valueOf(networkParam.defaultValue));
        }
    }

    public Heatmap getHeatmap() {
        return this.heatmap;
    }

    public TrainedNeuralNetwork trainNeuralNetwork(File file) throws Exception {
        String str;
        if (this.hostKey == null) {
            throw new Exception("Not authenticated. You must select a key file under the \"... Host\" tab");
        }
        if (this.useLocalForTraining) {
            throw new Exception("local network training not supported");
        }
        RemoteNeuralNetworkBuilder remoteNeuralNetworkBuilder = new RemoteNeuralNetworkBuilder(this, file, this.hostKey);
        int size = remoteNeuralNetworkBuilder.getAllRowstoExclude().size();
        if (size > 0) {
            Map<String, List<Integer>> rowsToExcludeByCause = remoteNeuralNetworkBuilder.getRowsToExcludeByCause();
            String str2 = size + " row(s) will need to be excluded";
            if (rowsToExcludeByCause.size() > 1) {
                str = str2 + ":";
                for (String str3 : rowsToExcludeByCause.keySet()) {
                    List<Integer> list = rowsToExcludeByCause.get(str3);
                    int size2 = list.size();
                    str = str + MessageFormat.format("\n{0} row(s) must be excluded because of NaN values under column \"{1}\"", Integer.valueOf(size2), str3);
                    if (size2 < 6) {
                        Iterator<Integer> it = list.iterator();
                        while (it.hasNext()) {
                            str = str + MessageFormat.format("\n       rowLabel \"{0}\"", this.heatmap.getRowLabel(it.next().intValue(), true));
                        }
                    }
                }
            } else {
                str = str2 + MessageFormat.format(" because of NaN values under column \"{0}\"", rowsToExcludeByCause.keySet().iterator().next());
            }
            if (!this.lbi.showConfirmDialog(str + "\nCreate neural network anyways?", str2)) {
                throw new LogicalBranchingInterface.UserCanceledException();
            }
        }
        return remoteNeuralNetworkBuilder.buildAndTrainNeuralNetwork();
    }

    public NeuronType getHiddenLayerNeuronType(int i) {
        return this.hiddenLayers.get(i).neuronType;
    }

    public int getHiddenLayerNeuronCount(int i) {
        return this.hiddenLayers.get(i).neuronCount;
    }

    public HiddenLayerSettings addHiddenLayer() {
        HiddenLayerSettings defaultHiddenLayerSettings = getDefaultHiddenLayerSettings();
        this.hiddenLayers.add(defaultHiddenLayerSettings);
        return defaultHiddenLayerSettings;
    }

    public void removeHiddenLayer(HiddenLayerSettings hiddenLayerSettings) {
        this.hiddenLayers.remove(hiddenLayerSettings);
        if (this.hiddenLayers.isEmpty()) {
            throw new Error("Must have at least one hidden layer");
        }
    }

    private HiddenLayerSettings getDefaultHiddenLayerSettings() {
        HiddenLayerSettings hiddenLayerSettings = new HiddenLayerSettings();
        hiddenLayerSettings.neuronCount = countSelectedInputs();
        hiddenLayerSettings.neuronType = NeuronType.TanH;
        return hiddenLayerSettings;
    }

    public int countHiddenLayers() {
        return this.hiddenLayers.size();
    }

    public int countSelectedInputs() {
        return (int) this.inputsSelected.entrySet().stream().filter(entry -> {
            return ((Boolean) entry.getValue()).booleanValue();
        }).count();
    }

    public List<String> listPossibleInputs() {
        return new ArrayList(this.inputsSelected.navigableKeySet());
    }

    public List<String> listSelectedInputs() {
        return (List) this.inputsSelected.entrySet().stream().filter(entry -> {
            return ((Boolean) entry.getValue()).booleanValue();
        }).map(entry2 -> {
            return (String) entry2.getKey();
        }).collect(Collectors.toList());
    }

    public Serializable[] getInputValues(String str) {
        return this.heatmap.getExtraColumnValues(str);
    }

    public void selectAllPossibleInputs() {
        for (String str : this.inputsSelected.keySet()) {
            this.inputsSelected.put(str, Boolean.valueOf(!this.labelSelected.equals(str)));
        }
    }

    public boolean isInputSelected(String str) {
        if (this.inputsSelected.keySet().contains(str)) {
            return this.inputsSelected.get(str).booleanValue();
        }
        throw new Error(MessageFormat.format("input \"{0}\" is not one of the possible inputs", str));
    }

    public void setInputSelected(String str, boolean z) {
        if (!this.inputsSelected.keySet().contains(str)) {
            throw new Error(MessageFormat.format("input \"{0}\" is not one of the possible inputs", str));
        }
        this.inputsSelected.put(str, Boolean.valueOf(z));
    }

    public String getSelectedLabel() {
        return this.labelSelected;
    }

    public void setSelectedLabel(String str) {
        if (!this.inputsSelected.containsKey(str)) {
            throw new Error(MessageFormat.format("\"{0}\" is not a valid label because it is not one of the possible inputs", str));
        }
        this.labelSelected = str;
    }

    public void setUseSimpleNetwork(boolean z) {
        this.useSimpleNetwork = z;
    }

    public boolean getUseSimpleNetwork() {
        return this.useSimpleNetwork;
    }

    public void setUseLocalForTraining(boolean z) {
        this.useLocalForTraining = z;
    }

    public boolean getUseLocalForTraining() {
        return this.useLocalForTraining;
    }

    public void setNetworkParamValue(NetworkParam networkParam, double d) {
        this.networkParamValues.put(networkParam, Double.valueOf(d));
    }

    public double getNetworkParamValue(NetworkParam networkParam) {
        return this.networkParamValues.get(networkParam).doubleValue();
    }

    public void setHostUrl(String str) {
        this.hostUrl = str;
    }

    public void setHostKey(File file) {
        this.hostKey = file;
    }

    public String getHostUrl() {
        return this.hostUrl;
    }

    public File getHostKey() {
        return this.hostKey;
    }
}
