package oliver.ui.logicdialog;

import ij.measure.CurveFitter;
import java.awt.Color;
import java.awt.image.BufferedImage;
import java.text.MessageFormat;
import java.util.Map;
import javax.swing.BoxLayout;
import javax.swing.ImageIcon;
import javax.swing.JButton;
import javax.swing.JLabel;
import javax.swing.JMenu;
import javax.swing.JMenuBar;
import javax.swing.JMenuItem;
import javax.swing.JOptionPane;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.JTabbedPane;
import oliver.logic.impl.TrainedNeuralNetwork;
import oliver.map.Heatmap;
import oliver.neuralnetwork.TrainedNetworkFromArchive;
import oliver.statistics.BasicStats;
import oliver.ui.components.ServerHostWithKeySettingsPanel;
import oliver.ui.components.TryWithErrorDialog;
import oliver.ui.mapeditor.HeatmapEditorUi;
import oliver.ui.workspace.HmWorkspace;
import org.fife.ui.rsyntaxtextarea.RSyntaxTextArea;
import org.math.plot.Plot2DPanel;

/* loaded from: input_file:oliver/ui/logicdialog/TrainedNeuralNetworkDialogUi.class */
public class TrainedNeuralNetworkDialogUi extends LogicDialog<TrainedNeuralNetwork> {
    private static final String[] linePlotOutputs = {"train_loss.out", "test_loss.out", "rsquared.out"};
    private final ServerHostWithKeySettingsPanel jpHostSettings;

    public TrainedNeuralNetworkDialogUi(TrainedNeuralNetwork trainedNeuralNetwork, Heatmap heatmap, HeatmapEditorUi heatmapEditorUi, HmWorkspace hmWorkspace) throws Exception {
        super(trainedNeuralNetwork, "Trained Neural Network Results", heatmapEditorUi, hmWorkspace);
        trainedNeuralNetwork.parseResults();
        Map<String, BufferedImage> images = trainedNeuralNetwork.getImages();
        Map<String, double[][]> series = trainedNeuralNetwork.getSeries();
        JTabbedPane jTabbedPane = new JTabbedPane();
        jTabbedPane.addTab("Inputs", new JScrollPane(new RSyntaxTextArea(buildColumnHeadersText())));
        for (String str : images.keySet()) {
            try {
                jTabbedPane.addTab("(image) " + str, new JScrollPane(new JLabel(new ImageIcon(images.get(str)))));
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        for (String str2 : linePlotOutputs) {
            try {
                jTabbedPane.addTab("(series) " + str2, buildLinePlotPanel(series.get(str2)));
            } catch (Exception e2) {
                new Exception(MessageFormat.format("Could not create lineplot panel for output file \"{0}\"", str2), e2).printStackTrace();
            }
        }
        Map<String, String> text = trainedNeuralNetwork.getText();
        for (String str3 : text.keySet()) {
            try {
                jTabbedPane.addTab("(text) " + str3, new JScrollPane(new RSyntaxTextArea(text.get(str3))));
            } catch (Exception e3) {
                e3.printStackTrace();
            }
        }
        if (trainedNeuralNetwork instanceof TrainedNetworkFromArchive) {
            TrainedNetworkFromArchive trainedNetworkFromArchive = (TrainedNetworkFromArchive) trainedNeuralNetwork;
            this.jpHostSettings = new ServerHostWithKeySettingsPanel(trainedNetworkFromArchive.getServerUrl(), trainedNetworkFromArchive.getServerKey());
        } else {
            this.jpHostSettings = new ServerHostWithKeySettingsPanel();
        }
        jTabbedPane.addTab("Prediction Host", new JScrollPane(this.jpHostSettings));
        JButton jButton = new JButton("Predict labels for " + heatmap.getTitle());
        jButton.addActionListener(actionEvent -> {
            tryAddPredictionsToHeatmap(heatmap);
        });
        JPanel jPanel = new JPanel();
        jPanel.setLayout(new BoxLayout(jPanel, 1));
        jPanel.add(jTabbedPane);
        jPanel.add(jButton);
        setContentPane(jPanel);
        setSize(CurveFitter.IterFactor, CurveFitter.IterFactor);
        if (trainedNeuralNetwork instanceof TrainedNetworkFromArchive) {
            JMenuBar jMenuBar = new JMenuBar();
            JMenu jMenu = new JMenu("File");
            JMenuItem jMenuItem = new JMenuItem("Export trained network to file...");
            jMenuItem.addActionListener(actionEvent2 -> {
                TryWithErrorDialog.tryWithErrorDialog(() -> {
                    ((TrainedNetworkFromArchive) trainedNeuralNetwork).saveResultsToFile();
                }, this, "Error saving trained neural network");
            });
            jMenu.add(jMenuItem);
            jMenuBar.add(jMenu);
            setJMenuBar(jMenuBar);
        }
    }

    private void tryAddPredictionsToHeatmap(Heatmap heatmap) {
        TryWithErrorDialog.tryWithErrorDialog(() -> {
            if (this.logic instanceof TrainedNetworkFromArchive) {
                TrainedNetworkFromArchive trainedNetworkFromArchive = (TrainedNetworkFromArchive) this.logic;
                trainedNetworkFromArchive.setServerUrl(this.jpHostSettings.getServerUrl());
                trainedNetworkFromArchive.setServerKey(this.jpHostSettings.getKeyFile());
            }
            ((TrainedNeuralNetwork) this.logic).addPredictionsToHeatmap(heatmap, "Prediction");
            this.hmeParent.addExtraColumnDisplayPanel("Prediction");
            JOptionPane.showMessageDialog(this, MessageFormat.format("Added extra column \"{0}\" to heatmap \"{1}\"", "Prediction", heatmap.getTitle()), "Prediction Successful", 1);
        }, this, "Error doing prediction");
    }

    private String buildColumnHeadersText() {
        String[] inputColumnHeaders = ((TrainedNeuralNetwork) this.logic).getInputColumnHeaders();
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < inputColumnHeaders.length; i++) {
            sb.append(i).append(" - ").append(inputColumnHeaders[i]).append("\n");
        }
        return sb.toString();
    }

    private double[] getSingleSeries(double[][] dArr) throws Exception {
        if (dArr == null) {
            throw new Exception("series was null");
        }
        if (dArr.length != 1) {
            throw new Exception("Output contains " + dArr.length + " series, expecting 1");
        }
        return dArr[0];
    }

    /* JADX WARN: Type inference failed for: r0v9, types: [double[], double[][]] */
    private JPanel buildLinePlotPanel(double[][] dArr) throws Exception {
        double[] singleSeries = getSingleSeries(dArr);
        int length = singleSeries.length;
        double[] dArr2 = new double[length];
        for (int i = 0; i < length; i++) {
            dArr2[i] = i;
        }
        double[][] withoutNaNColumns = BasicStats.getWithoutNaNColumns(new double[]{dArr2, singleSeries});
        double[] dArr3 = withoutNaNColumns[0];
        double[] dArr4 = withoutNaNColumns[1];
        Color color = Color.BLACK;
        Plot2DPanel plot2DPanel = new Plot2DPanel();
        plot2DPanel.setAxisLabels("", "");
        plot2DPanel.setAutoBounds();
        plot2DPanel.addLinePlot("", color, dArr3, dArr4);
        return plot2DPanel;
    }
}
