Tensorflow移动端之如何将自己训练的MNIST模型加载到Android手机上
本篇文章主要依托于官方demo,在官網demo上進行修改來體現如何在一個常規的app上加入深度學習的模型。因為對于在app中加入對應的模型也只是將app搜集的數據導入模型并進行處理,處理完之后將結果返回給app并進行后面的操作。其中只有處理的過程會涉及tensorflow,而本文主要介紹tensorflow處理的過程。所以需要依附于具體的app。
一、環境準備
? ? ? 要想在安卓手機上運行首先需要在app上有對應的tensorflow環境。具體可以看上一篇博客?tensorflow安裝環境?。對于下載好的代碼可以采用android Studio編譯的方式進行編譯。直到比編譯通過為止。
二、模型準備。
? ? ?在訓練Mnist模型的過程中增加輸入節點的名字以及將模型存為.pb文件。 其中對節點輸入名字主要是為了在調用時可以通過參數的名字指定需要傳入和輸出的節點張量。另外需要注意模型的輸入過程找中所有用占位符定義的變量都是需要定義對應的變量名字。因為所有占位符的變量均是在feed的時候傳入的值,如果不定義名字無法在使用時為其傳入值。此時模型的調用會報缺少東西。
? ?例如:我在構造模型時對keep_prod定義了一個32位的float型占位符。但是沒有對其命名。在調用pb文件時報缺少一個32位float型數字(由于調試過程沒有截圖就不呈現具體的報錯內容了)。
? ?構建pb文件:構建pb文件的過程主要涉及代碼,直接以代碼說明。原始的MNINST訓練代碼見我前面的博客(MNIST原始代碼),具體更改部分如下:
將上述四副截圖分別按1,2,3,4的順序排列起來解釋如下:
1:主要定義輸入,由于我們輸入自己的任意的一個圖片打開之后是一個矩陣的形式,因此直接以一個28x28x1的tensor作為輸入,不以784維的一維向量作為輸入。采用這種輸入方式對于一個圖片只需要將其變為28x28大小并將其變為灰度就可以了。
2. 定義了softmax和output,其中一個輸出的是各個值的概率,一個是最后的值。可以根據需求自己調用。
3. 由于輸入是28x28x1的張量,所以對模型輸入的形式進行了修改。
4. 將整個結果保存成pb文件。
三、模型調用
? 首先需要把pb文件放在assets下面,并新建一個txt文件,里面從0-9表示標簽。然后增加一個類(主要是構建MNIST的模型),具體如下:
?
package org.tensorflow.demo;import android.content.res.AssetManager; import android.graphics.Bitmap; import android.os.Trace; import android.util.Log;import org.tensorflow.Operation; import org.tensorflow.contrib.android.TensorFlowInferenceInterface;import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.util.ArrayList; import java.util.Comparator; import java.util.List; import java.util.PriorityQueue; import java.util.Vector;/** A classifier specialized to label images using TensorFlow. */ public class TensorFlowMnistClassifier implements Classifier {private static final String TAG = "TensorFlowImageClassifier";// Only return this many results with at least this confidence.private static final int MAX_RESULTS = 10;private static final float THRESHOLD = 0.1f;// Config values.private String inputName;private String outputName;private String keep_pro;private int inputSize;//private int numClass;// Pre-allocated buffers.private Vector<String> labels = new Vector<String>();private int[] intValues;private float[] floatValues;private float[] floatKeep;private float[] outputs;private String[] outputNames;private boolean logStats = false;private TensorFlowInferenceInterface inferenceInterface;private TensorFlowMnistClassifier() {}/*** Initializes a native TensorFlow session for classifying images.** @param assetManager The asset manager to be used to load assets.* @param modelFilename The filepath of the model GraphDef protocol buffer.* @param labelFilename The filepath of label file for classes.* @param inputSize The input size. A square image of inputSize x inputSize is assumed.* @param inputName The label of the image input node.* @param outputName The label of the output node.* @throws IOException*/public static Classifier create(AssetManager assetManager,String modelFilename,String labelFilename,int inputSize,String inputName,String outputName,int numClass) {TensorFlowMnistClassifier c = new TensorFlowMnistClassifier();c.inputName = inputName;c.outputName = outputName;// Read the label names into memory.// TODO(andrewharp): make this handle non-assets.String actualFilename = labelFilename.split("file:///android_asset/")[1];Log.i(TAG, "Reading labels from: " + actualFilename);BufferedReader br = null;try {br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));String line;while ((line = br.readLine()) != null) {c.labels.add(line);}br.close();} catch (IOException e) {throw new RuntimeException("Problem reading label file!" , e);}c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);c.inputSize = inputSize;//c.numClass = numClass;c.keep_pro = "keep_prob";// Pre-allocate buffers.c.outputNames = new String[] {outputName};c.intValues = new int[inputSize * inputSize];c.floatValues = new float[inputSize * inputSize];c.floatKeep = new float[1];c.outputs = new float[numClass];return c;}@Overridepublic List<Recognition> recognizeImage(final Bitmap bitmap) {// Log this method so that it can be analyzed with systrace.Trace.beginSection("recognizeImage");Trace.beginSection("preprocessBitmap");// Preprocess the image data from 0-255 int to normalized float based// on the provided parameters.bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());for (int i = 0; i < intValues.length; ++i) {final int val = intValues[i];//對輸入的圖片進行灰化處理final int r = (val >> 16) & 0xff;final int g = (val >> 8) & 0xff;final int b = val & 0xff;floatValues[i]=(float) (0.3 * r + 0.59 * g + 0.11 * b);}Trace.endSection();floatKeep[0] = (float)1.0;// Copy the input data into TensorFlow.向圖中輸入數據Trace.beginSection("feed");inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 1);//inferenceInterface.feed(keep_pro, floatKeep,1);Trace.endSection();// Run the inference call.運行出需要的結果Trace.beginSection("run");inferenceInterface.run(outputNames, logStats);Trace.endSection();// Copy the output Tensor back into the output array.將結果拿出來并進行存儲Trace.beginSection("fetch");inferenceInterface.fetch(outputName, outputs);Trace.endSection();// Find the best classifications.PriorityQueue<Recognition> pq =new PriorityQueue<Recognition>(10,new Comparator<Recognition>() {@Overridepublic int compare(Recognition lhs, Recognition rhs) {// Intentionally reversed to put high confidence at the head of the queue.return Float.compare(rhs.getConfidence(), lhs.getConfidence());}});for (int i = 0; i < outputs.length; ++i) {if (outputs[i] > THRESHOLD) {pq.add(new Recognition("" + i, labels.size() > i ? labels.get(i) : "unknown", outputs[i], null));}}final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();int recognitionsSize = Math.min(pq.size(), MAX_RESULTS);for (int i = 0; i < recognitionsSize; ++i) {recognitions.add(pq.poll());}Trace.endSection(); // "recognizeImage"return recognitions;}@Overridepublic void enableStatLogging(boolean logStats) {this.logStats = logStats;}@Overridepublic String getStatString() {return inferenceInterface.getStatString();}@Overridepublic void close() {inferenceInterface.close();} }對于如何調用這個類代碼如下:
?
package org.tensorflow.demo;import android.graphics.Bitmap; import android.graphics.Bitmap.Config; import android.graphics.Canvas; import android.graphics.Matrix; import android.graphics.Paint; import android.graphics.Typeface; import android.media.ImageReader.OnImageAvailableListener; import android.os.SystemClock; import android.util.Size; import android.util.TypedValue;import org.tensorflow.demo.OverlayView.DrawCallback; import org.tensorflow.demo.env.BorderedText; import org.tensorflow.demo.env.ImageUtils; import org.tensorflow.demo.env.Logger;import java.util.List; import java.util.Vector;public class MnistActivity extends CameraActivity implements OnImageAvailableListener {private static final Logger LOGGER = new Logger();protected static final boolean SAVE_PREVIEW_BITMAP = false;private ResultsView resultsView;private Bitmap rgbFrameBitmap = null;private Bitmap croppedBitmap = null;private Bitmap cropCopyBitmap = null;private long lastProcessingTimeMs;private static final int INPUT_SIZE = 28;private static final String INPUT_NAME = "input";private static final String OUTPUT_NAME = "softmax";private static final int NUM_CLASS = 10;private static final String MODEL_FILE = "file:///android_asset/mnist.pb";private static final String LABEL_FILE ="file:///android_asset/mnist.txt";private static final boolean MAINTAIN_ASPECT = true;private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480);private Integer sensorOrientation;private Classifier classifier;private Matrix frameToCropTransform;private Matrix cropToFrameTransform;private BorderedText borderedText;@Overrideprotected int getLayoutId() {return R.layout.camera_connection_fragment;}@Overrideprotected Size getDesiredPreviewFrameSize() {return DESIRED_PREVIEW_SIZE;}private static final float TEXT_SIZE_DIP = 10;@Overridepublic void onPreviewSizeChosen(final Size size, final int rotation) {final float textSizePx = TypedValue.applyDimension(TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());borderedText = new BorderedText(textSizePx);borderedText.setTypeface(Typeface.MONOSPACE);classifier =TensorFlowMnistClassifier.create(getAssets(),MODEL_FILE,LABEL_FILE,INPUT_SIZE,INPUT_NAME,OUTPUT_NAME,NUM_CLASS);previewWidth = size.getWidth();previewHeight = size.getHeight();sensorOrientation = rotation - getScreenOrientation();LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation);LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);croppedBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Config.ARGB_8888);frameToCropTransform = ImageUtils.getTransformationMatrix(previewWidth, previewHeight,INPUT_SIZE, INPUT_SIZE,sensorOrientation, MAINTAIN_ASPECT);cropToFrameTransform = new Matrix();frameToCropTransform.invert(cropToFrameTransform);addCallback(new DrawCallback() {@Overridepublic void drawCallback(final Canvas canvas) {renderDebug(canvas);}});}@Overrideprotected void processImage() {rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);final Canvas canvas = new Canvas(croppedBitmap);canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);// For examining the actual TF input.if (SAVE_PREVIEW_BITMAP) {ImageUtils.saveBitmap(croppedBitmap);}runInBackground(new Runnable() {@Overridepublic void run() {final long startTime = SystemClock.uptimeMillis();final List<Classifier.Recognition> results = classifier.recognizeImage(croppedBitmap);lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;LOGGER.i("Detect: %s", results);cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);if (resultsView == null) {resultsView = (ResultsView) findViewById(R.id.results);}resultsView.setResults(results);requestRender();readyForNextImage();}});}@Overridepublic void onSetDebug(boolean debug) {classifier.enableStatLogging(debug);}private void renderDebug(final Canvas canvas) {if (!isDebug()) {return;}final Bitmap copy = cropCopyBitmap;if (copy != null) {final Matrix matrix = new Matrix();final float scaleFactor = 2;matrix.postScale(scaleFactor, scaleFactor);matrix.postTranslate(canvas.getWidth() - copy.getWidth() * scaleFactor,canvas.getHeight() - copy.getHeight() * scaleFactor);canvas.drawBitmap(copy, matrix, new Paint());final Vector<String> lines = new Vector<String>();if (classifier != null) {String statString = classifier.getStatString();String[] statLines = statString.split("\n");for (String line : statLines) {lines.add(line);}}lines.add("Frame: " + previewWidth + "x" + previewHeight);lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight());lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight());lines.add("Rotation: " + sensorOrientation);lines.add("Inference time: " + lastProcessingTimeMs + "ms");borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines);}} }另外在AndroidMainfest.xml中添加
?
<activity android:name="org.tensorflow.demo.MnistActivity"android:screenOrientation="portrait"android:label="@string/activity_name_mnist"><intent-filter><action android:name="android.intent.action.MAIN" /><category android:name="android.intent.category.LAUNCHER" /><category android:name="android.intent.category.LEANBACK_LAUNCHER" /></intent-filter> </activity>最后運行整個工程,生成的apk中就包括TF Mnist的圖標以及對應的功能。
?
四、注意:
? 在上述代碼中主要更改的是模型文件,構造的時候對輸入數據的處理過程,以及輸出的數據。對于每一個模型來說,模型的構建以及運行都是一樣的,對于移植的時候主要是考慮運行的時候輸入的數據格式是不是和自己需求的一樣即可、輸出的是什么數據、什么形式。至于輸出數據之后的處理過程是根據具體的業務需求具體來實現的。
?
?
?
?
?
?
?
?
總結
以上是生活随笔為你收集整理的Tensorflow移动端之如何将自己训练的MNIST模型加载到Android手机上的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 使用机器学习来进行应用识别
- 下一篇: 关于Android 4.4(华为)调用系