public class GDBLearningStrategy extends Object
Modifier and Type | Field and Description |
---|---|
protected IgniteSupplier<DatasetTrainer<? extends Model<Vector,Double>,Double>> |
baseMdlTrainerBuilder
Base model trainer builder.
|
protected ConvergenceCheckerFactory |
checkConvergenceStgyFactory
Check convergence strategy factory.
|
protected int |
cntOfIterations
Count of iterations.
|
protected double[] |
compositionWeights
Composition weights.
|
protected LearningEnvironment |
environment
Learning environment.
|
protected IgniteFunction<Double,Double> |
externalLbToInternalMapping
External label to internal mapping.
|
protected Loss |
loss
Loss of gradient.
|
protected double |
meanLbVal
Mean label value.
|
protected long |
sampleSize
Sample size.
|
Constructor and Description |
---|
GDBLearningStrategy() |
Modifier and Type | Method and Description |
---|---|
double[] |
getCompositionWeights() |
double |
getMeanValue() |
protected @NotNull List<Model<Vector,Double>> |
initLearningState(GDBTrainer.GDBModel mdlToUpdate)
Restores state of already learned model if can and sets learning parameters according to this state.
|
<K,V> List<Model<Vector,Double>> |
learnModels(DatasetBuilder<K,V> datasetBuilder,
IgniteBiFunction<K,V,Vector> featureExtractor,
IgniteBiFunction<K,V,Double> lbExtractor)
Implementation of gradient boosting iterations.
|
<K,V> List<Model<Vector,Double>> |
update(GDBTrainer.GDBModel mdlToUpdate,
DatasetBuilder<K,V> datasetBuilder,
IgniteBiFunction<K,V,Vector> featureExtractor,
IgniteBiFunction<K,V,Double> lbExtractor)
Gets state of model in arguments, compare it with training parameters of trainer and if they are fit then
trainer updates model in according to new data and return new model.
|
GDBLearningStrategy |
withBaseModelTrainerBuilder(IgniteSupplier<DatasetTrainer<? extends Model<Vector,Double>,Double>> buildBaseMdlTrainer)
Sets base model builder.
|
GDBLearningStrategy |
withCheckConvergenceStgyFactory(ConvergenceCheckerFactory factory)
Sets CheckConvergenceStgyFactory.
|
GDBLearningStrategy |
withCntOfIterations(int cntOfIterations)
Sets count of iterations.
|
GDBLearningStrategy |
withCompositionWeights(double[] compositionWeights)
Sets composition weights vector.
|
GDBLearningStrategy |
withDefaultGradStepSize(double defaultGradStepSize)
Sets default gradient step size.
|
GDBLearningStrategy |
withEnvironment(LearningEnvironment environment)
Sets learning environment.
|
GDBLearningStrategy |
withExternalLabelToInternal(IgniteFunction<Double,Double> externalLbToInternal)
Sets external to internal label representation mapping.
|
GDBLearningStrategy |
withLossGradient(Loss loss)
Loss function.
|
GDBLearningStrategy |
withMeanLabelValue(double meanLbVal)
Sets mean label value.
|
GDBLearningStrategy |
withSampleSize(long sampleSize)
Sets sample size.
|
protected LearningEnvironment environment
protected int cntOfIterations
protected Loss loss
protected IgniteFunction<Double,Double> externalLbToInternalMapping
protected IgniteSupplier<DatasetTrainer<? extends Model<Vector,Double>,Double>> baseMdlTrainerBuilder
protected double meanLbVal
protected long sampleSize
protected double[] compositionWeights
protected ConvergenceCheckerFactory checkConvergenceStgyFactory
public <K,V> List<Model<Vector,Double>> learnModels(DatasetBuilder<K,V> datasetBuilder, IgniteBiFunction<K,V,Vector> featureExtractor, IgniteBiFunction<K,V,Double> lbExtractor)
datasetBuilder
- Dataset builder.featureExtractor
- Feature extractor.lbExtractor
- Label extractor.public <K,V> List<Model<Vector,Double>> update(GDBTrainer.GDBModel mdlToUpdate, DatasetBuilder<K,V> datasetBuilder, IgniteBiFunction<K,V,Vector> featureExtractor, IgniteBiFunction<K,V,Double> lbExtractor)
K
- Type of a key in upstream
data.V
- Type of a value in upstream
data.mdlToUpdate
- Learned model.datasetBuilder
- Dataset builder.featureExtractor
- Feature extractor.lbExtractor
- Label extractor.@NotNull protected @NotNull List<Model<Vector,Double>> initLearningState(GDBTrainer.GDBModel mdlToUpdate)
mdlToUpdate
- Model to update.public GDBLearningStrategy withEnvironment(LearningEnvironment environment)
environment
- Learning Environment.public GDBLearningStrategy withCntOfIterations(int cntOfIterations)
cntOfIterations
- Count of iterations.public GDBLearningStrategy withLossGradient(Loss loss)
loss
- Loss function.public GDBLearningStrategy withExternalLabelToInternal(IgniteFunction<Double,Double> externalLbToInternal)
externalLbToInternal
- External label to internal.public GDBLearningStrategy withBaseModelTrainerBuilder(IgniteSupplier<DatasetTrainer<? extends Model<Vector,Double>,Double>> buildBaseMdlTrainer)
buildBaseMdlTrainer
- Build base model trainer.public GDBLearningStrategy withMeanLabelValue(double meanLbVal)
meanLbVal
- Mean label value.public GDBLearningStrategy withSampleSize(long sampleSize)
sampleSize
- Sample size.public GDBLearningStrategy withCompositionWeights(double[] compositionWeights)
compositionWeights
- Composition weights.public GDBLearningStrategy withCheckConvergenceStgyFactory(ConvergenceCheckerFactory factory)
factory
- Factory.public GDBLearningStrategy withDefaultGradStepSize(double defaultGradStepSize)
defaultGradStepSize
- Default gradient step size.public double[] getCompositionWeights()
public double getMeanValue()
Follow @ApacheIgnite
Ignite Database and Caching Platform : ver. 2.7.2 Release Date : February 6 2019