1. 首页
  2. 自学中心
  3. 软件
  4. R

parttree 决策树可视化【R】

from https://github.com/grantmcdermott/parttree

A set of simple functions for visualizing decision tree partitions in R with ggplot2.

Installation

This package is not yet on CRAN, but can be installed from GitHub with:

# install.packages("remotes")
remotes::install_github("grantmcdermott/parttree")

Example

The main function that users will interact with is geom_parttree(). Here’s a simple example.

library(parttree)
library(rpart)
library(ggplot2)

iris_tree = rpart(Species ~ Sepal.Width + Petal.Width, data=iris)

## Let's construct a scatterplot of the original iris data
p = ggplot(data = iris, aes(x=Petal.Width, y=Sepal.Width)) +
  geom_point(aes(col=Species))

## We now add the partitions with geom_parttree()
p +  
  geom_parttree(data = iris_tree, aes(fill=Species), alpha = 0.1) +
  labs(caption = "Note: Points denote observed data. Shaded regions denote tree predictions.")

parttree 决策树可视化【R】

Limitations and caveats

Supported model classes

Currently, the package only works with decision trees created by the rpart package. However, it does support other packages and modes that call rpart::rpart() as the underlying engine. Here’s an example using the parsnip package.

library(parsnip)
library(titanic) ## Just for a different data set
set.seed(123) ## For consistent jitter

titanic_train$Survived = as.factor(titanic_train$Survived)

## Build our tree using parsnip (but with rpart as the model engine)
ti_tree =
  decision_tree() %>%
  set_engine("rpart") %>%
  set_mode("classification") %>%
  fit(Survived ~ Pclass + Age, data = titanic_train)

## Plot the data and model partitions
titanic_train %>%
  ggplot(aes(x=Pclass, y=Age)) +
  geom_jitter(aes(col=Survived), alpha=0.7) +
  geom_parttree(data = ti_tree, aes(fill=Survived), alpha = 0.1) +
  theme_minimal()
#> Warning: Removed 177 rows containing missing values (geom_point).

parttree 决策树可视化【R】

Orientation

Underneath the hood, geom_parttree() is calling the companion parttree() function, which coerces the rpart tree object into a data frame that is easily understood by ggplot2. For example, consider our “ti_tree” model from above. Here’s the print output of the raw model.

ti_tree
#> parsnip model object
#> 
#> Fit time:  6ms 
#> n= 891 
#> 
#> node), split, n, loss, yval, (yprob)
#>       * denotes terminal node
#> 
#>   1) root 891 342 0 (0.61616162 0.38383838)  
#>     2) Pclass>=2.5 491 119 0 (0.75763747 0.24236253)  
#>       4) Age>=6.5 461 102 0 (0.77874187 0.22125813) *
#>       5) Age< 6.5 30  13 1 (0.43333333 0.56666667) *
#>     3) Pclass< 2.5 400 177 1 (0.44250000 0.55750000)  
#>       6) Age>=17.5 365 174 1 (0.47671233 0.52328767)  
#>        12) Pclass>=1.5 161  66 0 (0.59006211 0.40993789) *
#>        13) Pclass< 1.5 204  79 1 (0.38725490 0.61274510)  
#>          26) Age>=44.5 67  32 0 (0.52238806 0.47761194)  
#>            52) Age>=60.5 14   3 0 (0.78571429 0.21428571) *
#>            53) Age< 60.5 53  24 1 (0.45283019 0.54716981)  
#>             106) Age< 47.5 13   3 0 (0.76923077 0.23076923) *
#>             107) Age>=47.5 40  14 1 (0.35000000 0.65000000) *
#>          27) Age< 44.5 137  44 1 (0.32116788 0.67883212) *
#>       7) Age< 17.5 35   3 1 (0.08571429 0.91428571) *

And here’s what we get after we feed it to parttree().

parttree(ti_tree)
#>   node Survived
#> 1    4        0
#> 2    5        1
#> 3    7        1
#> 4   12        0
#> 5   27        1
#> 6   52        0
#> 7  106        0
#> 8  107        1
#>                                                                                              path
#> 1                                                                    Pclass >= 2.5 --> Age >= 6.5
#> 2                                                                    Pclass >= 2.5 --> Age <  6.5
#> 3                                                                   Pclass <  2.5 --> Age <  17.5
#> 4                                                 Pclass <  2.5 --> Age >= 17.5 --> Pclass >= 1.5
#> 5                                 Pclass <  2.5 --> Age >= 17.5 --> Pclass <  1.5 --> Age <  44.5
#> 6                 Pclass <  2.5 --> Age >= 17.5 --> Pclass <  1.5 --> Age >= 44.5 --> Age >= 60.5
#> 7 Pclass <  2.5 --> Age >= 17.5 --> Pclass <  1.5 --> Age >= 44.5 --> Age <  60.5 --> Age <  47.5
#> 8 Pclass <  2.5 --> Age >= 17.5 --> Pclass <  1.5 --> Age >= 44.5 --> Age <  60.5 --> Age >= 47.5
#>   xmin xmax ymin ymax
#> 1  2.5  Inf  6.5  Inf
#> 2  2.5  Inf -Inf  6.5
#> 3 -Inf  2.5 -Inf 17.5
#> 4  1.5  2.5 17.5  Inf
#> 5 -Inf  1.5 17.5 44.5
#> 6 -Inf  1.5 60.5  Inf
#> 7 -Inf  1.5 44.5 47.5
#> 8 -Inf  1.5 47.5 60.5

Again, the resulting data frame is designed to be amenable to a ggplot2 geom layer, with columns like xminxmax, etc. specifying aesthetics that ggplot2 recognises. (Fun fact: geom_partree() is really just a thin wrapper around geom_rect().) The goal of the package is to abstract away these kinds of details from the user, so we can just specify geom_parttree() — with a valid tree object as the data input — and be done with it. However, while this generally works well, it can sometimes lead to unexpected behaviour in terms of plot orientation. That’s because it’s hard to guess ahead of time what the user will specify as the x and y axes/variables in their other plot layers. To see what I mean, let’s redo our titanic plot from earlier, but this time switch the axes in the main ggplot() call.

titanic_train %>%
  ggplot(aes(x=Age, y=Pclass)) + ## Changed!
  geom_jitter(aes(col=Survived), alpha=0.7) +
  geom_parttree(data = ti_tree, aes(fill=Survived), alpha = 0.1) +
  theme_minimal() +
  labs(title = "Oops!")
#> Warning: Removed 177 rows containing missing values (geom_point).

parttree 决策树可视化【R】

Normally, this kind of orientation mismatch should be pretty easy to recognize (as is the case here). But its admittedly annoying. I’ll try to add better support for catching/avoiding these kinds of errors in a future update, but as of the moment: caveat emptor.

原创文章,作者:xsmile,如若转载,请注明出处:http://www.17bigdata.com/parttree-%e5%86%b3%e7%ad%96%e6%a0%91%e5%8f%af%e8%a7%86%e5%8c%96%e3%80%90r%e3%80%91/

发表评论

登录后才能评论

联系我们

在线咨询:点击这里给我发消息

邮件:23683716@qq.com

跳至工具栏