2

I'm doing a test run of the Gradient Boosting Machine algorithm on the iris data with the caret package.

library(caret)
library(gbm)
data(iris)

set.seed(123)
inTraining <- createDataPartition(iris$Species, p = .75, list = FALSE)
training <- iris[ inTraining,]
testing  <- iris[-inTraining,]

gbmGrid <-  expand.grid(interaction.depth = c(1, 2, 3), 
                        n.trees = (1:10)*1000, 
                        shrinkage = c(0.001, 0.005, 0.01, 0.05, 0.1),
                        n.minobsinnode = c(1, 2, 5, 10, 15, 20))

fitControl <- trainControl(
  classProbs = TRUE,
  method = "repeatedcv",
  number = 10,
  repeats = 10,
  allowParallel = T)

set.seed(234)
gbmFit2 <- train(Species ~ ., 
                 data = training, 
                 method = "gbm", 
                 trControl = fitControl, 
                 verbose = FALSE, 
                 tuneGrid = gbmGrid)

I'm achieving excellent Accuracy metrics, however the predicted probabilities for the Species values in the test data are fairly evenly split. I expected GBM would return predicted probabilities of 90%+ for the correctly predicted Species value rather than in the 35%-40% range.

predict(gbmFit2, newdata=testing, type="prob")
     setosa versicolor virginica
1 0.3826163  0.3086751 0.3087086
2 0.3826643  0.3086374 0.3086983
3 0.3826681  0.3086355 0.3086964
4 0.3811067  0.3114695 0.3074237
5 0.3811067  0.3114695 0.3074237
...
32 0.3077245  0.3568080 0.3354674
33 0.3153934  0.3275473 0.3570593
34 0.3097463  0.3525782 0.3376756
35 0.3065883  0.3151160 0.3782957
36 0.3078244  0.3122151 0.3799605

Did I misspecify my model?

RobertF
  • 6,084

1 Answers1

0

I'm getting good results by applying Platt scaling to the predicted probabilities for each of the iris Species classes from the Gradient Boosting Machine model. Instead of binomial logistic regression I'm using the multinomial logistic regression model.

library(nnet)
predict_gbm = predict(gbmFit2, newdata=iris, type="prob")
iris_preds <- data.frame(cbind(testing, predict_gbm))
multinom_iris_calib <- multinom(Species ~ setosa + versicolor + virginica, data = iris_preds)
predict_multinom_iris_calib = fitted(multinom_iris_calib)
predict_multinom_iris_calib <- data.frame(cbind(testing, predict_multinom_iris_calib))
predict_multinom_iris_calib[,5:8]
           Species       setosa  versicolor    virginica
    1       setosa 0.9924330938 0.007566906 3.189546e-12
    5       setosa 0.9924869997 0.007513000 3.122451e-12
    7       setosa 0.9924908536 0.007509146 3.117173e-12
    13      setosa 0.9897471351 0.010252865 6.159896e-12
    14      setosa 0.9897471351 0.010252865 6.159896e-12
    19      setosa 0.9924455961 0.007554404 3.160900e-12
    20      setosa 0.9924369750 0.007563025 3.184155e-12
    26      setosa 0.9897471351 0.010252865 6.159896e-12
    30      setosa 0.9908937928 0.009106207 4.749171e-12
    34      setosa 0.9924330938 0.007566906 3.189546e-12
    44      setosa 0.9924920468 0.007507953 3.113992e-12
    47      setosa 0.9924330938 0.007566906 3.189546e-12
    53  versicolor 0.0111471614 0.430054907 5.587979e-01
    59  versicolor 0.0027915088 0.879689514 1.175190e-01
    62  versicolor 0.0067711749 0.941213020 5.201580e-02
    64  versicolor 0.0041798273 0.913490187 8.232999e-02
    66  versicolor 0.0072934320 0.944243978 4.846259e-02
    72  versicolor 0.0012548141 0.780996058 2.177491e-01
    76  versicolor 0.0064162669 0.938643715 5.494002e-02
    78  versicolor 0.0172668357 0.499051229 4.836819e-01
    80  versicolor 0.0008895809 0.738176308 2.609341e-01
    85  versicolor 0.0068591641 0.944825113 4.831572e-02
    89  versicolor 0.0068698892 0.944839179 4.829093e-02
    99  versicolor 0.0008884476 0.737538242 2.615733e-01
    104  virginica 0.0003896792 0.020226793 9.793835e-01
    105  virginica 0.0003855347 0.019409330 9.802051e-01
    106  virginica 0.0003543512 0.018632033 9.810136e-01
    107  virginica 0.0019856015 0.851050054 1.469643e-01
    116  virginica 0.0005269779 0.023493314 9.759797e-01
    119  virginica 0.0003167019 0.019191993 9.804913e-01
    126  virginica 0.0004884086 0.022714154 9.767974e-01
    127  virginica 0.0007718212 0.263689615 7.355386e-01
    135  virginica 0.0140170968 0.477610130 5.083728e-01
    139  virginica 0.0016218328 0.358758265 6.396199e-01
    148  virginica 0.0004809264 0.023503371 9.760157e-01
    149  virginica 0.0008173819 0.030148626 9.690340e-01
RobertF
  • 6,084