十一、加权线性回归案例:预测鲍鱼的年龄
加權(quán)線性回歸案例:預(yù)測鮑魚的年齡
點(diǎn)擊文章標(biāo)題即可獲取源代碼和筆記
數(shù)據(jù)集:https://download.csdn.net/download/weixin_44827418/12553408
1.導(dǎo)入數(shù)據(jù)集
數(shù)據(jù)集描述:
| 1 | 0.455 | 0.365 | 0.095 | 0.5140 | 0.2245 | 0.1010 | 0.150 | 15 |
| 1 | 0.350 | 0.265 | 0.090 | 0.2255 | 0.0995 | 0.0485 | 0.070 | 7 |
| -1 | 0.530 | 0.420 | 0.135 | 0.6770 | 0.2565 | 0.1415 | 0.210 | 9 |
| 1 | 0.440 | 0.365 | 0.125 | 0.5160 | 0.2155 | 0.1140 | 0.155 | 10 |
| 0 | 0.330 | 0.255 | 0.080 | 0.2050 | 0.0895 | 0.0395 | 0.055 | 7 |
| 4177.000000 | 4177.000000 | 4177.000000 | 4177.000000 | 4177.000000 | 4177.000000 | 4177.000000 | 4177.000000 | 4177.000000 |
| 0.052909 | 0.523992 | 0.407881 | 0.139516 | 0.828742 | 0.359367 | 0.180594 | 0.238831 | 9.933684 |
| 0.822240 | 0.120093 | 0.099240 | 0.041827 | 0.490389 | 0.221963 | 0.109614 | 0.139203 | 3.224169 |
| -1.000000 | 0.075000 | 0.055000 | 0.000000 | 0.002000 | 0.001000 | 0.000500 | 0.001500 | 1.000000 |
| -1.000000 | 0.450000 | 0.350000 | 0.115000 | 0.441500 | 0.186000 | 0.093500 | 0.130000 | 8.000000 |
| 0.000000 | 0.545000 | 0.425000 | 0.140000 | 0.799500 | 0.336000 | 0.171000 | 0.234000 | 9.000000 |
| 1.000000 | 0.615000 | 0.480000 | 0.165000 | 1.153000 | 0.502000 | 0.253000 | 0.329000 | 11.000000 |
| 1.000000 | 0.815000 | 0.650000 | 1.130000 | 2.825500 | 1.488000 | 0.760000 | 1.005000 | 29.000000 |
2. 查看數(shù)據(jù)分布狀況
import numpy as np import pandas as pd import random import matplotlib as mpl import matplotlib.pyplot as plt plt.rcParams['font.sans-serif']=['simhei'] #顯示中文 plt.rcParams['axes.unicode_minus']=False # 用來正常顯示負(fù)號(hào) %matplotlib inline mpl.cm.rainbow(np.linspace(0,1,10)) array([[5.00000000e-01, 0.00000000e+00, 1.00000000e+00, 1.00000000e+00],[2.80392157e-01, 3.38158275e-01, 9.85162233e-01, 1.00000000e+00],[6.07843137e-02, 6.36474236e-01, 9.41089253e-01, 1.00000000e+00],[1.66666667e-01, 8.66025404e-01, 8.66025404e-01, 1.00000000e+00],[3.86274510e-01, 9.84086337e-01, 7.67362681e-01, 1.00000000e+00],[6.13725490e-01, 9.84086337e-01, 6.41213315e-01, 1.00000000e+00],[8.33333333e-01, 8.66025404e-01, 5.00000000e-01, 1.00000000e+00],[1.00000000e+00, 6.36474236e-01, 3.38158275e-01, 1.00000000e+00],[1.00000000e+00, 3.38158275e-01, 1.71625679e-01, 1.00000000e+00],[1.00000000e+00, 1.22464680e-16, 6.12323400e-17, 1.00000000e+00]]) mpl.cm.rainbow(np.linspace(0,1,10))[0] array([0.5, 0. , 1. , 1. ]) def dataPlot(dataSet):m,n = dataSet.shapefig = plt.figure(figsize=(8,20),dpi=100)colormap = mpl.cm.rainbow(np.linspace(0,1,n))for i in range(n):fig_ = fig.add_subplot(n,1,i+1)plt.scatter(range(m),dataSet.iloc[:,i].values,s=2,c=colormap[i])plt.title(dataSet.columns[i])plt.tight_layout(pad=1.2) # 調(diào)節(jié)子圖間的距離 # 運(yùn)行函數(shù),查看數(shù)據(jù)分布: dataPlot(abalone) 'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'. Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points. 'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'. Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points. 'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'. Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points. 'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'. Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points. 'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'. Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points. 'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'. Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points. 'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'. Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points. 'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'. Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points. 'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'. Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.可以從數(shù)據(jù)分布散點(diǎn)圖中看出:
1)除“性別”之外,其他數(shù)據(jù)明顯存在規(guī)律性排列
2)“高度”這一特征中,有兩個(gè)異常值
從看到的現(xiàn)象,我們可以采取以下兩種措施:
1) 切分訓(xùn)練集和測試集時(shí),需要打亂原始數(shù)據(jù)集來進(jìn)行隨機(jī)挑選
2) 剔除"高度"這一特征中的異常值
abalone['高度']<0.4 0 True 1 True 2 True 3 True 4 True... 4172 True 4173 True 4174 True 4175 True 4176 True Name: 高度, Length: 4177, dtype: bool aba = abalone.loc[abalone['高度']<0.4,:] #再次查看數(shù)據(jù)集的分布 dataPlot(aba) 'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'. Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points. 'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'. Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points. 'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'. Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points. 'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'. Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points. 'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'. Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points. 'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'. Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points. 'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'. Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points. 'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'. Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points. 'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'. Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.2. 切分訓(xùn)練集和測試集
""" 函數(shù)功能:隨機(jī)切分訓(xùn)練集和測試集 參數(shù)說明:dataSet:原始數(shù)據(jù)集rate:訓(xùn)練集比例 返回:train,test:切分好的訓(xùn)練集和測試集 """ def randSplit(dataSet,rate):l = list(dataSet.index) # 將原始數(shù)據(jù)集的索引提取出來,存到列表中random.seed(123) # 設(shè)置隨機(jī)數(shù)種子random.shuffle(l) # 隨機(jī)打亂數(shù)據(jù)集中的索引dataSet.index = l # 把打亂后的索引重新賦值給數(shù)據(jù)集中的索引,# 索引打亂了就相當(dāng)于打亂了原始數(shù)據(jù)集中的數(shù)據(jù)m = dataSet.shape[0] # 原始數(shù)據(jù)集樣本總數(shù)n = int(m*rate) # 訓(xùn)練集樣本數(shù)量train = dataSet.loc[range(n),:] # 從打亂了的原始數(shù)據(jù)集中提取出訓(xùn)練集數(shù)據(jù)test = dataSet.loc[range(n,m),:] # 從打亂了的原始數(shù)據(jù)集中提取出測試集數(shù)據(jù)train.index = range(train.shape[0]) # 重置train訓(xùn)練數(shù)據(jù)集中的索引test.index = range(test.shape[0]) # 重置test測試數(shù)據(jù)集中的索引dataSet.index = range(dataSet.shape[0]) # 重置原始數(shù)據(jù)集中的索引return train,test train,test = randSplit(aba,0.8) #探索訓(xùn)練集 train.head()| -1 | 0.590 | 0.470 | 0.170 | 0.9000 | 0.3550 | 0.1905 | 0.2500 | 11 |
| 1 | 0.560 | 0.450 | 0.145 | 0.9355 | 0.4250 | 0.1645 | 0.2725 | 11 |
| -1 | 0.635 | 0.535 | 0.190 | 1.2420 | 0.5760 | 0.2475 | 0.3900 | 14 |
| 1 | 0.505 | 0.390 | 0.115 | 0.5585 | 0.2575 | 0.1190 | 0.1535 | 8 |
| 1 | 0.510 | 0.410 | 0.145 | 0.7960 | 0.3865 | 0.1815 | 0.1955 | 8 |
| 4177.000000 | 4177.000000 | 4177.000000 | 4177.000000 | 4177.000000 | 4177.000000 | 4177.000000 | 4177.000000 | 4177.000000 |
| 0.052909 | 0.523992 | 0.407881 | 0.139516 | 0.828742 | 0.359367 | 0.180594 | 0.238831 | 9.933684 |
| 0.822240 | 0.120093 | 0.099240 | 0.041827 | 0.490389 | 0.221963 | 0.109614 | 0.139203 | 3.224169 |
| -1.000000 | 0.075000 | 0.055000 | 0.000000 | 0.002000 | 0.001000 | 0.000500 | 0.001500 | 1.000000 |
| -1.000000 | 0.450000 | 0.350000 | 0.115000 | 0.441500 | 0.186000 | 0.093500 | 0.130000 | 8.000000 |
| 0.000000 | 0.545000 | 0.425000 | 0.140000 | 0.799500 | 0.336000 | 0.171000 | 0.234000 | 9.000000 |
| 1.000000 | 0.615000 | 0.480000 | 0.165000 | 1.153000 | 0.502000 | 0.253000 | 0.329000 | 11.000000 |
| 1.000000 | 0.815000 | 0.650000 | 1.130000 | 2.825500 | 1.488000 | 0.760000 | 1.005000 | 29.000000 |
| 3340.000000 | 3340.000000 | 3340.000000 | 3340.000000 | 3340.000000 | 3340.000000 | 3340.000000 | 3340.000000 | 3340.000000 |
| 0.060479 | 0.522754 | 0.406886 | 0.138790 | 0.824906 | 0.358151 | 0.179732 | 0.237158 | 9.911976 |
| 0.819021 | 0.120300 | 0.099372 | 0.038441 | 0.488535 | 0.222422 | 0.109036 | 0.137920 | 3.223534 |
| -1.000000 | 0.075000 | 0.055000 | 0.000000 | 0.002000 | 0.001000 | 0.000500 | 0.001500 | 1.000000 |
| -1.000000 | 0.450000 | 0.350000 | 0.115000 | 0.439000 | 0.184375 | 0.092000 | 0.130000 | 8.000000 |
| 0.000000 | 0.540000 | 0.420000 | 0.140000 | 0.796750 | 0.335500 | 0.171000 | 0.232000 | 9.000000 |
| 1.000000 | 0.615000 | 0.480000 | 0.165000 | 1.147250 | 0.498500 | 0.250500 | 0.325000 | 11.000000 |
| 1.000000 | 0.780000 | 0.630000 | 0.250000 | 2.825500 | 1.488000 | 0.760000 | 1.005000 | 27.000000 |
| 1 | 0.630 | 0.470 | 0.150 | 1.1355 | 0.5390 | 0.2325 | 0.3115 | 12 |
| -1 | 0.585 | 0.445 | 0.140 | 0.9130 | 0.4305 | 0.2205 | 0.2530 | 10 |
| -1 | 0.390 | 0.290 | 0.125 | 0.3055 | 0.1210 | 0.0820 | 0.0900 | 7 |
| 1 | 0.525 | 0.410 | 0.130 | 0.9900 | 0.3865 | 0.2430 | 0.2950 | 15 |
| 1 | 0.625 | 0.475 | 0.160 | 1.0845 | 0.5005 | 0.2355 | 0.3105 | 10 |
| 835.000000 | 835.000000 | 835.000000 | 835.000000 | 835.000000 | 835.000000 | 835.000000 | 835.000000 | 835.000000 |
| 0.022754 | 0.528808 | 0.411737 | 0.140784 | 0.842714 | 0.363370 | 0.183749 | 0.245320 | 10.022754 |
| 0.834341 | 0.119166 | 0.098627 | 0.038664 | 0.495990 | 0.218938 | 0.111510 | 0.143925 | 3.230284 |
| -1.000000 | 0.130000 | 0.100000 | 0.015000 | 0.013000 | 0.004500 | 0.003000 | 0.004000 | 3.000000 |
| -1.000000 | 0.450000 | 0.350000 | 0.115000 | 0.458000 | 0.192000 | 0.096500 | 0.132750 | 8.000000 |
| 0.000000 | 0.550000 | 0.430000 | 0.140000 | 0.810000 | 0.339000 | 0.170500 | 0.235000 | 10.000000 |
| 1.000000 | 0.620000 | 0.485000 | 0.170000 | 1.177250 | 0.510750 | 0.259250 | 0.337000 | 11.000000 |
| 1.000000 | 0.815000 | 0.650000 | 0.250000 | 2.555000 | 1.145500 | 0.590000 | 0.815000 | 29.000000 |
3.構(gòu)建輔助函數(shù)
''' 函數(shù)功能:輸入DF數(shù)據(jù)集(最后一列為標(biāo)簽),返回特征矩陣和標(biāo)簽矩陣 ''' def get_Mat(dataSet):xMat = np.mat(dataSet.iloc[:,:-1].values)yMat = np.mat(dataSet.iloc[:,-1].values).Treturn xMat,yMat ''' 函數(shù)功能:數(shù)據(jù)集可視化 ''' def plotShow(dataSet):xMat,yMat = get_Mat(dataSet)plt.scatter(xMat.A[:,1],yMat.A,c='b',s=5)plt.show() ''' 函數(shù)功能:計(jì)算回歸系數(shù) 參數(shù)說明:dataSet:原始數(shù)據(jù)集 返回:ws:回歸系數(shù) ''' def standRegres(dataSet):xMat,yMat = get_Mat(dataSet)xTx = xMat.T * xMatif np.linalg.det(xTx) == 0:print('矩陣為奇異矩陣,無法求逆!')returnws = xTx.I*(xMat.T*yMat) # xTx.I ,用來求逆矩陣return ws """ 函數(shù)功能:計(jì)算誤差平方和SSE 參數(shù)說明:dataSet:真實(shí)值regres:求回歸系數(shù)的函數(shù) 返回:SSE:誤差平方和 """ def sseCal(dataSet, regres):xMat,yMat = get_Mat(dataSet)ws = regres(dataSet)yHat = xMat*wssse = ((yMat.A.flatten() - yHat.A.flatten())**2).sum()# return sse以ex0數(shù)據(jù)集為例,查看函數(shù)運(yùn)行結(jié)果:
ex0 = pd.read_table("./datas/ex0.txt",header=None) ex0.head()| 1.0 | 0.067732 | 3.176513 |
| 1.0 | 0.427810 | 3.816464 |
| 1.0 | 0.995731 | 4.550095 |
| 1.0 | 0.738336 | 4.256571 |
| 1.0 | 0.981083 | 4.560815 |
構(gòu)建相關(guān)系數(shù)R2計(jì)算函數(shù)
""" 函數(shù)功能:計(jì)算相關(guān)系數(shù)R2 """ def rSquare(dataSet,regres):xMat,yMat=get_Mat(dataSet)sse = sseCal(dataSet,regres)sst = ((yMat.A-yMat.mean())**2).sum()# r2 = 1 - sse / sstreturn r2同樣以ex0數(shù)據(jù)集為例,查看函數(shù)運(yùn)行結(jié)果:
#簡單線性回歸的R2 rSquare(ex0, standRegres) 0.9731300889856916 ''' 函數(shù)功能:計(jì)算局部加權(quán)線性回歸的預(yù)測值 參數(shù)說明:testMat:測試集xMat:訓(xùn)練集的特征矩陣yMat:訓(xùn)練集的標(biāo)簽矩陣返回:yHat:函數(shù)預(yù)測值 ''' def LWLR(testMat,xMat,yMat,k=1.0):n = testMat.shape[0] # 測試數(shù)據(jù)集行數(shù)m = xMat.shape[0] # 訓(xùn)練集特征矩陣行數(shù)weights = np.mat(np.eye(m)) # 用單位矩陣來初始化權(quán)重矩陣,yHat = np.zeros(n) # 用0矩陣來初始化預(yù)測值矩陣for i in range(n):for j in range(m):diffMat = testMat[i] - xMat[j]weights[j,j] = np.exp(diffMat*diffMat.T / (-2*k**2))xTx = xMat.T*(weights*xMat)if np.linalg.det(xTx) == 0:print('矩陣為奇異矩陣,無法求逆')returnws = xTx.I*(xMat.T*(weights*yMat))yHat[i] = testMat[i] * wsreturn ws,yHat4.構(gòu)建加權(quán)線性模型
因?yàn)閿?shù)據(jù)量太大,計(jì)算速度極慢,所以此處選擇訓(xùn)練集的前100個(gè)數(shù)據(jù)作為訓(xùn)練集,測試集的前100個(gè)數(shù)據(jù)作為測試集。
""" 函數(shù)功能:繪制不同k取值下,訓(xùn)練集和測試集的SSE曲線 """ def ssePlot(train,test):X0,Y0 = get_Mat(train)X1,Y1 =get_Mat(test)train_sse = []test_sse = []for k in np.arange(0.2,10,0.5):ws1,yHat1 = LWLR(X0[:99],X0[:99],Y0[:99],k) sse1 = ((Y0[:99].A.T - yHat1)**2).sum() train_sse.append(sse1)ws2,yHat2 = LWLR(X1[:99],X0[:99],Y0[:99],k) sse2 = ((Y1[:99].A.T - yHat2)**2).sum() test_sse.append(sse2)plt.figure(figsize=(20,8),dpi=100)plt.plot(np.arange(0.2,10,0.5),train_sse,color='b')# plt.plot(np.arange(0.2,10,0.5),test_sse,color='r') plt.xlabel('不同k取值')plt.ylabel('SSE')plt.legend(['train_sse','test_sse'])運(yùn)行結(jié)果:
ssePlot(train,test)這個(gè)圖的解讀應(yīng)該是這樣的:從右往左看,當(dāng)K取較大值時(shí),模型比較穩(wěn)定,隨著K值的減小,訓(xùn)練集的SSE開始逐漸減小,當(dāng)K取到2左右,訓(xùn)練集的SSE與測試集的SSE相等,當(dāng)K繼續(xù)減小時(shí),訓(xùn)練集的SSE也越來越小,也就是說,模型在訓(xùn)練集上的表現(xiàn)越來越好,但是,模型在測試集上的表現(xiàn)卻越來越差了,這就說明模型開始出現(xiàn)過擬合了。其實(shí),這個(gè)圖與前面不同k值的結(jié)果圖是吻合的,K=1.0,
0.01, 0.003這三張圖也表明隨著K的減小,模型會(huì)逐漸出現(xiàn)過擬合。所以這里可以看出,K在2左右的取值最佳。
我們再將K=2帶入局部線性回歸模型中,然后查看預(yù)測結(jié)果:
train,test = randSplit(aba,0.8) # 隨機(jī)切分原始數(shù)據(jù)集,得到訓(xùn)練集和測試集 trainX,trainY = get_Mat(train) # 將切分好的訓(xùn)練集分成特征矩陣和標(biāo)簽矩陣 testX,testY = get_Mat(test) # 將切分好的測試集分成特征矩陣和標(biāo)簽矩陣 ws0,yHat0 = LWLR(testX,trainX,trainY,k=2)繪制真實(shí)值與預(yù)測值之間的關(guān)系圖
y=testY.A.flatten() plt.scatter(y,yHat0,c='b',s=5); # ;等效于plt.show()通過上圖可知,橫坐標(biāo)為真實(shí)值,縱坐標(biāo)為預(yù)測值,形成的圖像為呈現(xiàn)一個(gè)“喇叭形”,隨著橫坐標(biāo)真實(shí)值逐漸變大,縱坐標(biāo)預(yù)測值也越來越大,說明隨著真實(shí)值的增加,預(yù)測值偏差越來越大
封裝一個(gè)函數(shù)來計(jì)算SSE和R方,方便后續(xù)調(diào)用
""" 函數(shù)功能:計(jì)算加權(quán)線性回歸的SSE和R方 """ def LWLR_pre(dataSet):train,test = randSplit(dataSet,0.8)# trainX,trainY = get_Mat(train)testX,testY = get_Mat(test)ws,yHat = LWLR(testX,trainX,trainY,k=2)# sse = ((testY.A.T - yHat)**2).sum()# sst = ((testY.A-testY.mean())**2).sum() # r2 = 1 - sse / sstreturn sse,r2查看模型預(yù)測結(jié)果
LWLR_pre(aba) (4152.777097646255, 0.5228101340130846)從結(jié)果可以看出,SSE達(dá)4000+,相關(guān)系數(shù)只有0.52,模型效果并不是很好。
總結(jié)
以上是生活随笔為你收集整理的十一、加权线性回归案例:预测鲍鱼的年龄的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python输出字体的大小_Python
- 下一篇: 软件架构设计案例_透过现象看本质:常见的