Classification Trees with Uncorrelated Linear Discriminant Analysis Terminal Nodes
Source:R/Treee.R
Treee.Rd
This function fits a classification tree where each node has a Uncorrelated Linear Discriminant Analysis (ULDA) model. It can also handle missing values and perform downsampling. The resulting tree can be pruned either through pre-pruning or post-pruning methods.
Usage
Treee(
datX,
response,
ldaType = c("forward", "all"),
nodeModel = c("ULDA", "mode"),
pruneMethod = c("pre", "post"),
numberOfPruning = 10L,
maxTreeLevel = 20L,
minNodeSize = NULL,
pThreshold = NULL,
prior = NULL,
misClassCost = NULL,
missingMethod = c("medianFlag", "newLevel"),
kSample = -1,
verbose = TRUE
)
Arguments
- datX
A data frame of predictor variables.
- response
A vector of response values corresponding to
datX
.- ldaType
A character string specifying the type of LDA to use. Options are
"forward"
for forward ULDA or"all"
for full ULDA. Default is"forward"
.- nodeModel
A character string specifying the type of model used in each node. Options are
"ULDA"
for Uncorrelated LDA, or"mode"
for predicting based on the most frequent class. Default is"ULDA"
.- pruneMethod
A character string specifying the pruning method.
"pre"
performs pre-pruning based on p-value thresholds, and"post"
performs cross-validation-based post-pruning. Default is"pre"
.- numberOfPruning
An integer specifying the number of folds for cross-validation during post-pruning. Default is
10
.- maxTreeLevel
An integer controlling the maximum depth of the tree. Increasing this value allows for deeper trees with more nodes. Default is
20
.- minNodeSize
An integer controlling the minimum number of samples required in a node. Setting a higher value may lead to earlier stopping and smaller trees. If not specified, it defaults to one plus the number of response classes.
- pThreshold
A numeric value used as a threshold for pre-pruning based on p-values. Lower values result in more conservative trees. If not specified, defaults to
0.01
for pre-pruning and0.6
for post-pruning.- prior
A numeric vector of prior probabilities for each class. If
NULL
, the prior is automatically calculated from the data.- misClassCost
A square matrix \(C\), where each element \(C_{ij}\) represents the cost of classifying an observation into class \(i\) given that it truly belongs to class \(j\). If
NULL
, a default matrix with equal misclassification costs for all class pairs is used. Default isNULL
.- missingMethod
A character string specifying how missing values should be handled. Options include
'mean'
,'median'
,'meanFlag'
,'medianFlag'
for numerical variables, and'mode'
,'modeFlag'
,'newLevel'
for factor variables.'Flag'
options indicate whether a missing flag is added, while'newLevel'
replaces missing values with a new factor level.- kSample
An integer specifying the number of samples to use for downsampling during tree construction. Set to
-1
to disable downsampling.- verbose
A logical value. If
TRUE
, progress messages and detailed output are printed during tree construction and pruning. Default isFALSE
.
Value
An object of class Treee
containing the fitted tree, which is a
list of nodes, each an object of class TreeeNode
. Each TreeeNode
contains:
currentIndex
: The node index in the tree.currentLevel
: The depth of the current node in the tree.idxRow
,idxCol
: Row and column indices indicating which part of the original data was used for this node.currentLoss
: The training error for this node.accuracy
: The training accuracy for this node.stopInfo
: Information on why the node stopped growing.proportions
: The observed frequency of each class in this node.prior
: The (adjusted) class prior probabilities used for ULDA or mode prediction.misClassCost
: The misclassification cost matrix used in this node.parent
: The index of the parent node.children
: A vector of indices of this node’s direct children.splitFun
: The splitting function used for this node.nodeModel
: Indicates the model fitted at the node ('ULDA'
or'mode'
).nodePredict
: The fitted model at the node, either a ULDA object or the plurality class.alpha
: The p-value from a two-sample t-test used to evaluate the strength of the split.childrenTerminal
: A vector of indices representing the terminal nodes that are descendants of this node.childrenTerminalLoss
: The total training error accumulated from all nodes listed inchildrenTerminal
.
References
Wang, S. (2024). FoLDTree: A ULDA-Based Decision Tree Framework for Efficient Oblique Splits and Feature Selection. arXiv preprint arXiv:2410.23147. Available at https://arxiv.org/abs/2410.23147.
Wang, S. (2024). A New Forward Discriminant Analysis Framework Based On Pillai's Trace and ULDA. arXiv preprint arXiv:2409.03136. Available at https://arxiv.org/abs/2409.03136.
Examples
fit <- Treee(datX = iris[, -5], response = iris[, 5], verbose = FALSE)
# Use cross-validation to prune the tree
fitCV <- Treee(datX = iris[, -5], response = iris[, 5], pruneMethod = "post", verbose = FALSE)
head(predict(fit, iris)) # prediction
#> [1] "setosa" "setosa" "setosa" "setosa" "setosa" "setosa"
plot(fit) # plot the overall tree
plot(fit, datX = iris[, -5], response = iris[, 5], node = 1) # plot a certain node