将TVM集成到PyTorch上
將TVM集成到PyTorch上
隨著TVM不斷展示出對深度學習執行效率的改進,很明顯PyTorch將從直接利用編譯器堆棧中受益。PyTorch的主要宗旨是提供無縫且強大的集成,而這不會妨礙用戶。為此,PyTorch現在具有基于TVM的官方后端torch_tvm。
用法很簡單:
import torch_tvm
torch_tvm.enable()
PyTorch將嘗試在其JIT編譯過程中,將所有可能的運算符轉換為已知的Relay運算符。
背景
與許多其他ML框架不同,PyTorch公開了一個渴望執行的編程接口。這種編程風格避免了基于圖的元編程,而專注于以Python方式直接控制n維數組(張量)。因此,該框架最初非常適合模型的試驗和開發,但不適用于自動性能優化或部署。為了利用優化的編譯器技術,PyTorch引入了一些較大的更改來解決此問題。
PyTorch 1.0引入了PyTorch IR,PyTorch專用的中間表示形式,用于類似于Relay的模型。可以通過模型跟蹤將PyTorch程序轉換為IR,該跟蹤記錄模型或Python的子集TorchScript的執行。新的TVM后端將PyTorch的IR降低到了Relay,并能夠透明地提高PyTorch的性能,而無需用戶參與。
整合與結果
為了支持Relay,PyTorch JIT添加了兩個功能:自定義轉換過程和自定義子圖解釋器。
當torch_tvm啟用時,可以轉換到中繼PyTorch IR的子圖Expr旨意被標記為繼電器兼容。由于PyTorch IR并不總是包含形狀信息,因此在調用之前,無法以有用的方式編譯任何子圖。
在用戶調用期間,PyTorch JIT運行時將確定輸入形狀信息,并使用新的Relay C ++構建系統編譯先前標記的子圖。根據輸入形狀來緩存編譯,以供后續運行。可以在README中找到更多詳細信息。
torch_tvm建立了一個連續的基準測試系統,該系統正在監視ResNet18在CPU上的性能。對于各種ResNet型號,TVM的性能都是默認PyTorch JIT后端的兩倍以上。在AWS c5n.4xlarge實例上使用16個線程實現的每秒迭代次數(越大越好)。
這些結果令人鼓舞,該項目將繼續致力于,在更多模型上提高CPU推理速度。
未來的工作
現在,PyTorch JIT進行了大量工作來查找其IR的純功能子集,以饋送到Relay。這避免了將別名和控制流信息映射到中繼的需要,但這不是必需的。將更多的PyTorch IR映射到Relay可能會取得性能上的勝利,這是該項目的目標。PyTorch IR在開發過程中正在迅速變化,因此必須謹慎進行。
將做更多的工作來確保PyTorch和TVM代碼之間的切換是有效的。這包括統一線程模型,分配器以及減少與將輸入復制到TVM相關的開銷。
解析
如果已經編寫了PyTorch模型,最簡單的入門方法就是使用torch.jit.trace以下方法
import torch_tvm
from your_model import model, inputs
torch_tvm.enable(opt_level=3)
iters = 100
warmup = 10
Ensure your model is in eval mode and also turn off gradients.
with torch.no_grad():
Use tuned parameters for better performance.
with autotvm.apply_history_best(“test/autotvm_tuning.log”):
# This is where all the compilation happens.
trace_tvm = torch.jit.trace(model, inputs)
# Warmup
for _ in range(warmup):_ = trace_tvm(*inputs)# Benchmark
start = time.time()
for _ in range(iters):_ = trace_tvm(*inputs)
tvm_time = time.time() - startprint("Took {}s to run {} iters".format(tvm_time, iters))
這段代碼大部分來自Benchmarks.py。請注意,用于AVX2 LLVM編譯的調整參數位于存儲庫test/文件夾中。
如果更直接使用Relay,可以通過(隱式)跟蹤或TorchScript,直接從PyTorch函數中提取表達式:
def add(a, b, c):
return a + b + c
via tracing
relay_graph = torch_tvm.to_relay(add, inputs)
@torch.jit.script
def mul(a, b, c):
return a * b * c
via script
relay_graph = torch_tvm.to_relay(mul, inputs)
總結
以上是生活随笔為你收集整理的将TVM集成到PyTorch上的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 自定义Kubernetes调度程序来编排
- 下一篇: 自动生成低精度深度学习算子