Last active
December 31, 2024 08:39
-
-
Save StefanoLusardi/7a7477f107a1ff0eda68ae135fe1705b to your computer and use it in GitHub Desktop.
ONNX Classifier (Resnet18 / Squeezenet)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #define STB_IMAGE_IMPLEMENTATION | |
| #include "stb_image.h" | |
| #define STB_IMAGE_RESIZE_IMPLEMENTATION | |
| #include "stb_image_resize2.h" | |
| #include <onnxruntime/core/session/onnxruntime_cxx_api.h> | |
| #include <iostream> | |
| #include <vector> | |
| #include <cmath> | |
| #include <string> | |
| #include <algorithm> | |
| #include <numeric> | |
| #include <memory> | |
| #include <sstream> | |
| void preprocess(const std::string &imageFilepath, const std::vector<int> &inputDims, std::vector<float> &output) | |
| { | |
| int width, height, channels; | |
| unsigned char *imageData = stbi_load(imageFilepath.c_str(), &width, &height, &channels, 3); | |
| if (!imageData) | |
| { | |
| throw std::runtime_error("Failed to load image: " + imageFilepath); | |
| } | |
| // Target dimensions | |
| int targetHeight = inputDims[2]; | |
| int targetWidth = inputDims[3]; | |
| // Resized image buffer | |
| std::vector<unsigned char> resizedImage(targetHeight * targetWidth * 3); | |
| // Use stb_image_resize2 for resizing | |
| stbir_resize_uint8_linear( | |
| imageData, width, height, 0, | |
| resizedImage.data(), targetWidth, targetHeight, 0, | |
| stbir_pixel_layout(0)); | |
| stbi_image_free(imageData); // Free the original image buffer | |
| // Convert to float and normalize | |
| std::vector<float> normalizedImage(targetHeight * targetWidth * 3); | |
| for (int i = 0; i < targetHeight * targetWidth * 3; ++i) | |
| { | |
| float value = resizedImage[i] / 255.0f; // Scale to [0, 1] | |
| if (i % 3 == 0) | |
| { // R channel | |
| normalizedImage[i] = (value - 0.485f) / 0.229f; | |
| } | |
| else if (i % 3 == 1) | |
| { // G channel | |
| normalizedImage[i] = (value - 0.456f) / 0.224f; | |
| } | |
| else | |
| { // B channel | |
| normalizedImage[i] = (value - 0.406f) / 0.225f; | |
| } | |
| } | |
| // Convert HWC to CHW | |
| output.resize(targetHeight * targetWidth * 3); | |
| int imageSize = targetHeight * targetWidth; | |
| for (int h = 0; h < targetHeight; ++h) | |
| { | |
| for (int w = 0; w < targetWidth; ++w) | |
| { | |
| output[0 * imageSize + h * targetWidth + w] = normalizedImage[(h * targetWidth + w) * 3 + 0]; // R | |
| output[1 * imageSize + h * targetWidth + w] = normalizedImage[(h * targetWidth + w) * 3 + 1]; // G | |
| output[2 * imageSize + h * targetWidth + w] = normalizedImage[(h * targetWidth + w) * 3 + 2]; // B | |
| } | |
| } | |
| } | |
| std::string print_shape(const std::vector<std::int64_t> &v) | |
| { | |
| std::stringstream ss(""); | |
| for (std::size_t i = 0; i < v.size() - 1; i++) | |
| ss << v[i] << "x"; | |
| ss << v[v.size() - 1]; | |
| return ss.str(); | |
| } | |
| void log_callback(void *param, OrtLoggingLevel severity, const char *category, const char *logid, const char *code_location, const char *message) | |
| { | |
| std::cout << "" << message << std::endl; | |
| } | |
| int main() | |
| { | |
| // const std::basic_string<ORTCHAR_T> model_path = "../../models/resnet18-v1-7.onnx"; | |
| const std::basic_string<ORTCHAR_T> model_path = "../../models/squeezenet1.1-7.onnx"; | |
| const std::string image_path = "../../img/cat.jpeg"; | |
| const int input_width = 224; | |
| const int input_height = 224; | |
| const float conf_threshold = 0.5f; | |
| // Ort::Env env(OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, "ONNX", &log_callback, nullptr); | |
| Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "ONNX"); | |
| Ort::AllocatorWithDefaultOptions allocator; | |
| Ort::SessionOptions session_options; | |
| Ort::Session session(env, model_path.c_str(), session_options); | |
| // Inputs | |
| std::vector<const char *> input_names; | |
| std::vector<std::string> input_names_str; | |
| std::vector<std::vector<std::int64_t>> input_shapes; | |
| for (std::size_t i = 0; i < session.GetInputCount(); i++) | |
| { | |
| auto input_name = session.GetInputNameAllocated(i, allocator); | |
| auto input_shape = session.GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape(); | |
| auto input_type = session.GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementType(); | |
| auto input_count = session.GetInputTypeInfo(i).GetTensorTypeAndShapeInfo(); | |
| // const auto input_count = session.GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementCount(); | |
| for (auto& s : input_shape) | |
| { | |
| if(s < 0) | |
| s = 1; | |
| } | |
| input_shapes.emplace_back(input_shape); | |
| input_names_str.emplace_back(input_name.get()); | |
| std::cout | |
| << "Input: " << i | |
| << "\n - name: " << input_name | |
| << "\n - shape: " << print_shape(input_shape) | |
| << "\n - element type: " << input_type | |
| << "\n - element count: " << input_count | |
| << std::endl; | |
| } | |
| for (auto&& s : input_names_str) | |
| { | |
| input_names.emplace_back(s.c_str()); | |
| } | |
| // Outputs | |
| std::vector<const char *> output_names; | |
| std::vector<std::string> output_names_str; | |
| std::vector<std::vector<std::int64_t>> output_shapes; | |
| for (std::size_t i = 0; i < session.GetOutputCount(); i++) | |
| { | |
| auto output_name = session.GetOutputNameAllocated(i, allocator); | |
| auto output_shape = session.GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape(); | |
| auto output_type = session.GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementType(); | |
| auto output_count = session.GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo(); | |
| // const auto output_count = session.GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementCount(); | |
| for (auto& s : output_shape) | |
| { | |
| if(s < 0) | |
| s = 1; | |
| } | |
| output_shapes.emplace_back(output_shape); | |
| output_names_str.emplace_back(output_name.get()); | |
| std::cout | |
| << "Output: " << i | |
| << "\n - name: " << output_name | |
| << "\n - shape: " << print_shape(output_shape) | |
| << "\n - element type: " << output_type | |
| << "\n - element count: " << output_count | |
| << std::endl; | |
| } | |
| for (auto&& s : output_names_str) | |
| { | |
| output_names.emplace_back(s.c_str()); | |
| } | |
| std::vector<float> preprocessed_image; | |
| std::vector<int> input_dims = {1, 3, input_width, input_height}; | |
| preprocess(image_path, input_dims, preprocessed_image); | |
| Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); | |
| std::vector<Ort::Value> input_tensors; | |
| input_tensors.emplace_back( | |
| Ort::Value::CreateTensor<float>( | |
| memory_info, | |
| preprocessed_image.data(), | |
| preprocessed_image.size(), | |
| input_shapes.at(0).data(), | |
| input_shapes.at(0).size())); | |
| std::vector<Ort::Value> output_tensors = session.Run( | |
| Ort::RunOptions{nullptr}, | |
| input_names.data(), | |
| input_tensors.data(), | |
| input_names.size(), | |
| output_names.data(), | |
| output_names.size() | |
| ); | |
| float *floatarr = output_tensors.front().GetTensorMutableData<float>(); | |
| for (int i = 0; i < 999; i++) | |
| { | |
| std::cout << "Score for class [" << i << "] = " << floatarr[i] << '\n'; | |
| } | |
| std::cout << std::flush; | |
| return 0; | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment