星期四, 八月 30, 2012

EM算法的R实现和高斯混合模型

EM(Expectatioin-Maximalization)算法即期望最大算法,被誉为是数据挖掘的十大算法之一。它是在概率模型中寻找参数最大似然估计的算法,其中概率模型依赖于无法观测到的隐变量。最大期望算法经过两个步骤交替进行计算,第一步是计算期望(E),也就是将隐藏变量象能够观测到的一样包含在内,从而计算最大似然的期望值;另外一步是最大化(M),也就是最大化在 E 步上找到的最大似然的期望值从而计算参数的最大似然估计。M 步上找到的参数然后用于另外一个 E 步计算,这个过程不断交替进行。对于信息缺失的数据来说,EM算法是一种极有效的工具,我们先用一个简单例子来理解EM算法,并将其应用到GMM(高斯混合模型)中去。

幼儿园里老师给a,b,c,d四个小朋友发糖吃,但老师有点偏心,不同小朋友得到糖的概率不同,p(a)=0.5, p(b)=miu, p(c)=2*miu, p(d)=0.5-3*miu 如果确定了参数miu,概率分布就知道了。我们可以通过观察样本数据来推测参数。设四人实际得到糖果数目为(a,b,c,d),可以用ML(极大似然),估计出miu=(b+c)/6*(b+c+d),假如某一天四个小朋友分别得到了(40,18,0,23)个糖。根据公式可求出miu为0.1。在信息完整条件下,ML方法是很容易估计参数的。

假设情况再复杂一点:知道c和d二人得到的糖果数,也知道a与b二人的糖果数之和为h,如何来估计出参数miu呢?前面我们知道了,如果观察到a,b,c,d就可以用ML估计出miu。反之,如果miu已知,根据概率期望 a/b=0.5/miu,又有a+b=h。由两个式子可得到 a=0.5*h/(0.5+miu)和b=miu*h/(0.5+miu)。此时我们面临一个两难的境地,a,b需要miu才能求出来,miu需要a,b才能求出来。有点类似岸上的八戒和河里的沙僧的对话:你先上来,你先下来,你先上来......

那么一种思路就是一方先让一步,暂且先抛出一个随机的初值,然后用对方算出的数值反复迭代计算。直到计算结果收敛为止。这就是EM算法的思路,我们来看看用R来实现这种思路:

# 已知条件
h = 20
c = 10
d = 10
 
# 随机初始两个未知量
miu = runif(1,0,1/6)
b = round(runif(1,1,20))
 
iter = 1
nonstop=TRUE
while (nonstop) {
    # E步骤,根据假设的miu来算b
    b = c(b,miu[iter]*h/(0.5+miu[iter]))
    # M步骤,根据上面算出的b再来计算miu
    miu = c(miu,(b[iter+1] + c)/(6*(b[iter+1]+c+d)))
    # 记录循环次数
    iter = iter + 1
    # 如果前后两次的计算结果差距很小则退出
    nonstop = (miu[iter]-miu[iter-1]>10^(-4))
}
print(cbind(miu,b))
当循环到第四次时结果已经收敛,miu为0.094,b为3.18

EM算法在高斯混合模型GMM(Gaussian Mixture Model )中有很重要的用途。简单来讲GMM就是一些高斯分布的组合。如果我们已知观测到的数据的类别,则可以根据ML来估计出GMM的参数。反之,对于没有类别信息一堆数据,如果我们已知GMM的参数,可以很容易用贝叶斯公式将它们归入不同的类中;但尴尬的问题是我们即不知道GMM参数,也不知道观测数据的类别。以下面生成的一维数据为例,我们希望找到这两个高斯分布的参数,同时为这些数据分类。
# 设置模拟参数
miu1 <- 3
miu2 <- -2
sigma1 <- 1
sigma2 <- 2
alpha1 <- 0.4
alpha2 <- 0.6
# 生成两种高斯分布的样本
n <- 5000
x <- rep(0,n)
n1 <- floor(n*alpha1)
n2 <- n - n1
x[1:n1] <- rnorm(n1)*sigma1 + miu1
x[(n1+1):n] <- rnorm(n2)*sigma2 + miu2
hist(x,freq=F)
lines(density(x),col='red')
下面用EM算法来估计GMM的参数。

# 设置初始值
m <- 2
miu <- runif(m)
sigma <- runif(m)
alpha <- c(0.2,0.8)
prob <- matrix(rep(0,n*m),ncol=m)
 
for (step in 1:100){
    # E步骤
    for (j in 1:m){
        prob[,j]<- sapply(x,dnorm,miu[j],sigma[j])
    }
    sumprob <- rowSums(prob)
    prob<- prob/sumprob
 
    oldmiu <- miu
    oldsigma <- sigma
    oldalpha <- alpha
 
    # M步骤
    for (j in 1:m){
        p1 <- sum(prob[ ,j])
        p2 <- sum(prob[ ,j]*x)
        miu[j] <- p2/p1
        alpha[j] <- p1/n
        p3 <- sum(prob[ ,j]*(x-miu[j])^2)
        sigma[j] <- sqrt(p3/p1)
    }
 
    # 变化
    epsilo <- 1e-4
    if (sum(abs(miu-oldmiu))<epsilo &
        sum(abs(sigma-oldsigma))<epsilo &
        sum(abs(alpha-oldalpha))<epsilo) break
    cat('step',step,'miu',miu,'sigma',sigma,'alpha',alpha,'\n')
}
在33次循环之后运算结果趋于稳定,估计的miu为(-2.2,2.8),sigma为(1.82,1.14)

GMM 模型常用于基于模型的聚类分析,GMM中的每一个高斯分布都可以代表数据的一类,整个数据就是多个高斯分布的混合。在R中的mclust包中的Mclust函数可以用来进行基于GMM的聚类分析。下面即是以最常用的iris数据集为例,聚类结果生成的图形就是文章的第一幅图:
library(mclust)
mc <-  Mclust(iris[,1:4], 3)
plot(mc, data=iris[,1:4], what="classification",dimens=c(3,4))
table(iris$Species, mc$classification)

参考资料:
http://architects.dzone.com/articles/machine-learning-algorithms
http://www.autonlab.org/tutorials/gmm14.pdf
http://www.cs.ubbcluj.ro/~csatol/gep_tan/Bishop-CUED-2006.pdf
http://blog.sciencenet.cn/blog-637323-572768.html
http://home.dei.polimi.it/matteucc/Clustering/tutorial_html/mixture.html
http://wenku.baidu.com/view/f780fb0590c69ec3d5bb7538.html
http://www.matrixq.net/2011/09/10218.html

9 条评论:

  1. 请问plot中的dimens是什么意思?

    回复删除
    回复
    1. iris数据集中有四个变量,dimens参数是将第3、4两个变量作为x、y轴变量进行绘图

      删除
  2. 您好,非常感谢分享。

    我尝试了下,觉得E步骤中prob[,j]<- sapply(x,dnorm,miu[j],sigma[j])应该是alpha[j]*sapply(x,dnorm,miu[j],sigma[j])

    这样的话,经过62步收敛,结果为
    step 62 miu -2.01694 2.96893 sigma 2.009187 1.029817 alpha 0.5976662 0.4023338

    谢谢指正!
    裘鹏翔

    回复删除
    回复
    1. 裘兄说的很对,以前修正过代码,在下面,我忘了更新博客.
      https://github.com/xccds/my-r-code/blob/master/EM.R

      删除
  3. plot(mc, data=iris[,1:4], what="classification",dimens=c(3,4))
    这一句要把data那一项删掉,不然报错啊

    回复删除
  4. 估计出miu=(b+c)/6*(b+c+d),,,是怎么得到的?????

    回复删除
  5. 根据概率期望 a/b=0.5/miu?
    能解释一下吗?

    回复删除
  6. 我是R初学者,sapply(x,dnorm,miu[j],sigma[j]) 这条语句不是很懂,sapply()里面不是应该有个函数fun吗?求指导

    回复删除