Tensorflow C++使用ops::BatchMatMul实现特征批量乘法
生活随笔
收集整理的這篇文章主要介紹了
Tensorflow C++使用ops::BatchMatMul实现特征批量乘法
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
本例主要測試Tensorflow C++ API中的ops::BatchMatMul算子。
整體來說這個算子比較簡單。但是難在官網沒有例子。Tensorflow的單測也寫得不到位。
話不多說,上代碼。
代碼結構如下,
conanfile.txt
[requires]gtest/1.10.0glog/0.4.0protobuf/3.9.1eigen/3.4.0dataframe/1.20.0opencv/3.4.17boost/1.76.0abseil/20210324.0xtensor/0.23.10[generators]cmakeCMakeLists.txt
cmake_minimum_required(VERSION 3.3)project(test_math_ops)set(ENV{PKG_CONFIG_PATH} "$ENV{PKG_CONFIG_PATH}:/usr/local/lib/pkgconfig/")set(CMAKE_CXX_STANDARD 17) add_definitions(-g)include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake) conan_basic_setup()find_package(TensorflowCC REQUIRED) find_package(PkgConfig REQUIRED) pkg_search_module(PKG_PARQUET REQUIRED IMPORTED_TARGET parquet) pkg_search_module(PKG_ARROW REQUIRED IMPORTED_TARGET arrow) pkg_search_module(PKG_ARROW_COMPUTE REQUIRED IMPORTED_TARGET arrow-compute) pkg_search_module(PKG_ARROW_CSV REQUIRED IMPORTED_TARGET arrow-csv) pkg_search_module(PKG_ARROW_DATASET REQUIRED IMPORTED_TARGET arrow-dataset) pkg_search_module(PKG_ARROW_FS REQUIRED IMPORTED_TARGET arrow-filesystem) pkg_search_module(PKG_ARROW_JSON REQUIRED IMPORTED_TARGET arrow-json)set(ARROW_INCLUDE_DIRS ${PKG_PARQUET_INCLUDE_DIRS} ${PKG_ARROW_INCLUDE_DIRS} ${PKG_ARROW_COMPUTE_INCLUDE_DIRS} ${PKG_ARROW_CSV_INCLUDE_DIRS} ${PKG_ARROW_DATASET_INCLUDE_DIRS} ${PKG_ARROW_FS_INCLUDE_DIRS} ${PKG_ARROW_JSON_INCLUDE_DIRS})set(INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../include ${ARROW_INCLUDE_DIRS})set(ARROW_LIBS PkgConfig::PKG_PARQUET PkgConfig::PKG_ARROW PkgConfig::PKG_ARROW_COMPUTE PkgConfig::PKG_ARROW_CSV PkgConfig::PKG_ARROW_DATASET PkgConfig::PKG_ARROW_FS PkgConfig::PKG_ARROW_JSON)include_directories(${INCLUDE_DIRS})file( GLOB test_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/tensor_testutil.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/queue_runner.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/coordinator.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/status.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/death_handler/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/df/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/arr_/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/img_util/impl/*.cpp)add_library(${PROJECT_NAME}_lib SHARED ${APP_SOURCES}) target_link_libraries(${PROJECT_NAME}_lib PUBLIC ${CONAN_LIBS} TensorflowCC::TensorflowCC ${ARROW_LIBS})foreach( test_file ${test_file_list} )file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${test_file})string(REPLACE ".cpp" "" file ${filename})add_executable(${file} ${test_file})target_link_libraries(${file} PUBLIC ${PROJECT_NAME}_lib) endforeach( test_file ${test_file_list})tf_math2_test.cpp
#include <string> #include <vector> #include <glog/logging.h> #include "death_handler/death_handler.h" #include "tf_/tensor_testutil.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/client/client_session.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/training/coordinator.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/error_codes.pb.h" #include "tensorflow/core/protobuf/queue_runner.pb.h" #include "tensorflow/core/public/session.h"using namespace tensorflow;int main(int argc, char** argv) {FLAGS_log_dir = "./";FLAGS_alsologtostderr = true;// 日志級別 INFO, WARNING, ERROR, FATAL 的值分別為0、1、2、3FLAGS_minloglevel = 0;Debug::DeathHandler dh;google::InitGoogleLogging("./logs.log");::testing::InitGoogleTest(&argc, argv);int ret = RUN_ALL_TESTS();return ret; }TEST(TfArthimaticTests, BatchMatMul) {// BatchMatMul 測試// Refers to: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul// 2 * 1 * 2// 2 * 2 * 3// = // 2 * 1 * 3Scope root = Scope::NewRootScope();auto left_ = test::AsTensor<int>({1, 2, 3, 4}, {2, 1, 2});/*** @brief Left param* {{1, 2},* {3, 4}}*/auto right_ = test::AsTensor<int>({1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8}, {2, 2, 3});/*** @brief Right param* {{{1, 2, 3}, {4, 1, 2}},* {{3, 4, 5}, {6, 7, 8}}}*//*** @brief Result* {{9, 4, 7},* {33, 40, 47}}*/auto batch_op = ops::BatchMatMul(root, left_, right_);ClientSession session(root);std::vector<Tensor> outputs;session.Run({batch_op.output}, &outputs);test::PrintTensorValue<int>(std::cout, outputs[0]);test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({9, 4, 7, 33, 40, 47}, {2, 1, 3})); }TEST(TfArthimaticTests, BatchMatMulAdjXY) {// BatchMatMul 測試// Refers to: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul// 2 * 1 * 2// 2 * 2 * 3// = // 2 * 1 * 3Scope root = Scope::NewRootScope();auto left_ = test::AsTensor<int>({1, 2, 3, 4}, {2, 2, 1});/*** @brief Left param* {{{1}, * {2}},* {{3},* {4}}}*/auto right_ = test::AsTensor<int>({1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8}, {2, 3, 2});/*** @brief Right param* {{{1, 2}, * {3, 4}, * {1, 2}}, ** {{3, 4}, * {5, 6}, * {7, 8}} * }*//*** @brief Result* {{5, 11, 5},* {25, 39, 53}}*/auto attrs = ops::BatchMatMul::AdjX(true).AdjY(true);auto batch_op = ops::BatchMatMul(root, left_, right_, attrs);ClientSession session(root);std::vector<Tensor> outputs;session.Run({batch_op.output}, &outputs);test::PrintTensorValue<int>(std::cout, outputs[0]);test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({5, 11, 5, 25, 39, 53}, {2, 1, 3})); }程序輸出如下,代表兩個算子均測試通過。
總結
以上是生活随笔為你收集整理的Tensorflow C++使用ops::BatchMatMul实现特征批量乘法的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: STM32驱动3.97寸TFT液晶触摸屏
- 下一篇: RF PA