如何写新的Python OP
如何寫新的Python OP
Paddle 通過 py_func 接口支持在Python端自定義OP。 py_func的設(shè)計原理在于Paddle中的Tensor可以與numpy數(shù)組可以方便的互相轉(zhuǎn)換,從而可以使用Python中的numpy API來自定義一個Python OP。
py_func接口概述
py_func 具體接口為:
def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
pass
其中,
? x 是Python Op的輸入變量,可以是單個 Tensor | tuple[Tensor] | list[Tensor] 。多個Tensor以tuple[Tensor]或list[Tensor]的形式傳入。
? out 是Python Op的輸出變量,可以是單個 Tensor | tuple[Tensor] | list[Tensor],也可以是Numpy Array 。
? func 是Python Op的前向函數(shù)。在運行網(wǎng)絡(luò)前向時,框架會調(diào)用 out = func(*x) ,根據(jù)前向輸入 x 和前向函數(shù) func 計算前向輸出 out。在 func 建議先主動將Tensor轉(zhuǎn)換為numpy數(shù)組,方便靈活的使用numpy相關(guān)的操作,如果未轉(zhuǎn)換成numpy,則可能某些操作無法兼容。
? backward_func 是Python Op的反向函數(shù)。若 backward_func 為 None ,則該Python Op沒有反向計算邏輯; 若 backward_func 不為 None,則框架會在運行網(wǎng)路反向時調(diào)用 backward_func 計算前向輸入 x 的梯度。
? skip_vars_in_backward_input 為反向函數(shù) backward_func 中不需要的輸入,可以是單個 Tensor | tuple[Tensor] | list[Tensor] 。
如何使用py_func編寫Python Op
以下以tanh為例,介紹如何利用 py_func 編寫Python Op。
? 第一步:定義前向函數(shù)和反向函數(shù)
前向函數(shù)和反向函數(shù)均由Python編寫,可以方便地使用Python與numpy中的相關(guān)API來實現(xiàn)一個自定義的OP。
若前向函數(shù)的輸入為 x_1, x_2, …, x_n ,輸出為y_1, y_2, …, y_m,則前向函數(shù)的定義格式為:
def foward_func(x_1, x_2, …, x_n):
…
return y_1, y_2, …, y_m
默認情況下,反向函數(shù)的輸入?yún)?shù)順序為:所有前向輸入變量 + 所有前向輸出變量 + 所有前向輸出變量的梯度,因此對應(yīng)的反向函數(shù)的定義格式為:
def backward_func(x_1, x_2, …, x_n, y_1, y_2, …, y_m, dy_1, dy_2, …, dy_m):
…
return dx_1, dx_2, …, dx_n
若反向函數(shù)不需要某些前向輸入變量或前向輸出變量,可設(shè)置 skip_vars_in_backward_input 進行排除(步驟三中會敘述具體的排除方法)。
注:,x_1, …, x_n為輸入的多個Tensor,請以tuple(Tensor)或list[Tensor]的形式在py_func中傳入。建議先主動將Tensor通過numpy.array轉(zhuǎn)換為數(shù)組,否則Python與numpy中的某些操作可能無法兼容使用在Tensor上。
利用numpy的相關(guān)API完成tanh的前向函數(shù)和反向函數(shù)編寫。下面給出多個前向與反向函數(shù)定義的示例:
import numpy as np
前向函數(shù)1:模擬tanh激活函數(shù)
def tanh(x):
# 可以直接將Tensor作為np.tanh的輸入?yún)?shù)
return np.tanh(x)
前向函數(shù)2:將兩個2-D Tenosr相加,輸入多個Tensor以list[Tensor]或tuple(Tensor)形式
def element_wise_add(x, y):
# 必須先手動將Tensor轉(zhuǎn)換為numpy數(shù)組,否則無法支持numpy的shape操作
x = np.array(x)
y = np.array(y)
if x.shape != y.shape:raise AssertionError("the shape of inputs must be the same!")result = np.zeros(x.shape, dtype='int32')
for i in range(len(x)):for j in range(len(x[0])):result[i][j] = x[i][j] + y[i][j]return result
前向函數(shù)3:可用于調(diào)試正在運行的網(wǎng)絡(luò)(打印值)
def debug_func(x):
# 可以直接將Tensor作為print的輸入?yún)?shù)
print(x)
前向函數(shù)1對應(yīng)的反向函數(shù),默認的輸入順序為:x、out、out的梯度
def tanh_grad(x, y, dy):
# 必須先手動將Tensor轉(zhuǎn)換為numpy數(shù)組,否則"+/-"等操作無法使用
return np.array(dy) * (1 - np.square(np.array(y)))
注意,前向函數(shù)和反向函數(shù)的輸入均是 Tensor 類型,輸出可以是Numpy Array或 Tensor。 由于 Tensor 實現(xiàn)了Python的buffer protocol協(xié)議,因此即可通過 numpy.array 直接將 Tensor 轉(zhuǎn)換為numpy Array來進行操作,也可直接將 Tensor 作為numpy函數(shù)的輸入?yún)?shù)。但建議先主動轉(zhuǎn)換為numpy Array,則可以任意的使用python與numpy中的所有操作(例如”numpy array的+/-/shape”)。
tanh的反向函數(shù)不需要前向輸入x,因此我們可定義一個不需要前向輸入x的反向函數(shù),并在后續(xù)通過 skip_vars_in_backward_input 進行排除 :
def tanh_grad_without_x(y, dy):
return np.array(dy) * (1 - np.square(np.array(y)))
? 第二步:創(chuàng)建前向輸出變量
需調(diào)用 Program.current_block().create_var 創(chuàng)建前向輸出變量。在創(chuàng)建前向輸出變量時,必須指明變量的名稱name、數(shù)據(jù)類型dtype和維度shape。
import paddle
paddle.enable_static()
def create_tmp_var(program, name, dtype, shape):
return program.current_block().create_var(name=name, dtype=dtype, shape=shape)
in_var = paddle.static.data(name=‘input’, dtype=‘float32’, shape=[-1, 28, 28])
手動創(chuàng)建前向輸出變量
out_var = create_tmp_var(paddle.static.default_main_program(), name=‘output’, dtype=‘float32’, shape=[-1, 28, 28])
? 第三步:調(diào)用 py_func 組建網(wǎng)絡(luò)
py_func 的調(diào)用方式為:
paddle.static.nn.py_func(func=tanh, x=in_var, out=out_var, backward_func=tanh_grad)
若我們不希望在反向函數(shù)輸入?yún)?shù)中出現(xiàn)前向輸入,則可使用 skip_vars_in_backward_input 進行排查,簡化反向函數(shù)的參數(shù)列表。
paddle.static.nn.py_func(func=tanh, x=in_var, out=out_var, backward_func=tanh_grad_without_x,
skip_vars_in_backward_input=in_var)
至此,使用 py_func 編寫Python Op的步驟結(jié)束??梢耘c使用其他Op一樣進行網(wǎng)路訓(xùn)練/預(yù)測。
注意事項
? py_func 的前向函數(shù)和反向函數(shù)內(nèi)部不應(yīng)調(diào)用 paddle.xx組網(wǎng)接口 ,因為前向函數(shù)和反向函數(shù)是在網(wǎng)絡(luò)運行時調(diào)用的,而 paddle.xx 是在組建網(wǎng)絡(luò)的階段調(diào)用 。
? skip_vars_in_backward_input 只能跳過前向輸入變量和前向輸出變量,不能跳過前向輸出的梯度。
? 若某個前向輸出變量沒有梯度,則 backward_func 將接收到 None 的輸入。若某個前向輸入變量沒有梯度,則我們應(yīng)在 backward_func 中主動返回 None。
總結(jié)
以上是生活随笔為你收集整理的如何写新的Python OP的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: C++ OP相关注意事项
- 下一篇: 如何在框架外部自定义C++ OP