Tensorflow源码解析2 -- 前后端连接的桥梁 - Session
1 Session概述
Session是TensorFlow前后端連接的橋梁。用戶利用session使得client能夠與master的執(zhí)行引擎建立連接,并通過session.run()來觸發(fā)一次計算。它建立了一套上下文環(huán)境,封裝了operation計算以及tensor求值的環(huán)境。
session創(chuàng)建時,系統(tǒng)會分配一些資源,比如graph引用、要連接的計算引擎的名稱等。故計算完畢后,需要使用session.close()關(guān)閉session,避免引起內(nèi)存泄漏,特別是graph無法釋放的問題。可以顯式調(diào)用session.close(),或利用with上下文管理器,或者直接使用InteractiveSession。
session之間采用共享graph的方式來提高運行效率。一個session只能運行一個graph實例,但一個graph可以運行在多個session中。一般情況下,創(chuàng)建session時如果不指定Graph實例,則會使用系統(tǒng)默認(rèn)Graph。常見情況下,我們都是使用一個graph,即默認(rèn)graph。當(dāng)session創(chuàng)建時,不會重新創(chuàng)建graph實例,而是默認(rèn)graph引用計數(shù)加1。當(dāng)session close時,引用計數(shù)減1。只有引用計數(shù)為0時,graph才會被回收。這種graph共享的方式,大大減少了graph創(chuàng)建和回收的資源消耗,優(yōu)化了TensorFlow運行效率。
?
2 默認(rèn)session
op運算和tensor求值時,如果沒有指定運行在哪個session中,則會運行在默認(rèn)session中。通過session.as_default()可以將自己設(shè)置為默認(rèn)session。但個人建議最好還是通過session.run(operator)和session.run(tensor)來進(jìn)行op運算和tensor求值。
operation.run()
operation.run()等價于tf.get_default_session().run(operation)
@tf_export("Operation") class Operation(object):# 通過operation.run()調(diào)用,進(jìn)行operation計算def run(self, feed_dict=None, session=None):_run_using_default_session(self, feed_dict, self.graph, session)def _run_using_default_session(operation, feed_dict, graph, session=None):# 沒有指定session,則獲取默認(rèn)sessionif session is None:session = get_default_session()# 最終還是通過session.run()進(jìn)行運行的。tf中任何運算,都是通過session來run的。# 通過session來建立client和master的連接,并將graph發(fā)送給master,master再進(jìn)行執(zhí)行session.run(operation, feed_dict)tensor.eval()
tensor.eval()等價于tf.get_default_session().run(tensor), 如下
@tf_export("Tensor") class Tensor(_TensorLike):# 通過tensor.eval()調(diào)用,進(jìn)行tensor運算def eval(self, feed_dict=None, session=None):return _eval_using_default_session(self, feed_dict, self.graph, session)def _eval_using_default_session(tensors, feed_dict, graph, session=None):# 如果沒有指定session,則獲取默認(rèn)sessionif session is None:session = get_default_session()return session.run(tensors, feed_dict)默認(rèn)session的管理
tf通過運行時維護(hù)的session本地線程棧,來管理默認(rèn)session。故不同的線程會有不同的默認(rèn)session,默認(rèn)session是線程作用域的。
# session棧 _default_session_stack = _DefaultStack()# 獲取默認(rèn)session的接口 @tf_export("get_default_session") def get_default_session():return _default_session_stack.get_default()# _DefaultStack默認(rèn)session棧是線程相關(guān)的 class _DefaultStack(threading.local):# 默認(rèn)session棧的創(chuàng)建,其實就是一個listdef __init__(self):super(_DefaultStack, self).__init__()self._enforce_nesting = Trueself.stack = [] # 獲取默認(rèn)sessiondef get_default(self):return self.stack[-1] if len(self.stack) >= 1 else None?
3 前端Session類型
session類圖
會話Session的UML類圖如下
分為兩種類型,普通Session和交互式InteractiveSession。InteractiveSession和Session基本相同,區(qū)別在于
Session和InteractiveSession的代碼邏輯不多,主要邏輯均在其父類BaseSession中。主要代碼如下
@tf_export('Session') class Session(BaseSession):def __init__(self, target='', graph=None, config=None):# session創(chuàng)建的主要邏輯都在其父類BaseSession中super(Session, self).__init__(target, graph, config=config)self._default_graph_context_manager = Noneself._default_session_context_manager = None @tf_export('InteractiveSession') class InteractiveSession(BaseSession):def __init__(self, target='', graph=None, config=None):self._explicitly_closed = False# 將自己設(shè)置為default sessionself._default_session = self.as_default()self._default_session.enforce_nesting = False# 自動調(diào)用上下文管理器的__enter__()方法self._default_session.__enter__()self._explicit_graph = graphdef close(self):super(InteractiveSession, self).close()## 省略無關(guān)代碼## 自動調(diào)用上下文管理器的__exit__()方法,避免內(nèi)存泄漏self._default_session.__exit__(None, None, None)self._default_session = NoneBaseSession
BaseSession基本包含了所有的會話實現(xiàn)邏輯。包括會話的整個生命周期,也就是創(chuàng)建 執(zhí)行 關(guān)閉和銷毀四個階段。生命周期后面詳細(xì)分析。BaseSession包含的主要成員變量有g(shù)raph引用,序列化的graph_def, 要連接的tf引擎target,session配置信息config等。
?
4 后端Session類型
在后端master中,根據(jù)前端client調(diào)用tf.Session(target='', graph=None, config=None)時指定的target,來創(chuàng)建不同的Session。target為要連接的tf后端執(zhí)行引擎,默認(rèn)為空字符串。Session創(chuàng)建采用了抽象工廠模式,如果為空字符串,則創(chuàng)建本地DirectSession,如果以grpc://開頭,則創(chuàng)建分布式GrpcSession。類圖如下
DirectSession只能利用本地設(shè)備,將任務(wù)創(chuàng)建到本地的CPU GPU上。而GrpcSession則可以利用遠(yuǎn)端分布式設(shè)備,將任務(wù)創(chuàng)建到其他機器的CPU GPU上,然后通過grpc協(xié)議進(jìn)行通信。grpc協(xié)議是谷歌發(fā)明并開源的遠(yuǎn)程通信協(xié)議。
?
5 Session生命周期
Session作為前后端連接的橋梁,以及上下文運行環(huán)境,其生命周期尤其關(guān)鍵。大致分為4個階段
生命周期方法入口基本都在前端Python的BaseSession中,它會通過swig自動生成的函數(shù)符號映射關(guān)系,調(diào)用C層的實現(xiàn)。
5.1 創(chuàng)建
先從BaseSession類的init方法看起,只保留了主要代碼。
def __init__(self, target='', graph=None, config=None):# graph表示構(gòu)建的圖。TensorFlow的一個session會對應(yīng)一個圖。這個圖包含了所有涉及到的算子# graph如果沒有設(shè)置(通常都不會設(shè)置),則使用默認(rèn)graphif graph is None:self._graph = ops.get_default_graph()else:self._graph = graphself._opened = Falseself._closed = Falseself._current_version = 0self._extend_lock = threading.Lock()# target為要連接的tf執(zhí)行引擎if target is not None:self._target = compat.as_bytes(target)else:self._target = Noneself._delete_lock = threading.Lock()self._dead_handles = []# config為session的配置信息if config is not None:self._config = configself._add_shapes = config.graph_options.infer_shapeselse:self._config = Noneself._add_shapes = Falseself._created_with_new_api = ops._USE_C_API# 調(diào)用C層來創(chuàng)建sessionself._session = Noneopts = tf_session.TF_NewSessionOptions(target=self._target, config=config)self._session = tf_session.TF_NewSession(self._graph._c_graph, opts, status)BaseSession先進(jìn)行成員變量的賦值,然后調(diào)用TF_NewSession來創(chuàng)建session。TF_NewSession()方法由swig自動生成,在bazel-bin/tensorflow/python/pywrap_tensorflow_internal.py中
def TF_NewSession(graph, opts, status):return _pywrap_tensorflow_internal.TF_NewSession(graph, opts, status)_pywrap_tensorflow_internal包含了C層函數(shù)的符號表。在swig模塊import時,會加載pywrap_tensorflow_internal.so動態(tài)鏈接庫,從而得到符號表。在pywrap_tensorflow_internal.cc中,注冊了供Python調(diào)用的函數(shù)的符號表,從而實現(xiàn)Python到C的函數(shù)映射和調(diào)用。
// c++函數(shù)調(diào)用的符號表,Python通過它可以調(diào)用到C層代碼。符號表和動態(tài)鏈接庫由swig自動生成 static PyMethodDef SwigMethods[] = {// .. 省略其他函數(shù)定義// TF_NewSession的符號表,通過這個映射,Python中就可以調(diào)用到C層代碼了。{ (char *)"TF_NewSession", _wrap_TF_NewSession, METH_VARARGS, NULL},// ... 省略其他函數(shù)定義 }最終調(diào)用到c_api.c中的TF_NewSession()
// TF_NewSession創(chuàng)建session的新實現(xiàn),在C層后端代碼中 TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,TF_Status* status) {Session* session;// 創(chuàng)建sessionstatus->status = NewSession(opt->options, &session);if (status->status.ok()) {TF_Session* new_session = new TF_Session(session, graph);if (graph != nullptr) {// 采用了引用計數(shù)方式,多個session共享一個圖實例,效率更高。// session創(chuàng)建時,引用計數(shù)加1。session close時引用計數(shù)減1。引用計數(shù)為0時,graph才會被回收。mutex_lock l(graph->mu);graph->sessions[new_session] = Status::OK();}return new_session;} else {DCHECK_EQ(nullptr, session);return nullptr;} }session創(chuàng)建時,并創(chuàng)建graph,而是采用共享方式,只是引用計數(shù)加1了。這種方式減少了session創(chuàng)建和關(guān)閉時的資源消耗,提高了運行效率。NewSession()根據(jù)前端傳遞的target,使用sessionFactory創(chuàng)建對應(yīng)的TensorFlow::Session實例。
Status NewSession(const SessionOptions& options, Session** out_session) {SessionFactory* factory;const Status s = SessionFactory::GetFactory(options, &factory);// 通過sessionFactory創(chuàng)建多態(tài)的Session。本地session為DirectSession,分布式為GRPCSession*out_session = factory->NewSession(options);if (!*out_session) {return errors::Internal("Failed to create session.");}return Status::OK(); }創(chuàng)建session采用了抽象工廠模式。根據(jù)client傳遞的target,來創(chuàng)建不同的session。如果target為空字符串,則創(chuàng)建本地DirectSession。如果以grpc://開頭,則創(chuàng)建分布式GrpcSession。TensorFlow包含本地運行時和分布式運行時兩種運行模式。
下面來看DirectSessionFactory的NewSession()方法
class DirectSessionFactory : public SessionFactory {public:Session* NewSession(const SessionOptions& options) override {std::vector<Device*> devices;// job在本地執(zhí)行const Status s = DeviceFactory::AddDevices(options, "/job:localhost/replica:0/task:0", &devices);if (!s.ok()) {LOG(ERROR) << s;return nullptr;}DirectSession* session =new DirectSession(options, new DeviceMgr(devices), this);{mutex_lock l(sessions_lock_);sessions_.push_back(session);}return session;}GrpcSessionFactory的NewSession()方法就不詳細(xì)分析了,它會將job任務(wù)創(chuàng)建在分布式設(shè)備上,各job通過grpc協(xié)議通信。
5.2 運行
通過session.run()可以啟動graph的執(zhí)行。入口在BaseSession的run()方法中, 同樣只列出關(guān)鍵代碼
class BaseSession(SessionInterface):def run(self, fetches, feed_dict=None, options=None, run_metadata=None):# fetches可以為單個變量,或者數(shù)組,或者元組。它是圖的一部分,可以是操作operation,也可以是數(shù)據(jù)tensor,或者他們的名字String# feed_dict為對應(yīng)placeholder的實際訓(xùn)練數(shù)據(jù),它的類型為字典result = self._run(None, fetches, feed_dict, options_ptr,run_metadata_ptr)return resultdef _run(self, handle, fetches, feed_dict, options, run_metadata):# 創(chuàng)建fetch處理器fetch_handlerfetch_handler = _FetchHandler(self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)# 經(jīng)過不同類型的fetch_handler處理,得到最終的fetches和targets# targets為要執(zhí)行的operation,fetches為要執(zhí)行的tensor_ = self._update_with_movers(feed_dict_tensor, feed_map)final_fetches = fetch_handler.fetches()final_targets = fetch_handler.targets()# 開始運行if final_fetches or final_targets or (handle and feed_dict_tensor):results = self._do_run(handle, final_targets, final_fetches,feed_dict_tensor, options, run_metadata)else:results = []# 輸出結(jié)果到results中return fetch_handler.build_results(self, results)def _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata):# 將要運行的operation添加到graph中self._extend_graph()# 執(zhí)行一次運行run,會調(diào)用底層C來實現(xiàn)return tf_session.TF_SessionPRunSetup_wrapper(session, feed_list, fetch_list, target_list, status)# 將要運行的operation添加到graph中def _extend_graph(self):with self._extend_lock:if self._graph.version > self._current_version:# 生成graph_def對象,它是graph的序列化表示graph_def, self._current_version = self._graph._as_graph_def(from_version=self._current_version, add_shapes=self._add_shapes)# 通過TF_ExtendGraph將序列化后的graph,也就是graph_def傳遞給后端with errors.raise_exception_on_not_ok_status() as status:tf_session.TF_ExtendGraph(self._session,graph_def.SerializeToString(), status)self._opened = True邏輯還是十分復(fù)雜的,主要有一下幾步
我們分別來看extend和run。
5.2.1 extend添加節(jié)點到graph中
TF_ExtendGraph()會調(diào)用到c_api中,這個邏輯同樣通過swig工具自動生成。下面看c_api.cc中的TF_ExtendGraph()方法
// 增加節(jié)點到graph中,proto為序列化后的graph void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,size_t proto_len, TF_Status* status) {GraphDef g;// 先將proto反序列化,得到client傳遞的graph,放入g中if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {status->status = InvalidArgument("Invalid GraphDef");return;}// 再調(diào)用session的extend方法。根據(jù)創(chuàng)建的不同session類型,多態(tài)調(diào)用不同方法。status->status = s->session->Extend(g); }后端系統(tǒng)根據(jù)生成的Session類型,多態(tài)的調(diào)用Extend方法。如果是本地session,則調(diào)用DirectSession的Extend()方法。如果是分布式session,則調(diào)用GrpcSession的相關(guān)方法。下面來看GrpcSession的Extend方法。
Status GrpcSession::Extend(const GraphDef& graph) {CallOptions call_options;call_options.SetTimeout(options_.config.operation_timeout_in_ms());return ExtendImpl(&call_options, graph); }Status GrpcSession::ExtendImpl(CallOptions* call_options,const GraphDef& graph) {bool handle_is_empty;{mutex_lock l(mu_);handle_is_empty = handle_.empty();}if (handle_is_empty) {// 如果graph句柄為空,則表明graph還沒有創(chuàng)建好,此時extend就等同于createreturn Create(graph);}mutex_lock l(mu_);ExtendSessionRequest req;req.set_session_handle(handle_);*req.mutable_graph_def() = graph;req.set_current_graph_version(current_graph_version_);ExtendSessionResponse resp;// 調(diào)用底層實現(xiàn),來添加節(jié)點到graph中Status s = master_->ExtendSession(call_options, &req, &resp);if (s.ok()) {current_graph_version_ = resp.new_graph_version();}return s; }Extend()方法中要注意的一點是,如果是首次執(zhí)行Extend(), 則要先調(diào)用Create()方法進(jìn)行g(shù)raph的注冊。否則才是執(zhí)行添加節(jié)點到graph中。
5.2.2 run執(zhí)行圖的計算
同樣,Python通過swig自動生成的代碼,來實現(xiàn)對C API的調(diào)用。C層實現(xiàn)在c_api.cc的TF_Run()中。
// session.run()的C層實現(xiàn) void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options,// Input tensors,輸入的數(shù)據(jù)tensorconst char** c_input_names, TF_Tensor** c_inputs, int ninputs,// Output tensors,運行計算后輸出的數(shù)據(jù)tensorconst char** c_output_names, TF_Tensor** c_outputs, int noutputs,// Target nodes,要運行的節(jié)點const char** c_target_oper_names, int ntargets,TF_Buffer* run_metadata, TF_Status* status) {// 省略一段代碼TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names,c_outputs, target_oper_names, run_metadata, status); }// 真正的實現(xiàn)了session.run() static void TF_Run_Helper() {RunMetadata run_metadata_proto;// 調(diào)用不同的session實現(xiàn)類的run方法,來執(zhí)行result = session->Run(run_options_proto, input_pairs, output_tensor_names,target_oper_names, &outputs, &run_metadata_proto);// 省略代碼 }最終會調(diào)用創(chuàng)建的session來執(zhí)行run方法。DirectSession和GrpcSession的Run()方法會有所不同。后面很復(fù)雜,就不接著分析了。
5.3 關(guān)閉session
通過session.close()來關(guān)閉session,釋放相關(guān)資源,防止內(nèi)存泄漏。
class BaseSession(SessionInterface):def close(self):tf_session.TF_CloseSession(self._session, status)會調(diào)用到C API的TF_CloseSession()方法。
void TF_CloseSession(TF_Session* s, TF_Status* status) {status->status = s->session->Close(); }最終根據(jù)創(chuàng)建的session,多態(tài)的調(diào)用其Close()方法。同樣分為DirectSession和GrpcSession兩種。
::tensorflow::Status DirectSession::Close() {cancellation_manager_->StartCancel();{mutex_lock l(closed_lock_);if (closed_) return ::tensorflow::Status::OK();closed_ = true;}// 注銷sessionif (factory_ != nullptr) factory_->Deregister(this);return ::tensorflow::Status::OK(); }DirectSessionFactory中的Deregister()方法如下
void Deregister(const DirectSession* session) {mutex_lock l(sessions_lock_);// 釋放相關(guān)資源sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session),sessions_.end());}5.4 銷毀session
session的銷毀是由Python的GC自動執(zhí)行的。python通過引用計數(shù)方法來判斷是否回收對象。當(dāng)對象的引用計數(shù)為0,且虛擬機觸發(fā)了GC時,會調(diào)用對象的__del__()方法來銷毀對象。引用計數(shù)法有個很致命的問題,就是無法解決循環(huán)引用問題,故會存在內(nèi)存泄漏。Java虛擬機采用了調(diào)用鏈分析的方式來決定哪些對象會被回收。
class BaseSession(SessionInterface): def __del__(self):# 先close,防止用戶沒有調(diào)用close()try:self.close()# 再調(diào)用c api的TF_DeleteSession來銷毀sessionif self._session is not None:try:status = c_api_util.ScopedTFStatus()if self._created_with_new_api:tf_session.TF_DeleteSession(self._session, status)c_api.cc中的相關(guān)邏輯如下
void TF_DeleteSession(TF_Session* s, TF_Status* status) {status->status = Status::OK();TF_Graph* const graph = s->graph;if (graph != nullptr) {graph->mu.lock();graph->sessions.erase(s);// 如果graph的引用計數(shù)為0,也就是graph沒有被任何session持有,則考慮銷毀graph對象const bool del = graph->delete_requested && graph->sessions.empty();graph->mu.unlock();// 銷毀graph對象if (del) delete graph;}// 銷毀session和TF_Session delete s->session;delete s; }TF_DeleteSession()會判斷graph的引用計數(shù)是否為0,如果為0,則會銷毀graph。然后銷毀session和TF_Session對象。通過Session實現(xiàn)類的析構(gòu)函數(shù),來銷毀session,釋放線程池Executor,資源管理器ResourceManager等資源。
DirectSession::~DirectSession() {for (auto& it : partial_runs_) {it.second.reset(nullptr);}// 釋放線程池Executorfor (auto& it : executors_) {it.second.reset();}for (auto d : device_mgr_->ListDevices()) {d->op_segment()->RemoveHold(session_handle_);}// 釋放ResourceManagerfor (auto d : device_mgr_->ListDevices()) {d->ClearResourceMgr();}// 釋放CancellationManager實例functions_.clear();delete cancellation_manager_;// 釋放ThreadPool for (const auto& p_and_owned : thread_pools_) {if (p_and_owned.second) delete p_and_owned.first;}execution_state_.reset(nullptr);flib_def_.reset(nullptr); }?
6 總結(jié)
Session是TensorFlow的client和master連接的橋梁,client任何運算也是通過session來run。它是client端最重要的對象。在Python層和C++層,均有不同的session實現(xiàn)。session生命周期會經(jīng)歷四個階段,create run close和del。四個階段均由Python前端開始,最終調(diào)用到C層后端實現(xiàn)。由此也可以看到,TensorFlow框架的前后端分離和模塊化設(shè)計是多么的精巧。
原文鏈接
本文為云棲社區(qū)原創(chuàng)內(nèi)容,未經(jīng)允許不得轉(zhuǎn)載。
總結(jié)
以上是生活随笔為你收集整理的Tensorflow源码解析2 -- 前后端连接的桥梁 - Session的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Dubbo下一站:Apache顶级项目
- 下一篇: 离职阿里三年后,他又回来了