春节假期邻近,亲朋小聚饭后不可避免须要来点乏味的事件打发工夫,「我画你猜」就是一种很好的消遣形式。然而往年既然提倡「就地过年」,那么无妨就把这样的游戏搬到网上吧,照样能够玩到嗨~

一些童鞋可能还有印象,2018年时,Google推出了《猜画小歌》利用:玩家能够间接与AI进行你画我猜的游戏。通过画出一个房子或者一个猫,AI会推断出各种物品被画出的概率。它的实现得益于深度学习模型在其中的利用,通过深度神经网络的演绎,已经令人头疼的绘画辨认也变得大海捞针。现如今,只有应用一个简略的图片分类模型,咱们便能够轻松的实现绘画辨认。试试看这个在线涂鸦小游戏吧。

在过后,大部分机器学习计算工作仍旧须要依靠网络在云端进行。随着算力的一直增进,机器学习工作曾经能够间接在边缘设施部署,包含各类运行安卓零碎的智能手机。然而,因为安卓自身次要是用Java,部署基于Python的各类深度学习模型变成了一个难题。为了解决这个问题,AWS开发并开源了DeepJavaLibrary (DJL),一个为Java量身定制的深度学习框架。

在下文中,咱们将尝试通过PyTorch预训练模型在在安卓平台构建一个涂鸦绘画的利用。因为总代码量会比拟多,咱们这次会挑重点把最要害的代码实现。大家能够后续参考咱们残缺的我的项目进行构建。

环境配置

为了兼容DJL需要的Java性能,这个我的项目须要Android API 26及以上的版本。能够参考咱们案例配置来节约一些工夫,上面是这个我的项目须要的依赖项:

dependencies {
 implementation 'androidx.appcompat:appcompat:1.2.0'
 implementation 'ai.djl:api:0.7.0'
 implementation 'ai.djl.android:core:0.7.0'
 runtimeOnly 'ai.djl.pytorch:pytorch-engine:0.7.0'
 runtimeOnly 'ai.djl.android:pytorch-native:0.7.0'

咱们将应用DJL提供的API以及PyTorch包。

第一步:创立Layout

咱们能够先创立一个View class以及Layout(如下图)来构建安卓的前端显示界面。

如上图所示,咱们能够在主界面创立两个View指标。PaintView是用来让用户画画的,在右下角ImageView是用来展现用于深度学习推理的图像。同时咱们预留一个按钮来进行画板的清空操作。

第二部:应答绘画动作

在安卓设施上,咱们能够自定义安卓的触摸事件响应来应答用户的各种触控操作。在咱们的状况下,咱们须要定义上面三种工夫响应:

  • touchStart:感应触碰时触发
  • touchMove:当用户在屏幕上挪动手指时触发
  • touchUp:当用户抬起手指时触发

与此同时,咱们用paths来存储用户在画板所绘制的门路。当初看一下实现代码。

重写OnTouchEvent和OnDraw办法

当初咱们重写onTouchEvent来应答各种响应:

@Override
public boolean onTouchEvent(MotionEvent event) {
 float x = event.getX();
 float y = event.getY();
 switch (event.getAction()) {
 case MotionEvent.ACTION_DOWN :
 touchStart(x, y);
 invalidate();
 break;
 case MotionEvent.ACTION_MOVE :
 touchMove(x, y);
 invalidate();
 break;
 case MotionEvent.ACTION_UP :
 touchUp();
 runInference();
 invalidate();
 break;
 }
 return true;
}

如上述代码所示,咱们能够增加一个runInference办法在MotionEvent.ACTION_UP事件响应上。这个办法是用来在用户绘制完后对后果进行推理。在之后的几步中,咱们会解说它的具体实现。

咱们同样须要重写onDraw办法来展现用户绘制的图像:

@Override
protected void onDraw(Canvas canvas) {
 canvas.save();
 this.canvas.drawColor(DEFAULT_BG_COLOR);
 for (Path path : paths) {
 paint.setColor(DEFAULT_PAINT_COLOR);
 paint.setStrokeWidth(BRUSH_SIZE);
 this.canvas.drawPath(path, paint);
 }
 canvas.drawBitmap(bitmap, 0, 0, bitmapPaint);
 canvas.restore();
}

真正的图像会保留在一个Bitmap上。

touchStart

当用户触碰行为开始时,上面的代码会建设一个新的门路同时记录门路中每一个点在屏幕上的坐标。

private void touchStart(float x, float y) {
 path = new Path();
 paths.add(path);
 path.reset();
 path.moveTo(x, y);
 this.x = x;
 this.y = y;
}

touchMove

在手指挪动中,咱们会继续记录坐标点而后将它们形成一个quadratic bezier)。通过肯定的误差阀值来动静优化用户的绘画动作。只有差异超出误差范畴内的动作才会被记录下来。

private void touchMove(float x, float y) {
 if (x < 0 || x > getWidth() || y < 0 || y > getHeight()) {
 return;
 }
 float dx = Math.abs(x - this.x);
 float dy = Math.abs(y - this.y);
 if (dx >= TOUCH_TOLERANCE || dy >= TOUCH_TOLERANCE) {
 path.quadTo(this.x, this.y, (x + this.x) / 2, (y + this.y) / 2);
 this.x = x;
 this.y = y;
 }
}

touchUp

当触控操作完结后,上面的代码会绘制一个门路同时计算最小长方形指标框。

private void touchUp() {
 path.lineTo(this.x, this.y);
 maxBound.add(new Path(path));
}

Step 3:开始推理

