public class GDBLearningStrategy extends Object
Modifier and Type | Field and Description |
---|---|
protected IgniteSupplier<DatasetTrainer<? extends IgniteModel<Vector,Double>,Double>> |
baseMdlTrainerBuilder
Base model trainer builder.
|
protected ConvergenceCheckerFactory |
checkConvergenceStgyFactory
Check convergence strategy factory.
|
protected int |
cntOfIterations
Count of iterations.
|
protected double[] |
compositionWeights
Composition weights.
|
protected LearningEnvironmentBuilder |
envBuilder
Learning environment builder.
|
protected IgniteFunction<Double,Double> |
externalLbToInternalMapping
External label to internal mapping.
|
protected Loss |
loss
Loss of gradient.
|
protected double |
meanLbVal
Mean label value.
|
protected long |
sampleSize
Sample size.
|
protected LearningEnvironment |
trainerEnvironment
Learning environment used for trainer.
|
Constructor and Description |
---|
GDBLearningStrategy() |
Modifier and Type | Method and Description |
---|---|
double[] |
getCompositionWeights() |
double |
getMeanValue() |
protected @NotNull List<IgniteModel<Vector,Double>> |
initLearningState(GDBTrainer.GDBModel mdlToUpdate)
Restores state of already learned model if can and sets learning parameters according to this state.
|
<K,V> List<IgniteModel<Vector,Double>> |
learnModels(DatasetBuilder<K,V> datasetBuilder,
Preprocessor<K,V> preprocessor)
Implementation of gradient boosting iterations.
|
<K,V> List<IgniteModel<Vector,Double>> |
update(GDBTrainer.GDBModel mdlToUpdate,
DatasetBuilder<K,V> datasetBuilder,
Preprocessor<K,V> preprocessor)
Gets state of model in arguments, compare it with training parameters of trainer and if they are fit then trainer
updates model in according to new data and return new model.
|
GDBLearningStrategy |
withBaseModelTrainerBuilder(IgniteSupplier<DatasetTrainer<? extends IgniteModel<Vector,Double>,Double>> buildBaseMdlTrainer)
Sets base model builder.
|
GDBLearningStrategy |
withCheckConvergenceStgyFactory(ConvergenceCheckerFactory factory)
Sets CheckConvergenceStgyFactory.
|
GDBLearningStrategy |
withCntOfIterations(int cntOfIterations)
Sets count of iterations.
|
GDBLearningStrategy |
withCompositionWeights(double[] compositionWeights)
Sets composition weights vector.
|
GDBLearningStrategy |
withDefaultGradStepSize(double defaultGradStepSize)
Sets default gradient step size.
|
GDBLearningStrategy |
withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder)
Sets learning environment.
|
GDBLearningStrategy |
withExternalLabelToInternal(IgniteFunction<Double,Double> externalLbToInternal)
Sets external to internal label representation mapping.
|
GDBLearningStrategy |
withLossGradient(Loss loss)
Loss function.
|
GDBLearningStrategy |
withMeanLabelValue(double meanLbVal)
Sets mean label value.
|
GDBLearningStrategy |
withSampleSize(long sampleSize)
Sets sample size.
|
protected LearningEnvironmentBuilder envBuilder
protected LearningEnvironment trainerEnvironment
protected int cntOfIterations
protected Loss loss
protected IgniteFunction<Double,Double> externalLbToInternalMapping
protected IgniteSupplier<DatasetTrainer<? extends IgniteModel<Vector,Double>,Double>> baseMdlTrainerBuilder
protected double meanLbVal
protected long sampleSize
protected double[] compositionWeights
protected ConvergenceCheckerFactory checkConvergenceStgyFactory
public <K,V> List<IgniteModel<Vector,Double>> learnModels(DatasetBuilder<K,V> datasetBuilder, Preprocessor<K,V> preprocessor)
datasetBuilder
- Dataset builder.preprocessor
- Upstream preprocessor.public <K,V> List<IgniteModel<Vector,Double>> update(GDBTrainer.GDBModel mdlToUpdate, DatasetBuilder<K,V> datasetBuilder, Preprocessor<K,V> preprocessor)
K
- Type of a key in upstream
data.V
- Type of a value in upstream
data.mdlToUpdate
- Learned model.datasetBuilder
- Dataset builder.preprocessor
- Upstream preprocessor.@NotNull protected @NotNull List<IgniteModel<Vector,Double>> initLearningState(GDBTrainer.GDBModel mdlToUpdate)
mdlToUpdate
- Model to update.public GDBLearningStrategy withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder)
envBuilder
- Learning Environment.public GDBLearningStrategy withCntOfIterations(int cntOfIterations)
cntOfIterations
- Count of iterations.public GDBLearningStrategy withLossGradient(Loss loss)
loss
- Loss function.public GDBLearningStrategy withExternalLabelToInternal(IgniteFunction<Double,Double> externalLbToInternal)
externalLbToInternal
- External label to internal.public GDBLearningStrategy withBaseModelTrainerBuilder(IgniteSupplier<DatasetTrainer<? extends IgniteModel<Vector,Double>,Double>> buildBaseMdlTrainer)
buildBaseMdlTrainer
- Build base model trainer.public GDBLearningStrategy withMeanLabelValue(double meanLbVal)
meanLbVal
- Mean label value.public GDBLearningStrategy withSampleSize(long sampleSize)
sampleSize
- Sample size.public GDBLearningStrategy withCompositionWeights(double[] compositionWeights)
compositionWeights
- Composition weights.public GDBLearningStrategy withCheckConvergenceStgyFactory(ConvergenceCheckerFactory factory)
factory
- Factory.public GDBLearningStrategy withDefaultGradStepSize(double defaultGradStepSize)
defaultGradStepSize
- Default gradient step size.public double[] getCompositionWeights()
public double getMeanValue()
GridGain In-Memory Computing Platform : ver. 8.9.15 Release Date : December 3 2024