package bi.kevin;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.JsonArray;
import com.google.gson.JsonDeserializationContext;
import com.google.gson.JsonDeserializer;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParseException;
import com.google.gson.JsonParser;
import com.google.gson.JsonSerializationContext;
import com.google.gson.JsonSerializer;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.nd4j.linalg.api.buffer.DoubleBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cpu.nativecpu.NDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.lossfunctions.LossFunctions;

/* loaded from: input_file:bi/kevin/LocalNetwork.class */
public class LocalNetwork {
    private DataSet testData;
    private DataSet trainData;
    private int iterations;
    private int numInputs;
    private int numOutputs;
    private Layer[] layerArray;
    private boolean isClassification;
    private MultiLayerNetwork trainedModel;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:bi/kevin/LocalNetwork$ModelState.class */
    public class ModelState {
        private MultiLayerConfiguration conf;
        private double[] params;

        ModelState(MultiLayerConfiguration multiLayerConfiguration, double[] dArr) {
            this.conf = multiLayerConfiguration;
            this.params = dArr;
        }

        public MultiLayerConfiguration getConf() {
            return this.conf;
        }

        public double[] getParams() {
            return this.params;
        }
    }

    /* loaded from: input_file:bi/kevin/LocalNetwork$ModelStateDeserializer.class */
    private class ModelStateDeserializer implements JsonDeserializer<ModelState> {
        private ModelStateDeserializer() {
        }

        /* renamed from: deserialize, reason: merged with bridge method [inline-methods] */
        public ModelState m2929deserialize(JsonElement jsonElement, Type type, JsonDeserializationContext jsonDeserializationContext) throws JsonParseException {
            JsonObject asJsonObject = jsonElement.getAsJsonObject().get("conf").getAsJsonObject();
            Iterator it = jsonElement.getAsJsonObject().get("params").getAsJsonArray().iterator();
            ArrayList arrayList = new ArrayList();
            while (it.hasNext()) {
                arrayList.add(Double.valueOf(((JsonElement) it.next()).getAsDouble()));
            }
            return new ModelState(MultiLayerConfiguration.fromJson(asJsonObject.toString()), ArrayUtils.toPrimitive((Double[]) arrayList.toArray(new Double[arrayList.size()])));
        }
    }

    /* loaded from: input_file:bi/kevin/LocalNetwork$ModelStateSerializer.class */
    private class ModelStateSerializer implements JsonSerializer<ModelState> {
        private ModelStateSerializer() {
        }

        public JsonElement serialize(ModelState modelState, Type type, JsonSerializationContext jsonSerializationContext) {
            JsonObject asJsonObject = new JsonParser().parse(modelState.getConf().toJson()).getAsJsonObject();
            JsonArray asJsonArray = new JsonParser().parse(new Gson().toJson(modelState.getParams())).getAsJsonArray();
            JsonObject jsonObject = new JsonObject();
            jsonObject.add("conf", asJsonObject);
            jsonObject.add("params", asJsonArray);
            return jsonObject;
        }
    }

    public LocalNetwork(DataSet dataSet, double d, int i, Layer[] layerArr, boolean z) throws Exception {
        SplitTestAndTrain splitTestAndTrain = dataSet.splitTestAndTrain(d);
        this.testData = splitTestAndTrain.getTest();
        this.trainData = splitTestAndTrain.getTrain();
        this.iterations = i;
        this.layerArray = layerArr;
        this.numInputs = dataSet.numInputs();
        this.numOutputs = dataSet.numOutcomes();
        this.isClassification = z;
        this.trainedModel = new MultiLayerNetwork(buildNet());
    }

    public LocalNetwork(DataSet dataSet, double d, String str) {
        SplitTestAndTrain splitTestAndTrain = dataSet.splitTestAndTrain(d);
        this.testData = splitTestAndTrain.getTest();
        this.trainData = splitTestAndTrain.getTrain();
        ModelState modelState = (ModelState) new GsonBuilder().serializeSpecialFloatingPointValues().registerTypeAdapter(ModelState.class, new ModelStateDeserializer()).create().fromJson(str, ModelState.class);
        this.trainedModel = new MultiLayerNetwork(modelState.getConf().toJson(), new NDArray(new DoubleBuffer(modelState.getParams())));
        this.numInputs = this.trainedModel.getLayer(0).getParam("W").rows();
        this.numOutputs = this.trainedModel.getLayer(this.trainedModel.getLayers().length - 1).getParam("W").columns();
    }

    private String getActivation(Layer layer) throws Exception {
        int layerType = layer.getLayerType();
        if (layerType == 0) {
            return "sigmoid";
        }
        if (layerType == 1) {
            return "tanh";
        }
        if (layerType == 2) {
            return "relu";
        }
        throw new Exception("Tried to get activation of a layer without an activation.");
    }

