当 Android 开发者遇见 TensorFlow
前言
當(dāng)寫(xiě)下這篇文章的時(shí)候,其實(shí)我連TensorFlow怎么用都不會(huì),所以這篇文章你們就當(dāng)我放屁好了。我是一個(gè)Android開(kāi)發(fā)者,且我不會(huì)python(別鄙視我),所以取名為《當(dāng)Android開(kāi)發(fā)者遇見(jiàn)TensorFlow》。文章并沒(méi)有什么實(shí)質(zhì)性內(nèi)容,僅僅是為了敲開(kāi)機(jī)器學(xué)習(xí)的大門(mén)。
Java調(diào)用TensorFlow
前面說(shuō)了,本寶寶是一只不會(huì)python的寶寶,所以這篇文章不會(huì)涉及到任何python相關(guān)的內(nèi)容,所以Java自然而然地成為了我的首選語(yǔ)言。
Google開(kāi)源的TensorFlow的核心代碼是C++寫(xiě)的,因此Java自然而然的可以使用他,只是中間多了一層JNI。加上平時(shí)我對(duì)Gradle的接觸程度,選擇Gradle做構(gòu)建工具,而不是maven。
這里不得不再贊一下Intellij Idea,今天突然發(fā)現(xiàn)2017.1版本的Intellij Idea已經(jīng)能夠自動(dòng)將maven依賴轉(zhuǎn)換為gradle依賴了,我們直接復(fù)制maven依賴到gradle中,它就會(huì)自動(dòng)轉(zhuǎn)換為gradle依賴,再也不用我們手動(dòng)轉(zhuǎn)換。見(jiàn)證奇跡的時(shí)候到了
maven 依賴
| 12345 | <dependency> <groupId>org.tensorflow</groupId> <artifactId>tensorflow</artifactId> <version>1.1.0</version></dependency> |
轉(zhuǎn)換后的gradle依賴為
| 123 | dependencies { compile 'org.tensorflow:tensorflow:1.1.0'} |
為了運(yùn)行java程序,應(yīng)用application插件,并指定mainClassName,對(duì)應(yīng)的類在后文創(chuàng)建
| 1234567 | apply plugin: 'application'apply plugin: 'idea'mainClassName = "com.lizhangqu.application.Main"sourceCompatibility = 1.8dependencies { compile 'org.tensorflow:tensorflow:1.1.0'} |
來(lái)點(diǎn)有難度的,參考LabelImage.java,我們來(lái)做一個(gè)圖片識(shí)別工具
首先下載訓(xùn)練好的模型?inception5h.zip,將模型內(nèi)容解壓到src/main/resources/model目錄,如圖
然后隨便下載一張圖作為待識(shí)別的圖,這里使用這張圖,好大一座山
將其放到src/main/resources/pic目錄,如圖
然后新建一個(gè)Main類,拷貝一波LabelImage.java代碼,修改其main函數(shù)為
| 12345678910111213141516171819 | public static void main(String[] args) { //模型下載 //https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip String modelPath = Main.class.getClassLoader().getResource("model").getPath(); String picPath = Main.class.getClassLoader().getResource("pic").getPath(); byte[] graphDef = readAllBytesOrExit(Paths.get(modelPath, "tensorflow_inception_graph.pb"));List<String> labels =readAllLinesOrExit(Paths.get(modelPath, "imagenet_comp_graph_label_strings.txt")); byte[] imageBytes = readAllBytesOrExit(Paths.get(picPath, "moutain.jpg")); try (Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes)) { float[] labelProbabilities = executeInceptionGraph(graphDef, image); int bestLabelIdx = maxIndex(labelProbabilities);System.out.println( String.format( "BEST MATCH: %s (%.2f%% likely)",labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));}} |
做的修改很簡(jiǎn)單,將參數(shù)從外部傳入,修改為了從resources目錄讀取
Main類完整代碼如下
| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173 | package com.lizhangqu.application;import java.io.IOException;import java.nio.charset.Charset;import java.nio.file.Files;import java.nio.file.Path;import java.nio.file.Paths;import java.util.Arrays;import java.util.List;import org.tensorflow.DataType;import org.tensorflow.Graph;import org.tensorflow.Output;import org.tensorflow.Session;import org.tensorflow.Tensor;public class Main { public static void main(String[] args) { //模型下載 //https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip String modelPath = Main.class.getClassLoader().getResource("model").getPath(); String picPath = Main.class.getClassLoader().getResource("pic").getPath(); byte[] graphDef = readAllBytesOrExit(Paths.get(modelPath, "tensorflow_inception_graph.pb"));List<String> labels =readAllLinesOrExit(Paths.get(modelPath, "imagenet_comp_graph_label_strings.txt")); byte[] imageBytes = readAllBytesOrExit(Paths.get(picPath, "moutain.jpg")); try (Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes)) { float[] labelProbabilities = executeInceptionGraph(graphDef, image); int bestLabelIdx = maxIndex(labelProbabilities);System.out.println( String.format( "BEST MATCH: %s (%.2f%% likely)",labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));}} private static Tensor constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) { try (Graph g = new Graph()) {GraphBuilder b = new GraphBuilder(g); // Some constants specific to the pre-trained model at: // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip // // - The model was trained with images scaled to 224x224 pixels. // - The colors, represented as R, G, B in 1-byte each were converted to // float using (value - Mean)/Scale. final int H = 224; final int W = 224; final float mean = 117f; final float scale = 1f; // Since the graph is being constructed once per execution here, we can use a constant for the // input image. If the graph were to be re-used for multiple input images, a placeholder would // have been more appropriate. final Output input = b.constant("input", imageBytes); final Output output =b.div(b.sub(b.resizeBilinear(b.expandDims(b.cast(b.decodeJpeg(input, 3), DataType.FLOAT),b.constant("make_batch", 0)),b.constant("size", new int[]{H, W})),b.constant("mean", mean)),b.constant("scale", scale)); try (Session s = new Session(g)) { return s.runner().fetch(output.op().name()).run().get(0);}}} private static float[] executeInceptionGraph(byte[] graphDef, Tensor image) { try (Graph g = new Graph()) {g.importGraphDef(graphDef); try (Session s = new Session(g);Tensor result = s.runner().feed("input", image).fetch("output").run().get(0)) { final long[] rshape = result.shape(); if (result.numDimensions() != 2 || rshape[0] != 1) { throw new RuntimeException( String.format( "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",Arrays.toString(rshape)));} int nlabels = (int) rshape[1]; return result.copyTo(new float[1][nlabels])[0];}}} private static int maxIndex(float[] probabilities) { int best = 0; for (int i = 1; i < probabilities.length; ++i) { if (probabilities[i] > probabilities[best]) {best = i;}} return best;} private static byte[] readAllBytesOrExit(Path path) { try { return Files.readAllBytes(path);} catch (IOException e) {System.err.println("Failed to read [" + path + "]: " + e.getMessage());System.exit(1);} return null;} private static List<String> readAllLinesOrExit(Path path) { try { return Files.readAllLines(path, Charset.forName("UTF-8"));} catch (IOException e) {System.err.println("Failed to read [" + path + "]: " + e.getMessage());System.exit(0);} return null;} // In the fullness of time, equivalents of the methods of this class should be auto-generated from // the OpDefs linked into libtensorflow_jni.so. That would match what is done in other languages // like Python, C++ and Go. static class GraphBuilder {GraphBuilder(Graph g) { this.g = g;}Output div(Output x, Output y) { return binaryOp("Div", x, y);}Output sub(Output x, Output y) { return binaryOp("Sub", x, y);}Output resizeBilinear(Output images, Output size) { return binaryOp("ResizeBilinear", images, size);}Output expandDims(Output input, Output dim) { return binaryOp("ExpandDims", input, dim);}Output cast(Output value, DataType dtype) { return g.opBuilder("Cast", "Cast").addInput(value).setAttr("DstT", dtype).build().output(0);}Output decodeJpeg(Output contents, long channels) { return g.opBuilder("DecodeJpeg", "DecodeJpeg").addInput(contents).setAttr("channels", channels).build().output(0);}Output constant(String name, Object value) { try (Tensor t = Tensor.create(value)) { return g.opBuilder("Const", name).setAttr("dtype", t.dataType()).setAttr("value", t).build().output(0);}} private Output binaryOp(String type, Output in1, Output in2) { return g.opBuilder(type, type).addInput(in1).addInput(in2).build().output(0);} private Graph g;}} |
跑一波,命令行執(zhí)行
| 1 | ./gradlew run |
看下輸出內(nèi)容
輸出為
| 1 | BEST MATCH: alp (58.23% likely) |
我擦,alp是什么鬼,查下英文字典
| 12 | alp 英 [?lp] 美 [?lp] n. 高山 |
恩,沒(méi)錯(cuò),58.23%的概率這張圖是大山。沒(méi)錯(cuò),這張圖就是大山。當(dāng)然識(shí)別的圖的準(zhǔn)確率跟這個(gè)訓(xùn)練好的模型直接相關(guān),模型越屌,準(zhǔn)確率就越高。具體代碼什么意思你也別問(wèn)我,問(wèn)我我也不知道,文章開(kāi)頭已經(jīng)說(shuō)過(guò)了,寫(xiě)下這篇文章的時(shí)候,我還不會(huì)用TensorFlow。
Android調(diào)用TensorFlow
Java能調(diào)用,Android自然在一定程度上也能調(diào)用。
引入依賴
| 1 | compile 'org.tensorflow:tensorflow-android:1.2.0-rc0' |
將minSdkVersion設(shè)成19,因?yàn)橛玫搅烁逜pi,當(dāng)然如果你想設(shè)成14,自行將高Api的代碼刪了,主要是android.os.Trace類,去除了不影響正常使用
| 12345 | android { defaultConfig { minSdkVersion 19}} |
還是一樣的訓(xùn)練模型,這次把他們?nèi)拥絘ssets/model下,待識(shí)別的圖片放在assets/pic下,如圖
不過(guò)這次我們待識(shí)別的圖換了,換成了一個(gè)大蘋(píng)果
還是拷貝點(diǎn)代碼,到android/demo下,拷貝Classifier.java和TensorFlowImageClassifier.java兩個(gè)類,代碼就不貼了。
然后參考下ClassifierActivity.java的代碼,將assets/pic/apple.jpg進(jìn)行識(shí)別
| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 | public class MainActivity extends AppCompatActivity { private TextView result; private Button btn; private Classifier classifier; private static final int INPUT_SIZE = 224; private static final int IMAGE_MEAN = 117; private static final float IMAGE_STD = 1; private static final String INPUT_NAME = "input"; private static final String OUTPUT_NAME = "output"; private static final String MODEL_FILE = "file:///android_asset/model/tensorflow_inception_graph.pb"; private static final String LABEL_FILE = "file:///android_asset/model/imagenet_comp_graph_label_strings.txt";@Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState);setContentView(R.layout.activity_main);btn = (Button) findViewById(R.id.btn);result = (TextView) findViewById(R.id.result);classifier = TensorFlowImageClassifier.create(getAssets(),MODEL_FILE,LABEL_FILE,INPUT_SIZE,IMAGE_MEAN,IMAGE_STD,INPUT_NAME,OUTPUT_NAME);btn.setOnClickListener(new View.OnClickListener() {@Override public void onClick(View v) { new Thread(new Runnable() {@Override public void run() { try {Bitmap croppedBitmap = getBitmap(getApplicationContext(), "pic/apple.jpg", INPUT_SIZE); final List<Classifier.Recognition> results = classifier.recognizeImage(croppedBitmap); new Handler(Looper.getMainLooper()).post(new Runnable() {@Override public void run() {result.setText("results:" + results);}});} catch (IOException e) {e.printStackTrace();}}}).start();}});} private static Bitmap getBitmap(Context context, String path, int size) throws IOException {Bitmap bitmap = null;InputStream inputStream = null;inputStream = context.getAssets().open(path);bitmap = BitmapFactory.decodeStream(inputStream);inputStream.close(); int width = bitmap.getWidth(); int height = bitmap.getHeight(); float scaleWidth = ((float) size) / width; float scaleHeight = ((float) size) / height;Matrix matrix = new Matrix();matrix.postScale(scaleWidth, scaleHeight); return Bitmap.createBitmap(bitmap, 0, 0, width, height, matrix, true);}} |
識(shí)別結(jié)果如下:
識(shí)別率最高的是granny smith,什么意思呢,查一下發(fā)現(xiàn)是“青蘋(píng)果”,哭笑不得,這是一個(gè)紅蘋(píng)果,這么小的訓(xùn)練模型,也不期待識(shí)別率有多高了。
NDK交叉編譯TensorFlow
上面我們用了org.tensorflow:tensorflow-android:1.2.0-rc0這個(gè)庫(kù),還是得掌握下它的由來(lái),下面我們就編譯他。
tensorflow使用bazel構(gòu)建,且依賴一些python庫(kù),因此先安裝它們
| 12345 | brew install bazel brew install swig brew install python sudo easy_install pip sudo pip install six numpy wheel |
如果后面報(bào)各種各樣的環(huán)境缺失,請(qǐng)自行Google并補(bǔ)齊環(huán)境。
clone TensorFlow 代碼
| 1 | git clone --recurse-submodules https://github.com/tensorflow/tensorflow.git |
修改TensorFlow項(xiàng)目根下的WROKSPACE文件
將以下代碼反注釋
| 123456789101112131415161718 | # Uncomment and update the paths in these entries to build the Android demo.android_sdk_repository(name = "androidsdk",api_level = 23,# Ensure that you have the build_tools_version below installed in the# SDK manager as it updates periodically.build_tools_version = "25.0.2",# Replace with path to Android SDK on your systempath = "/Users/lizhangqu/AndroidSDK",)## Android NDK r12b is recommended (higher may cause issues with Bazel)android_ndk_repository( name="androidndk", path="/Users/lizhangqu/AndroidNDK/android-ndk-r12b",# This needs to be 14 or higher to compile TensorFlow.# Note that the NDK version is not the API level. api_level=14) |
然后修改android_sdk_repository中的path為自己電腦中的android sdk目錄,修改android_ndk_repository中的path為自己電腦的android ndk目錄。
值得注意的是,ndk的版本,官方建議使用r12b版本,事實(shí)證明,我用android sdk下的ndk-bundle是編譯不過(guò)去的,所以還是老老實(shí)實(shí)用r12b,下載地址android-ndk-r12b-darwin-x86_64.zip
編譯C++部分代碼
| 1234 | bazel build -c opt //tensorflow/contrib/android:libtensorflow_inference.so \--crosstool_top=//external:android/crosstool \ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ --cpu=armeabi-v7a |
如果你需要構(gòu)建其他cpu結(jié)構(gòu)的so,請(qǐng)自行修改armeabi-v7a為對(duì)應(yīng)的值,比如修改為x86_64
構(gòu)建好的so位于 bazel-bin/tensorflow/contrib/android/libtensorflow_inference.so,如圖所示
將libtensorflow_inference.so拷貝出來(lái)備份起來(lái),因?yàn)橄乱徊綐?gòu)建java代碼時(shí),此文件會(huì)被刪除。
編譯java部分代碼
| 1 | bazel build //tensorflow/contrib/android:android_tensorflow_inference_java |
編譯好的jar位于 bazel-bin/tensorflow/contrib/android/libandroid_tensorflow_inference_java.jar,如圖所示
然后將libandroid_tensorflow_inference_java.jar和libtensorflow_inference.so結(jié)合起來(lái),發(fā)布到maven,就是我們依賴的org.tensorflow:tensorflow-android:1.2.0-rc0了。
編譯PC上的Java版TensorFlow
不多說(shuō),和NDK交叉編譯差不多,編譯腳本
| 1234 | ./configurebazel build --config opt \//tensorflow/java:tensorflow \//tensorflow/java:libtensorflow_jni |
編譯產(chǎn)物位于bazel-bin/tensorflow/java,該目錄下有
- libtensorflow.jar文件
- libtensorflow_jni.so(linux)或libtensorflow_jni.dylib(mac)或tensorflow_jni.dll(windows,注:mac無(wú)法編譯出dll)文件,
如圖所示
編譯時(shí)依賴,請(qǐng)?zhí)砑觢ibtensorflow.jar
| 1 | javac - bazel-bin/tensorflow/java/libtensorflow.jar ... |
運(yùn)行期依賴,請(qǐng)?zhí)砑觢ibtensorflow.jar和libtensorflow_jni的路徑
| 123 | java -cp bazel-bin/tensorflow/java/libtensorflow.jar \-Djava.library.path=bazel-bin/tensorflow/java \... |
總結(jié)
當(dāng)然一般情況下,我們沒(méi)有必要自己去編譯TensorFlow,只需要使用編譯好的現(xiàn)成庫(kù)即可。
寫(xiě)了這么多,可是寶寶還是不會(huì)TensorFlow~
http://fucknmb.com/2017/06/02/%E5%BD%93Android%E5%BC%80%E5%8F%91%E8%80%85%E9%81%87%E8%A7%81TensorFlow/總結(jié)
以上是生活随笔為你收集整理的当 Android 开发者遇见 TensorFlow的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: Android Gradle Plugi
- 下一篇: 深度理解Android InstantR