Skip to content

Instantly share code, notes, and snippets.

@StefanoLusardi
Last active December 31, 2024 08:39
Show Gist options
  • Select an option

  • Save StefanoLusardi/7a7477f107a1ff0eda68ae135fe1705b to your computer and use it in GitHub Desktop.

Select an option

Save StefanoLusardi/7a7477f107a1ff0eda68ae135fe1705b to your computer and use it in GitHub Desktop.
ONNX Classifier (Resnet18 / Squeezenet)
#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"
#define STB_IMAGE_RESIZE_IMPLEMENTATION
#include "stb_image_resize2.h"
#include <onnxruntime/core/session/onnxruntime_cxx_api.h>
#include <iostream>
#include <vector>
#include <cmath>
#include <string>
#include <algorithm>
#include <numeric>
#include <memory>
#include <sstream>
void preprocess(const std::string &imageFilepath, const std::vector<int> &inputDims, std::vector<float> &output)
{
int width, height, channels;
unsigned char *imageData = stbi_load(imageFilepath.c_str(), &width, &height, &channels, 3);
if (!imageData)
{
throw std::runtime_error("Failed to load image: " + imageFilepath);
}
// Target dimensions
int targetHeight = inputDims[2];
int targetWidth = inputDims[3];
// Resized image buffer
std::vector<unsigned char> resizedImage(targetHeight * targetWidth * 3);
// Use stb_image_resize2 for resizing
stbir_resize_uint8_linear(
imageData, width, height, 0,
resizedImage.data(), targetWidth, targetHeight, 0,
stbir_pixel_layout(0));
stbi_image_free(imageData); // Free the original image buffer
// Convert to float and normalize
std::vector<float> normalizedImage(targetHeight * targetWidth * 3);
for (int i = 0; i < targetHeight * targetWidth * 3; ++i)
{
float value = resizedImage[i] / 255.0f; // Scale to [0, 1]
if (i % 3 == 0)
{ // R channel
normalizedImage[i] = (value - 0.485f) / 0.229f;
}
else if (i % 3 == 1)
{ // G channel
normalizedImage[i] = (value - 0.456f) / 0.224f;
}
else
{ // B channel
normalizedImage[i] = (value - 0.406f) / 0.225f;
}
}
// Convert HWC to CHW
output.resize(targetHeight * targetWidth * 3);
int imageSize = targetHeight * targetWidth;
for (int h = 0; h < targetHeight; ++h)
{
for (int w = 0; w < targetWidth; ++w)
{
output[0 * imageSize + h * targetWidth + w] = normalizedImage[(h * targetWidth + w) * 3 + 0]; // R
output[1 * imageSize + h * targetWidth + w] = normalizedImage[(h * targetWidth + w) * 3 + 1]; // G
output[2 * imageSize + h * targetWidth + w] = normalizedImage[(h * targetWidth + w) * 3 + 2]; // B
}
}
}
std::string print_shape(const std::vector<std::int64_t> &v)
{
std::stringstream ss("");
for (std::size_t i = 0; i < v.size() - 1; i++)
ss << v[i] << "x";
ss << v[v.size() - 1];
return ss.str();
}
void log_callback(void *param, OrtLoggingLevel severity, const char *category, const char *logid, const char *code_location, const char *message)
{
std::cout << "" << message << std::endl;
}
int main()
{
// const std::basic_string<ORTCHAR_T> model_path = "../../models/resnet18-v1-7.onnx";
const std::basic_string<ORTCHAR_T> model_path = "../../models/squeezenet1.1-7.onnx";
const std::string image_path = "../../img/cat.jpeg";
const int input_width = 224;
const int input_height = 224;
const float conf_threshold = 0.5f;
// Ort::Env env(OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, "ONNX", &log_callback, nullptr);
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "ONNX");
Ort::AllocatorWithDefaultOptions allocator;
Ort::SessionOptions session_options;
Ort::Session session(env, model_path.c_str(), session_options);
// Inputs
std::vector<const char *> input_names;
std::vector<std::string> input_names_str;
std::vector<std::vector<std::int64_t>> input_shapes;
for (std::size_t i = 0; i < session.GetInputCount(); i++)
{
auto input_name = session.GetInputNameAllocated(i, allocator);
auto input_shape = session.GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
auto input_type = session.GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementType();
auto input_count = session.GetInputTypeInfo(i).GetTensorTypeAndShapeInfo();
// const auto input_count = session.GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementCount();
for (auto& s : input_shape)
{
if(s < 0)
s = 1;
}
input_shapes.emplace_back(input_shape);
input_names_str.emplace_back(input_name.get());
std::cout
<< "Input: " << i
<< "\n - name: " << input_name
<< "\n - shape: " << print_shape(input_shape)
<< "\n - element type: " << input_type
<< "\n - element count: " << input_count
<< std::endl;
}
for (auto&& s : input_names_str)
{
input_names.emplace_back(s.c_str());
}
// Outputs
std::vector<const char *> output_names;
std::vector<std::string> output_names_str;
std::vector<std::vector<std::int64_t>> output_shapes;
for (std::size_t i = 0; i < session.GetOutputCount(); i++)
{
auto output_name = session.GetOutputNameAllocated(i, allocator);
auto output_shape = session.GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
auto output_type = session.GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementType();
auto output_count = session.GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo();
// const auto output_count = session.GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementCount();
for (auto& s : output_shape)
{
if(s < 0)
s = 1;
}
output_shapes.emplace_back(output_shape);
output_names_str.emplace_back(output_name.get());
std::cout
<< "Output: " << i
<< "\n - name: " << output_name
<< "\n - shape: " << print_shape(output_shape)
<< "\n - element type: " << output_type
<< "\n - element count: " << output_count
<< std::endl;
}
for (auto&& s : output_names_str)
{
output_names.emplace_back(s.c_str());
}
std::vector<float> preprocessed_image;
std::vector<int> input_dims = {1, 3, input_width, input_height};
preprocess(image_path, input_dims, preprocessed_image);
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
std::vector<Ort::Value> input_tensors;
input_tensors.emplace_back(
Ort::Value::CreateTensor<float>(
memory_info,
preprocessed_image.data(),
preprocessed_image.size(),
input_shapes.at(0).data(),
input_shapes.at(0).size()));
std::vector<Ort::Value> output_tensors = session.Run(
Ort::RunOptions{nullptr},
input_names.data(),
input_tensors.data(),
input_names.size(),
output_names.data(),
output_names.size()
);
float *floatarr = output_tensors.front().GetTensorMutableData<float>();
for (int i = 0; i < 999; i++)
{
std::cout << "Score for class [" << i << "] = " << floatarr[i] << '\n';
}
std::cout << std::flush;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment