AI中pass架构设计优化
AI中pass架構(gòu)設(shè)計(jì)優(yōu)化
Relay 和 TVM IR,包含一系列優(yōu)化passes,可提高模型的性能指標(biāo),例如平均推理,內(nèi)存占用,或特定設(shè)備的功耗。有一套標(biāo)準(zhǔn)優(yōu)化,及特定機(jī)器學(xué)習(xí)的優(yōu)化,包括常量折疊,死代碼消除,算子布局更改,算子融合,緩沖區(qū)處理和循環(huán)轉(zhuǎn)換等。這些passes中的每一個(gè)都構(gòu)造為一個(gè) ir-to -ir 轉(zhuǎn)換,使用在遍歷期間和/或前收集的分析結(jié)果。
隨著 TVM 的快速發(fā)展,對(duì)管理這些pass的更系統(tǒng),更有效的方法的需求,變得越來(lái)越明顯。此外,管理跨 TVM 堆棧不同層(例如 Relay 和 tir)pass的通用框架,為開(kāi)發(fā)人員快速構(gòu)建原型,將實(shí)現(xiàn)的pass插入系統(tǒng),鋪平了道路。
本節(jié)描述了基礎(chǔ)架構(gòu)設(shè)計(jì),利用產(chǎn)品編譯器,管理優(yōu)化pass,及構(gòu)建層的深度學(xué)習(xí)框架。
例如,許多現(xiàn)有的產(chǎn)品編譯器,如 GCC 和 LLVM,都采用pass管理器,有效管理pass的執(zhí)行。最初管理 pass 很簡(jiǎn)單,因?yàn)?pass 的數(shù)量很少,成熟的編譯器,將包含數(shù)百個(gè)單獨(dú)的 pass。通常,外部用戶(hù)希望正確調(diào)度自定義pass,無(wú)需修改單個(gè)手工制作的pass順序。
現(xiàn)代深度學(xué)習(xí)框架,如 Pytorch 和 MXNet Gluon,分別通過(guò)Sequential和Block,啟用pass-style 層構(gòu)建方案的趨勢(shì)。有了這樣的結(jié)構(gòu),這些現(xiàn)代框架能夠方便將模塊/層添加到容器中,輕松地構(gòu)建神經(jīng)網(wǎng)絡(luò)。
Relay pass infra 的設(shè)計(jì),很大程度上受到 LLVM 中,使用的分層pass管理器和流行的深度學(xué)習(xí)框架中,使用的塊式容器的啟發(fā)。pass基礎(chǔ)架構(gòu)的主要目標(biāo)包括:
- 實(shí)現(xiàn)更好的優(yōu)化編程調(diào)度。允許用戶(hù)靈活地定制和構(gòu)建優(yōu)化pass。
- 提供一種用戶(hù)友好的方式來(lái)調(diào)試優(yōu)化pass。
- 減輕開(kāi)發(fā)人員手動(dòng)和分別解決pass之間的依賴(lài)關(guān)系。
- 為開(kāi)發(fā)人員簡(jiǎn)化新pass的實(shí)施。例如,允許用戶(hù)在 Python 中,實(shí)現(xiàn)一個(gè) pass,讓 pass infra 操縱執(zhí)行。
設(shè)計(jì)
專(zhuān)注于為用戶(hù)提供易于擴(kuò)展的功能,讓用戶(hù)可以快速添加新pass,不會(huì)失去向后兼容性。該設(shè)計(jì)包含后端和前端。前者實(shí)現(xiàn)了 pass infra 的主要邏輯。后者為用戶(hù)提供了簡(jiǎn)單的 API 進(jìn)行交互,允許用戶(hù)快速創(chuàng)建優(yōu)化pass。
C++ 后端
提供一個(gè)PassInfo對(duì)象,包含pass所需的基本信息。name是傳遞名稱(chēng),opt_level指示,在哪個(gè)優(yōu)化級(jí)別啟用pass, required表示執(zhí)行特定pass,所需的pass(有關(guān)更多詳細(xì)信息,參閱include/tvm/ir/transform.h)。例如,在注冊(cè)pass期間,pass開(kāi)發(fā)人員,可以指定pass的名稱(chēng),將執(zhí)行的優(yōu)化級(jí)別和/或所需的pass。在用戶(hù)提供的優(yōu)化級(jí)別下運(yùn)行時(shí),是否需要執(zhí)行某個(gè) pass, opt_level可用于幫助 pass infra 識(shí)別。 required字段,可以被 pass infra 使用,解決 pass 依賴(lài)。
class PassInfoNode : public Object {
String name;
int opt_level;
Array required;
};
傳遞上下文
PassContext攜帶用于優(yōu)化pass的有用信息。例如,包含錯(cuò)誤報(bào)告系統(tǒng),可以提供有關(guān)優(yōu)化失敗原因的診斷。PassContext旨在替換舊的BuildConfig,用于幫助用戶(hù)配置編譯選項(xiàng),包括優(yōu)化級(jí)別和必需/禁用的pass等。例如,可能有一個(gè)配置, opt_level=3使用disabled_pass=xx提供的某些禁用的pass,執(zhí)行所有PassContext。可以將所有pass,放在opt_level=3,排除禁用pass列表中的那些。PassContext提供了一種檢測(cè)所有pass的方法。
PassContext包含優(yōu)化pass的有用信息。例如,包含錯(cuò)誤報(bào)告系統(tǒng),可以提供有關(guān)優(yōu)化失敗原因的診斷。PassContext設(shè)計(jì)用于替換舊的BuildConfig,該配置用于幫助用戶(hù)配置編譯選項(xiàng),包括優(yōu)化級(jí)別和必需/禁用的pass等。例如,可能有一個(gè)配置,該配置使用PassContext提供的disabled_pass=xx,在opt_level=3執(zhí)行所有pass,一些禁用的pass,使用disabled_pass=xx。現(xiàn)在,可以在opt_level=3時(shí),對(duì)所有pass,進(jìn)行全局排序,排除禁用pass列表中的pass。PassContext提供了一種對(duì)所有pass,進(jìn)行檢測(cè)的方法。
這個(gè)類(lèi)是為用戶(hù)設(shè)計(jì)的,使用語(yǔ)法編寫(xiě)Python,在特定配置下執(zhí)行優(yōu)化。用戶(hù)可以通過(guò)PassContext::Current(),線(xiàn)程安全的方式,獲得特定程序范圍內(nèi),可用的上下文,線(xiàn)程本地存儲(chǔ)PassContextThreadLocalStore,保存創(chuàng)建的pass context對(duì)象。將提供示例來(lái)說(shuō)明,如何使用C++和Python API,使用pass context,創(chuàng)建編譯pass。
class PassContextNode: public Object {
public:
int opt_level{2};
tvm::Arraytvm::Expr required_pass;
tvm::Arraytvm::Expr disabled_pass;
mutable Optional diag_ctx;
Map<String, ObjectRef> config;
Arrayinstrument::PassInstrument instruments;
};
class PassContext : public NodeRef {
public:
TVM_DLL static PassContext Create();
TVM_DLL static PassContext Current();
TVM_DLL void InstrumentEnterPassContext();
TVM_DLL void InstrumentExitPassContext();
TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;
/* Other fields are omitted. */
private:
// The entry of a pass context scope.
TVM_DLL void EnterWithScope();
// The exit of a pass context scope.
TVM_DLL void ExitWithScope();
// Classes to get the Python with like syntax.
friend class tvm::With;
};
struct PassContextThreadLocalEntry {
/*! \brief The default pass context. /
PassContext default_context;
/! \brief The current pass context. */
std::stack context_stack;
PassContextThreadLocalEntry() {
default_context = PassContext(make_node());
}
};
/! \brief The thread-local store to hold the pass context. /
typedef dmlc::ThreadLocalStore
PassContextThreadLocalStore;
pass構(gòu)建
pass infra以分層方式設(shè)計(jì),可以在 Relay/tir 程序的,不同粒度下工作。PassNode引入了一個(gè)純虛擬類(lèi),作為不同優(yōu)化pass的基礎(chǔ)。包含幾個(gè)必須由子類(lèi)在模塊,函數(shù)或pass序列級(jí)別實(shí)現(xiàn)的虛擬方法。
class PassNode : Object {
virtual PassInfo Info() const = 0;
virtual Module operator()(const IRModule& mod
const PassContext& pass_ctx) const = 0;
};
函子顯示必須如何實(shí)現(xiàn)pass,始終在 IRModule特定上下文下工作。所有pass都以ModuletoModule方式設(shè)計(jì)。由 pass infra 控制的優(yōu)化,將始終更新整個(gè)模塊。
已經(jīng)創(chuàng)建了幾個(gè)子類(lèi),實(shí)現(xiàn)不同類(lèi)型的優(yōu)化pass,例如,函數(shù)級(jí)pass,模塊級(jí)pass和順序pass。每個(gè)子類(lèi)本身都可以充當(dāng)pass管理器。例如,可以收集所需的pass執(zhí)行,或基于給定的元數(shù)據(jù)構(gòu)建,依賴(lài)關(guān)系圖。完整定義可以在src/relay/ir/transform.cc和src/ir/transform.cc 中找到。
模塊級(jí)pass
模塊級(jí)pass主要用于全局和pass間優(yōu)化 (IPO),類(lèi)似于 LLVM 中使用的模塊pass。Relay 中一些典型的 pass,需要一個(gè)模塊的全局圖片,比如 A-normal form 轉(zhuǎn)換和 lambda 提升等,都屬于這個(gè)集合。在此級(jí)別,用戶(hù)甚至可以在模塊中,添加和/或刪除功能。所有pass
class ModulePassNode : PassNode {
PassInfo pass_info;
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
// Other members/methods are omitted
};
pass_info維護(hù)模塊級(jí)pass所需的信息。pass_func勾勒出真正的優(yōu)化。例如,可能需要對(duì)模塊執(zhí)行死代碼消除。可以在pass_func中實(shí)現(xiàn)算法,在模塊上運(yùn)行。刪除死代碼,包括模塊中未使用的函數(shù)。該字段被設(shè)計(jì)為一個(gè)打包函數(shù),可以在 C++ 和 Python 中,實(shí)現(xiàn)優(yōu)化。
函數(shù)級(jí)pass
函數(shù)級(jí)pass,用于為給定的 Relay/tir 模塊,實(shí)現(xiàn)各種函數(shù)內(nèi)級(jí)優(yōu)化。一次從模塊的函數(shù)列表中,獲取一個(gè)函數(shù),進(jìn)行優(yōu)化,生成重寫(xiě)的 Relay Function或 tir PrimFunc。大多數(shù)pass可以歸入這一類(lèi),如Relay中,常見(jiàn)子表達(dá)式消除和推理簡(jiǎn)化,及tir中的矢量化和扁平化存儲(chǔ)等。
此級(jí)別的pass范圍是 Relay 函數(shù),或 tir 原始函數(shù)。無(wú)法通過(guò)pass,添加或刪除函數(shù)。
class FunctionPassNode : PassNode {
PassInfo pass_info;
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
bool SkipFunction(const Function& func) const;
// Other members/methods are omitted…
};
pass_info與剛剛在模塊pass中描述的相同。pass_func需要一個(gè)函數(shù),進(jìn)行優(yōu)化,需要一個(gè)模塊,可能會(huì)報(bào)告錯(cuò)誤。一個(gè)函數(shù)可以用“SkipOptimization”注釋,在優(yōu)化pass中忽略。
連續(xù)passes
SequentialPass與 Pytorch 類(lèi)似,nn.Sequential包含許多用于執(zhí)行的pass。
class SequentialPassNode : PassNode {
PassInfo pass_info;
// Passes need to be executed.
Array passes;
bool PassEnabled(const PassInfo& info) const;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
};
僅放置了在Relay中的少數(shù)pass。例如,FoldScaleAxis要求在內(nèi)部調(diào)度ForwardFoldScaleAxis和BackwardFoldScaleAxis。建議首先完成BackwardFoldScaleAxis。該pass是SequentialPass的理想候選。
下面的代碼顯示了如何調(diào)用順序過(guò)程中的各個(gè)pass。使用pass列表中,在一個(gè)順序pass中,執(zhí)行每個(gè)pass。
Module SequentialNode::operator()(const Module& module,
const PassContext& pass_ctx) const {
Module mod = module;
for (const Pass& pass : passes) {
ICHECK(pass.defined()) << “Found undefined pass for optimization.”;
const PassInfo& pass_info = pass->Info();
if (!PassEnabled(pass_info)) continue;
for (const auto& it : pass_info->required) {
const auto name = it.astvm::ir::StringImm();
ICHECK(name);
mod = GetPass(name->value)(mod, pass_ctx);
}
mod = pass(mod, pass_ctx);
}
return mod;
}
在調(diào)用pass時(shí),先檢查是否啟用了pass。檢查用戶(hù)是否明確禁用pass,是否被用戶(hù)指定為必需pass完成的。如果不確定,是否啟用了此pass,opt_level將進(jìn)行檢查。只有當(dāng)優(yōu)化級(jí)別不低于PassContext中,配置的優(yōu)化級(jí)別時(shí),才會(huì)啟用執(zhí)行pass。
要執(zhí)行pass,先需要使用pass名稱(chēng),在 TVM 打包函數(shù)注冊(cè)表中,檢索已注冊(cè)的pass。這是可能的,因?yàn)槊總€(gè)pass,都注冊(cè)了一個(gè) API 端點(diǎn),將在后面展示。
Pass GetPass(const std::string& pass_name) {
using tvm::runtime::Registry;
std::string fpass_name = “relay._transform.” + pass_name;
const auto f = Registry::Get(fpass_name);
ICHECK(f != nullptr) << "Cannot find " << fpass_name
<< "to create the pass " << pass_name;
return (*f)();
}
提供了一些輔助函數(shù),創(chuàng)建上述每種類(lèi)型的pass。這些幫助程序,暴露給 Python 前端,使用 Python API,創(chuàng)建特定的 pass 對(duì)象。
Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
int opt_level,
String name,
Array required);
Pass CreatePrimFuncPass(
const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level,
String name,
Array required);
Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level,
String name,
Array required);
Pass Sequential(tvm::Array passes, PassInfo pass_info);
pass注冊(cè)
不同級(jí)別pass的概念和用于編譯的context,可以輕松注冊(cè)pass,以 const 折疊為例。這個(gè)pass已經(jīng)實(shí)現(xiàn),折疊 Relay 函數(shù)中的常量(在 src/relay/transforms/fold_constant.cc 中找到)。
提供了一個(gè) API,執(zhí)行ExprtoExpr轉(zhuǎn)換。
Expr FoldConstant(const Expr& expr);
為了將這個(gè)pass注冊(cè)到pass infra,先需要決定這個(gè)pass,在哪個(gè)級(jí)別執(zhí)行。由于常量折疊,發(fā)生在單個(gè)函數(shù)上,應(yīng)該直觀(guān)FunctionPass通過(guò) CreateFunctionPass. 將pass_func作為打包函數(shù)返回,該函數(shù)在IRModule 中的每個(gè)函數(shù)上調(diào)用Exprto ExprAPI。{}表示此pass,不需要先決條件。否則,pass開(kāi)發(fā)人員必須識(shí)別列出。
使用名稱(chēng) relay._transform.FoldConstant,注冊(cè)一個(gè)pass API 端點(diǎn) 。這個(gè)pass成為注冊(cè)表中的一個(gè)條目,可以由C++(如GetPass上面的)和Python訪(fǎng)問(wèn)。
namespace transform {
Pass FoldConstant() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast(FoldConstant(f));
};
return CreateFunctionPass(pass_func, 2, “FoldConstant”, {});
}
TVM_REGISTER_GLOBAL(“relay._transform.FoldConstant”)
.set_body_typed(FoldConstant);
} // namespace transform
為了允許其它 C++ 模塊應(yīng)用這個(gè)pass,在include/tvm/relay/transform.h 中聲明了一個(gè)自由函數(shù), 如下所示:
TVM_DLL Pass FoldConstant();
pass儀器
Pass Instrument 是一種分析pass本身的機(jī)制。例如,可以使用基礎(chǔ)架構(gòu),了解一次pass需要多少時(shí)間和內(nèi)存,或者一次pass,如何轉(zhuǎn)換 IR 模塊。
生命周期中的四個(gè)儀器點(diǎn)PassContext。
TVM_DLL void InstrumentEnterPassContext();
TVM_DLL void InstrumentExitPassContext();
TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;
當(dāng)輸入PassContext實(shí)例的范圍時(shí),立即調(diào)用InstrumentEnterPassContext。
InstrumentExitPassContext在離開(kāi)PassContext的作用域時(shí)被調(diào)用,或者在執(zhí)行過(guò)程中發(fā)生異常。當(dāng)tvm.transform.PassContext中的OverrideU instruments重寫(xiě)儀器時(shí),會(huì)調(diào)用此方法。
在執(zhí)行前,調(diào)用InstrumentBeforePass。如果運(yùn)行pass,在執(zhí)行后調(diào)用InstrumentAfterPass。這種行為就像:
if (pass_ctx.InstrumentBeforePass(ir_module, pass_info)) {
new_ir_module = run_pass(ir_module, pass_ctx);
pass_ctx.InstrumentAfterPass(new_ir_module, pass_info);
return new_ir_module;
}
該P(yáng)assInstrument接口允許在上述四種方法中運(yùn)行任意代碼。多個(gè)PassInstrument實(shí)例,可以注冊(cè)到一個(gè) PassContext。PassInstrument實(shí)例按照instruments傳遞給參數(shù)序列,依次調(diào)用 PassContext。
PassInstrument 提供以下接口:
namespace instrument {
class PassInstrumentNode : public Object {
public:
String name;
virtual void EnterPassContext() const = 0;
virtual void ExitPassContext() const = 0;
virtual bool ShouldRun(const IRModule& mod, const transform::PassInfo& info) const = 0;
virtual void RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const = 0;
virtual void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const = 0;
/* Other fields are omitted. */
};
class PassInstrument : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(PassInstrument, ObjectRef, PassInstrumentNode);
};
} // namespace instrument
提供Python前端,以PassInstrument快速實(shí)現(xiàn)。
在 PassContext中, PassInstrument實(shí)例的調(diào)用順序是這樣的:
with PassContext(instruments=[pi]) # pi = a PassInstrument implementation.
pi.EnterPassContext()
if pi.ShouldRun(Pass1):pi.RunBeforePass()Pass1()pi.RunAfterPass()if pi.ShouldRun(Pass2):pi.RunBeforePass()Pass2()pi.RunAfterPass()pi.ExitPassContext()
介紹一下PassInstrument接口和PassContext方法的關(guān)系。有關(guān)更多詳細(xì)信息,參閱 ( src/ir/transform.cc )。
? InstrumentEnterPassContext
o EnterPassContext()按instruments傳遞給PassContext 的順序執(zhí)行。
o 當(dāng)異常發(fā)生時(shí),PassContext通過(guò)清除所有注冊(cè)的PassInstrument實(shí)例,禁用pass檢測(cè)。
o PassContext執(zhí)行ExitPassContext(),成功完成的每個(gè)PassInstrument實(shí)例的方法EnterPassContext()
o 例如,如果PassInstrumentA,B,C,注冊(cè)到 PassContext,A 完成,EnterPassContext(),B 拋出異常, C 永遠(yuǎn)不會(huì)執(zhí)行;ExitPassContext()A 的執(zhí)行。
? InstrumentExitPassContext
o 每個(gè)PassInstrument的實(shí)例ExitPassContext(),執(zhí)行順序是instruments傳遞給PassContext.
o 當(dāng)異常發(fā)生時(shí),instruments被清除。
o PassInstrument拋出異常后,注冊(cè)的實(shí)例,不執(zhí)行ExitPassContext。
? InstrumentBeforePass
o 如果pass未列為必需pass,執(zhí)行ShouldRun。
o RunBeforePass如果傳球沒(méi)有被ShouldRun 阻擋,按照instruments的順序執(zhí)行。
o InstrumentBeforePass返回一個(gè)布爾值,指示是否應(yīng)該運(yùn)行傳遞。
o 當(dāng)異常發(fā)生時(shí),立即拋出。依靠 Python Context ManagerPassContext安全退出(ExitPassContext每個(gè)儀器都會(huì)運(yùn)行的含義。對(duì)于 C++,參閱include/tvm/support/with.h。)
? InstrumentAfterPass
o RunAfterPass按instruments傳遞給 PassContext的順序執(zhí)行。
o 當(dāng)異常發(fā)生時(shí),立即拋出。依靠 Python Context Manager 或Withclass( include/tvm/support/with.h ),安全退出PassContext
build儀器
有幾種內(nèi)置工具,標(biāo)有TODO 的,沒(méi)有實(shí)現(xiàn)。
? PassTimingInstrument(參見(jiàn)src/ir/instrument.cc)
o 分析pass的執(zhí)行時(shí)間。
? PrintIRBefore(TODO)
o 在pass轉(zhuǎn)換前,打印 IR 模塊。如果在pass周?chē)迦?#xff0c;tvm.transform.PrintIR()可以達(dá)到這個(gè)目的。但是,使用PassInstrument,不需要修改passes的順序。
? 打印后(待辦事項(xiàng))
o 在pass轉(zhuǎn)換后,打印 IR 模塊。
Python前端
前端只需要一些簡(jiǎn)單的 API。例如,可以為用戶(hù)提供以下 API,創(chuàng)建和執(zhí)行一個(gè) pass(完整的實(shí)現(xiàn)在python/tvm/relay/transform/transform.py和 python/tvm/ir/transform.py 中提供)。后端接收信息,決定應(yīng)該使用哪個(gè)函數(shù),創(chuàng)建 Pass 對(duì)象。
PassContext
Python 前端為_(kāi)_enter____exit__current 提供了一個(gè)包裝器,通過(guò)覆蓋和PassContext,啟用with語(yǔ)法。為用戶(hù)提供了一種靜態(tài)方法,獲取在一定范圍內(nèi)使用的Context。
@tvm._ffi.register_object(“transform.PassContext”)
class PassContext(tvm.runtime.Object):
def enter(self):
_transform.EnterPassContext(self)
return self
def __exit__(self, ptype, value, trace, config):_transform.ExitPassContext(self)@staticmethod
def current():"""Return the current pass context."""return _transform.GetCurrentPassContext()
PassContext用于配置編譯選項(xiàng),包括優(yōu)化級(jí)別和必需/禁用的pass。可以帶一個(gè)配置字典,以便不同的pass,可以方便地獲取pass的數(shù)據(jù),如回退設(shè)備信息和循環(huán)展開(kāi)的步驟/深度等。為了能夠獲取所需的配置,必須通過(guò) TVM_REGISTER_PASS_CONFIG_OPTION注冊(cè)密鑰。例如,使用以下內(nèi)容,循環(huán)展開(kāi)pass
TVM_REGISTER_PASS_CONFIG_OPTION(“tir.UnrollLoop”, UnrollLoopConfig);
更多細(xì)節(jié),參考src/tir/transforms/unroll_loop.cc。
pass對(duì)象
Pass是所有pass對(duì)象的基類(lèi)。這里的所有方法,都只是在后端實(shí)現(xiàn)的簡(jiǎn)單包裝器。為了用戶(hù)方便地與 Python 中的基類(lèi),進(jìn)行交互定義的。在 pass 基類(lèi)中只定義了__call__,使子類(lèi)成為可調(diào)用對(duì)象,可以很容易調(diào)用(例如,pass_xx(arg))執(zhí)行。
@register_relay_node
class Pass(RelayNode):
def call(self, mod):
return _transform.RunPass(self, mod)
提供了一些輔助 API,從 Python 前端,輕松創(chuàng)建pass,讓pass基礎(chǔ)控制執(zhí)行。例如module_pass,function_pass和sequential提供給用戶(hù),以便可以定制pass。
對(duì)于在 C++ 后端實(shí)現(xiàn)的所有 pass,分別在python/tvm/ir/transform.py和 python/tvm/relay/transform/transform.py 中,提供了相應(yīng)的 Python API 。例如,const 折疊有一個(gè) Python API,如下所示:
def FoldConstant():
return _transform.FoldConstant()
可以構(gòu)建一個(gè)pass through裝飾:
@relay.transform.module_pass(opt_level=2)
def transform(mod, ctx):
tp = relay.TensorType((10,), “float32”)
x = relay.var(“x”, tp)
gv = relay.GlobalVar(“abs”)
func = relay.Function([x], relay.abs(x))
new_mod = tvm.IRModule({gv: func})
new_mod.update(mod)
return new_mod
module_pass = transform
assert isinstance(module_pass, transform.ModulePass)
assert module_pass.info.opt_level == 2
在transform功能增加了一個(gè)abs與輸入模塊的功能,可能是在模塊級(jí)的任何定制的優(yōu)化。創(chuàng)建module_pass后,應(yīng)用于任何 Relay 模塊。例如,可以構(gòu)建一個(gè)空模塊,應(yīng)用pass添加一個(gè)abs 函數(shù)。
mod = tvm.IRModule()
mod = module_pass(mod)
提供function_pass功能,一個(gè)示例函數(shù)級(jí)pass,可以寫(xiě)成如下:
@relay.transform.function_pass(opt_level=1)
class TestReplaceFunc:
def init(self, new_func):
self.new_func = new_func
def transform_function(self, func, mod, ctx):
# Just for demo purposes
# Transform func to new_func
return self.new_func
x = relay.var(“x”, shape=(10, 20))
f1 = relay.Function([x], x)
f2 = relay.Function([x], relay.log(x))
fpass is now a special pass that replaces every
function to f1
fpass = TestReplaceFunc(f1)
Now every function in input_mod is replaced by f1
res_mod = fpass(input_mod)
可以不使用裝飾器,直接注冊(cè)pass,調(diào)用。有關(guān)如何自定義優(yōu)化pass,調(diào)試 Relay 和 tir pass 的更多示例,參閱 use pass infra教程。
pass儀器
可以通過(guò)在實(shí)現(xiàn)以下方法的類(lèi)上,使用pass_instrument decorator(python/tvm/ir/instrument.py),實(shí)現(xiàn)PassInstrument。建議使用pass_instrument decorator,實(shí)現(xiàn)PassInstrument,不是重寫(xiě)或子類(lèi)化。
? enter_pass_ctx
o 該方法在進(jìn)入PassContext時(shí)運(yùn)行。
? exit_pass_ctx
o 此方法在退出PassContext時(shí)運(yùn)行。
? should_run
o 此方法在執(zhí)行pass前運(yùn)行,返回一個(gè)布爾值,指示是否應(yīng)運(yùn)行pass。
? run_before_pass
o 如果應(yīng)該運(yùn)行一次pass,在pass執(zhí)行之前運(yùn)行此方法。
? run_after_pass
o 此方法在執(zhí)行一次pass后,立即運(yùn)行。
PassInstrument實(shí)例可以通過(guò)tvm.transform.PassContext中的參數(shù) instruments注冊(cè)。
use pass instrument提供了如何使用 Python API,實(shí)現(xiàn)PassInstrument的示例。
覆蓋當(dāng)前 PassContext 中的儀器
提供了current PassContext覆蓋instruments 的override_instruments方法。例如,如果在沒(méi)有顯式創(chuàng)建 new 的情況下,運(yùn)行 pass PassContext,可以通過(guò)以下方式注冊(cè)PassInstrument到全局中PassContext:
cur_pass_ctx = tvm.transform.PassContext.current()
override PassInstrument instances
cur_pass_ctx.override_instruments([pass_inst])
mod = pass_seq(mod)
result = pass_inst.get_result()
當(dāng)override_instruments調(diào)用時(shí),舊PassInstrument實(shí)例的方法exit_pass_ctx會(huì)調(diào)用。然后new PassInstrument的enter_pass_ctx方法調(diào)用。
參考鏈接:
https://tvm.apache.org/docs/dev/pass_infra.html#pass-infra
總結(jié)
以上是生活随笔為你收集整理的AI中pass架构设计优化的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: Relay IR表示
- 下一篇: LLVM语法语义指令特性