Skip to main content
Version: 0.2.2

Object Detection

In this tutorial, you will implement a function that takes in an image and runs it through an object detection model to find what objects are in the image and where.

Viewing this Demo

In order to view this demo download the PlayTorch app.

Preview

If you want a sneak peek at what you'll be building, run this Snack by scanning the QR code in the PlayTorch app!

Overview

We'll go through the following steps:

This tutorial will focus on the machine learning aspects of creating this object detection prototype. In order to focus on machine learning, we will provide the UI code for you.

  1. Create a copy of the starter Expo Snack
  2. Run the project in the PlayTorch app
  3. Create an image handler function
  4. Convert image to tensor ready for inference
  5. Load model and run inference
  6. Format model output
  7. Display results

Copying the Starter Snack

Click the button below to open your own copy of the starter Expo Snack for this tutorial. It will open a new tab to an Expo Snack, the in-browser code editor we will use for this tutorial that allows us to test our code in the PlayTorch app.

Click to Copy Starter Snack

Let's run through a quick overview of everything that you can find in the starter snack.

Snack Contents
screens/
├─ CameraScreen.js
├─ LoadingScreen.js
├─ ResultsScreen.js
App.js
CoCoClasses.json
ObjectDetector.js
package.json

screens/

The screens folder contains three files -- one for each screen of the object detection prototype we are building.

  1. CameraScreen.js - A full screen camera view to capture the image we will run through our object detection model
  2. LoadingScreen.js - A loading screen that let's the user know the image is being processed
  3. ResultsScreen.js - A screen showing the image and the bounding boxes of the detected objects

App.js

This file manages which screen to show and helps pass data between them.

CoCoClasses.json

This file contains the labels for the different classes of objects that our model has been trained to detect.

ObjectDetector.js

This file will contain the code that actually runs our machine learning model.

package.json

This file contains a list of packages that our project depends on so the app knows to download them.

Running the project in the PlayTorch app

Open the PlayTorch app and from the home screen, tap "Scan QR Code".

If you have never done this before, it will ask for camera permissions. Grant the app camera permissions and scan the QR code from the right side of the Snack window.

If you haven't made any changes to the snack, you should see a screen that looks like this (with a different view in your camera of course):

In this starter state, nothing happens when you press the camera's capture button. Let's fix that.

Create an image handler function

In the initial state of the codebase, if we look into our CameraScreen.js file, we will notice:

  • It is expecting an onCapture function as a prop (line 6)
  • It passes the onCapture prop to the <Camera /> component provided by react-native-pytorch-core (line 12)
CameraScreen.js
import * as React from 'react';
import {useRef} from 'react';
import {StyleSheet} from 'react-native';
import {Camera} from 'react-native-pytorch-core';

export default function CameraScreen({onCapture}) {
const cameraRef = useRef(null);
return (
<Camera
ref={cameraRef}
style={styles.camera}
onCapture={onCapture}
targetResolution={{width: 1080, height: 1920}}
/>
);
}

const styles = StyleSheet.create({
camera: {width: '100%', height: '100%'},
});

However, currently the code in our App.js file doesn't pass anything to our <CameraScreen /> component. Make the changes in the code below to your App.js file to create a function for handling captured images.

Here is a quick summary of the changes that we are making:

  1. Create a new async function called handleImage that does the following:
    1. Log the image so we can inspect it
    2. Release the image so it doesn't leak memory
  2. Pass the new handleImage property to the onCapture prop
App.js
import * as React from 'react';
import {useCallback, useRef, useState} from 'react';
import {
ActivityIndicator,
Button,
SafeAreaView,
StyleSheet,
Text,
TouchableOpacity,
View,
} from 'react-native';
import {Camera, Canvas, Image} from 'react-native-pytorch-core';
import detectObjects from './ObjectDetector';
import CameraScreen from './screens/CameraScreen';
import LoadingScreen from './screens/LoadingScreen';
import ResultsScreen from './screens/ResultsScreen';

const ScreenStates = {
CAMERA: 0,
LOADING: 1,
RESULTS: 2,
};

export default function ObjectDetectionExample() {
const [image, setImage] = useState(null);
const [boundingBoxes, setBoundingBoxes] = useState(null);
const [screenState, setScreenState] = useState(ScreenStates.CAMERA);

// Handle the reset button and return to the camera capturing mode
const handleReset = useCallback(async () => {
setScreenState(ScreenStates.CAMERA);
if (image != null) {
await image.release();
}
setImage(null);
setBoundingBoxes(null);
}, [image, setScreenState]);

// This handler function handles the camera's capture event
async function handleImage(capturedImage) {
console.log('Captured image', capturedImage);
capturedImage.release();
}

return (
<SafeAreaView style={styles.container}>
{screenState === ScreenStates.CAMERA && (
<CameraScreen onCapture={handleImage} />
)}
{screenState === ScreenStates.LOADING && <LoadingScreen />}
{screenState === ScreenStates.RESULTS && (
<ResultsScreen
image={image}
boundingBoxes={boundingBoxes}
onReset={handleReset}
/>
)}
</SafeAreaView>
);
}

const styles = StyleSheet.create({
container: {
flex: 1,
},
});

Run the updated code on the PlayTorch app, press the camera capture button and then check the logs to see what it output.

Open the logs in the Snack window by clicking the settings gear icon at the bottom of the window, enabling the Panel, and clicking the logs tab of the newly opened panel.

You should see something like the following, but with a different ID:

Captured image {"ID":"ABEF4518-A791-45C3-BA5C-C424F96C58EC"}

Now that we have access to the images that we are capturing, let's write some code to convert them into a format that our ML model can understand.

Convert image to tensor ready for inference

In order to keep our code clean and simple to understand, we will seperate all of the code that interacts with the ML model into its own file. That file already exists and is called ObjectDetector.js.

Open that file and notice it is a rather bare file. All that it does at the moment is export an async function called detectObjects that expects an image as a parameter.

ObjectDetector.js
export default async function detectObjects(image) {}

The model we will be using for object detection is called DETR, which stands for Detection Transformer.

The DETR model only knows how to process data in a very specific format. We will be converting our image into a tensor (a multi-dimensional array) and then rearranging the data in that tensor to be formatted just how DETR expects it.

The code below does all the transformations needed to the image. Here's a summary of what is going on:

  1. Import media, torch, and torchvision from the PlayTorch SDK (react-native-pytorch-core)
  2. Create variable T that is a shortcut way of accessing torchvisions.transforms since we will be using the transforms a lot
  3. Update the empty detectObjects function to do the following:
    1. Store the dimensions of the image to imageWidth and imageHeight variables
    2. Create a blob (a raw memory store of file contents, which in this case contains the RGB byte values) from the image parameter using the media.toBlob function
    3. Create a tensor with the raw data in the newly created blob and with the shape defined by the array [imageHeight, imageWidth, 3] (the size of the image and 3 for RGB)
    4. Rearange the order of the tensor so the channels (RGB) are first, then the height and width
    5. Transform all the values in the tensor to be floating point numbers between 0 and 1 instead of byte values between 0 and 255
    6. Create a resize transformation with size 800 called resize
    7. Apply the resize transformation to the tensor
    8. Create a normalization transformation called normalize with mean and standard deviation values that the model was trained on
    9. Apply the nomralize transformation to the tensor
    10. Add an extra leading dimension to the tensor with tensor.unsqueeze
    11. Log the final shape to make sure it matches the format the model needs which is [1, 3, 800, 800]
ObjectDetector.js
import {media, torch, torchvision} from 'react-native-pytorch-core';

const T = torchvision.transforms;

