机器学习--CART分类回归树
目錄
文章目錄
- 目錄
- 前言
- 1.CART回歸樹簡介
- 2.剪枝策略
- 3.模型樹
- 4.線性回歸 回歸樹 模型樹比較
前言
雖然許多問題都可以用線性方法取得良好的結果,但現實中也有許多問題是非線性的,用線性模型并不能很好的擬合數據,這種情況下可以使用樹回歸來擬合數據。因此本文特別介紹一下CART, 樹剪枝,模型樹等等算法。
1.CART回歸樹簡介
傳統決策樹是一種貪心算法,在給定時間內做出最佳選擇,不關心是否達到全局最優。切分過于迅速,特征一旦使用后面將不再使用。不能處理連續型特征,進行離散化可能會破壞連續變量的內在特征。
CART 分類回歸樹,既能分類又能回歸。CRAT來進行節點決策時,使用二元切分來處理連續型變量,給定特征屬性以及特征值,若大于該值則執行左子樹,相反則放入右子樹。當某個節點不能再切分時,節點值是單個值(CART),也可以是一個線性方程(模型樹)。
加載數據集 按行加載到矩陣中:
def loadDataSet(fileName): #general function to parse tab -delimited floatsdataMat = [] #assume last column is target valuefr = open(fileName)for line in fr.readlines():curLine = line.strip().split('\t')fltLine = list(map(float,curLine)) #map all elements to float()dataMat.append(fltLine)return dataMat按照某一特征以及響應特征值來對數據集進行劃分:
feature是特征屬性的索引 即列數 value劃分閾值 如果value大于閾值則放入mat0 否則放入mat1
def binSplitDataSet(dataSet, feature, value):mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]return mat0,mat1leafType 創建葉節點的函數 errType代表誤差計算函數 每一個節點使用字典來存儲,分別包含spInd spVal left right等key值。
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filteringfeat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best splitif feat == None: return val #if the splitting hit a stop condition return valretTree = {}retTree['spInd'] = featretTree['spVal'] = vallSet, rSet = binSplitDataSet(dataSet, feat, val)retTree['left'] = createTree(lSet, leafType, errType, ops)retTree['right'] = createTree(rSet, leafType, errType, ops)return retTree樹節點劃分的度量,計算連續函數的混亂度(決策樹使用信息熵以及基尼純等),這里可以采用數據的總方差來計算數據的混亂度,均方差乘以數據集的樣本數。
遍歷所有特征以及所有特征值使總方差最小的值即為劃分特征以及劃分閾值。
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):tolS = ops[0]; tolN = ops[1]#if all the target variables are the same value: quit and return valueif len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1return None, leafType(dataSet)m,n = shape(dataSet)#the choice of the best feature is driven by Reduction in RSS error from meanS = errType(dataSet)bestS = inf; bestIndex = 0; bestValue = 0for featIndex in range(n-1):for splitVal in set((dataSet[:,featIndex].T.A.tolist())[0]):mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continuenewS = errType(mat0) + errType(mat1)if newS < bestS: bestIndex = featIndexbestValue = splitValbestS = newS#if the decrease (S-bestS) is less than a threshold don't do the splitif (S - bestS) < tolS: return None, leafType(dataSet) #exit cond 2mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): #exit cond 3return None, leafType(dataSet)return bestIndex,bestValue#returns the best feature to split on#and the value used for that splittolS為容許的誤差最小下降值,當劃分一次誤差小于該值時,提升效果不大,直接返回。
tolN為切分的最少樣本數,當切分之后,左右子數量小于tolN,說明切分字節過小,直接返回。
leafType為葉子結點的創建函數,采用均值方式
def regLeaf(dataSet):#returns the value used for each leafreturn mean(dataSet[:,-1])errType為誤差估計函數,這里使用總方差,即均方差乘以樣本總數
def regErr(dataSet):return var(dataSet[:,-1]) * shape(dataSet)[0]如果某個節點數據特征值都相同,則無法繼續劃分,直接返回葉子結點。
if len(set(dataSet[:,-1].T.tolist()[0])) == 1:遍歷每一個特征以及相應的特征值來進行劃分,計算每一種劃分的總方差,返回最優的特征屬性以及特征閾值:
for featIndex in range(n-1):for splitVal in set(dataSet[:,featIndex]):繪出樣本集的分布圖:
def plotarr(arr):import matplotlib.pyplot as pltfig = plt.figure()ax = fig.add_subplot(111)ax.scatter(arr[:,0].flatten().A[0], arr[:,1].flatten().A[0])plt.show()運行測試如下:
加載另一數據集:
得到CART回歸結果:
2.剪枝策略
當回歸樹葉子結點過多時,容易發生過擬合,導致泛化性能降低。可以采取剪枝來防止過擬合,有預剪枝以及后剪枝。
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):tolS = ops[0]; tolN = ops[1]#if all the target variables are the same value: quit and return valueif len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1return None, leafType(dataSet)m,n = shape(dataSet)#the choice of the best feature is driven by Reduction in RSS error from meanS = errType(dataSet)bestS = inf; bestIndex = 0; bestValue = 0for featIndex in range(n-1):for splitVal in set((dataSet[:,featIndex].T.A.tolist())[0]):mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continuenewS = errType(mat0) + errType(mat1)if newS < bestS: bestIndex = featIndexbestValue = splitValbestS = newS#if the decrease (S-bestS) is less than a threshold don't do the splitif (S - bestS) < tolS: return None, leafType(dataSet) #exit cond 2mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): #exit cond 3return None, leafType(dataSet)return bestIndex,bestValue#returns the best feature to split on#and the value used for that split其中tolS與tolN就能在一定程度上防止過擬合,主要采用預剪枝。通過tolS如果剪枝對于數據集的誤差降低不大則可以不劃分節點,tolN如果剪枝之后葉子結點數據過少,則也可以預剪枝處理。這對參數tolS,tolN的取值提出了很高的要求,往往難以設置求解。
后剪枝:將數據分為訓練集與測試集,首先構建一顆完整樹,然后依次尋找葉子結點,用測試集來判斷將葉子結點合并是否能降低測試誤差,若能則采取后剪枝。
基于已有的樹切分測試數據:
如果存在任一子集是一棵樹,在該子集繼續剪枝過程。
計算將兩個葉子結點合并后的誤差
計算不合并的誤差
若合并會降低誤差,則合并兩個葉子結點
判斷某一節點是否是一棵樹,及判斷是否為字典類型:
def isTree(obj):return (type(obj).__name__=='dict')執行樹坍塌過程,返回樹的平均值
def getMean(tree):if isTree(tree['right']): tree['right'] = getMean(tree['right'])if isTree(tree['left']): tree['left'] = getMean(tree['left'])return (tree['left']+tree['right'])/2.0進行后剪枝處理:
def prune(tree, testData):if shape(testData)[0] == 0: return getMean(tree) #if we have no test data collapse the tree#如果該樹是子集,則劃分測試數據,繼續后剪枝if (isTree(tree['right']) or isTree(tree['left'])):#if the branches are not trees try to prune themlSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)#if they are now both leafs, see if we can merge them#如果節點是葉子結點if not isTree(tree['left']) and not isTree(tree['right']):#劃分測試數據lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])#沒有合并前的誤差errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\sum(power(rSet[:,-1] - tree['right'],2))#合并后的誤差 合并后 節點值變為兩個子節點的平均值treeMean = (tree['left']+tree['right'])/2.0#取出最后一列y的值真實值與預測值求總方差errorMerge = sum(power(testData[:,-1] - treeMean,2))if errorMerge < errorNoMerge: print ("merging")return treeMeanelse: return treeelse: return tree加載數據集,創建一顆最完整的分類回歸樹, 設置tolS=0, tolN=1
>>> dd = mat(regTrees.loadDataSet('ex2.txt')) >>> mt = regTrees.createTree(dd, ops(0,1))此時回歸樹
剪枝之后:
有一部分節點被剪掉。
3.模型樹
前面CART葉子結點為某個值,現在可以把葉子結點變為一個分段函數,即某一個葉子結點下面允許分段函數形式的數據存在。
將一個數據集求出線性擬合函數:
def linearSolve(dataSet): #helper function used in two placesm,n = shape(dataSet)X = mat(ones((m,n))); Y = mat(ones((m,1)))#create a copy of data with 1 in 0th postionX[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]#and strip out YxTx = X.T*Xif linalg.det(xTx) == 0.0:raise NameError('This matrix is singular, cannot do inverse,\n\try increasing the second value of ops')ws = xTx.I * (X.T * Y)return ws,X,Y首先進行數據矩陣變換,利用線性模型直接求解回歸系數ws
如果一個節點是葉子結點時,需要存儲ws系數權向量
def modelLeaf(dataSet):#create linear model and return coeficientsws,X,Y = linearSolve(dataSet)return ws當采用線性模型時,使用平方誤差和來計算總誤差:
def modelErr(dataSet):ws,X,Y = linearSolve(dataSet)yHat = X * wsreturn sum(power(Y - yHat,2))加載數據集進行測試:
def testmodel():tt = mat(loadDataSet('exp2.txt'))return createTree(tt, modelLeaf, modelErr, (1, 10))數據集分布:
4.線性回歸 回歸樹 模型樹比較
通過對于同一份數據進行訓練模型,在通過測試集比較不同模型之間的性能差異。
模型樹與回歸樹預測值的輸出:
def treeForeCast(tree, inData, modelEval=regTreeEval):if not isTree(tree): return modelEval(tree, inData)if inData[tree['spInd']] > tree['spVal']:if isTree(tree['left']): return treeForeCast(tree['left'], inData, modelEval)else: return modelEval(tree['left'], inData)else:if isTree(tree['right']): return treeForeCast(tree['right'], inData, modelEval)else: return modelEval(tree['right'], inData)tree訓練樹所得,inData為待預測的樣本行向量,modelEval表示節點類型,當modelEval=regTreeEval說明葉子節點為分類類型,節點值為具體的分類值,即預測值值直接返回節點值即可,當modelEval=modelTreeEval時,說明葉子結點為回歸類型,節點值為線性權向量,返回值應該與測試數據相乘得到最終預測值。
def regTreeEval(model, inDat):return float(model)def modelTreeEval(model, inDat):n = shape(inDat)[1]X = mat(ones((1,n+1)))X[:,1:n+1]=inDatreturn float(X*model)返回測試集的預測值,列向量:
def createForeCast(tree, testData, modelEval=regTreeEval):m=len(testData)yHat = mat(zeros((m,1)))for i in range(m):yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval)return yHat利用相關系數來衡量數據擬合情況:
def regtree():traindata = mat(loadDataSet('bikeSpeedVsIq_train.txt'))testdata = mat(loadDataSet('bikeSpeedVsIq_test.txt'))mt = createTree(traindata, ops=(1, 20))yHat = createForeCast(mt, testdata[:,0])return corrcoef(yHat, testdata[:,1], rowvar=0)[0,1]def modeltree():traindata = mat(loadDataSet('bikeSpeedVsIq_train.txt'))testdata = mat(loadDataSet('bikeSpeedVsIq_test.txt'))mt = createTree(traindata, modelLeaf, modelErr, ops=(1, 20))yHat = createForeCast(mt, testdata[:,0], modelTreeEval)return corrcoef(yHat, testdata[:,1], rowvar=0)[0,1]def reg():traindata = mat(loadDataSet('bikeSpeedVsIq_train.txt'))testdata = mat(loadDataSet('bikeSpeedVsIq_test.txt'))ws, x, y = linearSolve(traindata)yHat=[0]*shape(testdata)[0]for i in range(shape(testdata)[0]):yHat[i] = testdata[i,0]*ws[1,0]+ws[0,0]return corrcoef(yHat, testdata[:,1], rowvar=0)[0,1]可知模型樹擬合效果最好
5.Tkinter庫圖形化
使用tkinter庫來實現是圖畫化展示數據擬合
from numpy import *#python3導入方式不變from tkinter import *import regTreesimport matplotlibmatplotlib.use('TkAgg')from matplotlib.backends.backend_tkagg import FigureCanvasTkAggfrom matplotlib.figure import Figuredef reDraw(tolS,tolN):reDraw.f.clf() # clear the figurereDraw.a = reDraw.f.add_subplot(111)if chkBtnVar.get():if tolN < 2: tolN = 2#繪出模型樹myTree=regTrees.createTree(reDraw.rawDat, regTrees.modelLeaf,\regTrees.modelErr, (tolS,tolN))yHat = regTrees.createForeCast(myTree, reDraw.testDat, \regTrees.modelTreeEval)else:#繪出回歸樹myTree=regTrees.createTree(reDraw.rawDat, ops=(tolS,tolN))yHat = regTrees.createForeCast(myTree, reDraw.testDat)#繪出數據分布時,矩陣和一位數組之間的轉換reDraw.a.scatter(reDraw.rawDat[:,0].flatten().A[0], reDraw.rawDat[:,1].flatten().A[0], s=5) #use scatter for data setreDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0) #use plot for yHat#修改draw()為show()reDraw.canvas.draw()def getInputs():try: tolN = int(tolNentry.get())except: tolN = 10 print ("enter Integer for tolN")tolNentry.delete(0, END)tolNentry.insert(0,'10')try: tolS = float(tolSentry.get())except: tolS = 1.0 print ("enter Float for tolS")tolSentry.delete(0, END)tolSentry.insert(0,'1.0')return tolN,tolSdef drawNewTree():tolN,tolS = getInputs()#get values from Entry boxesreDraw(tolS,tolN)root=Tk()reDraw.f = Figure(figsize=(5,4), dpi=100) #create canvasreDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)# show()方法應該修改為draw()reDraw.canvas.draw()reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)Label(root, text="tolN").grid(row=1, column=0)tolNentry = Entry(root)tolNentry.grid(row=1, column=1)tolNentry.insert(0,'10')Label(root, text="tolS").grid(row=2, column=0)tolSentry = Entry(root)tolSentry.grid(row=2, column=1)tolSentry.insert(0,'1.0')Button(root, text="ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3)chkBtnVar = IntVar()chkBtn = Checkbutton(root, text="Model Tree", variable = chkBtnVar)chkBtn.grid(row=3, column=0, columnspan=2)reDraw.rawDat = mat(regTrees.loadDataSet('sine.txt'))reDraw.testDat = arange(min(reDraw.rawDat[:,0]),max(reDraw.rawDat[:,0]),0.01)reDraw(1.0, 10)root.mainloop()由于python3的變化,代碼需要改變如下:
1.from tkinter import * 庫導入庫名變為小寫
2.reDraw.canvas.draw() FigureCanvasTkAgg對象draw方法而不是show()
3.reDraw.a.scatter(reDraw.rawDat[:,0].flatten().A[0], reDraw.rawDat[:,1].flatten().A[0], s=5)進行數據分布繪制時需要轉換矩陣為一維數組。
改變tolS tolN的值,繪制如下:
總結
以上是生活随笔為你收集整理的机器学习--CART分类回归树的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: android 缩放视图,当容器视图缩放
- 下一篇: Python:高阶错误