package bi.kevin;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:bi/kevin/DataFormatter.class */
public class DataFormatter {
    private int[] labelIndicies;
    private Collection<? extends Collection<Double>> collection;

    public DataFormatter(int[] iArr, Collection<? extends Collection<Double>> collection) throws Exception {
        if (collection.size() < 1 || collection.iterator().next().size() < 1) {
            throw new Exception("Size of all collections must be greater than 0.");
        }
        if (iArr.length > 0) {
            for (int i : iArr) {
                if (i < 0 || i >= collection.iterator().next().size()) {
                    throw new Exception("Label indicies not contained in the dataset");
                }
            }
        }
        this.labelIndicies = iArr;
        this.collection = collection;
    }

    public DataSet getAllData() {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        Iterator<? extends Collection<Double>> it = this.collection.iterator();
        while (it.hasNext()) {
            DataSet oneDatum = getOneDatum(it.next());
            arrayList.add(oneDatum.getFeatureMatrix());
            arrayList2.add(oneDatum.getLabels());
        }
        return new DataSet(Nd4j.vstack((INDArray[]) arrayList.toArray(new INDArray[0])), Nd4j.vstack((INDArray[]) arrayList2.toArray(new INDArray[0])));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v33, types: [java.util.List] */
    private DataSet getOneDatum(Collection<Double> collection) {
        ArrayList arrayList = collection instanceof List ? (List) collection : new ArrayList(collection);
        INDArray create = Nd4j.create(this.labelIndicies.length);
        INDArray create2 = Nd4j.create(collection.size() - this.labelIndicies.length);
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            boolean z = false;
            for (int i4 : this.labelIndicies) {
                if (i3 == i4) {
                    z = true;
                }
            }
            if (z) {
                int i5 = i2;
                i2++;
                create.putScalar(i5, ((Double) arrayList.get(i3)).doubleValue());
            } else {
                int i6 = i;
                i++;
                create2.putScalar(i6, ((Double) arrayList.get(i3)).doubleValue());
            }
        }
        return new DataSet(create2, create);
    }

    public DataSet getAllDataNormalized() {
        DataSet allData = getAllData();
        NormalizerStandardize normalizerStandardize = new NormalizerStandardize();
        normalizerStandardize.fit(allData);
        normalizerStandardize.transform(allData);
        return allData;
    }
}
