public class LogisticRegressionSGDTrainer extends SingleLabelDatasetTrainer<LogisticRegressionModel>
DatasetTrainer.EmptyDatasetException
envBuilder, environment
Constructor and Description |
---|
LogisticRegressionSGDTrainer() |
Modifier and Type | Method and Description |
---|---|
<K,V> LogisticRegressionModel |
fitWithInitializedDeployingContext(DatasetBuilder<K,V> datasetBuilder,
Preprocessor<K,V> extractor)
Trains model based on the specified data.
|
int |
getBatchSize()
Get the batch size.
|
int |
getLocIterations()
Get the amount of local iterations.
|
int |
getMaxIterations()
Get the max amount of iterations.
|
long |
getSeed()
Get the seed for random generator.
|
UpdatesStrategy |
getUpdatesStgy()
Get the update strategy.
|
boolean |
isUpdateable(LogisticRegressionModel mdl) |
protected <K,V> LogisticRegressionModel |
updateModel(LogisticRegressionModel mdl,
DatasetBuilder<K,V> datasetBuilder,
Preprocessor<K,V> extractor)
Trains new model taken previous one as a first approximation.
|
LogisticRegressionSGDTrainer |
withBatchSize(double batchSize)
Set up the batchSize parameter.
|
LogisticRegressionSGDTrainer |
withLocIterations(double amountOfLocIterations)
Set up the amount of local iterations of SGD algorithm.
|
LogisticRegressionSGDTrainer |
withMaxIterations(double maxIterations)
Set up the max amount of iterations before convergence.
|
LogisticRegressionSGDTrainer |
withSeed(long seed)
Set up the random seed parameter.
|
LogisticRegressionSGDTrainer |
withUpdatesStgy(UpdatesStrategy updatesStgy)
Set up the regularization parameter.
|
fit, fit, fit, fit, fit, fit, getLastTrainedModelOrThrowEmptyDatasetException, identityTrainer, learningEnvironment, update, update, update, update, update, withConvertedLabels, withEnvironmentBuilder
public <K,V> LogisticRegressionModel fitWithInitializedDeployingContext(DatasetBuilder<K,V> datasetBuilder, Preprocessor<K,V> extractor)
fitWithInitializedDeployingContext
in class DatasetTrainer<LogisticRegressionModel,Double>
K
- Type of a key in upstream
data.V
- Type of a value in upstream
data.datasetBuilder
- Dataset builder.extractor
- Extractor of UpstreamEntry
into LabeledVector
.protected <K,V> LogisticRegressionModel updateModel(LogisticRegressionModel mdl, DatasetBuilder<K,V> datasetBuilder, Preprocessor<K,V> extractor)
updateModel
in class DatasetTrainer<LogisticRegressionModel,Double>
K
- Type of a key in upstream
data.V
- Type of a value in upstream
data.mdl
- Learned model.datasetBuilder
- Dataset builder.extractor
- Extractor of UpstreamEntry
into LabeledVector
.public boolean isUpdateable(LogisticRegressionModel mdl)
isUpdateable
in class DatasetTrainer<LogisticRegressionModel,Double>
mdl
- Model.public LogisticRegressionSGDTrainer withMaxIterations(double maxIterations)
maxIterations
- The parameter value.public LogisticRegressionSGDTrainer withBatchSize(double batchSize)
batchSize
- The size of learning batch.public LogisticRegressionSGDTrainer withLocIterations(double amountOfLocIterations)
amountOfLocIterations
- The parameter value.public LogisticRegressionSGDTrainer withSeed(long seed)
seed
- Seed for random generator.public LogisticRegressionSGDTrainer withUpdatesStgy(UpdatesStrategy updatesStgy)
updatesStgy
- Update strategy.public UpdatesStrategy getUpdatesStgy()
public int getMaxIterations()
public int getBatchSize()
public int getLocIterations()
public long getSeed()
GridGain In-Memory Computing Platform : ver. 8.9.14 Release Date : November 5 2024