Decision Tree
Definition
Decision tree learning is a supervised learning approach used in statistics, data mining, and machine learning. In this formalism, a classification or regression decision tree is used as a predictive model to draw conclusions about a set of observations.
How it works
A decision tree starts with a root node, which does not have any incoming branches. The outgoing branches from the root node then feed into the internal nodes, also known as decision nodes. Based on the available features, both node types conduct evaluations to form homogenous subsets, which are denoted by leaf nodes, or terminal nodes. The leaf nodes represent all the possible outcomes within the dataset.
Considerations
While the basic underlying model is that of a decision tree, the decision tree node criteria, and the method for identifying splits varies significantly depending on the learning algorithm selected (e.g., CART, ID3, C4.5, C5.0, CHAID, MARS.) Extensions like linear and logistic trees can add additional expressiveness as well.
Key Test Considerations
Machine Learning:
Verify the dataset quality: Check the data to make sure it is free of errors. Quantify the degree of missing values, outliers, and noise in the data collection. If the data quality is low, it may be difficult or impossible to create models and systems with the desired performance.
Verify development datasets are representative: of expected operational environment and data collection means. Compare distributions of dataset features and labels with exploratory data analysis and assess the difference in tests on training data and tests on evaluation data (where the evaluation data must be drawn from a representative dataset.)
Use a variety of data sets: where available and applicable, to reflect different operating and environment conditions that are likley to be be encountered.
Use software libraries: and tools built for ML where possible, so that the underlying code is verified by prior use.**
Diagnose model errors with domain SMEs: Have problem domain SMEs investigate model errors for conditions for which the model may underperform and suggest refinements.
Classification:
Use Standard Classification Performance Measures: Not all of the following may be necessary, but should be considered for both verification (developmental test) and operational test stages use:
Accuracy: The fraction of predictions that were corret.
Precision: The proportion of positive identifications that were correct.
Recall: The proportion of actual positive cases identified correctly.
F-Measure: Combines the preicion and recall into a single score. It is the harmonic mean of the precision and recall.
Receiver Operating Characteristic (ROC) Curve: A ROC curve shows the performance of a classification model at all classification thresholds. It graphs the True Positive Rate over the False Positive Rate.
Area Under the ROC Curve (AUC): This measures the two-dimensional area under the ROC Curve. AUC is scale-invariant and classification-threshold invariant.
ROC TP vs FP points: In addition to a specific AUC score, the performance at points
Confusion Matrix: A confusion matrix is a table layout that allows the visualization of the performance of an algorithm. Each row of the matrix represents the instances in an actual class while each column represents the instances in a predicted class, or vice versa. It is a special kind of contingency table, with two dimensions ("actual" and "predicted"), and identical sets of "classes" in both dimensions (each combination of dimension and class is a variable in the contingency table.)
Prediction Bias: The difference between the average of the predicted labels and the average of the labels in the data set. One should check for prediction bias when evaluating the classifier's results. Causes of bias can include:
Noisy data set: Errors in original data can as the collection method may have an underlying bias.
Processing bug: Errors in the data pipeline can introduce bias.
Biased training sample (unbalanced samples): Model parameters may be skewed towards majority classes.
Overly strong regularization: Model may be underfitting model and too simple.
Proxy variables: Model features may be highly correlated.
Supervised Learning:
Overfitting and Underfitting: Overfitting occurs when the the model built corresponds too closely or exactly to a particular set of data, and thus may fail to fit to predict additional data reliably. An overfitted model is a mathematical model that contains more parameters than can be justified by the data. Underfitting occurs when the model built does adequately capture the patterns in the data. As an example, a linear model will underfit a non-linear dataset.
Sensitivity: Perform N-fold Cross validation to indicate how much sensitivity the algorithm has to data variation and to avoid overfitting operational models.
Decision Tree Learning:
Sensitive to unbalanced classes: Examine and determine target class balance; decision tree learning algorithms are especially sensitive to unbalanced target classes.
Consider decision boundaries: Perform exploratory data analysis to determine if decision boundaries lie alongaxes of features. Decision trees are ideal when decision boundaries can be found that lie along axes of features.
Decision tree overfitting may require tuning algorithm hyperparameters such as tree depth, max features used, max leaf nodes, etc.
Pruning may result in a more robust model in real-word applications.
Missing values: Inspect the data set to determine if there are missing values and select a means to address them, either by choosing an algorithm that works well or a way to impute the value or eliminate the missing values in the data sensors or pipeline.
Platforms, Tools, or Libraries
scikit-learn: includes tree algorithms for ID3, C4.5, C5.0, and CART.
Weka: includes J48 (C4.5), SimpleCart (CART), Logistic Model Trees, Naive Bayes Trees, and more.
Validation Approach
- Use operationally relevant data across the range of application's operating environment.
- Incorporate some kind of continuous validation to address concept drift and the need to retrain the model and/or check data quality.