Skip to main content
Version: 0.2.2

Prepare Custom Model

In this tutorial, we will learn how to prepare a custom PyTorch model to integrate it in a PlayTorch demo.

This section will guide you step-by-step for how to export a ScriptModule with model weights for the PyTorch Mobile Lite Interpreter runtime, which is used by PlayTorch to run inference with ML models. As an example, we will export the MobileNet V3 (small) model.

Set up Python virtual environment

It is recommended to run the Python scripts in a virtual environment. Python offers a command to create a virtual environment with the following command.

python3 -m venv venv
source venv/bin/activate

Install torch and torchvision dependencies

The Python script requires torch and torchvision. Use the Python package manager (pip3 or simply pip in a virtual environment) to install both dependencies.

pip3 install torch torchvision

Export the MobileNet V3 Model

The following script will download the MobileNet V3 model from the PyTorch Hub, optimize it for mobile use, and save it in a ptl format file (for PyTorch Mobile Lite Interpreter runtime).

export_model.py
import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile

model = torchvision.models.mobilenet_v3_small(pretrained=True)
model.eval()

scripted_model = torch.jit.script(model)
optimized_model = optimize_for_mobile(scripted_model)
optimized_model._save_for_lite_interpreter("mobilenet_v3_small.ptl")

print("model successfully exported")

Create the export_model.py file, add the Python script above, and then run Python script to create mobilenet_v3_small.ptl.

python3 export_model.py

The script will create a file mobilenet_v3_small.ptl. Upload the mobilenet_v3_small.ptl file to a server that's publicly accessible, and then head over to the image classification tutorial to see how you can use it with PlayTorch. Alternatively, you may also have the mobilenet_v3_small.ptl file placed in the appropriate directory in your project.

Give us feedback