BP神经网络处理iris数据集(Pytorch实现)
一.數(shù)據(jù)集介紹
這次數(shù)據(jù)集使用的是iris數(shù)據(jù)集,也稱鳶尾花卉數(shù)據(jù)集,是一類多重變量分析的數(shù)據(jù)集。數(shù)據(jù)集包含150個數(shù)據(jù)樣本,分為3類,每類50個數(shù)據(jù),每個數(shù)據(jù)包含4個屬性??赏ㄟ^花萼長度,花萼寬度,花瓣長度,花瓣寬度4個屬性預(yù)測鳶尾花卉屬于(Setosa,Versicolour,Virginica)三個種類中的哪一類。
該數(shù)據(jù)集進行神經(jīng)網(wǎng)絡(luò)時,輸入是 Sepal.Length(花萼長度), Sepal.Width(花萼寬度),Petal.Length(花瓣長度), Petal.Width(花瓣寬度),輸出為種類,Iris Setosa(山鳶尾)、Iris Versicolour(雜色鳶尾),以及Iris Virginica(維吉尼亞鳶尾)。
二.代碼實現(xiàn)
代碼部分總共為兩個版本,分別是CPU版本和GPU版本。
數(shù)據(jù)集是從sklearn中下載得到:
我們可以看一下該數(shù)據(jù)集的輸出,輸出可以自己看,這就不展示了。
print(data) print(iris_type)之后需要對數(shù)據(jù)進行處理,因為使用到pytorch,我們需要將數(shù)據(jù)轉(zhuǎn)為tensor格式:
input = torch.FloatTensor(dataset['data']) label = torch.LongTensor(dataset['target'])接下來可以定義神經(jīng)網(wǎng)絡(luò)模型:
class BPNerualNetwork(torch.nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Linear(input_size, hidden_size1),nn.ReLU(),nn.Linear(hidden_size1, hidden_size2),nn.ReLU(),nn.Linear(hidden_size2, hidden_size3),nn.ReLU(),nn.Linear(hidden_size3, output_size),nn.LogSoftmax(dim=1))def forward(self, x):x = self.model(x)return x這里我設(shè)置了三層隱藏層,不過你可以自己增減隱藏層,只需要調(diào)用函數(shù)nn.Linear(),激活函數(shù)可以直接設(shè)置,pytorch里面可以直接調(diào)用,我這里使用的是nn.ReLU()函數(shù)。
后續(xù)只需要進行訓(xùn)練就可以了(下面代碼是GPU版本):
三.效果
我使用了matplotlib將準(zhǔn)確率和loss進行展示:
準(zhǔn)確率達到了0.99,可以說效果不錯哦。
具體兩個版本代碼可以點這里下載。
總結(jié)
以上是生活随笔為你收集整理的BP神经网络处理iris数据集(Pytorch实现)的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: SQL Server数据库可疑处理
- 下一篇: 渲染管线概述