为了在安卓设施上进行推理工作,咱们须要实现上面几个工作:

  • 从URL读取模型
  • 构建前解决和后处理过程
  • 从PaintView进行推理工作

为了实现以下指标,咱们尝试构建一个DoodleModel class。在这一步,咱们将介绍一些实现这些工作的关键步骤。

读取模型

DJL内建了一套模型管理系统。开发者能够自定义贮存模型的文件夹。

File dir = getFilesDir();
System.setProperty("DJL_CACHE_DIR", dir.getAbsolutePath());

通过更改DJL_CACHE_DIR属性,模型会被存入相应门路下。

下一步能够通过定义Criteria从指定URL处下载模型。下载的zip文件内蕴含:

  • doodle_mobilenet.pt:PyTorch模型
  • synset.txt:内蕴含分类工作中所有类别的名称
Criteria<Image, Classifications> criteria =
 Criteria.builder()
 .setTypes(Image.class, Classifications.class)
 .optModelUrls("https://djl-ai.s3.amazonaws.com/resources/demo/pytorch/doodle_mobilenet.zip")
 .optTranslator(translator)
 .build();
return ModelZoo.loadModel(criteria);

上述代码同时定义了translator。translator会被用来做图片的前解决和后处理。

最初,如下述代码创立一个Model并用它创立一个Predictor:

@Override
protected Boolean doInBackground(Void... params) {
 try {
 model = DoodleModel.loadModel();
 predictor = model.newPredictor();
 return true;
 } catch (IOException | ModelException e) {
 Log.e("DoodleDraw", null, e);
 }
 return false;
}

更多对于模型加载的信息,请参阅如何加载模型。

用Translator定义前解决和后处理

在DJL中,咱们定义了Translator接口进行前解决和后处理。在DoodleModel中咱们定义了ImageClassificationTranslator来实现Translator:

ImageClassificationTranslator.builder()
 .addTransform(new ToTensor())
 .optFlag(Image.Flag.GRAYSCALE)
 .optApplySoftmax(true).build());

上面咱们具体论述translator所定义的前解决和后处理如何被用在模型的推理步骤中。当创立translator时,外部程序会主动加载synset.txt文件失去做分类工作时所有类别的名称。当模型的predict ()办法被调用时,外部程序会先执行所对应的translator的前解决步骤,而后执行理论推理步骤,最初执行translator的后处理步骤。对于前解决,咱们会将Image转化NDArray,用于作为模型推理过程的输出。对于后处理,咱们对推理输入的后果(NDArray)进行softmax操作。最终返回后果为Classifications的一个实例。

更多对于translator的工作原理以及如何个性化Translator的信息,请参阅Inference with your model。

Run inference from PaintView

最初,咱们来实现之前定义好的runInference办法。

public void runInference() {
 // 拷贝图像
 Bitmap bmp = Bitmap.createBitmap(bitmap);
 // 缩放图像
 bmp = Bitmap.createScaledBitmap(bmp, 64, 64, true);
 // 执行推理工作
 Classifications classifications = model.predict(bmp);
 // 展现输出的图像
 Bitmap present = Bitmap.createScaledBitmap(bmp, imageView.getWidth(), imageView.getHeight(), true);
 imageView.setImageBitmap(present);
 // 展现输入的图像
 if (messageToast != null) {
 messageToast.cancel();
 }
 messageToast = Toast.makeText(getContext(), classifications.toString(), Toast.LENGTH_SHORT);
 messageToast.show();
}

这将会创立一个Toast弹出页面用于展现后果,示例如下:

祝贺你!当初你就创立了一个残缺的Doodle Draw小程序!

Optional: Optimize input

为了失去更高的模型推理准确度,能够通过截取图像来去除无意义的边框局部。

下面右侧的图片会比右边的图片有更好的推理后果,因为它所蕴含的空白边框更少。咱们能够通过Bound类来寻找图片的无效边界,即能把图中所有红色像素点笼罩的最小矩形。在失去x轴最左坐标,y轴最上坐标,以及矩形高度和宽度后,就能够用这些信息截取出咱们想要的图形(如右图所示)实现代码如下:

RectF bound = maxBound.getBound();
int x = (int) bound.left;
int y = (int) bound.top;
int width = (int) Math.ceil(bound.width());
int height = (int) Math.ceil(bound.height());
// 截取局部图像
Bitmap bmp = Bitmap.createBitmap(bitmap, x, y, width, height)

祝贺你!当初你就把握了全副教程内容!期待看到你创立的第一个DoodleDraw安卓游戏!

最初,能够在GitHub找到本教程的残缺案例代码。

对于Deep Java Library

Deep Java Library (DJL)是一个基于Java的深度学习框架,同时反对训练以及推理。DJL博取众长,构建在多个深度学习框架之上(TenserFlow、PyTorch、MXNet等),也同时具备多个框架的低劣个性。咱们能够轻松应用DJL来进行训练而后部署你的模型。

它同时领有着弱小的模型库反对:只需一行便能够轻松读取各种预训练的模型。当初DJL的模型库同时反对高达70个来自GluonCV、HuggingFace、TorchHub以及Keras的模型。

我的项目地址:https://github.com/awslabs/djl/

在最新的版本中DJL 0.7.0增加了对于MXNet 1.7.0、PyTorch 1.6.0、TensorFlow 2.3.0的反对。咱们同时也增加了ONNXRuntime以及PyTorch在安卓平台的反对。

请参阅咱们的GitHub、demo repository、Slack channel和知乎频道获取更多信息!