Skip to main content
Version: 0.1

MNIST Digit Classification

In this tutorial we will use a model trained on the MNIST dataset of handwritten digits to predict the number that the user draws.

There are several pieces to this tutorial, so please follow each step carefully. If you get lost, completed examples of each step can be found here.

If you haven't installed the PyTorch Live CLI yet, please follow this tutorial to get started.

Create a new React Native project

We will start by creating a new React Native project with the PyTorch Live (PTL) template using the CLI. Run the following command:

npx torchlive-cli init MNISTClassifier

Once that is done, let's go into a our newly created project and run it!

cd MNISTClassifier
npx torchlive-cli run-android

Adding Basic UI

The aim of this tutorial is to help you become more familiar with PTL core components, so we will not spend time on how to style UI, but rather provide the layout and styles from the start.

Go ahead and start by copying the following code into the file src/demos/MyDemos.tsx:

src/demos/MyDemos.tsx
import React, {useState} from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {Canvas, CanvasRenderingContext2D} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';

export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
const [canvasSize, setCanvasSize] = useState<number>(0);
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();

return (
<View
style={styles.container}
onLayout={event => {
const {layout} = event.nativeEvent;
setCanvasSize(Math.min(layout?.width || 0, layout?.height || 0));
}}>
<View style={[styles.instruction, {marginTop: insets.top}]}>
<Text style={styles.label}>Write a number</Text>
<Text style={styles.label}>Let's test the MNIST model</Text>
</View>
<Canvas
style={{
height: canvasSize,
width: canvasSize,
}}
onContext2D={setCtx}
/>
<View style={[styles.resultView]} pointerEvents="none">
<Text style={[styles.label, styles.secondary]}>
Highest confidence will go here
</Text>
<Text style={[styles.label, styles.secondary]}>
Second highest will go here
</Text>
</View>
</View>
);
}

const styles = StyleSheet.create({
container: {
height: '100%',
width: '100%',
backgroundColor: '#180b3b',
justifyContent: 'center',
alignItems: 'center',
},
resultView: {
position: 'absolute',
bottom: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
instruction: {
position: 'absolute',
top: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
label: {
fontSize: 16,
color: '#ffffff',
},
secondary: {
color: '#ffffff99',
},
});

Now you should see UI that looks exactly like the screenshot below.

npx torchlive-cli run-android

Before we add more code, let's take a second to discuss some of what the above code does.

The PyTorch Live Canvas Component

We'll be using the PTL canvas in this tutorial to let the user draw numbers that we will try to classify.

Just like the name suggests, a canvas is a surface that we can programmatically draw on.

In order to draw things on a canvas, we use what is called the canvas context, the ctx state variable in this case.

Note that we haven't used the context to draw anything yet, so our canvas is essentially invisible.

...
export default function MNISTDemo() {
...
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
...
<Canvas
style={{
height: canvasSize,
width: canvasSize,
}}
onContext2D={setCtx}
/>
...

The onLayout Prop

In our code, we use the onLayout prop on the container view to get the dimensions of the screen space we are working with.

Once we have the dimensions of the screen, we find which is smaller between the screen width and height and then we use that to size our canvas.

This makes sure that our canvas is square and fits within the bounds of our screen in both portrait and landscape.

...
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
const [canvasSize, setCanvasSize] = useState<number>(0);
...
return (
<View
style={styles.container}
onLayout={event => {
const {layout} = event.nativeEvent;
setCanvasSize(Math.min(layout?.width || 0, layout?.height || 0));
}}>
...

Results placeholders

Note that for now we just have placeholder text where we will put our model results. Later on, after we run the model, we will update the text there to display the results.

...
<View style={[styles.resultView]} pointerEvents="none">
<Text style={[styles.label, styles.secondary]}>
Highest confidence will go here
</Text>
<Text style={[styles.label, styles.secondary]}>
Second highest will go here
</Text>
</View>
...

Filling the Canvas

Like we mentioned in the previous section, our canvas is currently completely blank.

Let's change that and make a clear surface for users to draw on.

Here's a short summary of the changes we're introducing:

  1. Import useCallback and useEffect from React.

  2. Define a color for our canvas background (COLOR_CANVAS_BACKGROUND). We'll use a lighter purple color to distinguish from the rest of the screen.

  3. Create a draw function that will fill in our background. We create it with useCallback to make it so the function updates every time the context or size of the canvas change.

    1. Check to make sure context is not null so we have something to draw with.

    2. Set the context's fill style to our canvas background purple (essentially choosing which marker to work with).

    3. Fill in a rectangle that starts at the origin coordinate (0,0) on our canvas (the top left corner) and ends in the bottom right corner of our canvas so it covers the whole thing.

    4. Call the invalidate function on our canvas context to let the screen know that we have drawn new things for it to show.

  4. Trigger the draw anytime it changes with the useEffect block. Remember that draw changes every time the canvas context or size changes, so essentially this useEffect runs every time the canvas changes.

note

The useCallback and useEffect that we imported as well as the useState function we already had imported are examples of React Hooks. Hooks allow React function components, like our MNISTDemo function component, to remember things.

You'll notice at the end of useCallback and useEffect we have a list []. This list is the list of "dependencies" for that hook. This just means that the hook will hold onto the value we give it until one of the "dependencies" changes, in which case it will update the value it remembers.

For more information on React Hooks, head over to the React docs where you can read or watch explanations.

@@ -1,8 +1,10 @@
-import React, {useState} from 'react';
+import React, {useCallback, useEffect, useState} from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {Canvas, CanvasRenderingContext2D} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';

+const COLOR_CANVAS_BACKGROUND = '#4F25C6';
+
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
@@ -10,6 +12,20 @@
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();

+ const draw = useCallback(() => {
+ if (ctx != null) {
+ // fill background by drawing a rect
+ ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
+ ctx.fillRect(0, 0, canvasSize, canvasSize);
+
+ ctx.invalidate();
+ }
+ }, [ctx, canvasSize]);
+
+ useEffect(() => {
+ draw();
+ }, [draw]);
+
return (
<View
style={styles.container}

Once you run your app, the My Demos screen should now look like this.

npx torchlive-cli run-android

I know that was a lot of new stuff to simply paint our canvas light purple, but it provides us with a good foundation for when we draw more on our canvas.

Drawing with Touch Input

Now that we have a clear area for the user to draw on, let's make it so they can draw!

Let's go over what we will change to make drawing possible:

  1. Import useRef from React.

  2. Define a color for the trail of the users touch (COLOR_TRAIL_STROKE). We'll use white to make it stand out.

  3. Define a TrailPoint type to keep our data safe, error free, and easy to use.

  4. Create a ref to a list of TrailPoints called trailRef and set it to an empty list.

  5. Keep track of if the user has finished drawing with the drawingDone state variable and initialize it to false.

  6. Add support for drawing the trail to our draw function:

    1. Create a variable called trail and set it to the current value of our trailRef. This is purely so we don't have to write trailRef.current every time we need the trail.
    2. Check to make sure the trail isn't null.
    3. Draw our background to cover anything previously drawn.
    4. Check to make sure our trail has at least 1 point.
    5. Set the context's strokeColor - you can think of it as picking the marker color we'll draw lines with.
    6. Set the context's line drawing style parameters (lineWidth, lineJoin, lineCap, and miterLimit).
    7. Tell the context to start a line at the first point in the trail.
    8. Loop through points of the trail to add them to the line we are drawing.
    9. Tell the context via the stroke method to actually draw the line that we constructed.
    10. Use the invalidate method to tell the screen we have updates ready to draw.
  7. Create functions for handling when a user touches the canvas (handleStart, handleTouch, and handleEnd).

  8. The handleStart is called when the user first touches the canvas. It is a simple function that does the following:

    1. Set the drawingDone variable to false.
    2. Reset the trailRef to an emptyList.
  9. The handleMove function is called each time the device detects that the touch has changed positions since the starting touch.

    1. Get the coordinates of the new touch location and store them in the position variable.
    2. If there are already points in the trail, only add the new position if it's 5 pixels away from the last position (avoids keeping unnecessary points that slow down the app).
    3. If there are no points in the trail, add the new position.
    4. Trigger the draw function to display the newly updated trail.
  10. The handleEnd function is called when the user's touch is no longer detected on the screen.

    1. Simply set the drawingDone state variable to true.
  11. Set the onTouchStart, onTouchMove, and onTouchEnd props on our <Canvas /> component to handleStart, handleMove, and handleEnd respectively.

@@ -1,26 +1,88 @@
-import React, {useCallback, useEffect, useState} from 'react';
+import React, {useCallback, useEffect, useState, useRef} from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {Canvas, CanvasRenderingContext2D} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';

const COLOR_CANVAS_BACKGROUND = '#4F25C6';
+const COLOR_TRAIL_STROKE = '#FFFFFF';
+
+type TrailPoint = {
+ x: number;
+ y: number;
+};

export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
const [canvasSize, setCanvasSize] = useState<number>(0);
+
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();

+ const trailRef = useRef<TrailPoint[]>([]);
+ const [drawingDone, setDrawingDone] = useState(false);
+
const draw = useCallback(() => {
if (ctx != null) {
- // fill background by drawing a rect
- ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
- ctx.fillRect(0, 0, canvasSize, canvasSize);
+ const trail = trailRef.current;
+ if (trail != null) {
+ // fill background by drawing a rect
+ ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
+ ctx.fillRect(0, 0, canvasSize, canvasSize);
+
+ // Draw the trail
+
+ if (trail.length > 0) {
+ ctx.strokeStyle = COLOR_TRAIL_STROKE;
+ ctx.lineWidth = 25;
+ ctx.lineJoin = 'round';
+ ctx.lineCap = 'round';
+ ctx.miterLimit = 1;
+ ctx.beginPath();
+ ctx.moveTo(trail[0].x, trail[0].y);
+ for (let i = 1; i < trail.length; i++) {
+ ctx.lineTo(trail[i].x, trail[i].y);
+ }
+ ctx.stroke();
+ }

- ctx.invalidate();
+ ctx.invalidate();
+ }
}
- }, [ctx, canvasSize]);
+ }, [ctx, canvasSize, trailRef]);
+
+ // handlers for touch events
+ const handleMove = useCallback(
+ async event => {
+ const position: TrailPoint = {
+ x: event.nativeEvent.locationX,
+ y: event.nativeEvent.locationY,
+ };
+ const trail = trailRef.current;
+ if (trail.length > 0) {
+ const lastPosition = trail[trail.length - 1];
+ const dx = position.x - lastPosition.x;
+ const dy = position.y - lastPosition.y;
+ // add a point to trail if distance from last point > 5
+ if (dx * dx + dy * dy > 25) {
+ trail.push(position);
+ }
+ } else {
+ trail.push(position);
+ }
+ draw();
+ },
+ [trailRef, draw],
+ );
+
+ const handleStart = useCallback(() => {
+ setDrawingDone(false);
+ trailRef.current = [];
+ }, [trailRef, setDrawingDone]);
+
+ const handleEnd = useCallback(() => {
+ setDrawingDone(true);
+ }, [setDrawingDone]);

useEffect(() => {
draw();
@@ -35,7 +97,9 @@
}}>
<View style={[styles.instruction, {marginTop: insets.top}]}>
<Text style={styles.label}>Write a number</Text>
- <Text style={styles.label}>Let's test the MNIST model</Text>
+ <Text style={styles.label}>
+ Let's see if the AI model will get it right
+ </Text>
</View>
<Canvas
style={{
@@ -43,15 +107,20 @@
width: canvasSize,
}}
onContext2D={setCtx}
+ onTouchMove={handleMove}
+ onTouchStart={handleStart}
+ onTouchEnd={handleEnd}
/>
- <View style={[styles.resultView]} pointerEvents="none">
- <Text style={[styles.label, styles.secondary]}>
- Highest confidence will go here
- </Text>
- <Text style={[styles.label, styles.secondary]}>
- Second highest will go here
- </Text>
- </View>
+ {drawingDone && (
+ <View style={[styles.resultView]} pointerEvents="none">
+ <Text style={[styles.label, styles.secondary]}>
+ Highest confidence will go here
+ </Text>
+ <Text style={[styles.label, styles.secondary]}>
+ Second highest will go here
+ </Text>
+ </View>
+ )}
</View>
);
}

Run this code and we should now be able to do some drawing like you can see in the video below.

As you will notice, the drawing seems to glitch out at times, especially as the trail gets longer and longer. Let's fix that next.

info

React Refs

Refs in React are a variable like state, but they don't cause the component to re-render when they are changed.

You can get or set the value of a ref via the .current property.

In our code, we access the trail with trailRef.current. We set the trail in our handleStart function to an empty list with trailRef.current = [].

Avoiding Excessive Re-rendering

The glitchiness we see in our code as it stands is because we are asking the screen to refresh before it is ready.

Mobile screens typically refresh 60 times per second (though some new phones refresh twice as often). When we display things with React, it takes care of matching our device's refresh rate.

While we are using React to render our <Canvas />, what we draw on our canvas we handle ourselves. Lucky for us, there is a simple way to make sure we don't render too often.

To address this, we will make a few updates to our code, mainly in the draw function:

  1. Create a ref called animationHandleRef that can be a number or null and set it to null. We will use this ref to check if rendering is currently in process or not.

  2. Use the animationHandleRef in the draw function to control how often we rerender:

    1. Start the function by checking if the animationHandleRef is set to a non-null value. If it is, we want to end early, because we know the device is already working on rendering.
    2. Wrap our code that does drawing in an inline function that we pass to requestAnimationFrame and set the animationHandleRef's value to what it returns. (Read more about this function in the note following the code.)
    3. After telling our canvas we are ready for a rerender with ctx.invalidate(), clear the animationHandleRef by setting its value to null.
    4. Add animationHandleRef to the draw function's callback dependencies list.
@@ -21,35 +21,40 @@

const trailRef = useRef<TrailPoint[]>([]);
const [drawingDone, setDrawingDone] = useState(false);
+ const animationHandleRef = useRef<number | null>(null);

const draw = useCallback(() => {
+ if (animationHandleRef.current != null) return;
if (ctx != null) {
- const trail = trailRef.current;
- if (trail != null) {
- // fill background by drawing a rect
- ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
- ctx.fillRect(0, 0, canvasSize, canvasSize);
-
- // Draw the trail
+ animationHandleRef.current = requestAnimationFrame(() => {
+ const trail = trailRef.current;
+ if (trail != null) {
+ // fill background by drawing a rect
+ ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
+ ctx.fillRect(0, 0, canvasSize, canvasSize);

- if (trail.length > 0) {
+ // Draw the trail
ctx.strokeStyle = COLOR_TRAIL_STROKE;
ctx.lineWidth = 25;
ctx.lineJoin = 'round';
ctx.lineCap = 'round';
ctx.miterLimit = 1;
- ctx.beginPath();
- ctx.moveTo(trail[0].x, trail[0].y);
- for (let i = 1; i < trail.length; i++) {
- ctx.lineTo(trail[i].x, trail[i].y);
+
+ if (trail.length > 0) {
+ ctx.beginPath();
+ ctx.moveTo(trail[0].x, trail[0].y);
+ for (let i = 1; i < trail.length; i++) {
+ ctx.lineTo(trail[i].x, trail[i].y);
+ }
}
ctx.stroke();
+ // Need to include this at the end, for now.
+ ctx.invalidate();
+ animationHandleRef.current = null;
}
-
- ctx.invalidate();
- }
+ });
}
- }, [ctx, canvasSize, trailRef]);
+ }, [animationHandleRef, ctx, canvasSize, trailRef]);

// handlers for touch events
const handleMove = useCallback(
info

What does requestAnimationFrame do?

requestAnimationFrame is a utility function that helps us run code when the screen is ready for the next rerender.

Input: a callback function as a parameter and then runs that function when the screen next refreshes.

Output: a number that functions as an ID for the callback. You can use that number to cancel the callback if you later decide you don't want to run the code. (We don't need that feature for this)

Once you have those changes in your code, go ahead and refresh the app and see how much smoother drawing is.

npx torchlive-cli run-android

With silky smooth drawing in place, we are now ready to start working with the MNIST model.

Running the Model

We'll start by creating a React hook that provides a function for running inference on an input image. We'll follow React hooks naming conventions and call ours useMNISTModel.

Let's summarize the changes we're making:

  1. Import Image and MobileModel from react-native-pytorch-core.
  2. Load the model file with the require function and call it mnistModel.
  3. Create a type called MNISTResult with the following properties:
    1. num - a digit from 0 to 9.
    2. score - the confidence the model has in the input image being the given num.
  4. Define a function called useMNISTModel that does the following:
    1. Creates a React callback async function called processImage that takes in Image as a parameter and does the following.
      1. Uses the MobileModel api to execute the mnistModel we loaded with a set of parameters that tell the model how much of the image to use and what the foreground and background colors are.
      2. Transform the raw scores into MNISTResult objects.
      3. Sort the results by score.
      4. return the sorted results.
    2. Returns an object containing the processImage function we just created.
@@ -1,6 +1,11 @@
import React, {useCallback, useEffect, useState, useRef} from 'react';
import {StyleSheet, Text, View} from 'react-native';
-import {Canvas, CanvasRenderingContext2D} from 'react-native-pytorch-core';
+import {
+ Canvas,
+ CanvasRenderingContext2D,
+ Image,
+ MobileModel,
+} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';

const COLOR_CANVAS_BACKGROUND = '#4F25C6';
@@ -11,6 +16,44 @@
y: number;
};

