炼数成金 门户 商业智能 机器学习 查看内容

原创翻译 | 用LIME来解释复杂的机器学习模型

2018-5-8 10:40| 发布者: 炼数成金_小数| 查看: 12376| 评论: 0

摘要: 机器学习模型做出的分类决策通常很难--如果不是不可能--我们的大脑就很难理解。一些最精确的分类器(如神经网络)的复杂性使它们表现得如此出色—往往比人类获得更好的结果。但这也使得它们本身很难解释,特别是对非 ...

网络 Python 模型 机器学习 神经网络

机器学习模型做出的分类决策通常很难--如果不是不可能--我们的大脑就很难理解。一些最较精确的分类器(如神经网络)的复杂性使它们表现得如此出色—往往比人类获得更好的结果。但这也使得它们本身很难解释,特别是对非数据科学家来说。


特别是,如果我们的目标是开发医疗诊断的机器学习模型,高精度的测试样本可能不足以出售给临床医生。医生和病人都不太相信他们不理解的模型所做的决定。


因此,我们希望能够具体地解释,为何一个模式会将一个有标签的个案分类,例如为何一个乳房肿块样本被归类为“恶性”而非“良性”。

局部可解释模型-不可知论解释(LIME)是一种尝试,使这些复杂的模型至少部分可以理解。该方法已发布

“我为什么要相信你?”解释任何分类器的预测。来自西雅图华盛顿大学的MarcoTulio Ribeiro,Sameer Singh和Carlos Guestrin。

LIME能够解释所有我们可以获得预测概率的模型(在R中,也就是每一个与预测(type=“prob”)一起工作的模型)。它利用了这样一个事实,即线性模型很容易解释,因为它们基于特征和类标签之间的线性关系:将复模型函数用局部拟合线性模型逼近原训练集的排列。


在每一个排列上,一个线性模型被拟合,并给出权重,从而惩罚那些与原始数据更相似的实例的错误分类(正权值支持决策,负权值与决策值相矛盾)。这将给出每个特性对模型所做决定的贡献程度(以及以何种方式)。

lime的代码最初是为Python提供的,但令人敬畏的ThomasLin Pedersen已经在R中创建了一个实现。它不是在CRAN上(我猜想),但您可以通过Github安装它:

devtools::install_github("thomasp85/lime")

我使用的数据是我上一篇文章中的世界幸福指数数据。因此,让我们在这个数据集上来训练一个神经网络以预测三种快乐程度得分:低、中、高。

load("data_15_16.RData") 
# configure multicore library(doParallel) cl <- makeCluster(detectCores()) registerDoParallel(cl) library(caret) 
set.seed(42) index <- createDataPartition(data_15_16$Happiness.Score.l, p = 0.7, list = FALSE) train_data <- data_15_16[index, ] test_data <- data_15_16[-index, ] 
set.seed(42) model_mlp <- caret::train(Happiness.Score.l ~ ., data = train_data, method = "mlp", trControl = trainControl(method = "repeatedcv", number = 10, repeats = 5, verboseIter = FALSE))

解释功能

lime的核心功能是lime()函数。它创建了下一步用来解释模型预测能力的函数。

我们可以给定几个选项,并查阅?lime帮助功能来寻求各种使用细节,但最重要的是考虑:

连续性的特征变量是否应该被剔除?如果是的话,剔除多少个?

在这里,我保持默认的bin_continuous值为TRUE但指定的值5而不是4(4为默认值)同时n_bins值为 5。

library(lime) explain <- lime(train_data, model_mlp, bin_continuous = TRUE, n_bins = 5, n_permutations = 1000) 


现在,让我们来看看如何解释模型。在此,我不打算浏览所有的测试案例,而是随机选择了三个正确和三个错误的预测案例。

pred <- data.frame(sample_id = 1:nrow(test_data), predict(model_mlp, test_data, type = "prob"), actual = test_data$Happiness.Score.l) pred$prediction <- colnames(pred)[3:5][apply(pred[, 3:5], 1, which.max)] pred$correct <- ifelse(pred$actual == pred$prediction, "correct", "wrong") 

Beware that we need to give our test-set data table row names with the sample names or IDs to be displayed in the header of our explanatory plots below.

请注意,我们需要给我们的测试集数据表添加以样例名或ID而起的行名,它们要在下面的示意图的标题中显示。

library(tidyverse) pred_cor <- filter(pred, correct == "correct") pred_wrong <- filter(pred, correct == "wrong") test_data_cor <- test_data %>% mutate(sample_id = 1:nrow(test_data)) %>% filter(sample_id %in% pred_cor$sample_id) %>% sample_n(size = 3) %>% remove_rownames() %>% tibble::column_to_rownames(var = "sample_id") %>% select(-Happiness.Score.l) test_data_wrong <- test_data %>% mutate(sample_id = 1:nrow(test_data)) %>% filter(sample_id %in% pred_wrong$sample_id) %>% sample_n(size = 3) %>% remove_rownames() %>% tibble::column_to_rownames(var = "sample_id") %>% select(-Happiness.Score.l)

