MXNet中x.grad源码追溯
生活随笔
收集整理的這篇文章主要介紹了
MXNet中x.grad源码追溯
小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
Python測(cè)試代碼如https://zh.gluon.ai/chapter_prerequisite/autograd.html
本文追溯x.grad這一行代碼的調(diào)用
grad調(diào)用的是函數(shù)MXNDArrayGetGrad,/usr/local/lib/python3.7/dist-packages/mxnet-1.5.0-py3.7.egg/mxnet/ndarray/ndarray.py
MXNDArrayGetGrad的源碼依舊是在文件src/c_api/c_api.cc中,
NDArray ret = arr->grad();
ret就是獲取到的梯度
這里grad的源碼文件為src/ndarray/ndarray.cc,
Imperative::AGInfo& info = Imperative::AGInfo::Get(entry_.node);return info.out_grads[0];
這里Imperative::AGInfo::Get的源碼文件為?include/mxnet/imperative.h
return dmlc::get<AGInfo>(node->info);
這里get的源碼文件為3rdparty/dmlc-core/include/dmlc/any.h
return *any::TypeInfo<T>::get_ptr(&(src.data_));
這個(gè)get_ptr調(diào)用的是同文件中的如下代碼:
template<typename T>
class any::TypeOnHeap {public:inline static T* get_ptr(any::Data* data) {return static_cast<T*>(data->pheap);}
回到上面的代碼,那個(gè)entry_是NDArrary類的一個(gè)對(duì)象:
/*! \brief node entry for autograd */nnvm::NodeEntry entry_;
NodeEntry 源碼文件為include/nnvm/node.h,
#大體來(lái)講,梯度就是arr->entry_.node->info.data_.pheap;
總結(jié)
以上是生活随笔為你收集整理的MXNet中x.grad源码追溯的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: MXNET源码中TShape值的获取和打
- 下一篇: mxnet 中的 DepthwiseCo