Skip to content

Instantly share code, notes, and snippets.

@segeljakt
Created August 31, 2023 14:58
Show Gist options
  • Select an option

  • Save segeljakt/998191053231e80ff002e706f488b113 to your computer and use it in GitHub Desktop.

Select an option

Save segeljakt/998191053231e80ff002e706f488b113 to your computer and use it in GitHub Desktop.
import torch
import onnxruntime as ort
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
from torchvision.transforms.functional import InterpolationMode
import json
import urllib.request
from PIL import ImageOps
# Load ONNX model
resnet_session = ort.InferenceSession("resnet18.onnx")
# Load an actual image with PIL (substitute the filename)
image = Image.open("resnet18-images/cats.jpg")
# Define image transformations
preprocess = transforms.Compose([
transforms.Resize(256, interpolation=InterpolationMode.NEAREST),
transforms.CenterCrop(224),
transforms.ToTensor(),
])
# Preprocess the image
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# Run inference
ort_inputs = {resnet_session.get_inputs()[0].name: to_numpy(input_batch)}
ort_outs = resnet_session.run(None, ort_inputs)
output = ort_outs[0]
# Download ImageNet labels
url = 'https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json'
class_idx = json.loads(urllib.request.urlopen(url).read().decode('utf-8'))
# Convert to a more usable format: index to label
idx_to_label = [class_idx[str(k)][1] for k in range(len(class_idx))]
# Convert output to class label
_, predicted_class_idx = torch.max(torch.from_numpy(output), 1)
# print(output)
# print(predicted_class_idx)
predicted_class_label = idx_to_label[predicted_class_idx]
print(f'Predicted class index: {predicted_class_idx.item()}')
print(f'Predicted class label: {predicted_class_label}')
use ndarray::Array1;
use ndarray::CowArray;
use ort::tensor::OrtOwnedTensor;
use ort::Environment;
use ort::ExecutionProvider;
use ort::GraphOptimizationLevel;
use ort::SessionBuilder;
use ort::Value;
fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
let env = Environment::builder()
.with_name("resnet18")
.with_execution_providers([ExecutionProvider::CPU(Default::default())])
.build()?
.into_arc();
let session = SessionBuilder::new(&env)?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.with_model_from_file("resnet18.onnx")?;
let data = std::fs::read("cats.jpg")?;
let image = image::load_from_memory(&data)?;
let image = image.resize_to_fill(
((256.0 * image.width() as f32) / (image.height() as f32)) as u32,
256,
image::imageops::FilterType::Nearest,
);
let image = image
.crop_imm(
(image.width() - 224) / 2,
(image.height() - 224) / 2,
224,
224,
)
.to_rgb8();
let vec = image
.pixels()
.flat_map(|rgb| {
[
rgb[0] as f32 / 255.0,
rgb[1] as f32 / 255.0,
rgb[2] as f32 / 255.0,
]
})
.collect::<Vec<_>>();
let array = Array1::from(vec).into_shape((1, 3, 224, 224))?;
let array = CowArray::from(array);
let array = array.into_dyn();
let x = vec![Value::from_array(session.allocator(), &array)?];
let y = session.run(x)?;
let classes = std::fs::read_to_string("imagenet_class_index.json")?;
let classes = json::parse(&classes)?;
let y: OrtOwnedTensor<f32, _> = y[0].try_extract()?;
let mut max_score = f32::MIN;
let mut max_idx = 0;
for (idx, score) in y.view().iter().enumerate() {
if *score > max_score {
let label = &classes[idx.to_string()][1];
println!("Class index: {idx} ({label}), score: {score}");
max_score = *score;
max_idx = idx;
}
}
println!("Predicted class index: {}", max_idx);
let label = &classes[max_idx.to_string()];
println!("Predicted class label: {}", label);
Ok(())
}
import torch
import torchvision.models as models
model = models.resnet18(pretrained=True)
model.eval()
x = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, # model being run
x, # model input (or a tuple for multiple inputs)
"resnet18.onnx", # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment