当 Android 开发者遇见 TensorFlow
前言
當寫下這篇文章的時候,其實我連TensorFlow怎么用都不會,所以這篇文章你們就當我放屁好了。我是一個Android開發者,且我不會python(別鄙視我),所以取名為《當Android開發者遇見TensorFlow》。文章并沒有什么實質性內容,僅僅是為了敲開機器學習的大門。
Java調用TensorFlow
前面說了,本寶寶是一只不會python的寶寶,所以這篇文章不會涉及到任何python相關的內容,所以Java自然而然地成為了我的首選語言。
Google開源的TensorFlow的核心代碼是C++寫的,因此Java自然而然的可以使用他,只是中間多了一層JNI。加上平時我對Gradle的接觸程度,選擇Gradle做構建工具,而不是maven。
這里不得不再贊一下Intellij Idea,今天突然發現2017.1版本的Intellij Idea已經能夠自動將maven依賴轉換為gradle依賴了,我們直接復制maven依賴到gradle中,它就會自動轉換為gradle依賴,再也不用我們手動轉換。見證奇跡的時候到了
maven 依賴
| 12345 | <dependency> <groupId>org.tensorflow</groupId> <artifactId>tensorflow</artifactId> <version>1.1.0</version></dependency> |
轉換后的gradle依賴為
| 123 | dependencies { compile 'org.tensorflow:tensorflow:1.1.0'} |
為了運行java程序,應用application插件,并指定mainClassName,對應的類在后文創建
| 1234567 | apply plugin: 'application'apply plugin: 'idea'mainClassName = "com.lizhangqu.application.Main"sourceCompatibility = 1.8dependencies { compile 'org.tensorflow:tensorflow:1.1.0'} |
來點有難度的,參考LabelImage.java,我們來做一個圖片識別工具
首先下載訓練好的模型?inception5h.zip,將模型內容解壓到src/main/resources/model目錄,如圖
然后隨便下載一張圖作為待識別的圖,這里使用這張圖,好大一座山
將其放到src/main/resources/pic目錄,如圖
然后新建一個Main類,拷貝一波LabelImage.java代碼,修改其main函數為
| 12345678910111213141516171819 | public static void main(String[] args) { //模型下載 //https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip String modelPath = Main.class.getClassLoader().getResource("model").getPath(); String picPath = Main.class.getClassLoader().getResource("pic").getPath(); byte[] graphDef = readAllBytesOrExit(Paths.get(modelPath, "tensorflow_inception_graph.pb"));List<String> labels =readAllLinesOrExit(Paths.get(modelPath, "imagenet_comp_graph_label_strings.txt")); byte[] imageBytes = readAllBytesOrExit(Paths.get(picPath, "moutain.jpg")); try (Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes)) { float[] labelProbabilities = executeInceptionGraph(graphDef, image); int bestLabelIdx = maxIndex(labelProbabilities);System.out.println( String.format( "BEST MATCH: %s (%.2f%% likely)",labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));}} |
做的修改很簡單,將參數從外部傳入,修改為了從resources目錄讀取
Main類完整代碼如下
| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173 | package com.lizhangqu.application;import java.io.IOException;import java.nio.charset.Charset;import java.nio.file.Files;import java.nio.file.Path;import java.nio.file.Paths;import java.util.Arrays;import java.util.List;import org.tensorflow.DataType;import org.tensorflow.Graph;import org.tensorflow.Output;import org.tensorflow.Session;import org.tensorflow.Tensor;public class Main { public static void main(String[] args) { //模型下載 //https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip String modelPath = Main.class.getClassLoader().getResource("model").getPath(); String picPath = Main.class.getClassLoader().getResource("pic").getPath(); byte[] graphDef = readAllBytesOrExit(Paths.get(modelPath, "tensorflow_inception_graph.pb"));List<String> labels =readAllLinesOrExit(Paths.get(modelPath, "imagenet_comp_graph_label_strings.txt")); byte[] imageBytes = readAllBytesOrExit(Paths.get(picPath, "moutain.jpg")); try (Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes)) { float[] labelProbabilities = executeInceptionGraph(graphDef, image); int bestLabelIdx = maxIndex(labelProbabilities);System.out.println( String.format( "BEST MATCH: %s (%.2f%% likely)",labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));}} private static Tensor constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) { try (Graph g = new Graph()) {GraphBuilder b = new GraphBuilder(g); // Some constants specific to the pre-trained model at: // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip // // - The model was trained with images scaled to 224x224 pixels. // - The colors, represented as R, G, B in 1-byte each were converted to // float using (value - Mean)/Scale. final int H = 224; final int W = 224; final float mean = 117f; final float scale = 1f; // Since the graph is being constructed once per execution here, we can use a constant for the // input image. If the graph were to be re-used for multiple input images, a placeholder would // have been more appropriate. final Output input = b.constant("input", imageBytes); final Output output =b.div(b.sub(b.resizeBilinear(b.expandDims(b.cast(b.decodeJpeg(input, 3), DataType.FLOAT),b.constant("make_batch", 0)),b.constant("size", new int[]{H, W})),b.constant("mean", mean)),b.constant("scale", scale)); try (Session s = new Session(g)) { return s.runner().fetch(output.op().name()).run().get(0);}}} private static float[] executeInceptionGraph(byte[] graphDef, Tensor image) { try (Graph g = new Graph()) {g.importGraphDef(graphDef); try (Session s = new Session(g);Tensor result = s.runner().feed("input", image).fetch("output").run().get(0)) { final long[] rshape = result.shape(); if (result.numDimensions() != 2 || rshape[0] != 1) { throw new RuntimeException( String.format( "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",Arrays.toString(rshape)));} int nlabels = (int) rshape[1]; return result.copyTo(new float[1][nlabels])[0];}}} private static int maxIndex(float[] probabilities) { int best = 0; for (int i = 1; i < probabilities.length; ++i) { if (probabilities[i] > probabilities[best]) {best = i;}} return best;} private static byte[] readAllBytesOrExit(Path path) { try { return Files.readAllBytes(path);} catch (IOException e) {System.err.println("Failed to read [" + path + "]: " + e.getMessage());System.exit(1);} return null;} private static List<String> readAllLinesOrExit(Path path) { try { return Files.readAllLines(path, Charset.forName("UTF-8"));} catch (IOException e) {System.err.println("Failed to read [" + path + "]: " + e.getMessage());System.exit(0);} return null;} // In the fullness of time, equivalents of the methods of this class should be auto-generated from // the OpDefs linked into libtensorflow_jni.so. That would match what is done in other languages // like Python, C++ and Go. static class GraphBuilder {GraphBuilder(Graph g) { this.g = g;}Output div(Output x, Output y) { return binaryOp("Div", x, y);}Output sub(Output x, Output y) { return binaryOp("Sub", x, y);}Output resizeBilinear(Output images, Output size) { return binaryOp("ResizeBilinear", images, size);}Output expandDims(Output input, Output dim) { return binaryOp("ExpandDims", input, dim);}Output cast(Output value, DataType dtype) { return g.opBuilder("Cast", "Cast").addInput(value).setAttr("DstT", dtype).build().output(0);}Output decodeJpeg(Output contents, long channels) { return g.opBuilder("DecodeJpeg", "DecodeJpeg").addInput(contents).setAttr("channels", channels).build().output(0);}Output constant(String name, Object value) { try (Tensor t = Tensor.create(value)) { return g.opBuilder("Const", name).setAttr("dtype", t.dataType()).setAttr("value", t).build().output(0);}} private Output binaryOp(String type, Output in1, Output in2) { return g.opBuilder(type, type).addInput(in1).addInput(in2).build().output(0);} private Graph g;}} |
跑一波,命令行執行
| 1 | ./gradlew run |
看下輸出內容
輸出為
| 1 | BEST MATCH: alp (58.23% likely) |
我擦,alp是什么鬼,查下英文字典
| 12 | alp 英 [?lp] 美 [?lp] n. 高山 |
恩,沒錯,58.23%的概率這張圖是大山。沒錯,這張圖就是大山。當然識別的圖的準確率跟這個訓練好的模型直接相關,模型越屌,準確率就越高。具體代碼什么意思你也別問我,問我我也不知道,文章開頭已經說過了,寫下這篇文章的時候,我還不會用TensorFlow。
Android調用TensorFlow
Java能調用,Android自然在一定程度上也能調用。
引入依賴
| 1 | compile 'org.tensorflow:tensorflow-android:1.2.0-rc0' |
將minSdkVersion設成19,因為用到了高Api,當然如果你想設成14,自行將高Api的代碼刪了,主要是android.os.Trace類,去除了不影響正常使用
| 12345 | android { defaultConfig { minSdkVersion 19}} |
還是一樣的訓練模型,這次把他們扔到assets/model下,待識別的圖片放在assets/pic下,如圖
不過這次我們待識別的圖換了,換成了一個大蘋果
還是拷貝點代碼,到android/demo下,拷貝Classifier.java和TensorFlowImageClassifier.java兩個類,代碼就不貼了。
然后參考下ClassifierActivity.java的代碼,將assets/pic/apple.jpg進行識別
| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 | public class MainActivity extends AppCompatActivity { private TextView result; private Button btn; private Classifier classifier; private static final int INPUT_SIZE = 224; private static final int IMAGE_MEAN = 117; private static final float IMAGE_STD = 1; private static final String INPUT_NAME = "input"; private static final String OUTPUT_NAME = "output"; private static final String MODEL_FILE = "file:///android_asset/model/tensorflow_inception_graph.pb"; private static final String LABEL_FILE = "file:///android_asset/model/imagenet_comp_graph_label_strings.txt";@Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState);setContentView(R.layout.activity_main);btn = (Button) findViewById(R.id.btn);result = (TextView) findViewById(R.id.result);classifier = TensorFlowImageClassifier.create(getAssets(),MODEL_FILE,LABEL_FILE,INPUT_SIZE,IMAGE_MEAN,IMAGE_STD,INPUT_NAME,OUTPUT_NAME);btn.setOnClickListener(new View.OnClickListener() {@Override public void onClick(View v) { new Thread(new Runnable() {@Override public void run() { try {Bitmap croppedBitmap = getBitmap(getApplicationContext(), "pic/apple.jpg", INPUT_SIZE); final List<Classifier.Recognition> results = classifier.recognizeImage(croppedBitmap); new Handler(Looper.getMainLooper()).post(new Runnable() {@Override public void run() {result.setText("results:" + results);}});} catch (IOException e) {e.printStackTrace();}}}).start();}});} private static Bitmap getBitmap(Context context, String path, int size) throws IOException {Bitmap bitmap = null;InputStream inputStream = null;inputStream = context.getAssets().open(path);bitmap = BitmapFactory.decodeStream(inputStream);inputStream.close(); int width = bitmap.getWidth(); int height = bitmap.getHeight(); float scaleWidth = ((float) size) / width; float scaleHeight = ((float) size) / height;Matrix matrix = new Matrix();matrix.postScale(scaleWidth, scaleHeight); return Bitmap.createBitmap(bitmap, 0, 0, width, height, matrix, true);}} |
識別結果如下:
識別率最高的是granny smith,什么意思呢,查一下發現是“青蘋果”,哭笑不得,這是一個紅蘋果,這么小的訓練模型,也不期待識別率有多高了。
NDK交叉編譯TensorFlow
上面我們用了org.tensorflow:tensorflow-android:1.2.0-rc0這個庫,還是得掌握下它的由來,下面我們就編譯他。
tensorflow使用bazel構建,且依賴一些python庫,因此先安裝它們
| 12345 | brew install bazel brew install swig brew install python sudo easy_install pip sudo pip install six numpy wheel |
如果后面報各種各樣的環境缺失,請自行Google并補齊環境。
clone TensorFlow 代碼
| 1 | git clone --recurse-submodules https://github.com/tensorflow/tensorflow.git |
修改TensorFlow項目根下的WROKSPACE文件
將以下代碼反注釋
| 123456789101112131415161718 | # Uncomment and update the paths in these entries to build the Android demo.android_sdk_repository(name = "androidsdk",api_level = 23,# Ensure that you have the build_tools_version below installed in the# SDK manager as it updates periodically.build_tools_version = "25.0.2",# Replace with path to Android SDK on your systempath = "/Users/lizhangqu/AndroidSDK",)## Android NDK r12b is recommended (higher may cause issues with Bazel)android_ndk_repository( name="androidndk", path="/Users/lizhangqu/AndroidNDK/android-ndk-r12b",# This needs to be 14 or higher to compile TensorFlow.# Note that the NDK version is not the API level. api_level=14) |
然后修改android_sdk_repository中的path為自己電腦中的android sdk目錄,修改android_ndk_repository中的path為自己電腦的android ndk目錄。
值得注意的是,ndk的版本,官方建議使用r12b版本,事實證明,我用android sdk下的ndk-bundle是編譯不過去的,所以還是老老實實用r12b,下載地址android-ndk-r12b-darwin-x86_64.zip
編譯C++部分代碼
| 1234 | bazel build -c opt //tensorflow/contrib/android:libtensorflow_inference.so \--crosstool_top=//external:android/crosstool \ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ --cpu=armeabi-v7a |
如果你需要構建其他cpu結構的so,請自行修改armeabi-v7a為對應的值,比如修改為x86_64
構建好的so位于 bazel-bin/tensorflow/contrib/android/libtensorflow_inference.so,如圖所示
將libtensorflow_inference.so拷貝出來備份起來,因為下一步構建java代碼時,此文件會被刪除。
編譯java部分代碼
| 1 | bazel build //tensorflow/contrib/android:android_tensorflow_inference_java |
編譯好的jar位于 bazel-bin/tensorflow/contrib/android/libandroid_tensorflow_inference_java.jar,如圖所示
然后將libandroid_tensorflow_inference_java.jar和libtensorflow_inference.so結合起來,發布到maven,就是我們依賴的org.tensorflow:tensorflow-android:1.2.0-rc0了。
編譯PC上的Java版TensorFlow
不多說,和NDK交叉編譯差不多,編譯腳本
| 1234 | ./configurebazel build --config opt \//tensorflow/java:tensorflow \//tensorflow/java:libtensorflow_jni |
編譯產物位于bazel-bin/tensorflow/java,該目錄下有
- libtensorflow.jar文件
- libtensorflow_jni.so(linux)或libtensorflow_jni.dylib(mac)或tensorflow_jni.dll(windows,注:mac無法編譯出dll)文件,
如圖所示
編譯時依賴,請添加libtensorflow.jar
| 1 | javac - bazel-bin/tensorflow/java/libtensorflow.jar ... |
運行期依賴,請添加libtensorflow.jar和libtensorflow_jni的路徑
| 123 | java -cp bazel-bin/tensorflow/java/libtensorflow.jar \-Djava.library.path=bazel-bin/tensorflow/java \... |
總結
當然一般情況下,我們沒有必要自己去編譯TensorFlow,只需要使用編譯好的現成庫即可。
寫了這么多,可是寶寶還是不會TensorFlow~
http://fucknmb.com/2017/06/02/%E5%BD%93Android%E5%BC%80%E5%8F%91%E8%80%85%E9%81%87%E8%A7%81TensorFlow/總結
以上是生活随笔為你收集整理的当 Android 开发者遇见 TensorFlow的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Android Gradle Plugi
- 下一篇: 深度理解Android InstantR