export default async function detectObjects(image) {
// Get image width and height
const imageWidth = image.getWidth();
const imageHeight = image.getHeight();

// Convert image to blob, which is a byte representation of the image
// in the format height (H), width (W), and channels (C), or HWC for short
const blob = media.toBlob(image);

// Get a tensor from image the blob and also define in what format
// the image blob is.
let tensor = torch.fromBlob(blob, [imageHeight, imageWidth, 3]);

// Rearrange the tensor shape to be [CHW]
tensor = tensor.permute([2, 0, 1]);

// Divide the tensor values by 255 to get values between [0, 1]
tensor = tensor.div(255);

// Resize the image tensor to 3 x min(height, 800) x min(width, 800)
const resize = T.resize(800);
tensor = resize(tensor);

// Normalize the tensor image with mean and standard deviation
const normalize = T.normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]);
tensor = normalize(tensor);

// Unsqueeze adds 1 leading dimension to the tensor
const formattedInputTensor = tensor.unsqueeze(0);
console.log(formattedInputTensor);
}

In order for us to be able to run and test this, let's make a quick update to the App.js file to make the handleImage function call our updated detectObjects function.

App.js
import * as React from 'react';
import {useCallback, useRef, useState} from 'react';
import {
ActivityIndicator,
Button,
SafeAreaView,
StyleSheet,
Text,
TouchableOpacity,
View,
} from 'react-native';
import {Camera, Canvas, Image} from 'react-native-pytorch-core';
import detectObjects from './ObjectDetector';
import CameraScreen from './screens/CameraScreen';
import LoadingScreen from './screens/LoadingScreen';
import ResultsScreen from './screens/ResultsScreen';

const ScreenStates = {
CAMERA: 0,
LOADING: 1,
RESULTS: 2,
};

export default function ObjectDetectionExample() {
const [image, setImage] = useState(null);
const [boundingBoxes, setBoundingBoxes] = useState(null);
const [screenState, setScreenState] = useState(ScreenStates.CAMERA);

// Handle the reset button and return to the camera capturing mode
const handleReset = useCallback(async () => {
setScreenState(ScreenStates.CAMERA);
if (image != null) {
await image.release();
}
setImage(null);
setBoundingBoxes(null);
}, [image, setScreenState]);

// This handler function handles the camera's capture event
async function handleImage(capturedImage) {
await detectObjects(capturedImage);
capturedImage.release();
}

return (
<SafeAreaView style={styles.container}>
{screenState === ScreenStates.CAMERA && (
<CameraScreen onCapture={handleImage} />
)}
{screenState === ScreenStates.LOADING && <LoadingScreen />}
{screenState === ScreenStates.RESULTS && (
<ResultsScreen
image={image}
boundingBoxes={boundingBoxes}
onReset={handleReset}
/>
)}
</SafeAreaView>
);
}

const styles = StyleSheet.create({
container: {
flex: 1,
},
});

Now that we are calling the updated detectObjects function, let's test it by running the snack in the PlayTorch app.

When you press the capture button now, you should see something similar to the following in the logs:

▼{dtype:"float32",shape:[…]}
dtype:"float32"
►shape:[1,3,1422,800]

With our tensor in the right shape and ready for inference, let's load a model and run it!

Load model and run inference