+// This is the custom model you have trained. See the tutorial for more on preparing a PyTorch model for mobile.
+const mnistModel = require('../../models/mnist.ptl');
+
+type MNISTResult = {
+ num: number;
+ score: number;
+};
+
+/**
+ * The React hook provides MNIST model inference on an input image.
+ */
+function useMNISTModel() {
+ const processImage = useCallback(async (image: Image) => {
+ // Runs model inference on input image
+ const {
+ result: {scores},
+ } = await MobileModel.execute<{scores: number[]}>(mnistModel, {
+ image,
+ crop_width: 1,
+ crop_height: 1,
+ scale_width: 28,
+ scale_height: 28,
+ colorBackground: COLOR_CANVAS_BACKGROUND,
+ colorForeground: COLOR_TRAIL_STROKE,
+ });
+
+ // Get the score of each number (index), and sort the array by the most likely first.
+ const sortedScore: MNISTResult[] = scores
+ .map((score, index) => ({score: score, num: index}))
+ .sort((a, b) => b.score - a.score);
+ return sortedScore;
+ }, []);
+
+ return {
+ processImage,
+ };
+}
+
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();

An even shorter summary: it takes in an Image and gives back a list of sorted results.

But, we don't have Images. We just have a trail on a canvas.

In the next section, we'll learn how to create an Image from the contents of our canvas that we can pass to the model.

