Decision Trees
Overview
Decision trees are a simple yet powerful model in supervised machine learning. The main idea is to split a feature space into regions such as that the value in each region varies a little. The measure of the values' variation in a region is called the impurity
of the region.
Apache Ignite provides an implementation of the algorithm optimized for data stored in rows (see partition-based dataset).
Splits are done recursively and every region created from a split can be split further. Therefore, the whole process can be described by a binary tree, where each node is a particular region and its children are the regions derived from it by another split.
Let each sample from a training set belong to some space S
and let p_i
be a projection on a feature with index i
, then a split by continuous feature with index i
has the form:
and a split by categorical feature with values from some set X
has the form:
Here X_0
is a subset of X
.
The model works this way - the split process stops when either the algorithm has reached the configured maximal depth, or splitting of any region has not resulted in significant impurity loss. Prediction of a value for point s
from S
is a traversal of the tree down to the node that corresponds to the region containing s
and getting back a value associated with this leaf.
Model
The Model in a decision tree classification is represented by the class DecisionTreeNode
. We can make a prediction for a given vector of features in the following way:
DecisionTreeNode mdl = ...
double prediction = mdl.apply(observation);
Model is fully independent object and after the training it can be saved, serialized and restored.
Trainer
A Decision Tree algorithm can be used for classification and regression depending upon the impurity measure and node instantiation approach.
Classification
The Classification Decision Tree uses the Gini impurity measure and you can use it in the following way:
// Create decision tree classification trainer.
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(
4, // Max deep.
0 // Min impurity decrease.
);
// Train model.
DecisionTreeNode mdl = trainer.fit(
ignite,
upstreamCache,
(k, pnt) -> pnt.coordinates,
(k, pnt) -> pnt.label
);
// Make a prediction.
double prediction = mdl.apply(coordinates);
// Create decision tree classification trainer.
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(
4, // Max deep.
0 // Min impurity decrease.
);
// Train model.
DecisionTreeNode mdl = trainer.fit(
upstreamMap,
10, // Number of partitions.
(k, pnt) -> pnt.coordinates,
(k, pnt) -> pnt.label
);
// Make a prediction.
double prediction = mdl.apply(coordinates);
Regression
The Regression Decision Tree uses the MSE impurity measure and you can use it in the following way:
// Create decision tree classification trainer.
DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(
4, // Max deep.
0 // Min impurity decrease.
);
// Train model.
DecisionTreeNode mdl = trainer.fit(
ignite,
upstreamCache,
(k, pnt) -> pnt.x,
(k, pnt) -> pnt.y
);
// Make a prediction.
double prediction = mdl.apply(x);
// Create decision tree classification trainer.
DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(
4, // Max deep.
0 // Min impurity decrease.
);
// Train model.
DecisionTreeNode mdl = trainer.fit(
upstreamMap,
10, // Number of partitions.
(k, pnt) -> pnt.x,
(k, pnt) -> pnt.y
);
// Make a prediction.
double prediction = mdl.apply(x);
Examples
To see how the Decision Tree can be used in practice, try this classification example and this regression example that are available on GitHub and delivered with every Apache Ignite distribution.
© 2024 GridGain Systems, Inc. All Rights Reserved. Privacy Policy | Legal Notices. GridGain® is a registered trademark of GridGain Systems, Inc.
Apache, Apache Ignite, the Apache feather and the Apache Ignite logo are either registered trademarks or trademarks of The Apache Software Foundation.