In the following, we explain the counterfactuals
workflow for both a classification and a regression task using concrete
use cases.
The Predictor
class of the iml
package
provides the necessary flexibility to cover classification and
regression models fitted with diverse R packages. In the introduction
vignette, we saw models fitted with the mlr3
and
randomForest
packages. In the following, we show extensions
to - an classification tree fitted with the caret
package,
the mlr
(a predecesor of mlr3
) and
tidymodels
. For each model we generate counterfactuals for
the 100th row of the plasma dataset of the gamlss.data
package using the WhatIf
method.
library("caret")
treecaret = caret::train(retplasma ~ ., data = plasma[-100L,], method = "rpart",
tuneGrid = data.frame(cp = 0.01))
predcaret = Predictor$new(model = treecaret, data = plasma[-100L,], y = "retplasma")
predcaret$predict(x_interest)
#> .prediction
#> 1 342.9231
nicecaret = NICERegr$new(predcaret, optimization = "proximity",
margin_correct = 0.5, return_multiple = FALSE)
nicecaret$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
#> 1 Counterfactual(s)
#>
#> Desired outcome range: [500, Inf]
#>
#> Head:
#> age sex smokstat bmi vituse calories fat fiber alcohol cholesterol betadiet retdiet betaplasma
#> <int> <fctr> <fctr> <num> <fctr> <num> <num> <num> <num> <num> <int> <int> <int>
#> 1: 46 1 3 35.25969 3 2667.5 131.6 10.1 0 550.5 1210 1291 218
library("tidymodels")
treetm = decision_tree(mode = "regression", engine = "rpart") %>%
fit(retplasma ~ ., data = plasma[-100L,])
predtm = Predictor$new(model = treetm, data = plasma[-100L,], y = "retplasma")
predtm$predict(x_interest)
#> .pred
#> 1 342.9231
nicetm = NICERegr$new(predtm, optimization = "proximity",
margin_correct = 0.5, return_multiple = FALSE)
nicetm$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
#> 1 Counterfactual(s)
#>
#> Desired outcome range: [500, Inf]
#>
#> Head:
#> age sex smokstat bmi vituse calories fat fiber alcohol cholesterol betadiet retdiet betaplasma
#> <int> <fctr> <fctr> <num> <fctr> <num> <num> <num> <num> <num> <int> <int> <int>
#> 1: 46 1 3 35.25969 3 2667.5 131.6 10.1 0 550.5 1210 1291 218
library("mlr")
#> Warning in fun(pkgname, pkgpath): Packages 'paradox' and 'ParamHelpers' are conflicting and should not be loaded in the same session
#> Warning in fun(pkgname, pkgpath): Packages 'mlr3' and 'mlr' are conflicting and should not be loaded in the same session
task = mlr::makeRegrTask(data = plasma[-100L,], target = "retplasma")
mod = mlr::makeLearner("regr.rpart")
treemlr = mlr::train(mod, task)
predmlr = Predictor$new(model = treemlr, data = plasma[-100L,], y = "retplasma")
predmlr$predict(x_interest)
#> .prediction
#> 1 342.9231
nicemlr = NICERegr$new(predmlr, optimization = "proximity",
margin_correct = 0.5, return_multiple = FALSE)
nicemlr$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
#> 1 Counterfactual(s)
#>
#> Desired outcome range: [500, Inf]
#>
#> Head:
#> age sex smokstat bmi vituse calories fat fiber alcohol cholesterol betadiet retdiet betaplasma
#> <int> <fctr> <fctr> <num> <fctr> <num> <num> <num> <num> <num> <int> <int> <int>
#> 1: 46 1 3 35.25969 3 2667.5 131.6 10.1 0 550.5 1210 1291 218
treerpart = rpart(retplasma ~ ., data = plasma[-100L,])
predrpart = Predictor$new(model = treerpart, data = plasma[-100L,], y = "retplasma")
predrpart$predict(x_interest)
#> pred
#> 1 342.9231
nicerpart = NICERegr$new(predrpart, optimization = "proximity",
margin_correct = 0.5, return_multiple = FALSE)
nicerpart$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
#> 1 Counterfactual(s)
#>
#> Desired outcome range: [500, Inf]
#>
#> Head:
#> age sex smokstat bmi vituse calories fat fiber alcohol cholesterol betadiet retdiet betaplasma
#> <int> <fctr> <fctr> <num> <fctr> <num> <num> <num> <num> <num> <int> <int> <int>
#> 1: 46 1 3 35.25969 3 2667.5 131.6 10.1 0 550.5 1210 1291 218