建立TensorFlow Lite for Android應用程式

說明如何使用Android Studio新建一個TensorFlow Lite for Android應用程式(APP),並在該APP上運作一個自行定義與訓練的模型。於此需用到兩項關鍵技術:

  • 使用Python程式語言撰寫基於TensorFlow框架的”機器學習”或”深度學習”程序,並將”訓練”與”驗證”後的成果轉換為TensorFlowLite模型檔(.tflite)
  • 使用Java程式語言撰寫Android APP

上述各別項目的技術與程式設計都是一本或多本書的內容,在一篇部落格文章中很難對此深入描述。故本文主要是說明如何建立APP;至於程式碼的部份則直接使用這篇” 基于Android搭建tensorflow lite,实现官网的Demo以及运行自定义tensorflow模型(二)”文章,用以建立一個”機器學習”或”深度學習” 的”Hello World”範例:MNIST手寫識別。Android APP的開發環境就參照前面文章進行安裝,而”機器學習”或”深度學習”的Python程式語言環境則請參考下面方式進行設定:

1. 下載並安裝Python的Anaconda整合發行版:https://www.anaconda.com/download/,以下都是以macOS作為開發的主機環境
2. 安裝完後,開啟終端機輸入以下命令以建立Python虛擬環境:conda create --name tensorflow python=3.6 anaconda
3. 完成後用以下命運進入Python虛擬環境:conda activate tensorflow
4. 進入Python虛擬環境環境後,用以下命令安裝TensorFlow:pip install tensorflow==1.8.0

這樣Python環境就算安裝完成。

之後執行文章中的” 一 在tensorflow 中生成tflite文件 ”(tflite_mnist_modle.py):

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Dec 18 15:10:24 2018

