Skip to main content
Version: 0.1

Image Classification

In this tutorial, you will build an app that can take pictures and classify objects in each image using an on-device image classification model.

If you haven't installed the PyTorch Live CLI yet, please follow this tutorial to get started.

Initialize New Project

Let's start by initializing a new project ImageClassificationTutorial with the PyTorch Live CLI.

npx torchlive-cli init ImageClassificationTutorial
note

The project init can take a few minutes depending on your internet connection and your computer.

After completion, navigate to the ImageClassificationTutorial directory created by the init command.

cd ImageClassificationTutorial

Run the project in the Android emulator or iOS Simulator

The run-android and run-ios commands from the PyTorch Live CLI allow you to run the image classification project in the Android emulator or iOS Simulator.

npx torchlive-cli run-android

The app will deploy and run on your physical Android device if it is connected to your computer via USB, and it is in developer mode. There are more details on that in the Get Started tutorial.

tip

Keep the app open and running! Any code change will immediately be reflected after saving.

Image Classification Demo

Let's get started with the UI for the image classification. Go ahead and start by copying the following code into the file src/demos/MyDemos.tsx:

note

The MyDemos.tsx already contains code. Replace the code with the code below.

src/demos/MyDemos.tsx
import * as React from 'react';
import {Text, View} from 'react-native';
import {useSafeAreaInsets} from 'react-native-safe-area-context';

export default function ImageClassificationDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
return (
<View style={{marginTop: insets.top, marginBottom: insets.bottom}}>
<Text>Image Classification</Text>
</View>
);
}

tip

The app starts with the "Examples" tab open. In order to see the changes you just made to the MyDemos.tsx, tap on the "My Demos" tab bar item at the bottom of the screen.

Style the component

Great! Let's add some basic styling to the app UI. The styles will change the View component background to #ffffff, spans container view to maximum available width and height, centers components horizontally, and adds a padding of 20 pixels. The Text component will have a margin at the bottom to provide spacing between the text label and the Camera component that will be added in the next steps.

@@ -1,13 +1,30 @@
import * as React from 'react';
-import {Text, View} from 'react-native';
+import {Text, StyleSheet, View} from 'react-native';
import {useSafeAreaInsets} from 'react-native-safe-area-context';

export default function ImageClassificationDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
return (
- <View style={{marginTop: insets.top, marginBottom: insets.bottom}}>
- <Text>Image Classification</Text>
+ <View
+ style={[
+ styles.container,
+ {marginTop: insets.top, marginBottom: insets.bottom},
+ ]}>
+ <Text style={styles.label}>Image Classification</Text>
</View>
);
}
+
+const styles = StyleSheet.create({
+ container: {
+ alignItems: 'center',
+ backgroundColor: '#ffffff',
+ display: 'flex',
+ flexGrow: 1,
+ padding: 20,
+ },
+ label: {
+ marginBottom: 10,
+ },
+});

Add camera component

Next, let's add a Camera component to take pictures that can be used later for the ML model inference to classify what object is in the picture. The camera will also get a basic style to fill the remaining space in the container.

@@ -1,5 +1,6 @@
import * as React from 'react';
import {Text, StyleSheet, View} from 'react-native';
+import {Camera} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';

export default function ImageClassificationDemo() {
@@ -12,6 +13,7 @@
{marginTop: insets.top, marginBottom: insets.bottom},
]}>
<Text style={styles.label}>Image Classification</Text>
+ <Camera style={styles.camera} />
</View>
);
}
@@ -20,11 +22,14 @@
container: {
alignItems: 'center',
backgroundColor: '#ffffff',
- display: 'flex',
flexGrow: 1,
padding: 20,
},
label: {
marginBottom: 10,
},
+ camera: {
+ flexGrow: 1,
+ width: '100%',
+ },
});

Add capture callback to camera

To receive an image whenever the camera capture button is pressed, we add an async handleImage function and set it for the onCapture property of the Camera component. This handleImage function will be called with an image from the camera when the capture button is pressed.

As a first step, let's log image to the console.

caution

The image.release() function call is important to release the memory allocated for the image object. This is a vital step to make sure we don't run out of memory on images we no longer need.

@@ -1,11 +1,19 @@
import * as React from 'react';
import {Text, StyleSheet, View} from 'react-native';
-import {Camera} from 'react-native-pytorch-core';
+import {Camera, Image} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';

export default function ImageClassificationDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
+
+ async function handleImage(image: Image) {
+ // Log captured image to Metro console
+ console.log(image);
+ // It is important to release the image to avoid memory leaks
+ image.release();
+ }
+
return (
<View
style={[
@@ -13,7 +21,7 @@
{marginTop: insets.top, marginBottom: insets.bottom},
]}>
<Text style={styles.label}>Image Classification</Text>
- <Camera style={styles.camera} />
+ <Camera style={styles.camera} onCapture={handleImage} />
</View>
);
}

Click on camera capture button and check logged output in terminal. It will log a JavaScript representation of the image to the console every time you click the capture button.

Run model inference

Fantastic! Now let's use the image and run inference on a captured image.

We'll require the MobileNet V3 (small) model and add the ImageClassificationResult type for type-safety. Then, we call the execute function on the MobileModel object with the model as first argument and an object with the image as second argument.

Don't forget the await keyword for the MobileModel.execute function call!

Last, let's log the inference result to the console.

@@ -1,15 +1,28 @@
import * as React from 'react';
import {Text, StyleSheet, View} from 'react-native';
-import {Camera, Image} from 'react-native-pytorch-core';
+import {Camera, Image, MobileModel} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';

+const model = require('../../models/mobilenet_v3_small.ptl');
+
+type ImageClassificationResult = {
+ maxIdx: number;
+ confidence: number;
+};
+
export default function ImageClassificationDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();

async function handleImage(image: Image) {
- // Log captured image to Metro console
- console.log(image);
+ const inferenceResult =
+ await MobileModel.execute<ImageClassificationResult>(model, {
+ image,
+ });
+
+ // Log model inference result to Metro console
+ console.log(inferenceResult);
+
// It is important to release the image to avoid memory leaks
image.release();
}

The logged inference result is a JavaScript object containing the inference result including the maxIdx (argmax result) mapping to the top class detected in the image, a confidence value for this class to be correct, and inference metrics (i.e., inference time, pack time, unpack time, and total time).

Get top image class

Ok! So, we have an maxIdx number as inference result (i.e., 673). It's not sensible to show a maxIdx to the user, so let's get label for the top class. For this, we need to import the image classes for this model, which is the MobileNetV3Classes JSON file containing an array of 1000 class labels. The maxIdx maps to a label representing the top class.

Here, we require the JSON file into the ImageClasses variable and use ImageClasses to retrieve the label for the top class using the maxIdx returned from the inference.

Let's see what the maxIdx 673 resolves into by logging the topClass label to the console!

@@ -10,18 +10,25 @@
confidence: number;
};

+const ImageClasses = require('../MobileNetV3Classes');
+
export default function ImageClassificationDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();

