运行时数据获取
運行時數(shù)據(jù)獲取
OneFlow 提供了 oneflow.watch 與 oneflow.watch_diff 接口,可以通過他們注冊回調(diào)函數(shù),以方便在作業(yè)函數(shù)運行過程中獲取張量數(shù)據(jù)或梯度。
使用流程
想要獲取作業(yè)函數(shù)運行時的數(shù)據(jù)或者梯度,其基本流程如下:
? 編寫回調(diào)函數(shù),回調(diào)函數(shù)的參數(shù)需要用注解方式表明監(jiān)控的數(shù)據(jù)類型,回調(diào)函數(shù)內(nèi)部邏輯由用戶自己實現(xiàn)
? 在定義作業(yè)函數(shù)時,通過 oneflow.watch 或 oneflow.watch_diff 注冊回調(diào)函數(shù),前者獲取張量數(shù)據(jù)本身,后者獲取對應(yīng)的梯度
? 在作業(yè)函數(shù)運行時,OneFlow 框架會在適當?shù)臅r機,調(diào)用之前注冊的回調(diào),將監(jiān)控的數(shù)據(jù)傳遞給回調(diào)函數(shù),并執(zhí)行回調(diào)函數(shù)中的邏輯
以 oneflow.watch 為例,以下偽代碼展示了使用過程:
def my_watch(x: T):
#處理x
@global_function()
def foo() -> T:
#定義網(wǎng)絡(luò)等 …
oneflow.watch(x, my_watch)
#…
以上的 T 即 oneflow.typing 中的數(shù)據(jù)類型,如 oneflow.typing.Numpy。
以下將用實際例子展示 watch 與 watch_diff 的使用方法
watch 使用例子
下面是一段完整的例子,用于展示如何使用 OneFlow 的 oneflow.watch 功能獲取網(wǎng)絡(luò)中間層的數(shù)據(jù)。
代碼
代碼:test_watch.py
運行該程序:
python3 test_watch.py
能夠得到類似下面的輸出:
in: [ 0.15727027 0.45887455 0.10939325 0.66666406 -0.62354755]
out: [0.15727027 0.45887455 0.10939325 0.66666406 0. ]
代碼解讀
在例子中,關(guān)注的是 ReluJob 里面的 y,所以調(diào)用 flow.watch(y, watch_handler)去監(jiān)控 y。oneflow.watch 需要兩個參數(shù):
? 第一個參數(shù)就是關(guān)注的對象 y;
? 第二個參數(shù)是一個回調(diào)函數(shù),OneFlow 在調(diào)用設(shè)備資源執(zhí)行 ReluJob 的時候會將 y 的計算結(jié)果作為參數(shù)傳遞給這個回調(diào)函數(shù)。而定義的回調(diào)函數(shù) watch_handler 的邏輯函數(shù),將得到的參數(shù)打印出來。
用戶通過自定義回調(diào)函數(shù),在回調(diào)函數(shù)中按照自己的需求處理 OneFlow 運行時從設(shè)備中拿到的數(shù)據(jù)。
watch_diff 使用例子
下面是一段完整的例子,用于展示如何使用 OneFlow 的 oneflow.watch_diff 功能獲取網(wǎng)絡(luò)中間層的梯度。
代碼
代碼:test_watch_diff.py
運行該程序:
python3 test_watch_diff.py
能夠得到類似下面的輸出:
[ …
[ 1.39966095e-03 3.49164731e-03 3.31605263e-02 4.50417027e-03
7.73609674e-04 4.89911772e-02 2.47627571e-02 7.65468649e-05
-1.18361652e-01 1.20161276e-03]] (100, 10) float32
代碼解讀
以上通過 oneflow.watch_diff 獲取梯度的例子,其流程與 通過 oneflow.watch 獲取張量數(shù)據(jù)的例子是類似的。
首先,定義了回調(diào)函數(shù):
def watch_diff_handler(blob: tp.Numpy):
print(“watch_diff_handler:”, blob, blob.shape, blob.dtype)
然后,在作業(yè)函數(shù)中使用 oneflow.watch_diff 注冊以上的回調(diào)函數(shù):
flow.watch_diff(logits, watch_diff_handler)
在 OneFlow 運行時, OneFlow 框架就會調(diào)用 watch_diff_handler,并且將以上的 logits 對應(yīng)的梯度傳遞給 watch_diff_handler。
總結(jié)
- 上一篇: OFRecord 图片文件制数据集
- 下一篇: VS Code 调试 OneFlow