一、Pytorch安装
下载Whl文件,并使用pip install本地安装
避坑:下载tar.bz2版本会缺失setup.py,使用conda可安装,但pip install安装会报错,且安装速度非常慢
Pytorch&torchvision .Whl文件下载地址
由于项目要求mobilenet_v3_small,在TorchVision低版本中不可用。如果使用它,需要升级到0.10.0(稳定版本)或至少0.9.0,这里直接选择使用torch==1.9.0,torch==1.9.0和torchvision==0.10.0对应,故选择torch-1.9.0+cu111-cp37-cp37m-win_amd64.whl、torchvision-0.10.0+cu111-cp37-cp37m-win_amd64.whl
安装时cmd运行进入whl所在目录,pip install "torch-1.9.0+cu111-cp37-cp37m-win_amd64.whl"
如果提示torchvision等其他包版本不匹配,先pip uninstall 包即可
二、HelloWorld
环境要求:Android SDK & Android NDK
- 预训练模型配置测试
import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile
model = torchvision.models.mobilenet_v3_small(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_traced_model = optimize_for_mobile(traced_script_module)
optimized_traced_model.save("model.pt")
在根目录下生成model.pt,大小9.7MB左右
- Android测试
git clone https://github.com/pytorch/android-demo-app.git
cd HelloWorldApp
将第一步生成的model.pt文件替换HelloWorldApp/app/src/main/assets/model.pt
随后利用android studio打开HelloWorldApp这个project,然后点击Build菜单,选择Build Bundle(s)/APK(s) 菜单中的“Build APK(s)”
提示Could not find org.pytorch:pytorch_android:1.8.0-SNAPSHOT,打开app/build.gradle
,按照如下修改:
apply plugin: 'com.android.application'
android {
compileSdkVersion 28
buildToolsVersion "29.0.2"
defaultConfig {
applicationId "org.pytorch.helloworld"
minSdkVersion 21
//noinspection ExpiredTargetSdkVersion
targetSdkVersion 28
versionCode 1
versionName "1.0"
}
buildTypes {
release {
minifyEnabled false
}
}
}
dependencies {
implementation 'androidx.appcompat:appcompat:1.1.0'
implementation 'org.pytorch:pytorch_android_lite:1.9.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
}
修改app/src/main/java/org/pytorch/helloworld/MainActivity.java
package org.pytorch.helloworld;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.util.Log;
import android.widget.ImageView;
import android.widget.TextView;
import org.pytorch.IValue;
import org.pytorch.LiteModuleLoader;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
import org.pytorch.MemoryFormat;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import androidx.appcompat.app.AppCompatActivity;
public class MainActivity extends AppCompatActivity {
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
Bitmap bitmap = null;
Module module = null;
try {
// creating bitmap from packaged into app android asset 'image.jpg',
// app/src/main/assets/image.jpg
bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
// loading serialized torchscript module from packaged into app android asset model.pt,
// app/src/model/assets/model.pt
module = LiteModuleLoader.load(assetFilePath(this, "model.pt"));
} catch (IOException e) {
Log.e("PytorchHelloWorld", "Error reading assets", e);
finish();
}
// showing image on UI
ImageView imageView = findViewById(R.id.image);
imageView.setImageBitmap(bitmap);
// preparing input tensor
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST);
// running the model
final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
// getting tensor content as java array of floats
final float[] scores = outputTensor.getDataAsFloatArray();
// searching for the index with maximum score
float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
if (scores[i] > maxScore) {
maxScore = scores[i];
maxScoreIdx = i;
}
}
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
// showing className on UI
TextView textView = findViewById(R.id.text);
textView.setText(className);
}
/**
* Copies specified asset to the file in /files app directory and returns this file absolute path.
*
* @return absolute file path
*/
public static String assetFilePath(Context context, String assetName) throws IOException {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
}
}
}
修改trace_model.py
import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile
model = torchvision.models.mobilenet_v3_small(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_traced_model = optimize_for_mobile(traced_script_module)
optimized_traced_model._save_for_lite_interpreter("app/src/main/assets/model.pt")
Terminal运行py脚本,生成model.pt,大小9.72 MB (10,199,082 字节),syns now,Build Apk,模拟器运行,效果详情参考Github
三、Image Segmentation
图像分割综述
新增拍照,选图,Live功能
AndroidManifest.xml
定义申请权限permission
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
<uses-permission android:name="android.permission.CAMERA" />
MainActivity.java
的onCreate
中检查并申请权限
if (ContextCompat.checkSelfPermission(this, Manifest.permission.READ_EXTERNAL_STORAGE) != PackageManager.PERMISSION_GRANTED) {
ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.READ_EXTERNAL_STORAGE}, 1);
}
if (ContextCompat.checkSelfPermission(this, Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {
ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.CAMERA}, 1);
}
选择图片时APP闪退等问题解决
(1)先判断所用的sdk是否大于19;
(2)如果大于19则使用Intent.ACTION_PICK来选择图片;
(3)小于19使用Intent.ACTION_GET_CONTENt来选择图片;
Intent intent;
if (Build.VERSION.SDK_INT < 19) {
intent = new Intent(Intent.ACTION_GET_CONTENT);
intent.setType("image/*");
} else {
intent = new Intent(Intent.ACTION_PICK, android.provider.MediaStore.Images.Media.EXTERNAL_CONTENT_URI);
}
startActivityForResult(intent, 1);
如果用的Android 6.0及以上的Android设备,请动态申请权限
如果用的Android 10.0及以上的Android设备,请动态申请权限,并在 AndroidManifest.xml中application标签内加上android:requestLegacyExternalStorage="true"
四、Object Detection
选择图片时APP闪退等问题解决
(1)先判断所用的sdk是否大于19;
(2)如果大于19则使用Intent.ACTION_PICK来选择图片;
(3)小于19使用Intent.ACTION_GET_CONTENt来选择图片;
Intent intent;
if (Build.VERSION.SDK_INT < 19) {
intent = new Intent(Intent.ACTION_GET_CONTENT);
intent.setType("image/*");
} else {
intent = new Intent(Intent.ACTION_PICK, android.provider.MediaStore.Images.Media.EXTERNAL_CONTENT_URI);
}
startActivityForResult(intent, 1);
权限申请不全导致闪退问题解决:多权限动态申请
如果用的Android 6.0及以上的Android设备,请动态申请权限
如果用的Android 10.0及以上的Android设备,请动态申请权限,并在 AndroidManifest.xml中application标签内加上android:requestLegacyExternalStorage="true"
附:15种ARGB颜色,用于标记图像分割标记