《机器学习实战》chapter03 决策树
生活随笔
收集整理的這篇文章主要介紹了
《机器学习实战》chapter03 决策树
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
分類生成決策樹
import operator from math import log import pickle# 計算香農熵 def calcShannonEnt(dataSet):"""1、計算每個類別的頻數"""numEntries = len(dataSet)# 類別字典,保存不同類別的頻數labelCounts = {}for featVec in dataSet:currentLabel = featVec[-1]# 如果當前類別不在字典中,將其加入if currentLabel not in labelCounts.keys():labelCounts[currentLabel] = 0# 當前類別數量+1labelCounts[currentLabel] += 1"""2、用香農熵公式計算香農熵"""shannonEnt = 0.0for key in labelCounts:prob = float(labelCounts[key]) / numEntriesshannonEnt -= prob * log(prob, 2)return shannonEntdef createDataSet():dataSet = [[1, 1, 'yes'],[1, 1, 'yes'],[1, 0, 'no'],[0, 1, 'no'],[0, 1, 'no']]labels = ['no surfacing', 'flippers']return dataSet, labels# 劃分數據集,以axis索引位的特征為根節(jié)點 def splitDataSet(dataSet, axis, value):retDataSet = []for featVec in dataSet:if featVec[axis] == value:reduceFeatVec = featVec[:axis]reduceFeatVec.extend(featVec[axis + 1:])retDataSet.append(reduceFeatVec)return retDataSet# 選擇最好的數據集劃分形式 def choseBestFeatureToSplit(dataSet):# 特征個數, 有一個是類別(去掉)numFeature = len(dataSet[0]) - 1baseEntropy = calcShannonEnt(dataSet)bestInfoGain = 0.0bestFeature = -1# 計算以第i個特征作為劃分節(jié)點時的信息增益,選擇信息增益最大的特征作為劃分節(jié)點for i in range(numFeature):# 取當前數據集的第i個特征(第i列的所有值)featList = [example[i] for example in dataSet]# 當前特征的可能取值范圍(去重復)uniqueValues = set(featList)newEntropy = 0.0# 計算當前特征的信息增益for value in uniqueValues:subDataSet = splitDataSet(dataSet, i, value)prob = len(subDataSet) / float(len(dataSet))newEntropy += prob * calcShannonEnt(subDataSet)infoGain = baseEntropy - newEntropy# 修正最大信息增益,最優(yōu)劃分節(jié)點if(infoGain > bestInfoGain):bestInfoGain = infoGainbestFeature = ireturn bestFeature# 多數表決確定葉子節(jié)點的分類 def majorityCnt(classList):classCount = {}for vote in classList:if vote not in classCount.keys():classCount[vote] = 0classCount[vote] += 1sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)return sortedClassCount[0][0]# 遞歸構建決策樹 def createTree(dataSet, labels):classList = [example[-1] for example in dataSet]# 類別完全相同則停止劃分if classList.count(classList[0]) == len(classList):return classList[0]# 遍歷完所有特征時,返回出現次數最多的類別if len(dataSet[0]) == 1:return majorityCnt(classList)bestFeat = choseBestFeatureToSplit(dataSet)bestFeatLabel = labels[bestFeat]myTree = {bestFeatLabel: {}}del labels[bestFeat]featVlues = [example[bestFeat] for example in dataSet]uniqueValues = set(featVlues)for value in uniqueValues:# 注意分號,復制labels到subLabels,單獨開辟了一塊內存空間# 如果沒有分號的則是subLabels指向labels指向的內存# 會因修改labels內容而出錯subLabels = labels[:]myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)return myTreedef classify(inputTree, featLabels, testVec):firstStr = list(inputTree.keys())[0]secondDict = inputTree[firstStr]featIndex = featLabels.index(firstStr)for key in secondDict.keys():if testVec[featIndex] == key:if type(secondDict[key]).__name__ == 'dict':classLabel = classify(secondDict[key], featLabels, testVec)else:classLabel = secondDict[key]return classLabeldef storeTree(inputTree, fileName):try:with open(fileName, 'wb') as fw:pickle.dump(inputTree, fw)except IOError as e:print("File Error : " + str(e))def grabTree(fileName):fr = open(fileName, 'rb')return pickle.load(fr)使用Matplotlib注解繪制樹形圖
import matplotlib.pyplot as plt# boxstyle文本框樣式, fc(face color)背景透明度 decisionNode = dict(boxstyle="round4, pad=0.5", fc="0.8") leafNode = dict(boxstyle="circle", fc="0.8") # 箭頭樣式 arrow_args = dict(arrowstyle="<-")# 繪制節(jié)點 def plotNode(nodeTxt, centerPt, parentPt, nodeType):# 被注釋的地方xy(x, y)和插入文本的地方xytext(x, y)# xycoords和textcoords指定xy和xytext的坐標系。此處是左下角(0.0,0.0),右上角(1.0,1.0)# 文本在文本框中的va(縱向),ha(橫向)居中createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords="axes fraction",xytext=centerPt, textcoords="axes fraction", va="center",ha="center", bbox=nodeType, arrowprops=arrow_args)# 獲取葉節(jié)點數目 def getNumLeafs(myTree):numLeafs = 0# Python3與Python2的區(qū)別,先轉換成list,再按索引取值# firstStr = myTree.keys()[0]firstStr = list(myTree.keys())[0]# 子樹secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':# 如果是decisionNode,遞歸numLeafs += getNumLeafs(secondDict[key])else:# leafNodenumLeafs += 1return numLeafs# 獲取樹的層數 def getTreeDepth(myTree):maxDepth = 0# 當前樹的根節(jié)點firstStr = list(myTree.keys())[0]# 子樹secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':# 如果是decisionNode(有子節(jié)點),遞歸thisDepth = 1 + getTreeDepth(secondDict[key])else:# leafNode,葉子節(jié)點thisDepth = 1# 修正maxDepth,保證maxDepth是最大值if thisDepth > maxDepth:maxDepth = thisDepthreturn maxDepth# 在父子節(jié)點之間填充文本信息 def plotMidText(cntrPt, parentPt, txtString):xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]createPlot.ax1.text(xMid, yMid, txtString)# 繪制決策樹 def plotTree(myTree, parentPt, nodeTxt):# 當前樹的葉子節(jié)點數和深度numLeafs = getNumLeafs(myTree)depth = getTreeDepth(myTree)# 當前根節(jié)點firstStr = list(myTree.keys())[0]# 修正當前位置,xOff + 當前樹的葉子節(jié)點數 / 2W + 1 / 2W# 加1/2W 是因為初始位置是-1/2W,修正這個位置cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / (2.0 * plotTree.totalW), plotTree.yOff)# 在父子節(jié)點間填充文本信息plotMidText(cntrPt, parentPt, nodeTxt)# decisionNode,繪制plotNode(firstStr, cntrPt, parentPt, decisionNode)# 當前樹的子節(jié)點secondDict = myTree[firstStr]# 深度加1,修正plotTree.yOff - 1/DplotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD# 遍歷繪制子節(jié)點for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':# decisionNode,調用plotTree繪制plotTree(secondDict[key], cntrPt, str(key))else:# 遇到leafNode,修正xOff + 1/W,調用plotNode繪制plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalWplotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalDdef createPlot(inTree):fig = plt.figure(1, facecolor='white')fig.clf()axprops = dict(xticks=[], yticks=[])createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)# 樹的寬度plotTree.totalW = float(getNumLeafs(inTree))# 樹的深度plotTree.totalD = float(getTreeDepth(inTree))# 初始偏移量-1/2W,每遇到一個葉節(jié)點加1/W,使畫出來的樹盡可能居中# 如3個葉子(1/6, 1/2, 5/6),4個葉子(1/8, 3/8, 5/8, 7/8)plotTree.xOff = -0.5 / plotTree.totalW# 初始深度0,第一層plotTree.yOff = 1.0# 繪制圖形plotTree(inTree, (0.5, 1.0), '')plt.show()myTree = {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}, 3: 'maybe'}} createPlot(myTree)測試
from chapter3 import treePlotter from chapter3 import treesfr = open('lenses.txt') lenses = [inst.strip().split('\t') for inst in fr.readlines()] lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate'] lensesTree = trees.createTree(lenses, lensesLabels) print(lensesTree) treePlotter.createPlot(lensesTree)總結
以上是生活随笔為你收集整理的《机器学习实战》chapter03 决策树的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 小试牛刀Matplotlib
- 下一篇: 《机器学习实战》chapter04 使用