跳至主要内容

TensorFlow-Lite Backend

We will use this example project to show how to make AI inference with a TensorFlow Lite model in WasmEdge and Rust.

Prerequisite

Besides the regular WasmEdge and Rust requirements, please make sure that you have the Wasi-NN plugin with TensorFlow Lite installed.

Quick Start

Because the example already includes a compiled WASM file from the Rust code, we could use WasmEdge CLI to execute the example directly.

First, git clone the WasmEdge-WASINN-examples.

git clone https://github.com/second-state/WasmEdge-WASINN-examples.git
cd WasmEdge-WASINN-examples
wasmedge --dir .:. wasmedge-wasinn-example-tflite-bird-image.wasm lite-model_aiy_vision_classifier_birds_V1_3.tflite bird.jpg

If everything goes well, you should have the terminal output:

Read graph weights, size in bytes: 3561598
Loaded graph into wasi-nn with ID: 0
Created wasi-nn execution context with ID: 0
Read input tensor, size in bytes: 150528
Executed graph inference
1.) [166](198)Aix galericulata
2.) [158](2)Coccothraustes coccothraustes
3.) [34](1)Gallus gallus domesticus
4.) [778](1)Sitta europaea
5.) [819](1)Anas platyrhynchos

Build and Run the example from Rust source code

Let's build the wasm file from the rust source code.

First, git clone the WasmEdge-WASINN-examples.

git clone https://github.com/second-state/WasmEdge-WASINN-examples.git
cd tflite-birds_v1-image/rust/

Second, use cargo to build the template project.

cargo build --target wasm32-wasi --release

The output WASM file is target/wasm32-wasi/release/wasmedge-wasinn-example-tflite-bird-image.wasm.

Next, let's use WasmEdge to identify your images.

wasmedge --dir .:. wasmedge-wasinn-example-mobilenet-image.wasm mobilenet.xml mobilenet.bin input.jpg

You can replace input.jpg with your image file.

Improve performance

For the AOT mode, which is much more quickly, you can compile the WASM first:

wasmedgec rust/tflite-bird/target/wasm32-wasi/release/wasmedge-wasinn-example-tflite-bird-image.wasm wasmedge-wasinn-example-tflite-bird-image.wasm
wasmedge --dir .:. wasmedge-wasinn-example-tflite-bird-image.wasm lite-model_aiy_vision_classifier_birds_V1_3.tflite bird.jpg

Understand the code

The main.rs is the complete example Rust source.

First, read the model description and weights into memory:

let args: Vec<String> = env::args().collect();
let model_bin_name: &str = &args[1]; // File name for the tflite model
let image_name: &str = &args[2]; // File name for the input image

let weights = fs::read(model_bin_name).unwrap();

We should use a helper function to convert the input image into the tensor data (the tensor type is U8):

fn image_to_tensor(path: String, height: u32, width: u32) -> Vec<u8> {
let pixels = Reader::open(path).unwrap().decode().unwrap();
let dyn_img: DynamicImage = pixels.resize_exact(width, height, image::imageops::Triangle);
let bgr_img = dyn_img.to_rgb8();
// Get an array of the pixel values
let raw_u8_arr: &[u8] = &bgr_img.as_raw()[..];
return raw_u8_arr.to_vec();
}

And use this helper function to convert the input image:

let tensor_data = image_to_tensor(image_name.to_string(), 224, 224);

Now we can start our inference with WASI-NN:

// load model
let graph = unsafe {
wasi_nn::load(
&[&weights],
wasi_nn::GRAPH_ENCODING_PYTORCH,
wasi_nn::EXECUTION_TARGET_CPU,
)
.unwrap()
};
// initialize the computation context
let context = unsafe { wasi_nn::init_execution_context(graph).unwrap() };
// initialize the input tensor
let tensor = wasi_nn::Tensor {
dimensions: &[1, 3, 224, 224],
r#type: wasi_nn::TENSOR_TYPE_F32,
data: &tensor_data,
};
// set_input
unsafe {
wasi_nn::set_input(context, 0, tensor).unwrap();
}
// Execute the inference.
unsafe {
wasi_nn::compute(context).unwrap();
}
// retrieve output
let mut output_buffer = vec![0f32; 1001];
unsafe {
wasi_nn::get_output(
context,
0,
&mut output_buffer[..] as *mut [f32] as *mut u8,
(output_buffer.len() * 4).try_into().unwrap(),
)
.unwrap();
}

Where the wasi_nn::GRAPH_ENCODING_TENSORFLOWLITE means using the TensorFlow-Lite backend and wasi_nn::EXECUTION_TARGET_CPU means running the computation on CPU.

Finally, we sort the output and then print the top-5 classification results:

let results = sort_results(&output_buffer);
for i in 0..5 {
println!(
" {}.) [{}]({:.4}){}",
i + 1,
results[i].0,
results[i].1,
imagenet_classes::IMAGENET_CLASSES[results[i].0]
);
}