Creating an Image from our Canvas

We are going to create another hook called useMNISTCanvasInference that uses the hook we just created (useMNISTModel).

This hook will take in the canvasSize and give us back two things:

  1. result - a state variable that holds the sorted list of MNISTResults from the model.
  2. classify - a function that takes in the canvas context, extracts an image from it, processes the image, and then updates the result state variable.

In our classify callback, we use some of the PTL core components, including the newly imported ImageUtil object.

The ImageUtil object allows us to take the imageData we pull from the canvas and turn it into an Image that can be used by our model.

You'll also see that we call the release function on both our imageData and our image variables as soon as we are done using them. This is a vital step to make sure we don't run out of memory on images we no longer need.

@@ -4,6 +4,7 @@
Canvas,
CanvasRenderingContext2D,
Image,
+ ImageUtil,
MobileModel,
} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
@@ -54,6 +55,48 @@
};
}

+/**
+ * The React hook provides MNIST inference using the image data extracted from
+ * a canvas.
+ *
+ * @param canvasSize The size of the square canvas
+ */
+function useMNISTCanvasInference(canvasSize: number) {
+ const [result, setResult] = useState<MNISTResult[]>();
+ const {processImage} = useMNISTModel();
+ const classify = useCallback(
+ async (ctx: CanvasRenderingContext2D) => {
+ // Return immediately if canvas is size 0
+ if (canvasSize === 0) {
+ return null;
+ }
+
+ // Get image data center crop
+ const imageData = await ctx.getImageData(0, 0, canvasSize, canvasSize);
+
+ // Convert image data to image.
+ const image: Image = await ImageUtil.fromImageData(imageData);
+
+ // Release image data to free memory
+ imageData.release();
+
+ // Run MNIST inference on the image
+ const result = await processImage(image);
+
+ // Release image to free memory
+ image.release();
+
+ // Set result state to force re-render of component that uses this hook
+ setResult(result);
+ },
+ [canvasSize, processImage, setResult],
+ );
+ return {
+ result,
+ classify,
+ };
+}
+
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();

With this second hook, we are ready to run our model with the user created drawings. Let's hook it up in the next section.

Running the Model & Displaying Results

While we add a decent amount of lines in this section, they are all simple changes.

Let's cut to the summary:

  1. Create a type called NumberLabelSet so we know what kind of data we have access to about a number.
  2. Create a list of NumberLabelSet objects and call it numLabels.
  3. Get the classify method and result state variable by calling useMNISTCanvasInference from within our demo component.
  4. Update the handleEnd function to check for a canvas context and then trigger the model.
  5. Add classify as a dependency to the handleEnd callback function.
  6. Change the text in the results section to reflect the numbers from the model output.
@@ -97,6 +97,54 @@
};
}

+type NumberLabelSet = {
+ english: string;
+ asciiSymbol: string;
+};
+
+const numLabels: NumberLabelSet[] = [
+ {
+ english: 'zero',
+ asciiSymbol: '🄌',
+ },
+ {
+ english: 'one',
+ asciiSymbol: '➊',
+ },
+ {
+ english: 'two',
+ asciiSymbol: '➋',
+ },
+ {
+ english: 'three',
+ asciiSymbol: '➌',
+ },
+ {
+ english: 'four',
+ asciiSymbol: '➍',
+ },
+ {
+ english: 'five',
+ asciiSymbol: '➎',
+ },
+ {
+ english: 'six',
+ asciiSymbol: '➏',
+ },
+ {
+ english: 'seven',
+ asciiSymbol: '➐',
+ },
+ {
+ english: 'eight',
+ asciiSymbol: '➑',
+ },
+ {
+ english: 'nine',
+ asciiSymbol: '➒',
+ },
+];
+
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
@@ -105,6 +153,8 @@
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();

+ const {classify, result} = useMNISTCanvasInference(canvasSize);
+
const trailRef = useRef<TrailPoint[]>([]);
const [drawingDone, setDrawingDone] = useState(false);
const animationHandleRef = useRef<number | null>(null);
@@ -173,7 +223,8 @@

const handleEnd = useCallback(() => {
setDrawingDone(true);
- }, [setDrawingDone]);
+ if (ctx != null) classify(ctx);
+ }, [setDrawingDone, classify, ctx]);

useEffect(() => {
draw();
@@ -205,10 +256,16 @@
{drawingDone && (
<View style={[styles.resultView]} pointerEvents="none">
<Text style={[styles.label, styles.secondary]}>
- Highest confidence will go here
+ {result &&
+ `${numLabels[result[0].num].asciiSymbol} it looks like ${
+ numLabels[result[0].num].english
+ }`}
</Text>
<Text style={[styles.label, styles.secondary]}>
- Second highest will go here
+ {result &&
+ `${numLabels[result[1].num].asciiSymbol} or it might be ${
+ numLabels[result[1].num].english
+ }`}
</Text>
</View>
)}

When you run the code, you should see it display results properly in the bottom left corner like the screen recording below.

npx torchlive-cli run-android

And with that we have a working MNIST classifier!

Give us feedback