    private MultiLayerConfiguration buildNet() throws Exception {
        NeuralNetConfiguration.ListBuilder list = new NeuralNetConfiguration.Builder().iterations(this.iterations).weightInit(WeightInit.XAVIER).regularization(true).l2(1.0E-4d).updater(Updater.ADADELTA).rho(0.95d).list();
        int i = 0;
        for (int i2 = 1; i2 < this.layerArray.length - 1; i2++) {
            Layer layer = this.layerArray[i2 - 1];
            Layer layer2 = this.layerArray[i2];
            if (layer2.getPhase() != 2 && layer2.getLayerType() != 4) {
                int neurons = layer.getNeurons();
                if (layer.getLayerType() == 4) {
                    neurons = this.numInputs;
                }
                int i3 = i;
                i++;
                list.layer(i3, new DenseLayer.Builder().nIn(neurons).nOut(layer2.getNeurons()).activation(getActivation(layer2)).build());
            }
        }
        Layer layer3 = this.layerArray[this.layerArray.length - 1];
        if (this.isClassification) {
            list.layer(i, new OutputLayer.Builder(LossFunctions.LossFunction.XENT).activation(getActivation(layer3)).nIn(this.layerArray[this.layerArray.length - 2].getNeurons()).nOut(this.numOutputs).build());
        } else {
            list.layer(i, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(getActivation(layer3)).nIn(this.layerArray[this.layerArray.length - 2].getNeurons()).nOut(this.numOutputs).build());
        }
        return list.pretrain(false).backprop(true).build();
    }

    public ArrayList<ModelInfo> trainModel(IterationListener... iterationListenerArr) throws Exception {
        ArrayList<ModelInfo> arrayList = new ArrayList<>();
        this.trainedModel.init();
        this.trainedModel.setListeners(iterationListenerArr);
        this.trainedModel.fit(this.trainData);
        return arrayList;
    }

    public Evaluation getEval() {
        Evaluation evaluation = new Evaluation(this.numOutputs);
        evaluation.eval(this.testData.getLabels(), this.trainedModel.output(this.testData.getFeatureMatrix()));
        return evaluation;
    }

    public MultiLayerNetwork getTrainedModel() {
        return this.trainedModel;
    }

    public DataSet getTestData() {
        return this.testData;
    }

    public DataSet getTrainData() {
        return this.trainData;
    }

    public ArrayList<INDArray> getTrainImportance() {
        ArrayList<INDArray> networkWeights = getNetworkWeights();
        ArrayList<INDArray> arrayList = new ArrayList<>();
        for (int i = 0; i < this.trainData.numExamples(); i++) {
            arrayList.add(getDatumImportance(networkWeights, this.trainData.get(i).getFeatures()));
        }
        return arrayList;
    }

    public ArrayList<INDArray> getTestImportance() {
        ArrayList<INDArray> networkWeights = getNetworkWeights();
        ArrayList<INDArray> arrayList = new ArrayList<>();
        for (int i = 0; i < this.testData.numExamples(); i++) {
            arrayList.add(getDatumImportance(networkWeights, this.testData.get(i).getFeatures()));
        }
        return arrayList;
    }

    public ArrayList<INDArray> getImportance(DataSet dataSet) {
        ArrayList<INDArray> networkWeights = getNetworkWeights();
        ArrayList<INDArray> arrayList = new ArrayList<>();
        for (int i = 0; i < this.trainData.numExamples(); i++) {
            arrayList.add(getDatumImportance(networkWeights, dataSet.get(i).getFeatures()));
        }
        return arrayList;
    }

    private INDArray getDatumImportance(ArrayList<INDArray> arrayList, INDArray iNDArray) {
        ArrayList arrayList2 = new ArrayList();
        INDArray output = this.trainedModel.output(iNDArray);
        for (BaseLayer baseLayer : this.trainedModel.getLayers()) {
            arrayList2.add(baseLayer.getInput());
        }
        arrayList2.add(output);
        ArrayList<INDArray> refActivations = getRefActivations();
        ArrayList arrayList3 = new ArrayList();
        for (int i = 0; i < refActivations.size(); i++) {
            arrayList3.add(((INDArray) arrayList2.get(i)).sub(refActivations.get(i)));
        }
        INDArray repeat = ((INDArray) arrayList3.get(arrayList3.size() - 1)).repeat(0, new int[]{arrayList.get(arrayList.size() - 1).rows()});
        repeat.muli(arrayList.get(arrayList.size() - 1));
        for (int size = arrayList3.size() - 2; size > 0; size--) {
            repeat.muli(((INDArray) arrayList3.get(size)).repeat(1, new int[]{repeat.columns()}));
            repeat = arrayList.get(size - 1).mmul(repeat);
        }
        repeat.muli(((INDArray) arrayList3.get(0)).repeat(1, new int[]{repeat.columns()}));
        return repeat;
    }

    private ArrayList<INDArray> getNetworkWeights() {
        ArrayList<INDArray> arrayList = new ArrayList<>();
        for (org.deeplearning4j.nn.api.Layer layer : this.trainedModel.getLayers()) {
            arrayList.add(layer.getParam("W"));
        }
        return arrayList;
    }

    private ArrayList<INDArray> getRefActivations() {
        ArrayList<INDArray> arrayList = new ArrayList<>();
        double[] dArr = new double[this.numInputs];
        for (int i = 0; i < this.numInputs; i++) {
            dArr[i] = 0.0d;
        }
        INDArray output = this.trainedModel.output(new NDArray(new DoubleBuffer(dArr)));
        for (BaseLayer baseLayer : this.trainedModel.getLayers()) {
            arrayList.add(baseLayer.getInput());
        }
        arrayList.add(output);
        return arrayList;
    }

    public String toJson() {
        return new GsonBuilder().serializeSpecialFloatingPointValues().registerTypeAdapter(ModelState.class, new ModelStateSerializer()).create().toJson(new ModelState(this.trainedModel.getLayerWiseConfigurations(), this.trainedModel.params().data().asDouble()));
    }
}