async function handleImage(image: Image) {
- const inferenceResult =
- await MobileModel.execute<ImageClassificationResult>(model, {
+ const {result} = await MobileModel.execute<ImageClassificationResult>(
+ model,
+ {
image,
- });
+ },
+ );
+
+ // Get max index (argmax) result to resolve the top class name
+ const topClass = ImageClasses[result.maxIdx];

- // Log model inference result to Metro console
- console.log(inferenceResult);
+ // Log top class to Metro console
+ console.log(topClass);

// It is important to release the image to avoid memory leaks
image.release();

It looks like the model classified the image as mouse, computer mouse. The next section will reveal if this is correct!

Show top image class

Instead of having the end-user looking at a console log, we will render the top image class in the app. We'll add a state for the objectClass using a React Hook, and when a class is detected, we'll set the top class as object class using the setObjectClass function.

The user interface will automatically re-render whenever the setObjectClass function is called with a new value, so you don't have to worry about calling anything else besides this function. On re-render, the objectClass variable will have this new value, so we can use it to render it on the screen.

note

The React.useState is a React Hook. Hooks allow React function components, like our ImageClassificationTutorial function component, to remember things.

For more information on React Hooks, head over to the React docs where you can read or watch explanations.

@@ -16,6 +16,9 @@
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();

+ // Component state that holds the detected object class
+ const [objectClass, setObjectClass] = React.useState<string>('');
+
async function handleImage(image: Image) {
const {result} = await MobileModel.execute<ImageClassificationResult>(
model,
@@ -27,8 +30,8 @@
// Get max index (argmax) result to resolve the top class name
const topClass = ImageClasses[result.maxIdx];

- // Log top class to Metro console
- console.log(topClass);
+ // Set object class state to be the top class detected in the image
+ setObjectClass(topClass);

// It is important to release the image to avoid memory leaks
image.release();
@@ -40,7 +43,7 @@
styles.container,
{marginTop: insets.top, marginBottom: insets.bottom},
]}>
- <Text style={styles.label}>Image Classification</Text>
+ <Text style={styles.label}>Object: {objectClass}</Text>
<Camera style={styles.camera} onCapture={handleImage} />
</View>
);

It looks like the model correctly classified the object in the image as a mouse, computer mouse!

Confidence threshold

Nice! The model will return a top class for what it thinks is in the image. However, it's not always 100% confident about each classification, and therefore returns a confidence value as part of the result. To see what the metrics looks like, have a look at the step where we logged the inferenceResult to the console!

Let's use this confidence value as a threshold, and only show top classes where the model has a confidence higher than 0.3 (the confidence range is [0, 1]).

@@ -27,11 +27,16 @@
},
);

- // Get max index (argmax) result to resolve the top class name
- const topClass = ImageClasses[result.maxIdx];
+ if (result.confidence > 0.3) {
+ // Get max index (argmax) result to resolve the top class name
+ const topClass = ImageClasses[result.maxIdx];

- // Set object class state to be the top class detected in the image
- setObjectClass(topClass);
+ // Set object class state to be the top class detected in the image
+ setObjectClass(topClass);
+ } else {
+ // Reset the object class if confidence value is low
+ setObjectClass('');
+ }

// It is important to release the image to avoid memory leaks
image.release();

Frame-by-Frame image processing

As a bonus, you can change the onCapture property to the onFrame property to do a frame-by-frame image classification, so you don't have to repeatedly press the capture button, and you can roam the phone around your place to see what the model can detect correctly.

note

Known problem: If the images aren't immediately processed frame by frame, flip the camera twice.

@@ -49,7 +49,11 @@
{marginTop: insets.top, marginBottom: insets.bottom},
]}>
<Text style={styles.label}>Object: {objectClass}</Text>
- <Camera style={styles.camera} onCapture={handleImage} />
+ <Camera
+ style={styles.camera}
+ onFrame={handleImage}
+ hideCaptureButton={true}
+ />
</View>
);
}

Congratulations! You finished your first PyTorch Live tutorial.

Next steps

PyTorch Live comes with three image classification models that are ready to use. In the example code provided in this tutorial, we use mobilenet_v3_small.ptl for inference, but feel free to try out the others by replacing the model with code from the tabbed viewer below.

const model = require('../../models/mobilenet_v3_small.ptl');

Challenge

Rank the models from slowest to fastest!

tip

Log the metrics from the inference result to the console or render it on the screen!

Use custom image classification model

You can follow the Prepare Custom Model tutorial to prepare your own classification model that you can plug into the demo code provided here.

Give us feedback