@author: tungyilin
"""

# https://blog.csdn.net/qq_22765745/article/details/80488012
# 一 在tensorflow 中生成tflite文件

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("mnist",one_hot=True)


# 定义批次大小
batch_size = 100
n_batch = mnist.train.num_examples

# 定义placeholder
x = tf.placeholder(tf.float32,[1,784],name='input_x')
y = tf.placeholder(tf.float32,[1,10],name='output_y')

# 定义 测试
x_test = tf.placeholder(tf.float32,[None,784],name='input_test_x')
y_test = tf.placeholder(tf.float32,[None,10],name='input_test_y')

# 创建一个简单的神经网络
W = tf.Variable(tf.zeros([784,10]),name="W")
b = tf.Variable(tf.zeros([1,10]),name="b")

prediction = tf.nn.softmax(tf.matmul(x,W)+b)



# 创建损失函数
train = tf.train.GradientDescentOptimizer(0.02).minimize(tf.reduce_mean(tf.square(y-prediction)))

# 名称转换
def canonical_name(x):
return x.name.split(":")[0]

# 计算准确率
test_prediction = tf.nn.softmax(tf.matmul(x_test,W)+b)
accuarcy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y_test,1),tf.argmax(test_prediction,1)),tf.float32))

init = tf.global_variables_initializer()
out = tf.identity(prediction, name="output")

with tf.Session() as sess:
sess.run(init)
for epoch in range(10):
for batch in range(n_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
for index in range(len(batch_xs)):
xs = batch_xs[index].reshape(1,784)
ys = batch_ys[index].reshape(1,10)
sess.run(train, feed_dict={x: xs, y: ys})

acc = sess.run(accuarcy,feed_dict={x_test:mnist.test.images,y_test:mnist.test.labels})
print("over"+str(acc))

frozen_tensors = [out]
out_tensors = [out]

frozen_graphdef = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, list(map(canonical_name, frozen_tensors)))
tflite_model = tf.contrib.lite.toco_convert(frozen_graphdef, [x], out_tensors)

open("writer_model.tflite", "wb").write(tflite_model)

文章中的這段程式並沒有最佳化,”訓練”的部份就算有GPU跑起也是很久。訓練好後產出TensorFlow Lite模型檔writer_model.tflite。

之後在Android Studio新建一個AndroidAPP,參考以下設定過程:

建立完成後請修改app/build.gradle:

apply plugin: 'com.android.application'

android {
compileSdkVersion 28
defaultConfig {
applicationId "com.tungyilin.tflite_mnist"
minSdkVersion 23
targetSdkVersion 28
versionCode 1
versionName "1.0"
testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
}
}
aaptOptions {
noCompress "tflite"
}
}

dependencies {
implementation fileTree(dir: 'libs', include: ['*.jar'])
implementation 'com.android.support:appcompat-v7:28.0.0'
implementation 'com.android.support.constraint:constraint-layout:1.1.3'
compile 'org.tensorflow:tensorflow-lite:+'
testImplementation 'junit:junit:4.12'
androidTestImplementation 'com.android.support.test:runner:1.0.2'
androidTestImplementation 'com.android.support.test.espresso:espresso-core:3.0.2'
}

在 android 區段中加入aaptOptions { noCompress “tflite” }這段,讓TensorFlow Lite模型檔放入APP中不會被壓縮,以避免載入錯誤的問題發生。在dependencies區段中加入 compile ‘org.tensorflow:tensorflow-lite:+’,就可以讓Android APP支援TensorFlow Lite,這部份詳細可參考這篇文章

完成後,就依據文章說明” 二 创建自己的分类器”建立WriterIdentify類別”(WriterIdentify.java):

package com.tungyilin.tflite_mnist;

// https://blog.csdn.net/qq_22765745/article/details/80488012
// 三 读取MNIST数据集中的数据

import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.net.Uri;
import android.os.Bundle;
import android.app.Fragment;
import android.util.Log;
import android.view.LayoutInflater;
import android.view.View;
import android.view.ViewGroup;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;

public class WriterFragment extends Fragment implements View.OnClickListener {

private Button btnStart;
private Button btnChange;
private TextView tvContent;
private ImageView ivNumber;

private Context context;
// 图片数据
private int[] imageIds;
private static int currentImageIds;
public WriterFragment() {

}
// TODO: Rename and change types and number of parameters
//public static WriterFragment newInstance(String param1, String param2) {
// WriterFragment fragment = new WriterFragment();
// return fragment;
//}
@Override
public void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
}
@Override
public View onCreateView(LayoutInflater inflater, ViewGroup container,
Bundle savedInstanceState) {
View view = inflater.inflate(R.layout.writer_fragment, container, false);
context = view.getContext();
init(view);
return view;
}
private void init(View view) {
btnStart = (Button) view.findViewById(R.id.btnStart);
tvContent = (TextView) view.findViewById(R.id.tvContent);
ivNumber = (ImageView) view.findViewById(R.id.ivNumber);
btnChange = (Button) view.findViewById(R.id.btnChange);
btnStart.setOnClickListener(this);
btnChange.setOnClickListener(this);
imageIds = new int[]{R.mipmap.mnist_0,R.mipmap.mnist_1,R.mipmap.mnist_2,
R.mipmap.mnist_3,R.mipmap.mnist_4,R.mipmap.mnist_5,
R.mipmap.mnist_6,R.mipmap.mnist_7,R.mipmap.mnist_8,
R.mipmap.mnist_9,R.mipmap.mnist_10,R.mipmap.mnist_11,
R.mipmap.mnist_12};
currentImageIds = 0;
ivNumber.setImageResource(imageIds[currentImageIds]);
}
@Override
public void onClick(View v) {
switch (v.getId()){
case R.id.btnStart:
WriterIdentify writerIdentify = WriterIdentify.newInstance(context);
BitmapFactory.Options bfoOptions = new BitmapFactory.Options();
bfoOptions.inScaled = false;
Bitmap bitmap = BitmapFactory.decodeResource(getResources(), imageIds[currentImageIds],bfoOptions);
writerIdentify.run(bitmap);
tvContent.setText("Result:" + writerIdentify.getResult());
break;
case R.id.btnChange:
currentImageIds = (++currentImageIds) % imageIds.length;
ivNumber.setImageResource(imageIds[currentImageIds]);
break;
}
}
}

在”三 读取MNIST数据集中的数据”說明中用Python程式產生手寫數字圖檔 (mnist_get_image.py):

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Dec 18 22:03:57 2018

@author: tungyilin
"""

