PAL算法原理及代码实现
博主發現之前寫的博客都是偏程序方面,而較少涉及數學或算法方面的東西,其實無論什么軟件工具,最終都是為了更好地給理論鋪路搭橋,所以我覺得不應該就某個程序貼個博客,而是在實際算法研究中,將理論描述清晰,再通過工具實現,兩個結合。
??????廢話不多說,最近上臺灣大學的ML課程,說到PLA(perception learning algorithm)算法,涉及到ML的一個入門算法,我花了一些時間消化整理,在這里跟大家分享一下,希望大家再回過頭去看臺灣大學ML課程的時候,能更加如魚得水。
算法具體如下:
??????PLA是一種能夠通過自己學習而不斷改進的分類算法,可將二維或者更高維的數據切分成對應不同的種類(1和-1),假設我們有n個數據樣本,每個數據樣本對應的維度為m,可以表示成如下:
??????對于每個樣本,其對應的類別為1或-1,可表示為如下:
??????我們假設一條直線:
??????其對應為樣本m個維度的系數,這里需要注意的是,我們的目標是求解出W的值,將對應的兩種類別很好地分開,而不是在樣本中做回歸求誤差最小。
??????所以我們的目標是使下面式子成立:
??????其中sign是符號函數,對于所有的正數,返回1,對于所有非正數返回-1.
??????可以通過將表示為而化簡上市,其中,則有如下:???????????????????????????????????????????????????????????????????????????????????????????????????(1)
??????實際過程中上述等式可能沒辦法在一開始就成立,所以當等式不成立的時候,我們需要某種方法來修正過程中的W參數,下面舉個栗子:
??????比如我們計算出來:??????是正的,而卻是負的,從某種意義上來說,W參數是偏大的;而當是負的,而對應的卻是正的,那么W參數是偏小的,那么,我們該如何調整W參數呢?
可以通過如下:
??????這樣我們就可以通過將對應的W參數自主學習調整為越來越靠近正確的W。
也許你會問,為什么這樣通過修改W最后一定會收斂?或者換個說法,為什么通過這樣不斷地變化W參數,最后一定會有一條直線能將樣本較好地分開呢?
??????下面我會證明上面這個問題,也就是證明PLA算法的收斂性:
??????假設存在一條直線能將我們樣本數據很好分類,那么則有:
??????該式對應上文式(1),這里我通過向量表示消除符號過多的問題。
??????為了證明W會朝著靠攏,我們可以構造如下式子:
???????????????????????????????????????????????????????????????????????????????????????????????????(2)
其中我們上文以及假設是正確的分類線,那么意味式(2)中,
則算法在每次迭代修改W時,,那么從向量內積的角度來看,這意味著兩個向量越來越靠近。
??????也許你還會問,兩個向量內積越來越大,除了角度變小的可能外,還有兩個向量越來越大的可能?
下面我會證明其實在W參數學習的過程中其單位長度在不斷變小:
其中我們已經知道和符號相異,那么
則在W自主學習的過程中,其模越來越小,而上述式(2)我們證明了越來越大,那么綜合只有當向量和的角度越來越小時,式(2)才會成立,所以我們證明了自主學習,W會朝著越來越正確的方向變動(即使有時候這種變動我們察覺不出)。
??????PLA算法在多維度分類效果也比較好,收斂速度很快,這里博主用的是雙維度樣本,該樣本在更新1400多次后輸出了對應的結果,代碼質量還有待改進。??????
?
下面是算法的實現(R語言)
#加載ggplot2包
library(ggplot2)
library(plyr)
#PLA數據,取R自帶數據集iris,確保直線下方數據標簽為-1
?????pladata <- data.frame(x1=iris[1:100,1],x2=iris[1:100,2],y=c(rep(1,50),rep(-1,50)))
?????ggplot(data=pladata,aes(x1,x2,col=factor(y)))+geom_point()?????#樣本數據展示
#PLA函數,x表示樣本數據,y為對應類別,initial為w初始值,delta為相對誤差率
PLA <- function(x,y,initial,delta){
???????????w <- initial;n <- length(y);
???????????x <- as.matrix(cbind(x0=rep(1,dim(x)[1L]),x))
???????????error <- 1
???????????while(error > delta){
??????????????if(all(sign(x %*% w)==y)){
???????????????????error <- 0
??????????????}else{
???????????????????xnt <- which(sign(x %*% w)!=y)
???????????????????w <- w + x[xnt[1],] * rep(y[xnt[1]],dim(x)[2L])
???????????????????xnt1 <- which(sign(x %*% w)!=y)
???????????????????error <- length(xnt1)/n
??????????????}
???????}
?????????????names(w) <- paste("w",0:(dim(x)[2L]-1),sep="");print(w);
}
w <- PLA(x=pladata[,1:2],y=pladata[,3],initial=c(1,0,0),delta=0)
#分類結果展示:
names(w) <- NULL
ggplot(data=pladata,aes(x1,x2,col=factor(y)))+
geom_point()+
geom_abline(aes(intercept=(-w[1]/w[3]),slope=(-w[2]/w[3])))
?
??????其中未分類前的散點圖如下:
??????通過自主學習訓練后的結果如下:
C++代碼實現
/*<span style="font-family:Times New Roman;">?
? ? Author: DreamerMonkey?
? ? Time : 5/3/2015?
? ? Title : PLA Algorithm?
*/ ?
#include<iostream> ?
#include<vector> ?
using namespace std; ?
??
//以二維空間為例,x1 x2為屬性 ?
struct Item{ ?
? ? int x0; ?
? ? double x1,x2; ?
? ? int label; ?
}; ?
//權重結構體,w1 w2為屬性x1 x2的權重,初始值全設為0 ?
struct Weight{ ?
? ? double w0,w1,w2;// ?
}Wit0={0,0,0}; ?
??
//符號函數,根據向量內積和的特點判斷是否應該發放信用卡 ?
int sign(double x){ ?
? ? if(x>0) ?
? ? ? ? return 1; ?
? ? else if(x<0) ?
? ? ? ? return -1; ?
? ? else return 0; ?
} ?
//兩個向量的內積 ?
double DotPro(Item item,Weight wight){ ?
? ? return item.x0*wight.w0+item.x1*wight.w1+item.x2*wight.w2; ?
} ?
//更新權重 ?
Weight UpdateWeight(Item item,Weight weight){ ?
? ? Weight newWeight; ?
? ? newWeight.w0=weight.w0+item.x0*item.label; ?
? ? newWeight.w1=weight.w1+item.x1*item.label; ?
? ? newWeight.w2=weight.w2+item.x2*item.label; ?
? ? return newWeight; ?
} ?
int main(){ ?
? ? ??
? ? vector<Item> ivec; ?
? ? Item temp; ?
? ? cout<<"Please input Item.x1-Item.x2;"<<endl; ?
? ? while(cin>>temp.x1>>temp.x2>>temp.label){ ?
? ? ? ? temp.x0=1; ?
? ? ? ? ivec.push_back(temp); ?
? ? } ?
? ? Weight wit=Wit0; ?
? ? for(vector<Item>::iterator iter=ivec.begin();iter!=ivec.end();++iter){ ?
? ? ? ? if((*iter).label!=sign(DotPro(*iter,wit))){ ?
? ? ? ? ? ? wit=UpdateWeight(*iter,wit); ?
? ? ? ? ? ? iter=ivec.begin();//在從頭開始判斷,因為更新權重后可能會導致前面的點出故障,需要從頭再判斷 ?
? ? ? ? } ?
? ? } ?
? ? //打印結果 ?
? ? cout<<wit.w0<<" "<<wit.w1<<" "<<wit.w2<<" "<<endl;</span> ?
??
}
matlab代碼實現
x_1=[120 185 215 275 310 337];
x_2=[110 125 185 250 130 137];
plot(x_1,x_2,'ob','linewidth',3,'markersize',15);?
hold on;
x1=[55 98 115 110 95 122 70 205 225 ];
y1=[90 178 170 225 270 270 310 345 290 ];
plot(x1,y1,'xr','linewidth',3,'markersize',15)
hold on;
negpoints = [55,90,-1;310,130,1;98,178,-1;115,110,1;115,165,-1;185,125,1;110,225,-1;215,185,1;95,270,-1;275,260,1;122,270,-1;70,310,-1;337,137,1;205,345,-1;225,280,-1]
pospoints = [310,130,-1;115,110,-1;185,125,-1;215,185,-1;275,260,-1;337,137,-1]
weight = [0,300,100]
H_value = 0
sig=true
axis([50 350 50 350])
while sig
? ? for i=1:1:15
? ? ? ? sig=false
? ? ? ? q = sign(negpoints(i,3))
? ? ? ? h_x_i = sign(weight(1)+weight(2)*negpoints(i,1)+weight(3)*negpoints(i,2))
? ? ? ? if h_x_i == q
? ? ? ? ? ? if (i==15 && sig==false ) ? ? ? ? ? ?
? ? ? ? ? ? ? ?
? ? ? ? ? ? ? ? x =[50,100,200,250,350]
? ? ? ? ? ? ? ? y = -(weight(2)/weight(3))*x -( weight(1)/weight(3))
? ? ? ? ? ? ? ? plot(x,y,'b'); ? ? ? ? ??
? ? ? ? ? ? ? ? hold on;
? ? ? ? ? ? else
? ? ? ? ? ? ? ? continue
? ? ? ? ? ? end
? ? ? ? else ?
? ? ? ? ? ? sig=true
? ? ? ? ? ? ew1 = weight(2)
? ? ? ? ? ? ew2 = weight(3)
? ? ? ? ? ? weight(1)= (weight(1)+ q*1)
? ? ? ? ? ? weight(2)= (weight(2)+ q*negpoints(i,1))
? ? ? ? ? ? weight(3)= (weight(3)+ q*negpoints(i,2))
? ? ? ? ? ?
? ? ? ? ? ? x =[50,100,200,250,350]
? ? ? ? ? ? x1 =[50,100,200,250,350]
? ? ? ? ? ? y1 = (weight(3)/weight(2))*(x1-200) +200
? ? ? ? ? ? plot(x1,y1,'b'); ? ? ? ? ??
? ? ? ? ? ? hold on;
? ? ? ? ? ? y = -(weight(2)/weight(3))*x -( weight(1)/weight(3))
? ? ? ? ? ? plot(x,y,'r'); ? ? ? ? ??
? ? ? ? ? ? hold on;
? ? ? ? end
? ? end ?
end
轉載于:https://blog.51cto.com/6510827/1854839
總結
以上是生活随笔為你收集整理的PAL算法原理及代码实现的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python登录系统简易框架
- 下一篇: js,jq设置获取属性,样式