public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition,Double>
DecisionTreeRegressionTrainer
, KNNRegressionTrainer
,
LinearRegressionLSQRTrainer
, RandomForestRegressionTrainer
, LinearRegressionSGDTrainer
.
But in practice Decision Trees is most used regressors (see: DecisionTreeRegressionTrainer
).Modifier and Type | Class and Description |
---|---|
static class |
GDBTrainer.GDBModel
GDB model.
|
DatasetTrainer.EmptyDatasetException
Modifier and Type | Field and Description |
---|---|
protected ConvergenceCheckerFactory |
checkConvergenceStgyFactory
Check convergence strategy factory.
|
protected Loss |
loss
Loss function.
|
environment
Constructor and Description |
---|
GDBTrainer(double gradStepSize,
Integer cntOfIterations,
Loss loss)
Constructs GDBTrainer instance.
|
Modifier and Type | Method and Description |
---|---|
protected abstract @NotNull DatasetTrainer<? extends Model<Vector,Double>,Double> |
buildBaseModelTrainer()
Returns regressor model trainer for one step of GDB.
|
protected boolean |
checkState(ModelsComposition mdl) |
protected <V,K> IgniteBiTuple<Double,Long> |
computeInitialValue(DatasetBuilder<K,V> builder,
IgniteBiFunction<K,V,Vector> featureExtractor,
IgniteBiFunction<K,V,Double> lbExtractor)
Compute mean value of label as first approximation.
|
protected abstract double |
externalLabelToInternal(double lbl)
Maps external representation of label to internal.
|
<K,V> ModelsComposition |
fit(DatasetBuilder<K,V> datasetBuilder,
IgniteBiFunction<K,V,Vector> featureExtractor,
IgniteBiFunction<K,V,Double> lbExtractor)
Trains model based on the specified data.
|
protected GDBLearningStrategy |
getLearningStrategy()
Returns learning strategy.
|
protected abstract double |
internalLabelToExternal(double lbl)
Maps internal representation of label to external.
|
protected abstract <V,K> boolean |
learnLabels(DatasetBuilder<K,V> builder,
IgniteBiFunction<K,V,Vector> featureExtractor,
IgniteBiFunction<K,V,Double> lExtractor)
Defines unique labels in dataset if need (useful in case of classification).
|
protected <K,V> ModelsComposition |
updateModel(ModelsComposition mdl,
DatasetBuilder<K,V> datasetBuilder,
IgniteBiFunction<K,V,Vector> featureExtractor,
IgniteBiFunction<K,V,Double> lbExtractor)
Gets state of model in arguments, update in according to new data and return new model.
|
GDBTrainer |
withCheckConvergenceStgyFactory(ConvergenceCheckerFactory factory)
Sets CheckConvergenceStgyFactory.
|
fit, fit, fit, fit, getLastTrainedModelOrThrowEmptyDatasetException, setEnvironment, update, update, update, update, update
protected final Loss loss
protected ConvergenceCheckerFactory checkConvergenceStgyFactory
public GDBTrainer(double gradStepSize, Integer cntOfIterations, Loss loss)
gradStepSize
- Grad step size.cntOfIterations
- Count of learning iterations.loss
- Gradient of loss function. First argument is sample size, second argument is valid answer
third argument is current model prediction.public <K,V> ModelsComposition fit(DatasetBuilder<K,V> datasetBuilder, IgniteBiFunction<K,V,Vector> featureExtractor, IgniteBiFunction<K,V,Double> lbExtractor)
fit
in class DatasetTrainer<ModelsComposition,Double>
K
- Type of a key in upstream
data.V
- Type of a value in upstream
data.datasetBuilder
- Dataset builder.featureExtractor
- Feature extractor.lbExtractor
- Label extractor.protected <K,V> ModelsComposition updateModel(ModelsComposition mdl, DatasetBuilder<K,V> datasetBuilder, IgniteBiFunction<K,V,Vector> featureExtractor, IgniteBiFunction<K,V,Double> lbExtractor)
updateModel
in class DatasetTrainer<ModelsComposition,Double>
K
- Type of a key in upstream
data.V
- Type of a value in upstream
data.mdl
- Learned model.datasetBuilder
- Dataset builder.featureExtractor
- Feature extractor.lbExtractor
- Label extractor.protected boolean checkState(ModelsComposition mdl)
checkState
in class DatasetTrainer<ModelsComposition,Double>
mdl
- Model.protected abstract <V,K> boolean learnLabels(DatasetBuilder<K,V> builder, IgniteBiFunction<K,V,Vector> featureExtractor, IgniteBiFunction<K,V,Double> lExtractor)
builder
- Dataset builder.featureExtractor
- Feature extractor.lExtractor
- Labels extractor.@NotNull protected abstract @NotNull DatasetTrainer<? extends Model<Vector,Double>,Double> buildBaseModelTrainer()
protected abstract double externalLabelToInternal(double lbl)
lbl
- Label value.protected abstract double internalLabelToExternal(double lbl)
lbl
- Label value.protected <V,K> IgniteBiTuple<Double,Long> computeInitialValue(DatasetBuilder<K,V> builder, IgniteBiFunction<K,V,Vector> featureExtractor, IgniteBiFunction<K,V,Double> lbExtractor)
builder
- Dataset builder.featureExtractor
- Feature extractor.lbExtractor
- Label extractor.public GDBTrainer withCheckConvergenceStgyFactory(ConvergenceCheckerFactory factory)
factory
- Factory.protected GDBLearningStrategy getLearningStrategy()
Follow @ApacheIgnite
Ignite Database and Caching Platform : ver. 2.7.2 Release Date : February 6 2019