M
- Type of model.L
- Type of a label (truth or prediction).K
- Type of a key in upstream
data.V
- Type of a value in upstream
data.public abstract class AbstractCrossValidation<M extends IgniteModel<Vector,L>,L,K,V> extends Object
Modifier and Type | Class and Description |
---|---|
static class |
AbstractCrossValidation.TaskResult
Represents the scores and map of parameters.
|
Modifier and Type | Field and Description |
---|---|
protected int |
amountOfFolds
Amount of folds.
|
protected LearningEnvironmentBuilder |
envBuilder
Learning environment builder.
|
protected LearningEnvironment |
environment
Learning Environment.
|
protected IgniteBiPredicate<K,V> |
filter
Filter.
|
protected boolean |
isRunningOnPipeline
Execution over the pipeline or the chain of preprocessors and separate trainer, otherwise.
|
protected UniformMapper<K,V> |
mapper
Mapper.
|
protected Metric<L> |
metric
Metric.
|
protected ParamGrid |
paramGrid
Parameter grid.
|
protected int |
parts
Parts.
|
protected Pipeline<K,V,Integer,Double> |
pipeline
Pipeline.
|
protected Preprocessor<K,V> |
preprocessor
Preprocessor.
|
protected DatasetTrainer<M,L> |
trainer
Trainer.
|
Constructor and Description |
---|
AbstractCrossValidation() |
protected LearningEnvironmentBuilder envBuilder
protected LearningEnvironment environment
protected DatasetTrainer<M extends IgniteModel<Vector,L>,L> trainer
protected Preprocessor<K,V> preprocessor
protected IgniteBiPredicate<K,V> filter
protected int amountOfFolds
protected int parts
protected ParamGrid paramGrid
protected boolean isRunningOnPipeline
protected UniformMapper<K,V> mapper
public CrossValidationResult tuneHyperParamterers()
public abstract double[] scoreByFolds()
protected double[] score(Function<IgniteBiPredicate<K,V>,DatasetBuilder<K,V>> datasetBuilderSupplier, BiFunction<IgniteBiPredicate<K,V>,M,LabelPairCursor<L>> testDataIterSupplier)
datasetBuilderSupplier
- Dataset builder supplier.testDataIterSupplier
- Test data iterator supplier.protected double[] scorePipeline(Function<IgniteBiPredicate<K,V>,DatasetBuilder<K,V>> datasetBuilderSupplier, BiFunction<IgniteBiPredicate<K,V>,M,LabelPairCursor<L>> testDataIterSupplier)
datasetBuilderSupplier
- Dataset builder supplier.testDataIterSupplier
- Test data iterator supplier.public AbstractCrossValidation<M,L,K,V> withTrainer(DatasetTrainer<M,L> trainer)
trainer
- Trainer.public AbstractCrossValidation<M,L,K,V> withMetric(Metric<L> metric)
metric
- Metric.public AbstractCrossValidation<M,L,K,V> withPreprocessor(Preprocessor<K,V> preprocessor)
preprocessor
- Preprocessor.public AbstractCrossValidation<M,L,K,V> withFilter(IgniteBiPredicate<K,V> filter)
filter
- Filter.public AbstractCrossValidation<M,L,K,V> withAmountOfFolds(int amountOfFolds)
amountOfFolds
- Amount of folds.public AbstractCrossValidation<M,L,K,V> withParamGrid(ParamGrid paramGrid)
paramGrid
- Parameter grid.public AbstractCrossValidation<M,L,K,V> isRunningOnPipeline(boolean runningOnPipeline)
runningOnPipeline
- Running on pipeline.public AbstractCrossValidation<M,L,K,V> withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder)
envBuilder
- Learning environment builder.public AbstractCrossValidation<M,L,K,V> withPipeline(Pipeline<K,V,Integer,Double> pipeline)
pipeline
- Pipeline.public AbstractCrossValidation<M,L,K,V> withMapper(UniformMapper<K,V> mapper)
mapper
- Mapper.
GridGain In-Memory Computing Platform : ver. 8.9.14 Release Date : November 5 2024