mxnet symbol 解析
mxnet symbol類定義:https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/symbol/symbol.py
對于一個symbol,可分為non-grouped和grouped。且symbol具有輸出,和輸出屬性。比如,對于Variable而言,其輸入和輸出就是它自己。對于c = a+b,c的內部有個_plus0 symbol,對于_plus0這個symbol,它的輸入是a,b,輸出是_plus0_output。
class Symbol(SymbolBase):"""Symbol is symbolic graph of the mxnet."""# disable dictionary storage, also do not have parent type.# pylint: disable=no-member其中,Symbol還不是最基礎的類,Symbol類繼承了SymbolBase這個類。
而SymbolBase這個類實際是在
https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/symbol/_internal.py
中引用的,通過以下方式引用:
from .._ctypes.symbol import SymbolBase, _set_symbol_class, _set_np_symbol_class而SymbolBase的定義是在:https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/_ctypes/symbol.py
這里暫時先不管SymbolBase,這應該是是python調用c++接口創建的一個類。
回到Symbol中來,對于mxnet符號式編程而言,定義的任何網絡,或者變量,都是symbol類型,所以,了解這個類就顯得很重要。
Symbol類中有幾類函數:
1、普通函數
2、__xx__ 函數
3、@property 修飾的函數
4、函數名為xx,實際調用op.xx的函數
1、普通函數
attr
根據key返回symbol對應的屬性字符串,只對non-grouped symbols起作用。
list_attr
得到symbol的所有屬性
attr_dict
遞歸的得到symbol和孩子的屬性
_set_attr
通過key-value方式,對attr進行設置
get_internals
獲取symbol的所有內部節點symbol,是一個group類型(包括輸入,輸出節點symbol)。如果我們想階段一個network,應該獲取它某內部節點的輸出,這樣才能作為新增加的symbol的輸入。
get_children
獲取當前symbol輸出節點的inputs
list_arguments
列出當前symbol的所有參數(可以配合call對symbol進行改造)
list_outputs
列出當前smybol的所有輸出,如果當前symbol是grouped類型,回遍歷輸出每一個symbol的輸出
list_auxiliary_states
列出symbol中的輔助狀態參數,比如BN
list_inputs
列出當前symbol的所有輸入參數,和輔助狀態,等價于 list_arguments和 list_auxiliary_states
2、__xx__函數
__repr__
對于gruop symbol,它是沒有name屬性的,print或者回車,結果就是其內部symbol節點的name
__iter__(self):
普通的symbol長度都只有1,只有Grouped 的symbol,長度才大于1:return (self[i] for i in range(len(self)))
算數及邏輯運算:
+,-,*, /,%,abs,**, 取負(-x),==,!=,>,>=,<,<=, # 使用時,要注意Broadcasting 是否支持
__copy__和__deep_copy__
通過deep_copy,創建一個深拷貝,返回輸入對象的一個拷貝,包括它當前所有參數的當前狀態,比如weight,bias等
__call__
表示symbol的實例是一個可調用對象。可以返回一個新的symbol,這個symbol繼承了之前symbol的權重啥的,但是和之前的symbol是不同的對象,可以輸入參數對symbol進行組合。
這里,我改變了b,將其輸入參數的x的值變為了tt。
__getitem__
如果symbol的長度只有1,那么返回的就是它的輸出symbol,如果symbol長度>1,可以通過切片訪問其輸出symbol,返回的也是一個Group symbol。symbol可以分為non-grouped和grouped。
獲取內部節點symbol還可以輸入str,但輸入的str必須屬于list_outputs(),
symbol.py 除了Symbol這個類之外,還有游離在外的函數:
1、 def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None,init=None, stype=None, **kwargs):"""Creates a symbolic variable with specified name. # for back compatibility Variable = var # 調用 mx.sym.var和mx.sym.Variable 等價2、 def Group(symbols, create_fn=Symbol):"""Creates a symbol that contains a collection of other symbols, grouped together.A classic symbol (`mx.sym.Symbol`) will be returned if all the symbols in the listare of that type; a numpy symbol (`mx.sym.np._Symbol`) will be returned if all thesymbols in the list are of that type. A type error will be raised if a list of mixedclassic and numpy symbols are provided.Example------->>> a = mx.sym.Variable('a')>>> b = mx.sym.Variable('b')>>> mx.sym.Group([a,b])<Symbol Grouped>Parameters----------symbols : listList of symbols to be grouped.3、 def load(fname):"""Loads symbol from a JSON file.You also get the benefit being able to directly load/save from cloud storage(S3, HDFS).Returns-------sym : SymbolThe loaded symbol.See Also--------Symbol.save : Used to save symbol into file. # 輸入文件可以是hdfs文件 4、 數學相關函數,輸入可為scalar或者是symbol def pow(base, exp):"""Returns element-wise result of base element raised to powers from exp element.base 和 exp可以是數字或者symbol # def power(base, exp): # 實際調用pow def maximum(left, right): def minimum(left, right): def hypot(left, right): # 返回直角三角形的斜邊 def eye(N, M=0, k=0, dtype=None, **kwargs):"""Returns a new symbol of 2-D shpae, filled with ones on the diagonal and zeros elsewhere. # 返回2D shape的symbol,對角線為1,其余位置為0 def zeros(shape, dtype=None, **kwargs):"""Returns a new symbol of given shape and type, filled with zeros. # 返回一個shape的全0 symbol def ones(shape, dtype=None, **kwargs):"""Returns a new symbol of given shape and type, filled with ones. def full(shape, val, dtype=None, **kwargs):"""Returns a new array of given shape and type, filled with the given value `val`. def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, name=None, dtype=None):"""Returns evenly spaced values within a given interval. def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, name=None, dtype=None):"""Returns evenly spaced values within a given interval. def linspace(start, stop, num, endpoint=True, name=None, dtype=None):"""Return evenly spaced numbers within a specified interval. def histogram(a, bins=10, range=None, **kwargs):"""Compute the histogram of the input data. def split_v2(ary, indices_or_sections, axis=0, squeeze_axis=False):"""Split an array into multiple sub-arrays.總結
以上是生活随笔為你收集整理的mxnet symbol 解析的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 构造函数。
- 下一篇: 云服务器维护包含哪些,云服务器维护内容