Created
August 31, 2023 14:58
-
-
Save segeljakt/998191053231e80ff002e706f488b113 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import onnxruntime as ort | |
| from PIL import Image | |
| import numpy as np | |
| import torchvision.transforms as transforms | |
| from torchvision.transforms.functional import InterpolationMode | |
| import json | |
| import urllib.request | |
| from PIL import ImageOps | |
| # Load ONNX model | |
| resnet_session = ort.InferenceSession("resnet18.onnx") | |
| # Load an actual image with PIL (substitute the filename) | |
| image = Image.open("resnet18-images/cats.jpg") | |
| # Define image transformations | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(256, interpolation=InterpolationMode.NEAREST), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| ]) | |
| # Preprocess the image | |
| input_tensor = preprocess(image) | |
| input_batch = input_tensor.unsqueeze(0) | |
| def to_numpy(tensor): | |
| return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() | |
| # Run inference | |
| ort_inputs = {resnet_session.get_inputs()[0].name: to_numpy(input_batch)} | |
| ort_outs = resnet_session.run(None, ort_inputs) | |
| output = ort_outs[0] | |
| # Download ImageNet labels | |
| url = 'https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json' | |
| class_idx = json.loads(urllib.request.urlopen(url).read().decode('utf-8')) | |
| # Convert to a more usable format: index to label | |
| idx_to_label = [class_idx[str(k)][1] for k in range(len(class_idx))] | |
| # Convert output to class label | |
| _, predicted_class_idx = torch.max(torch.from_numpy(output), 1) | |
| # print(output) | |
| # print(predicted_class_idx) | |
| predicted_class_label = idx_to_label[predicted_class_idx] | |
| print(f'Predicted class index: {predicted_class_idx.item()}') | |
| print(f'Predicted class label: {predicted_class_label}') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use ndarray::Array1; | |
| use ndarray::CowArray; | |
| use ort::tensor::OrtOwnedTensor; | |
| use ort::Environment; | |
| use ort::ExecutionProvider; | |
| use ort::GraphOptimizationLevel; | |
| use ort::SessionBuilder; | |
| use ort::Value; | |
| fn main() -> Result<(), Box<dyn std::error::Error>> { | |
| tracing_subscriber::fmt::init(); | |
| let env = Environment::builder() | |
| .with_name("resnet18") | |
| .with_execution_providers([ExecutionProvider::CPU(Default::default())]) | |
| .build()? | |
| .into_arc(); | |
| let session = SessionBuilder::new(&env)? | |
| .with_optimization_level(GraphOptimizationLevel::Level1)? | |
| .with_intra_threads(1)? | |
| .with_model_from_file("resnet18.onnx")?; | |
| let data = std::fs::read("cats.jpg")?; | |
| let image = image::load_from_memory(&data)?; | |
| let image = image.resize_to_fill( | |
| ((256.0 * image.width() as f32) / (image.height() as f32)) as u32, | |
| 256, | |
| image::imageops::FilterType::Nearest, | |
| ); | |
| let image = image | |
| .crop_imm( | |
| (image.width() - 224) / 2, | |
| (image.height() - 224) / 2, | |
| 224, | |
| 224, | |
| ) | |
| .to_rgb8(); | |
| let vec = image | |
| .pixels() | |
| .flat_map(|rgb| { | |
| [ | |
| rgb[0] as f32 / 255.0, | |
| rgb[1] as f32 / 255.0, | |
| rgb[2] as f32 / 255.0, | |
| ] | |
| }) | |
| .collect::<Vec<_>>(); | |
| let array = Array1::from(vec).into_shape((1, 3, 224, 224))?; | |
| let array = CowArray::from(array); | |
| let array = array.into_dyn(); | |
| let x = vec![Value::from_array(session.allocator(), &array)?]; | |
| let y = session.run(x)?; | |
| let classes = std::fs::read_to_string("imagenet_class_index.json")?; | |
| let classes = json::parse(&classes)?; | |
| let y: OrtOwnedTensor<f32, _> = y[0].try_extract()?; | |
| let mut max_score = f32::MIN; | |
| let mut max_idx = 0; | |
| for (idx, score) in y.view().iter().enumerate() { | |
| if *score > max_score { | |
| let label = &classes[idx.to_string()][1]; | |
| println!("Class index: {idx} ({label}), score: {score}"); | |
| max_score = *score; | |
| max_idx = idx; | |
| } | |
| } | |
| println!("Predicted class index: {}", max_idx); | |
| let label = &classes[max_idx.to_string()]; | |
| println!("Predicted class label: {}", label); | |
| Ok(()) | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torchvision.models as models | |
| model = models.resnet18(pretrained=True) | |
| model.eval() | |
| x = torch.randn(1, 3, 224, 224) | |
| torch.onnx.export(model, # model being run | |
| x, # model input (or a tuple for multiple inputs) | |
| "resnet18.onnx", # where to save the model (can be a file or file-like object) | |
| export_params=True, # store the trained parameter weights inside the model file | |
| opset_version=10, # the ONNX version to export the model to | |
| do_constant_folding=True, # whether to execute constant folding for optimization | |
| input_names=['input'], # the model's input names | |
| output_names=['output'], # the model's output names | |
| dynamic_axes={'input': {0: 'batch_size'}, # variable length axes | |
| 'output': {0: 'batch_size'}}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment