python pytorch fft_看PyTorch源代码的心路历程
1. 起因
曾經(jīng)碰到過(guò)別人的模型prelu在內(nèi)部的推理引擎算出的結(jié)果與其在原始框架PyTorch中不一致的情況,雖然理論上大家實(shí)現(xiàn)的都是一個(gè)算法,但是從參數(shù)上看,因?yàn)榻?jīng)過(guò)了模型轉(zhuǎn)換,中間做了一些調(diào)整。為了確定究竟是初始參數(shù)傳遞就出了問(wèn)題還是在后續(xù)傳遞過(guò)程中繼續(xù)做了更改、亦或者是最終算法實(shí)現(xiàn)方面有著細(xì)微差別導(dǎo)致最終輸出不同,就想著去看一看PyTorch一路下來(lái)是怎么做的。
但是代碼跟著跟著就跟丟了,才會(huì)發(fā)現(xiàn),PyTorch真的是一個(gè)很復(fù)雜的項(xiàng)目,但就像舌尖里面說(shuō)的,環(huán)境越是惡劣,回報(bào)越是豐厚。為了以后再想跟蹤的時(shí)候方便,因此決定以PReLU為例靜態(tài)梳理一下PyTorch的代碼結(jié)構(gòu)。搗鼓的這些天,對(duì)如何構(gòu)建一個(gè)帶有C/C++代碼的Python又有了新的了解,這也算是意外的收獲吧。
2. 歷程
首先,我們從PReLU的導(dǎo)入路徑torch.nn.PReLU中知道,他應(yīng)在徑進(jìn)torch\nn\之下,進(jìn)入該路徑雖然沒(méi)看到,但是我們?cè)谠撀窂较碌腳_init__.py中知道,其實(shí)它就在torch\nn\modules\activation.py中。類(lèi)PReLU最終調(diào)用了從torch\nn\functional.py導(dǎo)入的prelu方法。順騰摸瓜,找到prelu,它長(zhǎng)下面這樣:
def prelu(input, weight):
# type: (Tensor, Tensor) -> Tensor
if not torch.jit.is_scripting():
if type(input) is not Tensor and has_torch_function((input,)):
return handle_torch_function(prelu, (input,), input, weight)
return torch.prelu(input, weight)
經(jīng)過(guò)人腦對(duì)代碼的一番執(zhí)行你會(huì)發(fā)現(xiàn),第一個(gè)if條件滿(mǎn)足,而第二個(gè)if不滿(mǎn)足。因此,最終想看算法,得去看torch.prelu()。好吧,接著干……
一番搜尋之后你會(huì)發(fā)現(xiàn),Python代碼中在torch這個(gè)包下面你是找不到prelu的定義的。但是絕望之際我們?cè)趖orch包的__init__.py之中看到看下面幾行代碼:
# pytorch\torch\__init__.py
# 為了簡(jiǎn)潔,省去不必要代碼,詳細(xì)代碼參見(jiàn)pytorch\torch\__init__.py
try:
# _initExtension is chosen (arbitrarily) as a sentinel.
from torch._C import _initExtension
__all__ += [name for name in dir(_C)
if name[0] != '_' and
not name.endswith('Base')]
if TYPE_CHECKING:
# Some type signatures pulled in from _VariableFunctions here clash with
# signatures already imported. For now these clashes are ignored; see
# PR #43339 for details.
from torch._C._VariableFunctions import * # type: ignore
for name in dir(_C._VariableFunctions):
if name.startswith('__'):
continue
globals()[name] = getattr(_C._VariableFunctions, name)
__all__.append(name)
這是全村最后的希望了。我們知道__all__中的名字其實(shí)就是該模塊有意暴露出去的API。
什么意思呢?也就是說(shuō)雖然我們明文上已經(jīng)看不到了prelu的定義,但是這幾行代碼表明有一大堆身份不明的API被暗搓搓的導(dǎo)入了,這其中就很有可能存在我們朝思暮想的prelu。
那么我們?cè)趺磻{借這么一點(diǎn)微弱的線索確定我們的猜測(cè)到底對(duì)不對(duì)呢?這里我們就用到了Python的一個(gè)關(guān)鍵知識(shí):C/C++擴(kuò)展。(戳這里《使用C語(yǔ)言編寫(xiě)Python模塊-引子》《Python調(diào)用C++之PYBIND11簡(jiǎn)介》了解更多)
我們知道Python C/C++擴(kuò)展有著固定的格式,只要我們找到模塊初始化入口,就能順藤摸瓜找到該模塊暴露的給Python解釋器所有函數(shù)。Python 3中的初始化函數(shù)樣子為PyInit_,其中就是模塊的名字。例如在前面提到的from torch._C import *中,模塊torch下面必要有一個(gè)名字為_(kāi)C的子模塊。因此它的初始化函數(shù)應(yīng)該為PyInit__C,我們搜索該名字就能找到模塊入口。當(dāng)然另外還有一種方法,就是查看setup.py文件中關(guān)于擴(kuò)展的描述信息:
// pytorch\setup.py
main_sources = ["torch/csrc/stub.c"]
C = Extension("torch._C",
libraries=main_libraries,
sources=main_sources,
language='c',
extra_compile_args=main_compile_args + extra_compile_args,
include_dirs=[],
library_dirs=library_dirs,
extra_link_args=extra_link_args + main_link_args + make_relative_rpath_args('lib'))
extensions.append(C)
不管是通過(guò)搜索還是查看setup.py,我們最終都成功定位到了位于pytorch\torch\csrc\stub.c下的模塊初始化函數(shù)PyInit__C(void),并進(jìn)一步跟蹤其調(diào)用的函數(shù)initModule(),便可以知道具體都暴露了哪些API給Python解釋器。
// pytorch\torch\csrc\stub.c
PyMODINIT_FUNC PyInit__C(void)
{
return initModule();
}
// pytorch\torch\csrc\Module.cpp
initModule()
進(jìn)入initModule()尋找一番,你會(huì)發(fā)現(xiàn),模塊_C中依然沒(méi)有prelu的Python接口。怎么辦?莫慌,通過(guò)前面對(duì)torch.__init__.py的分析,我們知道我們還有希望——_C模塊下的子模塊_VariableFunctions,這真的是最后的希望了!沒(méi)了別的路可以走了,只能是硬著頭皮找。經(jīng)過(guò)一番驚天地泣鬼神、艱苦卓絕的尋找,我們?cè)趇nitModule()的調(diào)用鏈initModule()->THPVariable_initModule(module)->torch::autograd::initTorchFunctions(module)中發(fā)現(xiàn)了_VariableFunctions的蹤影。Aha,simple!
void initTorchFunctions(PyObject* module) {
if (PyType_Ready(&THPVariableFunctions) < 0) {
throw python_error();
}
Py_INCREF(&THPVariableFunctions);
// Steals
Py_INCREF(&THPVariableFunctions);
if (PyModule_AddObject(module, "_VariableFunctionsClass", reinterpret_cast(&THPVariableFunctions)) < 0) {
throw python_error();
}
// PyType_GenericNew returns a new reference
THPVariableFunctionsModule = PyType_GenericNew(&THPVariableFunctions, Py_None, Py_None);
// PyModule_AddObject steals a reference
if (PyModule_AddObject(module, "_VariableFunctions", THPVariableFunctionsModule) < 0) {
throw python_error();
}
}
但是!!別高興太早!查看模塊_VariableFunctions中暴露的接口你會(huì)發(fā)現(xiàn),根本就沒(méi)有我們想要的!如下面的代碼所示:
static PyMethodDef torch_functions[] = {
{"arange", castPyCFunctionWithKeywords(THPVariable_arange),
METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"as_tensor", castPyCFunctionWithKeywords(THPVariable_as_tensor),
METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"dsmm", castPyCFunctionWithKeywords(THPVariable_mm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"from_numpy", THPVariable_from_numpy, METH_STATIC | METH_O, NULL},
{"full", castPyCFunctionWithKeywords(THPVariable_full), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"hsmm", castPyCFunctionWithKeywords(THPVariable_hspmm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"nonzero", castPyCFunctionWithKeywords(THPVariable_nonzero), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"randint", castPyCFunctionWithKeywords(THPVariable_randint), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"range", castPyCFunctionWithKeywords(THPVariable_range), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"saddmm", castPyCFunctionWithKeywords(THPVariable_sspaddmm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"sparse_coo_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_coo_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"_sparse_coo_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_coo_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"_validate_sparse_coo_tensor_args", castPyCFunctionWithKeywords(THPVariable__validate_sparse_coo_tensor_args), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"spmm", castPyCFunctionWithKeywords(THPVariable_mm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"tensor", castPyCFunctionWithKeywords(THPVariable_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"get_device", castPyCFunctionWithKeywords(THPVariable_get_device), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"numel", castPyCFunctionWithKeywords(THPVariable_numel), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
${py_method_defs}
{NULL}
};
上面的代碼中我們找不到prelu的任何身影。會(huì)不會(huì)prelu可以繞開(kāi)C/C++擴(kuò)展的方式直接被Python使用呢?所以不會(huì)出現(xiàn)在這里?答案是不會(huì),自古華山一條路,程序是不會(huì)跟你講潛規(guī)則的。那么既然最終代碼已經(jīng)跟丟了,作者一定是使用了黑魔法,作為麻瓜的我無(wú)計(jì)可施,本文也該結(jié)束了……
等等,上面的C代碼中好像混入了奇怪的東西——${py_method_defs}。這種語(yǔ)法好像C/C++語(yǔ)法里面是沒(méi)有的,反而是Shell這類(lèi)腳本里面才會(huì)有,難道是新特性?費(fèi)勁查找了一圈,并沒(méi)有發(fā)現(xiàn)C/C++中有這種語(yǔ)法,既然不是正經(jīng)語(yǔ)法,那么混入C/C++中肯定會(huì)導(dǎo)致編譯失敗,但是它確實(shí)就在那里。那么真相只有一個(gè):它就是個(gè)占位符,后面肯定會(huì)有真正的代碼替換它!
接下來(lái)怎么辦?搜索!使用py_method_defs作為關(guān)鍵字全局搜索,最終我們會(huì)發(fā)現(xiàn),確實(shí)是有一個(gè)Python腳本對(duì)這個(gè)占位符進(jìn)行了替換,而替換的結(jié)果就是我們一直尋找的prelu終于出現(xiàn)在了模塊_VariableFunctions之中。好,破案了。
但是就像警察破案,即便有單個(gè)證據(jù),也要找到其他證據(jù)形成完整證據(jù)鏈才能使得證據(jù)具有說(shuō)服力。雖然我們通過(guò)搜索得知了prelu會(huì)出現(xiàn)在模塊_VariableFunctions中,但是它究竟怎么來(lái)的目前還是很模糊:占位符在什么時(shí)候被誰(shuí)調(diào)用的腳本進(jìn)行了替換?
實(shí)際上,這一切都是有跡可循的。蹤跡依舊在setup.py中。進(jìn)入setup.py的主函數(shù),在調(diào)用setup函數(shù)之前會(huì)看到一個(gè)名為build_deps()的函數(shù)調(diào)用,此函數(shù)最終會(huì)調(diào)用指定平臺(tái)的CMake去按照根目錄下CMakeLists.txt中的腳本進(jìn)行構(gòu)建。根目錄下的CMakeLists.txt最終又會(huì)調(diào)用到caffe2目錄下的CMakeLists.txt(add_subdirectory(caffe2)),而caffe2/CMakeLists.txt中就會(huì)調(diào)用到進(jìn)行代碼生成的Python腳本,如下所示:
代碼生成腳本起調(diào)過(guò)程示意圖
// pytorch\caffe2\CMakeLists.txt
add_custom_command( OUTPUT
${TORCH_GENERATED_CODE}
COMMAND
"${PYTHON_EXECUTABLE}" tools/setup_helpers/generate_code.py
--declarations-path "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
--native-functions-path "aten/src/ATen/native/native_functions.yaml"
--nn-path "aten/src"
$:--disable-autograd>
$:--selected-op-list-path="${SELECTED_OP_LIST}">
--force_schema_registration
進(jìn)行代碼生成的主要流程如下面代碼塊所示,其大概流程是main()先解析傳遞給腳本的參數(shù),之后將參數(shù)傳遞給generate_code()。結(jié)合caffe2/CMakeLists.txt中腳本調(diào)用時(shí)傳遞的參數(shù)可知,generate_code()中的是三個(gè)gen_*()函數(shù)都得到了調(diào)用,而在gen_autograd_python()會(huì)調(diào)用到一個(gè)名為create_python_bindings()的函數(shù),這個(gè)函數(shù)就是真正執(zhí)行代碼生成的地方。
代碼生成器調(diào)用流程示意圖
// tools/setup_helpers/generate_code.py
def generate_code(ninja_global=None,
declarations_path=None,
nn_path=None,
native_functions_path=None,
install_dir=None,
subset=None,
disable_autograd=False,
force_schema_registration=False,
operator_selector=None):
if subset == "pybindings" or not subset:
gen_autograd_python(
declarations_path or DECLARATIONS_PATH,
native_functions_path or NATIVE_FUNCTIONS_PATH,
autograd_gen_dir,
autograd_dir)
if operator_selector is None:
operator_selector = SelectiveBuilder.get_nop_selector()
if subset == "libtorch" or not subset:
gen_autograd(
declarations_path or DECLARATIONS_PATH,
native_functions_path or NATIVE_FUNCTIONS_PATH,
autograd_gen_dir,
autograd_dir,
disable_autograd=disable_autograd,
operator_selector=operator_selector,
)
if subset == "python" or not subset:
gen_annotated(
native_functions_path or NATIVE_FUNCTIONS_PATH,
python_install_dir,
autograd_dir)
def main():
parser = argparse.ArgumentParser(description='Autogenerate code')
parser.add_argument('--declarations-path')
parser.add_argument('--native-functions-path')
parser.add_argument('--nn-path')
parser.add_argument('--ninja-global')
parser.add_argument('--install_dir')
parser.add_argument(
'--subset',
help='Subset of source files to generate. Can be "libtorch" or "pybindings". Generates both when omitted.'
)
parser.add_argument(
'--disable-autograd',
default=False,
action='store_true',
help='It can skip generating autograd related code when the flag is set',
)
parser.add_argument(
'--selected-op-list-path',
help='Path to the YAML file that contains the list of operators to include for custom build.',
)
parser.add_argument(
'--operators_yaml_path',
help='Path to the model YAML file that contains the list of operators to include for custom build.',
)
parser.add_argument(
'--force_schema_registration',
action='store_true',
help='force it to generate schema-only registrations for ops that are not'
'listed on --selected-op-list'
)
options = parser.parse_args()
generate_code(
options.ninja_global,
options.declarations_path,
options.nn_path,
options.native_functions_path,
options.install_dir,
options.subset,
options.disable_autograd,
options.force_schema_registration,
# options.selected_op_list
operator_selector=get_selector(options.selected_op_list_path, options.operators_yaml_path),
)
if __name__ == "__main__":
main()
// pytorch\tools\autograd\gen_autograd.py
def gen_autograd_python(aten_path, native_functions_path, out, autograd_dir):
from .load_derivatives import load_derivatives
differentiability_infos = load_derivatives(
os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path)
template_path = os.path.join(autograd_dir, 'templates')
# Generate Functions.h/cpp
from .gen_autograd_functions import gen_autograd_functions_python
gen_autograd_functions_python(
out, differentiability_infos, template_path)
# Generate Python bindings
from . import gen_python_functions
deprecated_path = os.path.join(autograd_dir, 'deprecated.yaml')
gen_python_functions.gen(
out, native_functions_path, deprecated_path, template_path)
// pytorch\tools\autograd\gen_python_functions.py
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Main Function
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def gen(out: str, native_yaml_path: str, deprecated_yaml_path: str, template_path: str) -> None:
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
methods = load_signatures(native_yaml_path, deprecated_yaml_path, method=True)
create_python_bindings(
fm, methods, is_py_variable_method, None, 'python_variable_methods.cpp', method=True)
functions = load_signatures(native_yaml_path, deprecated_yaml_path, method=False)
create_python_bindings(
fm, functions, is_py_torch_function, 'torch', 'python_torch_functions.cpp', method=False)
create_python_bindings(
fm, functions, is_py_nn_function, 'torch.nn', 'python_nn_functions.cpp', method=False)
create_python_bindings(
fm, functions, is_py_fft_function, 'torch.fft', 'python_fft_functions.cpp', method=False)
create_python_bindings(
fm, functions, is_py_linalg_function, 'torch.linalg', 'python_linalg_functions.cpp', method=False)
def create_python_bindings(
fm: FileManager,
pairs: Sequence[PythonSignatureNativeFunctionPair],
pred: Callable[[NativeFunction], bool],
module: Optional[str],
filename: str,
*,
method: bool,
) -> None:
"""Generates Python bindings to ATen functions"""
py_methods: List[str] = []
py_method_defs: List[str] = []
py_forwards: List[str] = []
grouped: Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]] = defaultdict(list)
for pair in pairs:
if pred(pair.function):
grouped[pair.function.func.name.name].append(pair)
for name in sorted(grouped.keys(), key=lambda x: str(x)):
overloads = grouped[name]
py_methods.append(method_impl(name, module, overloads, method=method))
py_method_defs.append(method_def(name, module, overloads, method=method))
py_forwards.extend(forward_decls(name, overloads, method=method))
fm.write_with_template(filename, filename, lambda: {
'generated_comment': '@' + f'generated from {fm.template_dir}/{filename}',
'py_forwards': py_forwards,
'py_methods': py_methods,
'py_method_defs': py_method_defs,
})
最終通過(guò)查看native_functions.yaml的內(nèi)容以及深入跟蹤加載native_functions.yaml的代碼發(fā)現(xiàn),native_functions.yaml中的prelu最終會(huì)被寫(xiě)到以python_torch_functions.cpp為模板的文件中,也就是調(diào)用
create_python_bindings(
fm, functions, is_py_torch_function, 'torch', 'python_torch_functions.cpp', method=False)
的時(shí)候被生成。整個(gè)生成的過(guò)程其實(shí)是很繁瑣的,一層層跟蹤后可以發(fā)現(xiàn),最終生成的代碼可以實(shí)現(xiàn)將一個(gè)名為at::的函數(shù)暴露給Python。例如我們的prelu,暴露給Python的API最終會(huì)調(diào)用一個(gè)名為at::prelu()的函數(shù)來(lái)做真正的計(jì)算。那么這個(gè)at::(例如at::prelu())的定義又在哪里呢?
還是一樣,故技重施!仍然使用Python腳本根據(jù)native_functions.yaml文件中的內(nèi)容去以pytorch\aten\src\ATen\templates目錄下的各種模板去生成對(duì)應(yīng)的實(shí)際C++源文件。最終結(jié)果是得到at::,在這個(gè)函數(shù)中,它調(diào)用了Dispatcher這個(gè)類(lèi)尋找到目標(biāo)函數(shù)的句柄。通常情況下能夠使用的函數(shù)句柄都通過(guò)一個(gè)叫Library的類(lèi)來(lái)管理。Python腳本以RegisterSchema.cpp為模板,生成了注冊(cè)這些目標(biāo)函數(shù)的注冊(cè)代碼,并通過(guò)一個(gè)名為T(mén)ORCH_LIBRARY的宏調(diào)用Library類(lèi)來(lái)注冊(cè)管理。
#define TORCH_LIBRARY(ns, m) \
static void TORCH_LIBRARY_init_ ## ns (torch::Library&); \
static const torch::detail::TorchLibraryInit TORCH_LIBRARY_static_init_ ## ns ( \
torch::Library::DEF, \
&TORCH_LIBRARY_init_ ## ns, \
#ns, c10::nullopt, __FILE__, __LINE__ \
); \
void TORCH_LIBRARY_init_ ## ns (torch::Library& m)
class TorchLibraryInit final {
private:
using InitFn = void(Library&);
Library lib_;
public:
TorchLibraryInit(Library::Kind kind, InitFn* fn, const char* ns, c10::optional<:dispatchkey> k, const char* file, uint32_t line)
: lib_(kind, ns, k, file, line) {
fn(lib_);
}
};
PyTorch組成示意圖
3. 總結(jié)
PyTorch雖然在使用上是非常的Pythonic,但實(shí)際上Python只不過(guò)是為了方便使用裹在C++代碼上的一層糖衣。用起來(lái)雖然好用,但是看起來(lái)實(shí)在是非常費(fèi)勁,特別是如果靜態(tài)的梳理代碼,很多用于連接Python C/C++接口與實(shí)際邏輯代碼之間的C++代碼都是通過(guò)Python腳本生成的。至此,整個(gè)大的線索已經(jīng)摸清了,剩下的就是去查看具體細(xì)節(jié)的實(shí)現(xiàn)。
說(shuō)實(shí)話(huà),人腦執(zhí)行Python代碼之后再去理解C++代碼實(shí)在是費(fèi)勁,也費(fèi)頭發(fā)。因此我決定的讓電腦去生成C++代碼再接著看更具體的細(xì)節(jié),比如究竟每一個(gè)算子是怎么注冊(cè)到Library之中的。
4. Bonus
我真心懷疑我們生活在一個(gè)虛擬機(jī)里,為什么呢?因?yàn)榈教幙梢?jiàn)運(yùn)用于計(jì)算機(jī)里面的空間和時(shí)間局部性原理的實(shí)例。就在我寫(xiě)完這個(gè)博客的時(shí)候,意外的發(fā)現(xiàn)了一篇PyTorch工程師講解PyTorch內(nèi)部原理的博文,這對(duì)后續(xù)讀代碼應(yīng)該會(huì)有很大幫助。等不及就戳它吧 http://blog.ezyang.com/2019/05/pytorch-internals/
與50位技術(shù)專(zhuān)家面對(duì)面20年技術(shù)見(jiàn)證,附贈(zèng)技術(shù)全景圖總結(jié)
以上是生活随笔為你收集整理的python pytorch fft_看PyTorch源代码的心路历程的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: TGA 2023 年度最佳游戏提名公布:
- 下一篇: python调用高德api路径规划_Py