星期五, 七月 06, 2012

谈一谈支持向量机分类器


支持向量机(Support Vector Machine)名字听起来很炫,功能也很炫,但公式理解起来常有眩晕感。所以本文尝试不用一个公式来说明SVM的原理,以保证不吓跑一个读者。理解SVM有四个关键名词:分离超平面、最大边缘超平面、软边缘、核函数

  • 分离超平面separating hyperplane:处理分类问题的时候需要一个决策边界,好象楚河汉界一样,在界这边我们判别A,在界那边我们判别B。这种决策边界将两类事物相分离,而线性的决策边界就是分离超平面。
  • 最大边缘超平面(Maximal Margin Hyperplane):分离超平面可以有很多个,怎么找最好的那个呢,SVM的作法是找一个“最中间”的。换句话说,就是这个平面要尽量和两边保持距离,以留足余量,减小泛化误差,保证稳健性。或者用中国人的话讲叫做“执中”。以江河为国界的时候,就是以航道中心线为界,这个就是最大边缘超平面的体现。在数学上找到这个最大边缘超平面的方法是一个二次规划问题。
  • 软边缘(Soft Margin):但世界上没这么美的事,很多情况下都是“你中有我,我中有你”的混杂状态。不大可能用一个平面完美的分离两个类别。在线性不可分情况下就要考虑软边缘了。软边缘可以破例允许个别样本跑到其它类别的地盘上去。但要使用参数来权衡两端,一个是要保持最大边缘的分离,另一个要使这种破例不能太离谱。这种参数就是对错误分类的惩罚程度C。
  • 核函数(Kernel Function),为了解决完美分离的问题,SVM还提出一种思路,就是将原始数据映射到高维空间中去,直觉上可以感觉高维空间中的数据变的稀疏,有利于“分清敌我”。那么映射的方法就是使用“核函数”。如果这种“核技术”选择得当,高维空间中的数据就变得容易线性分离了。而且可以证明,总是存在一种核函数能将数据集映射成可分离的高维数据。看到这里各位不要过于兴奋,映射到高维空间中并非是有百利而无一害的。维数过高的害处就是会出现过度拟合。
所以选择合适的核函数以及软边缘参数C就是训练SVM的重要因素。一般来讲,核函数越复杂,模型越偏向于拟合过度。在参数C方面,它可以看作是LASSO算法中的lambda的倒数,C越大模型越偏向于拟合过度,反之则拟合不足。实际问题中怎么选呢?用人类最古老的办法,试错。


常用的核函数有如下种类:
  • Linear:使用它的话就成为线性向量机,效果基本等价于Logistic回归。但它可以处理变量极多的情况,例如文本挖掘。
  • polynomial:多项式核函数,适用于图像处理问题。
  • Radial basis,高斯核函数,最流行易用的选择。参数包括了sigma,其值若设置过小,会有过度拟合出现。
  • sigmoid:反曲核函数,多用于神经网络的激活函数。
好吧,理论说了一大堆,关键得在R里面出手。R语言中可以用e1071包中的svm函数建模,而另一个kernlab包中则包括了更多的核方法函数,本例用其中的ksvm函数,来说明参数C的作用和核函数的选择。我们先人为构造一个线性不可分的数据,先用线性核函数来建模,其参数C取值为1。然后我们用图形来观察建模结果,下图是根据线性SVM得到各样本的判别值等高线图(判别值decision value相当于Logistic回归中的X,X取0时为决策边界)。可以清楚的看到决策边界为线性,中间的决策边缘显示为白色区域,有相当多的样本落入此区域。
下面为了更好的拟合,我们加大了C的取值,这样如下图所示。可以预料到,当加大惩罚参数后决策边缘缩窄,也使训练误差减少,但仍有个别样本未被正确的分类。
最后我们换用高斯核函数,这样得到的非线性决策边界。所有的样本都得到了正确的分类。
在实际运用中,为了寻找最优参数我们还可以用caret包来配合建模,并且如同前文那样使用多重交叉检验来评价模型。还需要注意一点SVM建模最好先标准化处理。最后来总结一下SVM的优势:
  • 可用于分类、回归和异常检验
  • 可以发现全局最优解
  • 可以用参数来控制过度拟合问题

