MXNet——symbol
參考資料:有基礎(chǔ)(Pytorch/TensorFlow基礎(chǔ))mxnet+gluon快速入門
symbol
symbol 是一個(gè)重要的概念,可以理解為符號(hào),就像我們平時(shí)使用的代數(shù)符號(hào) x,y,z 一樣。一個(gè)簡(jiǎn)單的類比,一個(gè)函數(shù) \(f(x) = x^{2}\),符號(hào) x 就是 symbol,而具體 x 的值就是 ndarray,關(guān)于 symbol 的是 mxnet.sym,具體可參照官方API文檔
基本操作
- 使用 mxnet.sym.Variable() 傳入名稱可建立一個(gè) symbol
- 使用 mxnet.viz.plot_network(symbol=) 傳入 symbol 可以繪制運(yùn)算圖
帶入 ndarray
使用 mxnet.sym.bind() 方法可以獲得一個(gè)帶入操作數(shù)的對(duì)象,再使用 forward() 方法可運(yùn)算出數(shù)值
x = c.bind(ctx=mx.cpu(),args={"a": mx.nd.ones(5),"b":mx.nd.ones(5)}) result = x.forward() print(result) [ [2. 2. 2. 2. 2.] <NDArray 5 @cpu(0)>]mxnet 的數(shù)據(jù)載入
深度學(xué)習(xí)中數(shù)據(jù)的載入方式非常重要,mxnet 提供了 mxnet.io 的一系列 dataiter 用于處理數(shù)據(jù)載入,詳細(xì)可參照官方API文檔。同時(shí),動(dòng)態(tài)圖接口gluon 也提供了 mxnet.gluon.data 系列的 dataiter 用于數(shù)據(jù)載入,詳細(xì)可參照官方API文檔
mxnet.io 數(shù)據(jù)載入
mxnet.io的數(shù)據(jù)載入核心是 mxnet.io.DataIter 類及其派生類,例如 ndarray 的 iter:NDArrayIter
- 參數(shù) data:傳入一個(gè)(名稱-數(shù)據(jù))的數(shù)據(jù) dict
- 參數(shù) label:傳入一個(gè)(名稱-標(biāo)簽)的標(biāo)簽 dict
- 參數(shù) batch_size:傳入 batch 大小
gluon.data 數(shù)據(jù)載入
gluon 的數(shù)據(jù) API 幾乎與 pytorch 相同,均是 Dataset+DataLoader 的方式:
- Dataset:存儲(chǔ)數(shù)據(jù),使用時(shí)需要繼承該基類并重載 __len__(self) 和 __getitem__(self,idx) 方法
- DataLoader:將 Dataset 變成能產(chǎn)生 batch 的可迭代對(duì)象
網(wǎng)絡(luò)搭建
mxnet 網(wǎng)絡(luò)搭建
mxnet 網(wǎng)絡(luò)搭建類似于 TensorFlow,使用 symbol 搭建出網(wǎng)絡(luò),再用一個(gè) module 封裝
data = mx.sym.Variable('data') # layer1 conv1 = mx.sym.Convolution(data=data, kernel=(5,5), num_filter=32,name="conv1") relu1 = mx.sym.Activation(data=conv1,act_type="relu",name="relu1") pool1 = mx.sym.Pooling(data=relu1,pool_type="max",kernel=(2,2),stride=(2,2),name="pool1") # layer2 conv2 = mx.sym.Convolution(data=pool1, kernel=(3,3), num_filter=64,name="conv2") relu2 = mx.sym.Activation(data=conv2,act_type="relu",name="relu2") pool2 = mx.sym.Pooling(data=relu2,pool_type="max",kernel=(2,2),stride=(2,2),name="pool2") # layer3 fc1 = mx.symbol.FullyConnected(data=mx.sym.flatten(pool2), num_hidden=256,name="fc1") relu3 = mx.sym.Activation(data=fc1, act_type="relu",name="relu3") # layer4 fc2 = mx.symbol.FullyConnected(data=relu3, num_hidden=10,name="fc2") out = mx.sym.SoftmaxOutput(data=fc2, label=mx.sym.Variable("label"),name='softmax') mxnet_model = mx.mod.Module(symbol=out,label_names=["label"],context=mx.gpu()) mx.viz.plot_network(symbol=out)福利:剛剛發(fā)現(xiàn)一個(gè)解決路徑錯(cuò)誤的方法:只需要將 *\Anaconda3\Library\bin\graphviz 添加到 Path 環(huán)境變量之下即可 (安裝后記得重啟,環(huán)境變量修改才可以生效,調(diào)用庫(kù),即可成功)!
轉(zhuǎn)載于:https://www.cnblogs.com/q735613050/p/9315504.html
總結(jié)
以上是生活随笔為你收集整理的MXNet——symbol的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。