Skip to main content

PyTorch Backend

We will use this example project to show how to make AI inference with a PyTorch model in WasmEdge and Rust.

Prerequisite

Besides the regular WasmEdge and Rust requirements, please make sure that you have the Wasi-NN plugin with PyTorch installed.

Quick Start

Because the example already includes a compiled WASM file from the Rust code, we could use WasmEdge CLI to execute the example directly.

First, git clone the WasmEdge-WASINN-examples.

git clone https://github.com/second-state/WasmEdge-WASINN-examples.git
cd WasmEdge-WASINN-examples
# Please check that you've installed the libtorch and set the `LD_LIBRARY_PATH`.
wasmedge --dir .:. wasmedge-wasinn-example-mobilenet-image.wasm mobilenet.pt input.jpg
# If you didn't install the project, you should give the `WASMEDGE_PLUGIN_PATH` environment variable for specifying the WASI-NN plugin path.

If everything goes well, you should have the terminal output:

Read torchscript binaries, size in bytes: 14376924
Loaded graph into wasi-nn with ID: 0
Created wasi-nn execution context with ID: 0
Read input tensor, size in bytes: 602112
Executed graph inference
1.) [954](20.6681)banana
2.) [940](12.1483)spaghetti squash
3.) [951](11.5748)lemon
4.) [950](10.4899)orange
5.) [953](9.4834)pineapple, ananas

Build and Run the example from Rust source code

Let's build the wasm file from the rust source code.

First, git clone the WasmEdge-WASINN-examples.

git clone https://github.com/second-state/WasmEdge-WASINN-examples.git
cd pytorch-mobilenet-image/rust

Second, use cargo to build the template project.

cargo build --target wasm32-wasi --release

The output WASM file is target/wasm32-wasi/release/wasmedge-wasinn-example-mobilenet-image.wasm.

Next, let's use WasmEdge to identify your images.

wasmedge --dir .:. wasmedge-wasinn-example-mobilenet-image.wasm mobilenet.pt input.jpg

You can replace input.jpg with your image file.

Improve performance

For the AOT mode, which is much more quickly, you can compile the WASM first:

wasmedgec wasmedge-wasinn-example-mobilenet.wasm out.wasm
wasmedge --dir .:. out.wasm mobilenet.pt input.jpg

Understand the code

The main.rs is the complete example Rust source.

First, read the model description and weights into memory:

let args: Vec<String> = env::args().collect();
let model_bin_name: &str = &args[1]; // File name for the PyTorch model
let image_name: &str = &args[2]; // File name for the input image

let weights = fs::read(model_bin_name).unwrap();

We should use a helper function to convert the input image into the tensor data (the tensor type is F32):

fn image_to_tensor(path: String, height: u32, width: u32) -> Vec<u8> {
let mut file_img = File::open(path).unwrap();
let mut img_buf = Vec::new();
file_img.read_to_end(&mut img_buf).unwrap();
let img = image::load_from_memory(&img_buf).unwrap().to_rgb8();
let resized =
image::imageops::resize(&img, height, width, ::image::imageops::FilterType::Triangle);
let mut flat_img: Vec<f32> = Vec::new();
for rgb in resized.pixels() {
flat_img.push((rgb[0] as f32 / 255. - 0.485) / 0.229);
flat_img.push((rgb[1] as f32 / 255. - 0.456) / 0.224);
flat_img.push((rgb[2] as f32 / 255. - 0.406) / 0.225);
}
let bytes_required = flat_img.len() * 4;
let mut u8_f32_arr: Vec<u8> = vec![0; bytes_required];

for c in 0..3 {
for i in 0..(flat_img.len() / 3) {
// Read the number as a f32 and break it into u8 bytes
let u8_f32: f32 = flat_img[i * 3 + c] as f32;
let u8_bytes = u8_f32.to_ne_bytes();

for j in 0..4 {
u8_f32_arr[((flat_img.len() / 3 * c + i) * 4) + j] = u8_bytes[j];
}
}
}
return u8_f32_arr;
}

And use this helper funcion to convert the input image:

let tensor_data = image_to_tensor(image_name.to_string(), 224, 224);

Now we can start our inference with WASI-NN:

// load model
let graph = unsafe {
wasi_nn::load(
&[&weights],
wasi_nn::GRAPH_ENCODING_PYTORCH,
wasi_nn::EXECUTION_TARGET_CPU,
)
.unwrap()
};
// initialize the computation context
let context = unsafe { wasi_nn::init_execution_context(graph).unwrap() };
// initialize the input tensor
let tensor = wasi_nn::Tensor {
dimensions: &[1, 3, 224, 224],
type_: wasi_nn::TENSOR_TYPE_F32,
data: &tensor_data,
};
// set_input
unsafe {
wasi_nn::set_input(context, 0, tensor).unwrap();
}
// Execute the inference.
unsafe {
wasi_nn::compute(context).unwrap();
}
// retrieve output
let mut output_buffer = vec![0f32; 1001];
unsafe {
wasi_nn::get_output(
context,
0,
&mut output_buffer[..] as *mut [f32] as *mut u8,
(output_buffer.len() * 4).try_into().unwrap(),
)
.unwrap();
}

Where the wasi_nn::GRAPH_ENCODING_PYTORCH means using the PyTorch backend, and wasi_nn::EXECUTION_TARGET_CPU means running the computation on CPU.

Finally, we sort the output and then print the top-5 classification results:

let results = sort_results(&output_buffer);
for i in 0..5 {
println!(
" {}.) [{}]({:.4}){}",
i + 1,
results[i].0,
results[i].1,
imagenet_classes::IMAGENET_CLASSES[results[i].0]
);
}