代码如下:
# 构造数据
x1 <- seq(0,pi,length.out=100)
y1 <- sin(x1) + 0.1*rnorm(100)
x2 <- 1.5+ seq(0,pi,length.out=100)
y2 <- cos(x2) + 0.1*rnorm(100)
data <- data.frame(c(x1,x2),c(y1,y2),c(rep(1, 100), rep(-1, 100)))
names(data) <- c('x1','x2','y')
data$y <- factor(data$y)
 
# 使用线性核函数,不能很好的划分数据
model1 <- ksvm(y~.,data=data,kernel='vanilladot',C=0.1)
plot(model1,data=data)
# 加大惩罚参数,决策边缘缩窄,使训练误差减小
model2 <- ksvm(y~.,data=data,kernel='vanilladot',C=100)
plot(model2,data=data)
# 使用高斯核函数,正确的分类
model3 <- ksvm(y~.,data=data,kernel='rbfdot')
plot(model3,data=data)
 
# 10折交叉检验训练iris数据,选择最优参数C为0.5
fitControl <- trainControl(method = "repeatedcv", number = 10, repeats = 3,returnResamp = "all")
model <- train(Species~., data=iris,method='svmRadialCost',trControl = fitControl)

参考资料:
http://www.autonlab.org/tutorials/svm15.pdf
http://www.jstatsoft.org/v15/i09/paper
http://www.broadinstitute.org/annotation/winter_course_2006/index_files/Noble%202006%20SVM%20tutorial%20Nat%20Biotech.pdf
http://cran.r-project.org/web/packages/kernlab/vignettes/kernlab.pdf

21 条评论:

  1. 看了楼主的很多文章,受益匪浅!谢谢楼主。不过请问您有没有对mcmc方法的一些学习和认识,能否写一篇文章呢?

    回复删除
  2. fitControl <- trainControl(method = "repeatedcv", number = 10, repeats = 3,returnResamp = "all")
    model <- train(Species~., data=iris,method='svmRadialCost',trControl = fitControl)

    报错,请问trainControl和train都是kernlab的函数么?

    回复删除
    回复
    1. 不好意思,漏掉了一行,还要加载一个包,没有的话要安装。
      library(caret)

      删除
  3. 你好,看到您的博客对我学习启发很大,谢谢,最近在使用train做向量机交叉验证,每次有如下警告:At least one of the class levels are not valid R variables names; This may cause errors if class probabilities are generated because the variables names will be converted to: X0, X1。不知道有何影响,交叉验证的最优参数只有C么,感觉交叉验证函数就在0.25,0.5,1中取C值,而如果是高斯函数参数sigma却没找到最优的验证取值额,谢谢了。

    回复删除
    回复
    1. 设置method = "svmRadial" 可以同时针对 cost 和 sigma 进行优化, 对于被优化参数的步长可以自己定义,Gridlen=n 或者 通过createGrid自定义参数网格, 甚至method 整个函数都可以自定义,caret包是按s4-method写,很灵活的,具体可以参考文档。

      偶然看到博主的R博客,文笔深入浅出,很是细腻,面向平民,高低兼顾,让我这样的R菜鸟也是很受启发。
      为博主"just for fun"精神鼓掌!

      删除
    2. 楼上的兄弟对caret包很熟悉啊,谢谢鼓励。

      删除
  4. 博主,这篇文章最上边那个图是怎么做的啊?

    回复删除
  5. 请问:svm函数训练完成样本数据后,如何能获得超平面的参数。

    回复删除
    回复
    1. 那个model结果里可能有,好长时间没看这个包了。

      删除
  6. 你好,

    我在使用你的方法进行SVM参数调优, 在plot时出错:plot function only supports binary classification, 是不是因为我的分类数目>2导致的?有什么解决方法么?

    回复删除
    回复
    1. 对,分类数只能是二元的,如果是123三类的,可以先将12类转为一类进行分类,完了再对12类进行第二次分类。

      删除
    2. 我改成两类之后,在plot时还是出错,错误信息是: Error in `[.data.frame`(expand.grid(lis), , labels(terms(x))) :
      undefined columns selected
      In addition: Warning message:
      In names(lis)[1:2] <- colnames(sub) :
      number of items to replace is not a multiple of replacement length

      删除
    3. 在R里头一步步的来,多用print观察中间结果。或者把数据发给我看一下。

      删除
    4. 请问我怎样把数据发给你呢?

      删除
  7. 试着认真读完肖凯兄的每一篇博文,我看的第一个R语言视频就是肖凯兄您的

    回复删除