# https://www.google.com/search?q=compile+%27org.tensorflow%3Atensorflow-lite%3A%2B%27&oq=compile+%27org.tensorflow%3Atensorflow-lite%3A%2B%27&aqs=chrome..69i57.281j0j4&sourceid=chrome&ie=UTF-8
# 三 读取MNIST数据集中的数据

import numpy as np
import struct

from PIL import Image
#import os

#data_file = 'MNIST_data/train-images.idx3-ubyte' # 需要修改的路径
data_file = 'mnist/train-images-idx3-ubyte'
# It's 47040016B, but we should set to 47040000B
data_file_size = 47040016
data_file_size = str(data_file_size - 16) + 'B'

data_buf = open(data_file, 'rb').read()

magic, numImages, numRows, numColumns = struct.unpack_from(
'>IIII', data_buf, 0)
datas = struct.unpack_from(
'>' + data_file_size, data_buf, struct.calcsize('>IIII'))
datas = np.array(datas).astype(np.uint8).reshape(
numImages, 1, numRows, numColumns)


datas_root = 'images/' # 需要修改的路径

for ii in range(100):
print(ii)
img = Image.fromarray(datas[ii, 0, 0:28, 0:28])
file_name = datas_root + 'mnist_' + str(ii) + '.png'
img.save(file_name)

在”四 在android中运行自定的分类器”說明中將手寫數字的圖檔加入AndroidAPP資源(res)的mipmap項目中。以及建立頁面布局(writer_fragment.xml):

<?xml version="1.0" encoding="utf-8"?>

<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context="com.tungyilin.tflite_mnist.WriterFragment"
>

<!-- TODO: Update blank fragment layout -->
<LinearLayout
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical"
android:gravity="center"
>
<ImageView
android:id="@+id/ivNumber"
android:layout_width="150dp"
android:layout_height="150dp"
/>
<TextView
android:layout_width="match_parent"
android:layout_height="30dp"
/>
<TextView
android:id="@+id/tvContent"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:textSize="24sp"
android:text="@string/hello_blank_fragment"
/>
<TextView
android:layout_width="match_parent"
android:layout_height="30dp"
/>
<Button
android:id="@+id/btnChange"
android:layout_width="300dp"
android:layout_height="wrap_content"
android:text="@string/btnChange"
/>
<Button
android:id="@+id/btnStart"
android:layout_width="300dp"
android:layout_height="wrap_content"
android:text="@string/btnClick"
/>
</LinearLayout>

</FrameLayout>

與WriterFragment類別(WriterFragment.java):

package com.tungyilin.tflite_mnist;

// https://blog.csdn.net/qq_22765745/article/details/80488012
// 二 创建自己的分类器

import android.content.Context;
import android.content.res.AssetFileDescriptor;
import android.graphics.Bitmap;
import android.util.Log;

import org.tensorflow.lite.Interpreter;

import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;

