星期日, 十月 26, 2014

python的决策树和随机森林

Python的决策树和随机森林

决策树模型是一种简单易用的非参数分类器。它不需要对数据有任何的先验假设,计算速度较快,结果容易解释,而且稳健性强,对噪声数据和缺失数据不敏感。下面示范用titanic中的数据集为做决策树分类,目标变量为survive。

第一步:读取数据

In [2]:
%pylab inline
import pandas as pd
df = pd.read_csv('titanic.csv')
df.head()
#df.info()
Populating the interactive namespace from numpy and matplotlib

Out[2]:
survived pclass name sex age sibsp parch ticket fare cabin embarked
0 0 3 Braund, Mr. Owen Harris male 22 1 0 A/5 21171 7.2500 NaN S
1 1 1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38 1 0 PC 17599 71.2833 C85 C
2 1 3 Heikkinen, Miss. Laina female 26 0 0 STON/O2. 3101282 7.9250 NaN S
3 1 1 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35 1 0 113803 53.1000 C123 S
4 0 3 Allen, Mr. William Henry male 35 0 0 373450 8.0500 NaN S

第二步:数据整理

  • 只取出三个自变量
  • 将将age缺失值进行补全
  • 将pclass变量转为三个哑变量
  • 将sex转为0-1变量
In [3]:
subdf = df[['pclass','sex','age']]
y = df.survived
# sklearn中的Imputer也可以
age = subdf['age'].fillna(value=subdf.age.mean())
# sklearn OneHotEncoder也可以
pclass = pd.get_dummies(subdf['pclass'],prefix='pclass')
sex = (subdf['sex']=='male').astype('int')
X = pd.concat([pclass,age,sex],axis=1)
X.head()
Out[3]:
pclass_1 pclass_2 pclass_3 age sex
0 0 0 1 22 1
1 1 0 0 38 0
2 0 0 1 26 0
3 1 0 0 35 0
4 0 0 1 35 1

第三步:建模

  • 数据切分为train和test
In [4]:
from sklearn.cross_validation import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=33)
  • 使用决策树观察在检验集表现
In [6]:
from sklearn import tree
clf = tree.DecisionTreeClassifier(criterion='entropy', max_depth=3,min_samples_leaf=5)
clf = clf.fit(X_train,y_train)
print "准确率为:{:.2f}".format(clf.score(X_test,y_test))
准确率为:0.83

  • 观察各变量的重要性
In [7]:
clf.feature_importances_
Out[7]:
array([ 0.08398076,  0.        ,  0.23320717,  0.10534824,  0.57746383])
  • 使用更多指标来评估模型
In [17]:
from sklearn import metrics
def measure_performance(X,y,clf, show_accuracy=True, 
                        show_classification_report=True, 
                        show_confusion_matrix=True):
    y_pred=clf.predict(X)   
    if show_accuracy:
        print "Accuracy:{0:.3f}".format(metrics.accuracy_score(y,y_pred)),"\n"

    if show_classification_report:
        print "Classification report"
        print metrics.classification_report(y,y_pred),"\n"
        
    if show_confusion_matrix:
        print "Confusion matrix"
        print metrics.confusion_matrix(y,y_pred),"\n"
        
measure_performance(X_test,y_test,clf, show_classification_report=True, show_confusion_matrix=True)
Accuracy:0.834 

Classification report
             precision    recall  f1-score   support

          0       0.85      0.88      0.86       134
          1       0.81      0.76      0.79        89

avg / total       0.83      0.83      0.83       223


Confusion matrix
[[118  16]
 [ 21  68]] 


  • 使用交叉验证来评估模型
In [8]:
from sklearn import cross_validation
scores1 = cross_validation.cross_val_score(clf, X, y, cv=10)
scores1
Out[8]:
array([ 0.82222222,  0.82222222,  0.7752809 ,  0.87640449,  0.82022472,
        0.76404494,  0.7752809 ,  0.76404494,  0.83146067,  0.78409091])

第三步:决策树画图

  • 需要安装GraphViz'
In [9]:
import pydot,StringIO
dot_data = StringIO.StringIO() 
In [14]:
tree.export_graphviz(clf, out_file=dot_data, feature_names=['age','sex','1st_class','2nd_class','3rd_class']) 
dot_data.getvalue()
pydot.graph_from_dot_data(dot_data.getvalue())
graph = pydot.graph_from_dot_data(dot_data.getvalue()) 
#graph.write_png('titanic.png') 
#from IPython.core.display import Image 
#Image(filename='titanic.png')

第四步:使用随机森林进行比较

In [11]:
from sklearn.ensemble import RandomForestClassifier
clf2 = RandomForestClassifier(n_estimators=1000,random_state=33)
clf2 = clf2.fit(X_train,y_train)
scores2 = cross_validation.cross_val_score(clf2,X, y, cv=10)
clf2.feature_importances_
Out[11]:
array([ 0.05526809,  0.02266161,  0.08156048,  0.46552672,  0.37498309])
In [12]:
scores2.mean(), scores1.mean()
Out[12]:
(0.81262938372488946, 0.80352769265690616)

4 条评论:

  1. titanic.csv 你的数据在哪?能分享吗?

    回复删除
  2. 请问可以把ipynb的文件分享到nbviewer.ipython.org吗?上传到github或者贴到gist即可。

    这样可以直接查看生成的图片,代码也更整洁。gist 的代码也可以嵌入到文章。

    回复删除
    回复
    1. 谢谢你的建议,老的ipynb没有保存,新的博客已经这样处理了

      删除