package bi.kevin;

import java.util.ArrayList;
import java.util.Iterator;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:bi/kevin/Main.class */
public class Main {
    public static void main(String[] strArr) {
        Layer[] layerArr = {new Layer(4, 0, "training_files", 1, 100), new Layer(4, 0, "testing_files", 2, 100), new Layer(2, 35), new Layer(2, 2)};
        NetGenerator netGenerator = new NetGenerator(layerArr, "test");
        System.out.println("Example Neural Network:");
        System.out.println(netGenerator.createNet(false));
        System.out.println(netGenerator.createNet(true));
        SolverGenerator solverGenerator = new SolverGenerator("train_net.prototxt", "test_net.prototxt");
        System.out.println("Example Solver with Defaults");
        System.out.println(solverGenerator.createSolver());
        SolverGenerator solverGenerator2 = new SolverGenerator("train_net.prototxt", "test_net.prototxt", 50, 0.1f, 0.5f, 0.5f, 0.0f, 1.0E-5f, "test");
        System.out.println("Example Solver with Customization");
        System.out.println(solverGenerator2.createSolver());
        try {
            Nd4j.ENFORCE_NUMERICAL_STABILITY = true;
            DataFormatter dataFormatter = new DataFormatter(new int[]{3, 4}, Test.getResult());
            LocalNetwork localNetwork = new LocalNetwork(dataFormatter.getAllDataNormalized(), 0.65d, 100, layerArr, true);
            Iterator<ModelInfo> it = localNetwork.trainModel(new CustomListener(localNetwork.getTrainData(), localNetwork.getTestData(), 10, new ArrayList())).iterator();
            while (it.hasNext()) {
                ModelInfo next = it.next();
                System.out.println(next.getIteration());
                System.out.println(next.getTestScore());
                System.out.println(next.getTrainScore());
            }
            System.out.println(localNetwork.getEval().accuracy());
            MultiLayerNetwork trainedModel = localNetwork.getTrainedModel();
            LocalNetwork localNetwork2 = new LocalNetwork(dataFormatter.getAllDataNormalized(), 0.65d, localNetwork.toJson());
            localNetwork2.getTrainedModel();
            trainedModel.getLayer(1).getParam("W");
            localNetwork2.toJson();
            localNetwork.getTrainImportance();
            localNetwork.getTestImportance();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}
