运行时数据获取
運行時數據獲取
OneFlow 提供了 oneflow.watch 與 oneflow.watch_diff 接口,可以通過他們注冊回調函數,以方便在作業函數運行過程中獲取張量數據或梯度。
使用流程
想要獲取作業函數運行時的數據或者梯度,其基本流程如下:
? 編寫回調函數,回調函數的參數需要用注解方式表明監控的數據類型,回調函數內部邏輯由用戶自己實現
? 在定義作業函數時,通過 oneflow.watch 或 oneflow.watch_diff 注冊回調函數,前者獲取張量數據本身,后者獲取對應的梯度
? 在作業函數運行時,OneFlow 框架會在適當的時機,調用之前注冊的回調,將監控的數據傳遞給回調函數,并執行回調函數中的邏輯
以 oneflow.watch 為例,以下偽代碼展示了使用過程:
def my_watch(x: T):
#處理x
@global_function()
def foo() -> T:
#定義網絡等 …
oneflow.watch(x, my_watch)
#…
以上的 T 即 oneflow.typing 中的數據類型,如 oneflow.typing.Numpy。
以下將用實際例子展示 watch 與 watch_diff 的使用方法
watch 使用例子
下面是一段完整的例子,用于展示如何使用 OneFlow 的 oneflow.watch 功能獲取網絡中間層的數據。
代碼
代碼:test_watch.py
運行該程序:
python3 test_watch.py
能夠得到類似下面的輸出:
in: [ 0.15727027 0.45887455 0.10939325 0.66666406 -0.62354755]
out: [0.15727027 0.45887455 0.10939325 0.66666406 0. ]
代碼解讀
在例子中,關注的是 ReluJob 里面的 y,所以調用 flow.watch(y, watch_handler)去監控 y。oneflow.watch 需要兩個參數:
? 第一個參數就是關注的對象 y;
? 第二個參數是一個回調函數,OneFlow 在調用設備資源執行 ReluJob 的時候會將 y 的計算結果作為參數傳遞給這個回調函數。而定義的回調函數 watch_handler 的邏輯函數,將得到的參數打印出來。
用戶通過自定義回調函數,在回調函數中按照自己的需求處理 OneFlow 運行時從設備中拿到的數據。
watch_diff 使用例子
下面是一段完整的例子,用于展示如何使用 OneFlow 的 oneflow.watch_diff 功能獲取網絡中間層的梯度。
代碼
代碼:test_watch_diff.py
運行該程序:
python3 test_watch_diff.py
能夠得到類似下面的輸出:
[ …
[ 1.39966095e-03 3.49164731e-03 3.31605263e-02 4.50417027e-03
7.73609674e-04 4.89911772e-02 2.47627571e-02 7.65468649e-05
-1.18361652e-01 1.20161276e-03]] (100, 10) float32
代碼解讀
以上通過 oneflow.watch_diff 獲取梯度的例子,其流程與 通過 oneflow.watch 獲取張量數據的例子是類似的。
首先,定義了回調函數:
def watch_diff_handler(blob: tp.Numpy):
print(“watch_diff_handler:”, blob, blob.shape, blob.dtype)
然后,在作業函數中使用 oneflow.watch_diff 注冊以上的回調函數:
flow.watch_diff(logits, watch_diff_handler)
在 OneFlow 運行時, OneFlow 框架就會調用 watch_diff_handler,并且將以上的 logits 對應的梯度傳遞給 watch_diff_handler。
總結
- 上一篇: OFRecord 图片文件制数据集
- 下一篇: VS Code 调试 OneFlow