public class WriterIdentify {

// 运行生成的文件,形成分类器
private Interpreter tflite;
// 输出的结构
private float[][] labelProbArray = null;

public static WriterIdentify newInstance(Context context) {
WriterIdentify writerIdentify = new WriterIdentify(context);
return writerIdentify;
}

private WriterIdentify(Context context) {
try {
tflite = new Interpreter(loadModelFile(context));
} catch (Exception e) {

}
labelProbArray = new float[1][10];

}

public void run(Bitmap bitmap) {
tflite.run(convertBitmapToByteBuffer(bitmap), labelProbArray);
//convertBitmapToByteBuffer(bitmap,width,height);
}

// 返回输出的结果
public int getResult() {
int[] resultDict = new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
for (int i = 0; i < labelProbArray[0].length; i++) {
if (labelProbArray[0][i] == 1.0f) {
return resultDict[i];
}
}
return -1;
}

private ByteBuffer convertBitmapToByteBuffer(Bitmap bitmap) {
int width = bitmap.getWidth();
int height = bitmap.getHeight();
ByteBuffer tempData = ByteBuffer.allocateDirect(width * height * 4);

// 数组排列用nativeOrder
tempData.order(ByteOrder.nativeOrder());
int[] pixels = getPicturePixel(bitmap);
for (int i = 0; i < pixels.length; i++) {
byte[] bytes = float2byte((float)(pixels[i]));
for (int k = 0; k < bytes.length; k++) {
tempData.put(bytes[k]);
}
}
return tempData;
}
// 读取图片像素
private int[] getPicturePixel(Bitmap bitmap) {

int width = bitmap.getWidth();
int height = bitmap.getHeight();

// 保存所有的像素的数组,图片宽×高
int[] pixels = new int[width * height];

bitmap.getPixels(pixels, 0, width, 0, 0, width, height);
String str = "";
for (int i = 0; i < pixels.length; i++) {
pixels[i] = pixels[i] & 0x000000ff;
}
return pixels;
}
// 把float转bytes字节
private byte[] float2byte(float f) {

// 把float转换为byte[]
int fbit = Float.floatToIntBits(f);

byte[] b = new byte[4];
for (int i = 0; i < 4; i++) {
b[i] = (byte) (fbit >> (24 - i * 8));
}

// 翻转数组
int len = b.length;
// 建立一个与源数组元素类型相同的数组
byte[] dest = new byte[len];
// 为了防止修改源数组,将源数组拷贝一份副本
System.arraycopy(b, 0, dest, 0, len);
byte temp;
// 将顺位第i个与倒数第i个交换
for (int i = 0; i < len / 2; ++i) {
temp = dest[i];
dest[i] = dest[len - i - 1];
dest[len - i - 1] = temp;
}
return dest;
}

// 获取文件
private MappedByteBuffer loadModelFile(Context context) throws IOException {
AssetFileDescriptor fileDescriptor = context.getAssets().openFd(getModelPath());
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}

private String getModelPath() {
return "writer_model.tflite";
}
}

至此文章中的內容就已完成,在這邊要注意的是Java套件名稱,請記得修改與專案設定一致。

但這樣還沒結束,我們還要將TensorFlow Lite模型放到app/src/main/assets目錄下;此外,文章程式碼是架構在TensorFlow Lite展示程式之下,但以新建專案來說還需要在主Layout中加入FrameLayout(activity_main.xml):

<?xml version="1.0" encoding="utf-8"?>
<android.support.constraint.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity"
>

<FrameLayout
android:id="@+id/fragment_frame"
android:layout_width="match_parent"
android:layout_height="match_parent"
>

</FrameLayout>
</android.support.constraint.ConstraintLayout>

並修改在MainActivity類別的onCreate函式(MainActivity.java):

package com.tungyilin.tflite_mnist;

import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;

import android.app.FragmentManager;
import android.app.FragmentTransaction;

public class MainActivity extends AppCompatActivity {

@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);

WriterFragment fr = new WriterFragment();
Bundle args = new Bundle();
fr.setArguments(args);
FragmentManager fm = getFragmentManager();
FragmentTransaction fragmentTransaction = fm.beginTransaction();
fragmentTransaction.replace(R.id.fragment_frame, fr);
fragmentTransaction.commit();
}
}

過程很複雜,會需要耐心地去除錯Android APP上面的問題;因此,本文中提到的程式碼都打包放在這裡,如果不想”複製-貼上”網頁的資料可以直接下載使用。整個程式的執行結果如下:

發表迴響

你的電子郵件位址並不會被公開。 必要欄位標記為 *