Now we get to actually add some ML to our app! Let's walk through the updates we are making to the ObjectDetector.js file in the block below.

  1. Add the MobileModel to our imports from the PlayTorch SDK (react-native-pytorch-core)
  2. Create a variable called MODEL_URL and set it equal to the download link for the DETR model (https://github.com/facebookresearch/playtorch/releases/download/v0.1.0/detr_resnet50.ptl)
  3. Create a variable called model and initially set it to null. We put this variable outside of our detectObjects function so we only have to load it once.
  4. Check if the model is still null and if it is, download the model file from the MODEL_URL and then load the model into memory with the newly downloaded file
  5. Run the DETR model on our formattedInputTensor by calling the model.forward function and store the result in a variable called output
  6. Log the output to see what our model has generated
ObjectDetector.js
import {
media,
MobileModel,
torch,
torchvision,
} from 'react-native-pytorch-core';

const T = torchvision.transforms;

const MODEL_URL =
'https://github.com/facebookresearch/playtorch/releases/download/v0.1.0/detr_resnet50.ptl';

let model = null;

export default async function detectObjects(image) {
// Get image width and height
const imageWidth = image.getWidth();
const imageHeight = image.getHeight();

// Convert image to blob, which is a byte representation of the image
// in the format height (H), width (W), and channels (C), or HWC for short
const blob = media.toBlob(image);

// Get a tensor from image the blob and also define in what format
// the image blob is.
let tensor = torch.fromBlob(blob, [imageHeight, imageWidth, 3]);

// Rearrange the tensor shape to be [CHW]
tensor = tensor.permute([2, 0, 1]);

// Divide the tensor values by 255 to get values between [0, 1]
tensor = tensor.div(255);

// Resize the image tensor to 3 x min(height, 800) x min(width, 800)
const resize = T.resize(800);
tensor = resize(tensor);

// Normalize the tensor image with mean and standard deviation
const normalize = T.normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]);
tensor = normalize(tensor);

// Unsqueeze adds 1 leading dimension to the tensor
const formattedInputTensor = tensor.unsqueeze(0);

// Load model if not loaded
if (model == null) {
console.log('Loading model...');
const filePath = await MobileModel.download(MODEL_URL);
model = await torch.jit._loadForMobile(filePath);
console.log('Model successfully loaded');
}

// Run inference
const output = await model.forward(formattedInputTensor);
console.log(output);
}

Now that we have code to download and run the model, let's see what it produces!

Remember that it will be slower the first time you run it because it will need to download the model. After that, it should run much faster. It will also need to redownload the model each time you update the code and it reloads on your device. There are plenty of ways to get around this, but we will not cover them in this tutorial.

When you press the capture button now, after the model is done downloading, you should see something similar to the following in the logs:

Loading model...
Model successfully loaded
▼{pred_logits:{…},pred_boxes:{…}}
▼pred_logits:{dtype:"float32",shape:[…]}
dtype:"float32"
►shape:[1,100,92]
▼pred_boxes:{dtype:"float32",shape:[…]}
dtype:"float32"
►shape:[1,100,4]

Hurray! Our model is running! The output seems a bit confusing at first glance though 🤔

Let's unpack this output and format it for the rest of our app to use.

Format model output

In order to make the output easy to use in the rest of our app, we will have to do some transformations of it. Here's a summary of the changes we're making to the ObjectDetector.js file:

  1. Import the class labels from the CoCoClasses.json file as COCO_CLASSES
  2. Create a variable called probabilityThreshold set to 0.7 meaning we will only consider an object detected if the model has a confidence score of over 0.7 for that detection
  3. Extract the predLogits and predBoxes from the output. Note that "pred" is short for predicted and we use the .squeeze(0) method to get rid of the initial 1 from the shape we saw in the logged output and just have the data in the rest of the tensor.
  4. Grab the number of predictions from the predLogits output object and store it in a variable called numPredictions
  5. Create an empty array called resultBoxes
  6. Loop over each of the model predictions and do the following:
    1. Grab the confidence scores of the current prediction for each of the different labels and store them in variable called confidencesTensor
    2. Convert the confidencesTensor into a tensor of probabilities between 0 and 1 by using the softmax function and store that in a variable called scores
    3. Find the index of highest probability class in the scores tensor with the argmax function and save it to a variable called maxIdx
    4. Use the maxIdx to grab the highest probability out of the scores tensor and store it in a variable called maxProb
    5. If the maxProb is below our probabilityThreshold or the maxIdx is beyond the number of classes we have, skip the rest of the loop that adds the prediction to the results with the continue statement
    6. Grab the current box from the predBoxes tensor and store it in the boxTensor variable
    7. Extract the coordinates of the center point of the box (centerX, centerY) as well as the dimensions of the box (boxWidth, boxHeight) from the boxTensor
    8. Calculate the coordinates (x, y) of the top-left corner of the box
    9. Create an array that contains the coordinates of the top-left corner of the box as well as its dimensions and store it in a variable called bounds
    10. Create an object that contains the bounds array we just created as well as the class label, which we can grab out of the COCO_CLASSES array by indexing into it with the maxIdx
    11. Add the new match object to our resultBoxes array with the .push method
  7. Log the result boxes so we can inspect them
  8. Return the result boxes so whoever calls this function can use them
ObjectDetector.js
import {
media,
MobileModel,
torch,
torchvision,
} from 'react-native-pytorch-core';
import COCO_CLASSES from './CoCoClasses.json';

const T = torchvision.transforms;

const MODEL_URL =
'https://github.com/facebookresearch/playtorch/releases/download/v0.1.0/detr_resnet50.ptl';

let model = null;
const probabilityThreshold = 0.7;

export default async function detectObjects(image) {
// Get image width and height
const imageWidth = image.getWidth();
const imageHeight = image.getHeight();

// Convert image to blob, which is a byte representation of the image
// in the format height (H), width (W), and channels (C), or HWC for short
const blob = media.toBlob(image);

// Get a tensor from image the blob and also define in what format
// the image blob is.
let tensor = torch.fromBlob(blob, [imageHeight, imageWidth, 3]);

// Rearrange the tensor shape to be [CHW]
tensor = tensor.permute([2, 0, 1]);

// Divide the tensor values by 255 to get values between [0, 1]
tensor = tensor.div(255);

// Resize the image tensor to 3 x min(height, 800) x min(width, 800)
const resize = T.resize(800);
tensor = resize(tensor);

// Normalize the tensor image with mean and standard deviation
const normalize = T.normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]);
tensor = normalize(tensor);

// Unsqueeze adds 1 leading dimension to the tensor
const formattedInputTensor = tensor.unsqueeze(0);

// Load model if not loaded
if (model == null) {
console.log('Loading model...');
const filePath = await MobileModel.download(MODEL_URL);
model = await torch.jit._loadForMobile(filePath);
console.log('Model successfully loaded');
}

// Run inference
const output = await model.forward(formattedInputTensor);

const predLogits = output.pred_logits.squeeze(0);
const predBoxes = output.pred_boxes.squeeze(0);

const numPredictions = predLogits.shape[0];

const resultBoxes = [];

for (let i = 0; i < numPredictions; i++) {
const confidencesTensor = predLogits[i];
const scores = confidencesTensor.softmax(0);
const maxIndex = confidencesTensor.argmax().item();
const maxProb = scores[maxIndex].item();

if (maxProb <= probabilityThreshold || maxIndex >= COCO_CLASSES.length) {
continue;
}

const boxTensor = predBoxes[i];
const [centerX, centerY, boxWidth, boxHeight] = boxTensor.data();
const x = centerX - boxWidth / 2;
const y = centerY - boxHeight / 2;

// Adjust bounds to image size
const bounds = [
x * imageWidth,
y * imageHeight,
boxWidth * imageWidth,
boxHeight * imageHeight,
];

const match = {
objectClass: COCO_CLASSES[maxIndex],
bounds,
};

resultBoxes.push(match);
}

console.log(resultBoxes);
return resultBoxes;
}

That was a lot of changes! Let's test them out to see what our results look like now.

When you press the capture button now, you should see something like this in the logs:

▼[{…},{…},{…}]
▼0:{objectClass:"bowl",bounds:[…]}
objectClass:"bowl"
►bounds:[58.80519703030586,1618.3972656726837,123.18191081285477,98.28389883041382]
►1:{objectClass:"person",bounds:[…]}
►2:{objectClass:"cup",bounds:[…]}
length:3

Now that we have the results in a much more usable format with class labels and a bounding box, let's display our results in the app!

