Tensorflow源码解析6 -- TensorFlow本地运行时
1 概述
TensorFlow后端分為四層,運(yùn)行時(shí)層、計(jì)算層、通信層、設(shè)備層。運(yùn)行時(shí)作為第一層,實(shí)現(xiàn)了session管理、graph管理等很多重要的邏輯,是十分關(guān)鍵的一層。根據(jù)任務(wù)分布的不同,運(yùn)行時(shí)又分為本地運(yùn)行時(shí)和分布式運(yùn)行時(shí)。本地運(yùn)行時(shí),所有任務(wù)運(yùn)行于本地同一進(jìn)程內(nèi)。而分布式運(yùn)行時(shí),則允許任務(wù)運(yùn)行在不同機(jī)器上。
Tensorflow的運(yùn)行,通過(guò)session搭建了前后端溝通的橋梁,前端幾乎所有操作都是通過(guò)session進(jìn)行。session的生命周期由創(chuàng)建、運(yùn)行、關(guān)閉、銷(xiāo)毀組成,前文已經(jīng)詳細(xì)講述過(guò)。可以將session看做TensorFlow運(yùn)行的載體。而TensorFlow運(yùn)行的核心對(duì)象,則是計(jì)算圖Graph。它由計(jì)算算子和計(jì)算數(shù)據(jù)兩部分構(gòu)成,可以完整描述整個(gè)計(jì)算內(nèi)容。Graph的生命周期包括構(gòu)建和傳遞、剪枝、分裂、執(zhí)行等步驟,本文會(huì)詳細(xì)講解。理解TensorFlow的運(yùn)行時(shí),重點(diǎn)就是理解會(huì)話session和計(jì)算圖Graph。
本地運(yùn)行時(shí),client master和worker都在本地機(jī)器的同一進(jìn)程內(nèi),均通過(guò)DirectSession類(lèi)來(lái)描述。由于在同一進(jìn)程內(nèi),三者間可以共享內(nèi)存,通過(guò)DirectSession的相關(guān)函數(shù)實(shí)現(xiàn)調(diào)用。
client前端直接面向用戶,負(fù)責(zé)session的創(chuàng)建,計(jì)算圖Graph的構(gòu)造。并通過(guò)session.run()將Graph序列化后傳遞給master。master收到后,先反序列化得到Graph,然后根據(jù)反向依賴(lài)關(guān)系,得到幾個(gè)最小依賴(lài)子圖,這一步稱(chēng)為剪枝。之后master根據(jù)可運(yùn)行的設(shè)備情況,將子圖分裂到不同設(shè)備上,從而可以并發(fā)執(zhí)行,這一步稱(chēng)為分裂。最后,由每個(gè)設(shè)備上的worker并行執(zhí)行分裂后的子圖,得到計(jì)算結(jié)果后返回。
2 Graph構(gòu)建和傳遞
session.run()開(kāi)啟了后端Graph的構(gòu)建和傳遞。在前文session生命周期的講解中,session.run()時(shí)會(huì)先調(diào)用_extend_graph()將要運(yùn)行的Operation添加到Graph中,然后再啟動(dòng)運(yùn)行過(guò)程。extend_graph()會(huì)先將graph序列化,得到graph_def,然后調(diào)用后端的TF_ExtendGraph()方法。下面我們從c_api.cc中的TF_ExtendGraph()看起。
// 增加節(jié)點(diǎn)到graph中,proto為序列化后的graph void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,size_t proto_len, TF_Status* status) {GraphDef g;// 先將proto轉(zhuǎn)換為GrapDef。graphDef是圖的序列化表示,反序列化在后面。if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {status->status = InvalidArgument("Invalid GraphDef");return;}// 再調(diào)用session的extend方法。根據(jù)創(chuàng)建的不同session類(lèi)型,多態(tài)調(diào)用不同方法。status->status = s->session->Extend(g); }后端系統(tǒng)根據(jù)生成的Session類(lèi)型,多態(tài)的調(diào)用Extend方法。如果是本地session,則調(diào)用DirectSession的Extend()方法。下面看DirectSession的Extend()方法。
Status DirectSession::Extend(const GraphDef& graph) {// 保證線程安全,然后調(diào)用ExtendLocked()mutex_lock l(graph_def_lock_);return ExtendLocked(graph); }// 主要任務(wù)就是創(chuàng)建GraphExecutionState對(duì)象。 Status DirectSession::ExtendLocked(const GraphDef& graph) {bool already_initialized;if (already_initialized) {TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library()));// 創(chuàng)建GraphExecutionStatestd::unique_ptr<GraphExecutionState> state;TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));execution_state_.swap(state);}return Status::OK(); }最終創(chuàng)建了GraphExecutionState對(duì)象。它主要工作有
構(gòu)造Graph:反序列化GraphDef為Graph
由于client傳遞給master的是序列化后的計(jì)算圖,所以master需要先反序列化。通過(guò)ConvertGraphDefToGraph實(shí)現(xiàn)。代碼在graph_constructor.cc中,如下
Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,const GraphDef& gdef, Graph* g) {ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner,/*return_tensors=*/nullptr, /*return_nodes=*/nullptr,/*missing_unused_input_map_keys=*/nullptr); }編排OP
Operation編排的目的是,將op以最高效的方式,放在合適的硬件設(shè)備上,從而最大限度的發(fā)揮硬件能力。通過(guò)Placer的run()方法進(jìn)行,算法很復(fù)雜,在placer.cc中,我也看得不大懂,就不展開(kāi)了。
3 Graph剪枝
反序列化構(gòu)建好Graph,并進(jìn)行了Operation編排后,master就開(kāi)始對(duì)Graph剪枝了。剪枝就是根據(jù)Graph的輸入輸出列表,反向遍歷全圖,找到幾個(gè)最小依賴(lài)的子圖,從而方便并行計(jì)算。
Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,std::unique_ptr<ClientGraph>* out) {std::unique_ptr<Graph> ng;Status s = OptimizeGraph(options, &ng);if (!s.ok()) {// 1 復(fù)制一份原始的Graphng.reset(new Graph(flib_def_.get()));CopyGraph(*graph_, ng.get());}// 2 剪枝,根據(jù)輸入輸出feed fetch,對(duì)graph進(jìn)行增加節(jié)點(diǎn)或刪除節(jié)點(diǎn)等操作。通過(guò)RewriteGraphForExecution()方法subgraph::RewriteGraphMetadata rewrite_metadata;if (session_options_ == nullptr ||!session_options_->config.graph_options().place_pruned_graph()) {TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(ng.get(), options.feed_endpoints, options.fetch_endpoints,options.target_nodes, device_set_->client_device()->attributes(),options.use_function_convention, &rewrite_metadata));}// 3 處理優(yōu)化選項(xiàng)optimization_optionsGraphOptimizationPassOptions optimization_options;optimization_options.session_options = session_options_;optimization_options.graph = &ng;optimization_options.flib_def = flib.get();optimization_options.device_set = device_set_;TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));// 4 復(fù)制一份ClientGraphstd::unique_ptr<ClientGraph> dense_copy(new ClientGraph(std::move(flib), rewrite_metadata.feed_types,rewrite_metadata.fetch_types));CopyGraph(*ng, &dense_copy->graph);*out = std::move(dense_copy);return Status::OK(); }剪枝的關(guān)鍵在RewriteGraphForExecution()方法中,在subgraph.cc文件中。
Status RewriteGraphForExecution(Graph* g, const gtl::ArraySlice<string>& fed_outputs,const gtl::ArraySlice<string>& fetch_outputs,const gtl::ArraySlice<string>& target_node_names,const DeviceAttributes& device_info, bool use_function_convention,RewriteGraphMetadata* out_metadata) {std::unordered_set<string> endpoints;// 1 構(gòu)建節(jié)點(diǎn)的name_index,從而快速索引節(jié)點(diǎn)。為FeedInputs,FetchOutputs等步驟所使用NameIndex name_index;name_index.reserve(g->num_nodes());for (Node* n : g->nodes()) {name_index[n->name()] = n;}// 2 FeedInputs,添加輸入節(jié)點(diǎn)if (!fed_outputs.empty()) {FeedInputs(g, device_info, fed_outputs, use_function_convention, &name_index, &out_metadata->feed_types);}// 3 FetchOutputs,添加輸出節(jié)點(diǎn)std::vector<Node*> fetch_nodes;if (!fetch_outputs.empty()) {FetchOutputs(g, device_info, fetch_outputs, use_function_convention, &name_index, &fetch_nodes, &out_metadata->fetch_types);}// 4 剪枝,形成若干最小依賴(lài)子圖if (!fetch_nodes.empty() || !target_node_names.empty()) {PruneForTargets(g, name_index, fetch_nodes, target_node_names);}return Status::OK(); }主要有4步
PruneForTargets()從輸出節(jié)點(diǎn)反向搜索,按照BFS廣度優(yōu)先算法,找到若干個(gè)最小依賴(lài)子圖。
static Status PruneForTargets(Graph* g, const subgraph::NameIndex& name_index,const std::vector<Node*>& fetch_nodes,const gtl::ArraySlice<string>& target_nodes) {string not_found;std::unordered_set<const Node*> targets;// 1 AddNodeToTargets添加節(jié)點(diǎn)到targets中,從輸出節(jié)點(diǎn)按照BFS反向遍歷。for (Node* n : fetch_nodes) {AddNodeToTargets(n->name(), name_index, &targets);}// 2 剪枝,得到多個(gè)最小依賴(lài)子圖子圖PruneForReverseReachability(g, targets);// 修正Source和Sink節(jié)點(diǎn)的依賴(lài)邊,將沒(méi)有輸出邊的節(jié)點(diǎn)連接到sink node上FixupSourceAndSinkEdges(g);return Status::OK(); }主要有3步
PruneForReverseReachability()在algorithm.cc文件中,算法就不分析了,總體是按照BFS廣度優(yōu)先算法搜索的。
bool PruneForReverseReachability(Graph* g,std::unordered_set<const Node*> visited) {// 按照BFS廣度優(yōu)先算法,從輸出節(jié)點(diǎn)開(kāi)始,反向搜索節(jié)點(diǎn)的依賴(lài)關(guān)系std::deque<const Node*> queue;for (const Node* n : visited) {queue.push_back(n);}while (!queue.empty()) {const Node* n = queue.front();queue.pop_front();for (const Node* in : n->in_nodes()) {if (visited.insert(in).second) {queue.push_back(in);}}}// 刪除不在"visited"列表中的節(jié)點(diǎn),說(shuō)明最小依賴(lài)子圖不依賴(lài)此節(jié)點(diǎn)std::vector<Node*> all_nodes;all_nodes.reserve(g->num_nodes());for (Node* n : g->nodes()) {all_nodes.push_back(n);}bool any_removed = false;for (Node* n : all_nodes) {if (visited.count(n) == 0 && !n->IsSource() && !n->IsSink()) {g->RemoveNode(n);any_removed = true;}}return any_removed; }4 Graph分裂
剪枝完成后,master即得到了最小依賴(lài)子圖ClientGraph。然后根據(jù)本地機(jī)器的硬件設(shè)備,以及op所指定的運(yùn)行設(shè)備等關(guān)系,將圖分裂為多個(gè)Partition Graph,傳遞到相關(guān)設(shè)備的worker上,從而進(jìn)行并行運(yùn)算。這就是Graph的分裂。
Graph分裂的算法在graph_partition.cc的Partition()方法中。算法比較復(fù)雜,我們就不分析了。圖分裂有兩種
splitbydevice按設(shè)備分裂,也就是將Graph分裂到本地各CPU GPU上。本地運(yùn)行時(shí)只使用按設(shè)備分裂。
static string SplitByDevice(const Node* node) {return node->assigned_device_name(); }splitByWorker 按worker分裂, 也就是將Graph分裂到各分布式任務(wù)上,常用于分布式運(yùn)行時(shí)。分布式運(yùn)行時(shí),圖會(huì)經(jīng)歷兩次分裂。先splitByWorker分裂到各分布式任務(wù)上,一般是各分布式機(jī)器。然后splitbydevice二次分裂到分布式機(jī)器的CPU GPU等設(shè)備上。
static string SplitByWorker(const Node* node) {string task;string device;DeviceNameUtils::SplitDeviceName(node->assigned_device_name(), &task, &device);return task; }5 Graph執(zhí)行
Graph經(jīng)過(guò)master剪枝和分裂后,就可以在本地的各CPU GPU設(shè)備上執(zhí)行了。這個(gè)過(guò)程的管理者叫worker。一般一個(gè)worker對(duì)應(yīng)一個(gè)分裂后的子圖partitionGraph。每個(gè)worker啟動(dòng)一個(gè)執(zhí)行器Executor,入度為0的節(jié)點(diǎn)數(shù)據(jù)依賴(lài)已經(jīng)ready了,故可以并行執(zhí)行。等所有Executor執(zhí)行完畢后,通知執(zhí)行完畢。
各CPU GPU設(shè)備間可能需要數(shù)據(jù)通信,通過(guò)創(chuàng)建send/recv節(jié)點(diǎn)來(lái)解決。數(shù)據(jù)發(fā)送方創(chuàng)建send節(jié)點(diǎn),將數(shù)據(jù)放在send節(jié)點(diǎn)內(nèi),不阻塞。數(shù)據(jù)接收方創(chuàng)建recv節(jié)點(diǎn),從recv節(jié)點(diǎn)中取出數(shù)據(jù),recv節(jié)點(diǎn)中如果沒(méi)有數(shù)據(jù)則阻塞。這又是一個(gè)典型的生產(chǎn)者-消費(fèi)者關(guān)系。
Graph執(zhí)行的代碼邏輯在direct_session.cc文件的DirectSession::Run()方法中。代碼邏輯很長(zhǎng),我們抽取其中的關(guān)鍵部分。
Status DirectSession::Run(const RunOptions& run_options,const NamedTensorList& inputs,const std::vector<string>& output_names,const std::vector<string>& target_nodes,std::vector<Tensor>* outputs,RunMetadata* run_metadata) {// 1 將輸入tensor的name取出,組成一個(gè)列表,方便之后快速索引輸入tensorstd::vector<string> input_tensor_names;input_tensor_names.reserve(inputs.size());for (const auto& it : inputs) {input_tensor_names.push_back(it.first);}// 2 傳遞輸入數(shù)據(jù)給executor,通過(guò)FunctionCallFrame方式。// 2.1 創(chuàng)建FunctionCallFrame,用來(lái)輸入數(shù)據(jù)給executor,并從executor中取出數(shù)據(jù)。FunctionCallFrame call_frame(executors_and_keys->input_types,executors_and_keys->output_types);// 2.2 構(gòu)造輸入數(shù)據(jù)feed_argsgtl::InlinedVector<Tensor, 4> feed_args(inputs.size());for (const auto& it : inputs) {if (it.second.dtype() == DT_RESOURCE) {Tensor tensor_from_handle;ResourceHandleToInputTensor(it.second, &tensor_from_handle);feed_args[executors_and_keys->input_name_to_index[it.first]] = tensor_from_handle;} else {feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;}}// 2.3 將feed_args輸入數(shù)據(jù)設(shè)置到Arg節(jié)點(diǎn)上const Status s = call_frame.SetArgs(feed_args);// 3 開(kāi)始執(zhí)行executor// 3.1 創(chuàng)建run_state, 和IntraProcessRendezvousRunState run_state(args.step_id, &devices_);run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());CancellationManager step_cancellation_manager;args.call_frame = &call_frame;// 3.2 創(chuàng)建ExecutorBarrier,它是一個(gè)執(zhí)行完成的計(jì)數(shù)器。同時(shí)注冊(cè)執(zhí)行完成的監(jiān)聽(tīng)事件executors_done.Notify()const size_t num_executors = executors_and_keys->items.size();ExecutorBarrier* barrier = new ExecutorBarrier(num_executors, run_state.rendez, [&run_state](const Status& ret) {{mutex_lock l(run_state.mu_);run_state.status.Update(ret);}// 所有線程池計(jì)算完畢后,會(huì)觸發(fā)Notify,發(fā)送消息。run_state.executors_done.Notify();});args.rendezvous = run_state.rendez;args.cancellation_manager = &step_cancellation_manager;args.session_state = &session_state_;args.tensor_store = &run_state.tensor_store;args.step_container = &run_state.step_container;args.sync_on_finish = sync_on_finish_;// 3.3 創(chuàng)建executor的運(yùn)行器RunnerExecutor::Args::Runner default_runner = [this,pool](Executor::Args::Closure c) {SchedClosure(pool, std::move(c));};// 3.4 依次啟動(dòng)所有executor,開(kāi)始運(yùn)行for (const auto& item : executors_and_keys->items) {item.executor->RunAsync(args, barrier->Get());}// 3.5 阻塞,收到所有executor執(zhí)行完畢的通知WaitForNotification(&run_state, &step_cancellation_manager, operation_timeout_in_ms_);// 4 接收?qǐng)?zhí)行器執(zhí)行完畢的輸出值if (outputs) {// 4.1 從RetVal節(jié)點(diǎn)中得到輸出值sorted_outputsstd::vector<Tensor> sorted_outputs;const Status s = call_frame.ConsumeRetvals(&sorted_outputs);// 4.2 處理原始輸出sorted_outputs,保存到最終的輸出outputs中outputs->clear();outputs->reserve(sorted_outputs.size());for (int i = 0; i < output_names.size(); ++i) {const string& output_name = output_names[i];if (first_indices.empty() || first_indices[i] == i) {outputs->emplace_back(std::move(sorted_outputs[executors_and_keys->output_name_to_index[output_name]]));} else {outputs->push_back((*outputs)[first_indices[i]]);}}}// 5 保存輸出的tensorrun_state.tensor_store.SaveTensors(output_names, &session_state_));return Status::OK(); }主要步驟如下
6 總結(jié)
本文主要講解了TensorFlow的本地運(yùn)行時(shí),牢牢抓住session和graph兩個(gè)對(duì)象即可。Session的生命周期前文講解過(guò),本文主要講解了Graph的生命周期,包括構(gòu)建與傳遞,剪枝,分裂和執(zhí)行。Graph是TensorFlow的核心對(duì)象,很多問(wèn)題都是圍繞它來(lái)進(jìn)行的,理解它有一定難度,但十分關(guān)鍵。文章中可能有一些理解不正確的地方,希望小伙伴們不吝賜教。
總結(jié)
以上是生活随笔為你收集整理的Tensorflow源码解析6 -- TensorFlow本地运行时的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 资产的特征 资产有哪些特征
- 下一篇: excel怎么把单元格内某个字标红,其他