上面的解释函数现在可以与我们的测试样本一起使用。我们可以指定的其他选项包括:


我们想在解释函数中使用多少功能?

假设我们有一个100个功能的大型训练集。查看所有的特性并试图理解它们,可能会更令人困惑,而不是提供帮助。而且很多时候,一些非常重要的特性将足以以合理的精度预测测试样本(也请参阅我最后一篇关于OneR的文章)。因此,我们可以选择要使用“N_Feature”选项查看的特征。


我们要如何选择这些功能?

接下来,我们指定了如何找到这个特性子集。如果我们选择n_FERES<=6,而使用权重较高的特性,则默认的auto使用前向选择。我们可以直接选择F feature_select = "forward_selection", feature_select = "highest_weights或feature_select = "lasso_path". 再检查一下?关于LIME的细节。

在我们的示例数据集中,我们只有7个特性并且我想看看前5个。

我还想对所有三类标签的响应变量进行解释(低、中、高幸福值),所以我选择令n_labels = 3。

explanation_cor <- explain(test_data_cor, n_labels = 3, n_features = 5) explanation_wrong <- explain(test_data_wrong, n_labels = 3, n_features = 5) 

它将返回一个整洁的tibble对象,我们可以对他用plot_features()函数进行绘图:

plot_features(explanation_cor, ncol = 3) 

plot_features(explanation_wrong, ncol = 3) 

输出的Tibble图中的信息可用帮助函数?lime和以下代码查看


tibble::glimpse(explanation_cor)

那么,这能告诉我们什么呢?让我们看一下案例22(我们的图正确预测的类的第一行):这个示例已经被正确地预测来自中等幸福组,因为它

2.03和2.32之间的一个异常值 信任/政府腐败分数低于0.05, GDP/经济分数在1.06到1.23之间并且 预期寿命在0.59至0.7岁之间.

从对“高”标签的解释中,我们也可以发现这种情况下家庭得分大于1.12,这更代表了高幸福度的样本情况。

pred %>% filter(sample_id == 22) 
##   sample_id        low   medium       high actual prediction correct
## 1        22 0.02906327 0.847562 0.07429938 medium     medium correct


名为“反乌托邦”的解释函数是此预测最有力的支持特性。Dytopia是一个想象中的国家,它拥有世界上最不快乐的人。建立盲眼的目的是要有一个基准,所有国家都可以在这六个关键变量中的每一个方面进行有利的比较(没有哪个国家的表现比Dytopia差)。 […]

解释图告诉我们每个特性和类标签的值范围,一个代表性的数据点将下降。如果是的话,这将被视为对此预测的支持,如果它没有,则它会被评为矛盾。对于案例22和反乌托邦的特征,数据点2.27属于中等幸福范围(2.03到2.32之间),具有较高的权重。


当我们观察的这个案例是在这一特征的取值范围内时,我们可以看到确实是非常接近中等样本的中位数和远离高和低的样本的中位数。其他支撑特征也显示出同样的趋势。

train_data %>% gather(x, y, Economy..GDP.per.Capita.:Dystopia.Residual) %>% ggplot(aes(x = Happiness.Score.l, y = y))   geom_boxplot(alpha = 0.8, color = "grey")   geom_point(data = gather(test_data[22, ], x, y, Economy..GDP.per.Capita.:Dystopia.Residual), color = "red", size = 3)   facet_wrap(~ x, scales = "free", ncol = 4) 

对案例22的前5个解释特性的概述存储在:

