package jsat.parameters;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.evaluation.Accuracy;
import jsat.classifiers.evaluation.ClassificationScore;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.regression.Regressor;
import jsat.regression.evaluation.MeanSquaredError;
import jsat.regression.evaluation.RegressionScore;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/parameters/ModelSearch.class */
public abstract class ModelSearch implements Classifier, Regressor {
    protected Classifier baseClassifier;
    protected Classifier trainedClassifier;
    protected Regressor baseRegressor;
    protected Regressor trainedRegressor;
    protected List<Parameter> searchParams;
    protected int folds;
    protected ClassificationScore classificationTargetScore = new Accuracy();
    protected RegressionScore regressionTargetScore = new MeanSquaredError(true);
    protected boolean trainModelsInParallel = true;
    protected boolean trainFinalModel = true;
    protected boolean reuseSameCVFolds = true;

    public ModelSearch(Regressor regressor, int i) {
        if (!(regressor instanceof Parameterized)) {
            throw new FailedToFitException("Given regressor does not support parameterized alterations");
        }
        this.baseRegressor = regressor;
        if (regressor instanceof Classifier) {
            this.baseClassifier = (Classifier) regressor;
        }
        this.searchParams = new ArrayList();
        this.folds = i;
    }

    public ModelSearch(Classifier classifier, int i) {
        if (!(classifier instanceof Parameterized)) {
            throw new FailedToFitException("Given classifier does not support parameterized alterations");
        }
        this.baseClassifier = classifier;
        if (classifier instanceof Regressor) {
            this.baseRegressor = (Regressor) classifier;
        }
        this.searchParams = new ArrayList();
        this.folds = i;
    }

    public ModelSearch(ModelSearch modelSearch) {
        if (modelSearch.baseClassifier != null) {
            this.baseClassifier = modelSearch.baseClassifier.mo509clone();
            if (this.baseClassifier instanceof Regressor) {
                this.baseRegressor = (Regressor) this.baseClassifier;
            }
        } else {
            this.baseRegressor = modelSearch.baseRegressor.mo509clone();
            if (this.baseRegressor instanceof Classifier) {
                this.baseClassifier = (Classifier) this.baseRegressor;
            }
        }
        if (modelSearch.trainedClassifier != null) {
            this.trainedClassifier = modelSearch.trainedClassifier.mo509clone();
        }
        if (modelSearch.trainedRegressor != null) {
            this.trainedRegressor = modelSearch.trainedRegressor.mo509clone();
        }
        this.searchParams = new ArrayList();
        Iterator<Parameter> it = modelSearch.searchParams.iterator();
        while (it.hasNext()) {
            this.searchParams.add(getParameterByName(it.next().getName()));
        }
        this.folds = modelSearch.folds;
    }

    public void setTrainModelsInParallel(boolean z) {
        this.trainModelsInParallel = z;
    }

    public boolean isTrainModelsInParallel() {
        return this.trainModelsInParallel;
    }

    public void setTrainFinalModel(boolean z) {
        this.trainFinalModel = z;
    }

    public boolean isTrainFinalModel() {
        return this.trainFinalModel;
    }

    public void setReuseSameCVFolds(boolean z) {
        this.reuseSameCVFolds = z;
    }

    public boolean isReuseSameCVFolds() {
        return this.reuseSameCVFolds;
    }

    public Classifier getBaseClassifier() {
        return this.baseClassifier;
    }

    public Classifier getTrainedClassifier() {
        return this.trainedClassifier;
    }

    public Regressor getBaseRegressor() {
        return this.baseRegressor;
    }

    public Regressor getTrainedRegressor() {
        return this.trainedRegressor;
    }

    public void setClassificationTargetScore(ClassificationScore classificationScore) {
        this.classificationTargetScore = classificationScore;
    }

    public ClassificationScore getClassificationTargetScore() {
        return this.classificationTargetScore;
    }

    public void setRegressionTargetScore(RegressionScore regressionScore) {
        this.regressionTargetScore = regressionScore;
    }

    public RegressionScore getRegressionTargetScore() {
        return this.regressionTargetScore;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Parameter getParameterByName(String str) throws IllegalArgumentException {
        Parameter parameter = this.baseClassifier != null ? ((Parameterized) this.baseClassifier).getParameter(str) : ((Parameterized) this.baseRegressor).getParameter(str);
        if (parameter == null) {
            throw new IllegalArgumentException("Parameter " + str + " does not exist");
        }
        return parameter;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.trainedClassifier == null) {
            throw new UntrainedModelException("Model has not yet been trained");
        }
        return this.trainedClassifier.classify(dataPoint);
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        if (this.trainedRegressor == null) {
            throw new UntrainedModelException("Model has not yet been trained");
        }
        return this.trainedRegressor.regress(dataPoint);
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return this.baseClassifier != null ? this.baseClassifier.supportsWeightedData() : this.baseRegressor.supportsWeightedData();
    }

    @Override // jsat.regression.Regressor
    /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public abstract ModelSearch mo509clone();
}
