package bi.kevin;

import bi.kevin.Caffe;
import com.google.protobuf.TextFormat;
import java.util.ArrayList;
import java.util.Iterator;

/* loaded from: input_file:bi/kevin/NetGenerator.class */
public class NetGenerator {
    private Layer[] netLayers;
    private String userDir;

    public NetGenerator(Layer[] layerArr, String str) {
        this.netLayers = new Layer[0];
        this.userDir = "";
        this.netLayers = layerArr;
        this.userDir = str;
    }

    public String createNet(Boolean bool) {
        Layer[] layerArr = this.netLayers;
        ArrayList arrayList = new ArrayList();
        int i = 0;
        for (Layer layer : layerArr) {
            if (layer.getLayerType() == 4) {
                i++;
            }
        }
        if (i == 0) {
            System.err.println("Requires at least one data layer");
            return "ERROR";
        }
        if (i > 2) {
            System.err.println("Can't have more than two data layers");
            return "ERROR";
        }
        if (i == 2 && (layerArr[0].getPhase() == 0 || layerArr[1].getPhase() == 0)) {
            System.err.println("Phases must be defined for two data layers");
            return "ERROR";
        }
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(dataLayerBuilder(layerArr[i2], bool));
        }
        for (int i3 = i; i3 < layerArr.length - 1; i3++) {
            Caffe.LayerParameter.Builder[] hiddenLayerBuilder = hiddenLayerBuilder(layerArr[i3], (i3 - i) + 1);
            hiddenLayerBuilder[0].addBottom(((Caffe.LayerParameter) arrayList.get(arrayList.size() - 1)).getTop(0));
            for (Caffe.LayerParameter.Builder builder : hiddenLayerBuilder) {
                arrayList.add(builder.m1421build());
            }
        }
        Caffe.LayerParameter.Builder[] outputLayerBuilder = outputLayerBuilder(layerArr[layerArr.length - 1], bool);
        outputLayerBuilder[0].addBottom(((Caffe.LayerParameter) arrayList.get(arrayList.size() - 1)).getTop(0));
        for (Caffe.LayerParameter.Builder builder2 : outputLayerBuilder) {
            arrayList.add(builder2.m1421build());
        }
        Caffe.NetParameter.Builder newBuilder = Caffe.NetParameter.newBuilder();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            newBuilder.addLayer((Caffe.LayerParameter) it.next());
        }
        return TextFormat.printToString(newBuilder);
    }

    private Caffe.LayerParameter.Builder[] hiddenLayerBuilder(Layer layer, int i) {
        Caffe.LayerParameter.Builder newBuilder = Caffe.LayerParameter.newBuilder();
        String[] strArr = {"undef", "Sigmoid", "TanH", "ReLU", "HDF5Data"};
        newBuilder.setName("inner" + String.valueOf(i)).setType("InnerProduct").setInnerProductParam(Caffe.InnerProductParameter.newBuilder().setNumOutput(layer.getNeurons()).setWeightFiller(Caffe.FillerParameter.newBuilder().setType("xavier"))).addTop("inner" + String.valueOf(i));
        Caffe.LayerParameter.Builder newBuilder2 = Caffe.LayerParameter.newBuilder();
        newBuilder2.setName(strArr[layer.getLayerType()] + String.valueOf(i)).setType(strArr[layer.getLayerType()]).addTop(strArr[layer.getLayerType()] + String.valueOf(i)).addBottom("inner" + String.valueOf(i));
        Caffe.LayerParameter.Builder[] builderArr = {newBuilder, newBuilder2};
        if (layer.getPhase() == 1) {
            builderArr[0].addInclude(Caffe.NetStateRule.newBuilder().setPhase(Caffe.Phase.TRAIN));
            builderArr[1].addInclude(Caffe.NetStateRule.newBuilder().setPhase(Caffe.Phase.TRAIN));
        } else if (layer.getPhase() == 2) {
            builderArr[0].addInclude(Caffe.NetStateRule.newBuilder().setPhase(Caffe.Phase.TEST));
            builderArr[1].addInclude(Caffe.NetStateRule.newBuilder().setPhase(Caffe.Phase.TEST));
        }
        return builderArr;
    }

    private Caffe.LayerParameter dataLayerBuilder(Layer layer, Boolean bool) {
        Caffe.LayerParameter.Builder newBuilder = Caffe.LayerParameter.newBuilder();
        newBuilder.setName("data").setType("HDF5Data").addTop("data").setHdf5DataParam(Caffe.HDF5DataParameter.newBuilder().setSource(this.userDir + "/" + layer.getDataFile()).setBatchSize(layer.getBatchSize()).m1037build());
        if (!bool.booleanValue()) {
            newBuilder.addTop("label");
        }
        if (layer.getPhase() == 1) {
            newBuilder.addInclude(Caffe.NetStateRule.newBuilder().setPhase(Caffe.Phase.TRAIN));
        } else if (layer.getPhase() == 2) {
            newBuilder.addInclude(Caffe.NetStateRule.newBuilder().setPhase(Caffe.Phase.TEST));
        }
        return newBuilder.m1421build();
    }

    private Caffe.LayerParameter.Builder[] outputLayerBuilder(Layer layer, Boolean bool) {
        Caffe.LayerParameter.Builder newBuilder = Caffe.LayerParameter.newBuilder();
        String[] strArr = {"undef", "Sigmoid", "TanH", "ReLU", "HDF5Data"};
        Caffe.LayerParameter.Builder[] builderArr = bool.booleanValue() ? new Caffe.LayerParameter.Builder[2] : new Caffe.LayerParameter.Builder[4];
        newBuilder.setName("innerBottom").setType("InnerProduct").setInnerProductParam(Caffe.InnerProductParameter.newBuilder().setNumOutput(layer.getNeurons()).setWeightFiller(Caffe.FillerParameter.newBuilder().setType("xavier"))).addTop("innerBottom");
        builderArr[0] = newBuilder;
        Caffe.LayerParameter.Builder newBuilder2 = Caffe.LayerParameter.newBuilder();
        newBuilder2.setName("NeuronBottom").setType(strArr[layer.getLayerType()]).addTop("NeuronBottom").addBottom("innerBottom");
        builderArr[1] = newBuilder2;
        if (!bool.booleanValue()) {
            Caffe.LayerParameter.Builder newBuilder3 = Caffe.LayerParameter.newBuilder();
            newBuilder3.setName("accuracy").setType("Accuracy").addBottom("NeuronBottom").addBottom("label").addTop("accuracy");
            builderArr[2] = newBuilder3;
            Caffe.LayerParameter.Builder newBuilder4 = Caffe.LayerParameter.newBuilder();
            if (layer.getNeurons() == 1) {
                newBuilder4.setName("loss").setType("EuclideanLoss").addBottom("NeuronBottom").addBottom("label").addTop("loss");
                builderArr[3] = newBuilder4;
            } else {
                newBuilder4.setName("loss").setType("SoftmaxWithLoss").addBottom("NeuronBottom").addBottom("label").addTop("loss");
                builderArr[3] = newBuilder4;
            }
            if (layer.getPhase() == 1) {
                builderArr[3].addInclude(Caffe.NetStateRule.newBuilder().setPhase(Caffe.Phase.TRAIN));
            } else if (layer.getPhase() == 2) {
                builderArr[3].addInclude(Caffe.NetStateRule.newBuilder().setPhase(Caffe.Phase.TEST));
            }
        }
        if (layer.getPhase() == 1) {
            builderArr[0].addInclude(Caffe.NetStateRule.newBuilder().setPhase(Caffe.Phase.TRAIN));
            builderArr[1].addInclude(Caffe.NetStateRule.newBuilder().setPhase(Caffe.Phase.TRAIN));
        } else if (layer.getPhase() == 2) {
            builderArr[0].addInclude(Caffe.NetStateRule.newBuilder().setPhase(Caffe.Phase.TEST));
            builderArr[1].addInclude(Caffe.NetStateRule.newBuilder().setPhase(Caffe.Phase.TEST));
        }
        return builderArr;
    }
}