Display results

In order to display the results in our app, all we need to do is make some simple changes in our App.js file.

Before we make any changes, we should have a screenState state variable that chooses which screen we show. Currently we are only ever showing the camera screen, so we will need to update that state variable. You can also find a ScreenStates object that has the 3 different screen states our app can have.

We also should notice that we have a state variable called boundingBoxes that is currently always an empty array. That boundingBoxes state variable is passed to our ResultsScreen, so we will need to make sure we updated it for our results to be displayed.

With that understanding of how our app works, let's make some updates to our App.js file. All of our changes will be inside our handleImage function. Here's a summary:

  1. Update the image state variable to be the newly capturedImage
  2. Update the screenState to ScreenStates.LOADING in order to display the LoadingScreen while our model downloads and runs
  3. In a try block:
    1. Await the result of the detectObjects function and store it in a variable called newBoxes
    2. Update the boundingBoxes state variable to be the newBoxes
    3. Update the screenState to ScreenStates.RESULTS in order to display the ResultsScreen so we can visually see what objects DETR detected in our image
  4. Add a catch block in case there is an error when we run our model:
    1. Call the handleReset button to put the app back in the CameraScreen so we can take a new photo
App.js
import * as React from 'react';
import {useCallback, useRef, useState} from 'react';
import {
ActivityIndicator,
Button,
SafeAreaView,
StyleSheet,
Text,
TouchableOpacity,
View,
} from 'react-native';
import {Camera, Canvas, Image} from 'react-native-pytorch-core';
import detectObjects from './ObjectDetector';
import CameraScreen from './screens/CameraScreen';
import LoadingScreen from './screens/LoadingScreen';
import ResultsScreen from './screens/ResultsScreen';

const ScreenStates = {
CAMERA: 0,
LOADING: 1,
RESULTS: 2,
};

export default function ObjectDetectionExample() {
const [image, setImage] = useState(null);
const [boundingBoxes, setBoundingBoxes] = useState(null);
const [screenState, setScreenState] = useState(ScreenStates.CAMERA);

// Handle the reset button and return to the camera capturing mode
const handleReset = useCallback(async () => {
setScreenState(ScreenStates.CAMERA);
if (image != null) {
await image.release();
}
setImage(null);
setBoundingBoxes(null);
}, [image, setScreenState]);

// This handler function handles the camera's capture event
async function handleImage(capturedImage) {
setImage(capturedImage);
// Wait for image to process through DETR model and draw resulting image
setScreenState(ScreenStates.LOADING);
try {
const newBoxes = await detectObjects(capturedImage);
setBoundingBoxes(newBoxes);
// Switch to the ResultsScreen to display the detected objects
setScreenState(ScreenStates.RESULTS);
} catch (err) {
// In case something goes wrong, go back to the CameraScreen to take a new picture
handleReset();
}
}

return (
<SafeAreaView style={styles.container}>
{screenState === ScreenStates.CAMERA && (
<CameraScreen onCapture={handleImage} />
)}
{screenState === ScreenStates.LOADING && <LoadingScreen />}
{screenState === ScreenStates.RESULTS && (
<ResultsScreen
image={image}
boundingBoxes={boundingBoxes}
onReset={handleReset}
/>
)}
</SafeAreaView>
);
}

const styles = StyleSheet.create({
container: {
flex: 1,
},
});

With the results of the detectObjects function hooked up, we are now ready to see the full app working. Load the latest changes into the PlayTorch app and let's test it!

When you press the capture button now, after the model finishes running, you should see the picture you took but with labeled boxes around the detected objects, just like the screenshot below:

Wrapping up

We hope you enjoyed getting an object detection model up and running with PlayTorch.

Providing feedback

Was something unclear? Did you run into any bugs? Is there another tutorial or model you'd like to see from us?

We would love to hear how we can improve.

Come chat with us on Discord or file an issue on GitHub.