as.data.frame(explanation_cor[1:9]) %>% filter(case == "22") 
##    case  label label_prob   model_r2 model_intercept
## 1    22 medium 0.84756196 0.05004205       0.5033729
## 2    22 medium 0.84756196 0.05004205       0.5033729
## 3    22 medium 0.84756196 0.05004205       0.5033729
## 4    22 medium 0.84756196 0.05004205       0.5033729
## 5    22 medium 0.84756196 0.05004205       0.5033729
## 6    22   high 0.07429938 0.06265119       0.2293890
## 7    22   high 0.07429938 0.06265119       0.2293890
## 8    22   high 0.07429938 0.06265119       0.2293890
## 9    22   high 0.07429938 0.06265119       0.2293890
## 10   22   high 0.07429938 0.06265119       0.2293890
## 11   22    low 0.02906327 0.07469729       0.3528088
## 12   22    low 0.02906327 0.07469729       0.3528088
## 13   22    low 0.02906327 0.07469729       0.3528088
## 14   22    low 0.02906327 0.07469729       0.3528088
## 15   22    low 0.02906327 0.07469729       0.3528088
##                          feature feature_value feature_weight
## 1              Dystopia.Residual       2.27394     0.14690100
## 2  Trust..Government.Corruption.       0.03005     0.06308598
## 3       Economy..GDP.per.Capita.       1.13764     0.02944832
## 4       Health..Life.Expectancy.       0.66926     0.02477567
## 5                     Generosity       0.00199    -0.01326503
## 6                         Family       1.23617     0.13629781
## 7                     Generosity       0.00199    -0.07514534
## 8  Trust..Government.Corruption.       0.03005    -0.07574480
## 9              Dystopia.Residual       2.27394    -0.07687559
## 10      Economy..GDP.per.Capita.       1.13764     0.07167086
## 11                        Family       1.23617    -0.14932931
## 12      Economy..GDP.per.Capita.       1.13764    -0.12738346
## 13                    Generosity       0.00199     0.09730858
## 14             Dystopia.Residual       2.27394    -0.07464384
## 15 Trust..Government.Corruption.       0.03005     0.06220305
##                                       feature_desc
## 1         2.025072 < Dystopia.Residual <= 2.320632
## 2        Trust..Government.Corruption. <= 0.051198
## 3  1.064792 < Economy..GDP.per.Capita. <= 1.275004
## 4  0.591822 < Health..Life.Expectancy. <= 0.701046
## 5                           Generosity <= 0.123528
## 6                                1.119156 < Family
## 7                           Generosity <= 0.123528
## 8        Trust..Government.Corruption. <= 0.051198
## 9         2.025072 < Dystopia.Residual <= 2.320632
## 10 1.064792 < Economy..GDP.per.Capita. <= 1.275004
## 11                               1.119156 < Family
## 12 1.064792 < Economy..GDP.per.Capita. <= 1.275004
## 13                          Generosity <= 0.123528
## 14        2.025072 < Dystopia.Residual <= 2.320632
## 15       Trust..Government.Corruption. <= 0.051198

同样,我们可以探究为什么有些预测是错误的。


如果你对更多的机器学习的内容感兴趣,看看我的博客上machine_learning类别列表下面的内容。

sessionInfo() 
## R version 3.3.3 (2017-03-06)
## Platform: x86_64-apple-darwin13.4.0 (64-bit)
## Running under: macOS Sierra 10.12.3
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] parallel  stats     graphics  grDevices utils     datasets  methods  
## [8] base     
## 
## other attached packages:
##  [1] dplyr_0.5.0       purrr_0.2.2       readr_1.1.0      
##  [4] tidyr_0.6.1       tibble_1.3.0      tidyverse_1.1.1  
##  [7] RSNNS_0.4-9       Rcpp_0.12.10      lime_0.1.0       
## [10] caret_6.0-73      ggplot2_2.2.1     lattice_0.20-35  
## [13] doParallel_1.0.10 iterators_1.0.8   foreach_1.4.3    
## 
## loaded via a namespace (and not attached):
##  [1] lubridate_1.6.0    assertthat_0.2.0   glmnet_2.0-5      
##  [4] rprojroot_1.2      digest_0.6.12      psych_1.7.3.21    
##  [7] R6_2.2.0           plyr_1.8.4         backports_1.0.5   
## [10] MatrixModels_0.4-1 stats4_3.3.3       evaluate_0.10     
## [13] httr_1.2.1         hrbrthemes_0.1.0   lazyeval_0.2.0    
## [16] readxl_0.1.1       minqa_1.2.4        SparseM_1.76      
## [19] extrafontdb_1.0    car_2.1-4          nloptr_1.0.4      
## [22] Matrix_1.2-8       rmarkdown_1.4      labeling_0.3      
## [25] splines_3.3.3      lme4_1.1-12        extrafont_0.17    
## [28] stringr_1.2.0      foreign_0.8-67     munsell_0.4.3     
## [31] hunspell_2.3       broom_0.4.2        modelr_0.1.0      
## [34] mnormt_1.5-5       mgcv_1.8-17        htmltools_0.3.5   
## [37] nnet_7.3-12        codetools_0.2-15   MASS_7.3-45       
## [40] ModelMetrics_1.1.0 grid_3.3.3         nlme_3.1-131      
## [43] jsonlite_1.4       Rttf2pt1_1.3.4     gtable_0.2.0      
## [46] DBI_0.6-1          magrittr_1.5       scales_0.4.1      
## [49] stringi_1.1.5      reshape2_1.4.2     xml2_1.1.1        
## [52] tools_3.3.3        forcats_0.2.0      hms_0.3           
## [55] pbkrtest_0.4-7     yaml_2.1.14        colorspace_1.3-2  
## [58] rvest_0.3.2        knitr_1.15.1       haven_1.0.0       
## [61] quantreg_5.29


英文原文:https://www.r-bloggers.com/explaining-complex-machine-learning-models-with-lime/


欢迎加入本站公开兴趣群

商业智能与数据分析群

兴趣范围包括各种让数据产生价值的办法,实际应用案例分享与讨论,分析工具,ETL工具,数据仓库,数据挖掘工具,报表系统等全方位知识

QQ群:81035754


鲜花

握手

雷人

路过

鸡蛋

相关阅读

最新评论

热门频道

  • 大数据
  • 商业智能
  • 量化投资
  • 科学探索
  • 创业

即将开课

 

GMT+8, 2018-5-27 23:17 , Processed in 0.164457 second(s), 26 queries .