Skip to content

Instantly share code, notes, and snippets.

@shinh
Last active April 18, 2026 03:58
Show Gist options
  • Select an option

  • Save shinh/34b3f6af69fa7a2cd115c84b5ad476b8 to your computer and use it in GitHub Desktop.

Select an option

Save shinh/34b3f6af69fa7a2cd115c84b5ad476b8 to your computer and use it in GitHub Desktop.
A bug in onnx-tool
#!/usr/bin/env python3
"""Reproduce scoring/runtime behavior for Sqrt vs Expand(Sqrt) with shape [30]."""
from __future__ import annotations
import tempfile
from pathlib import Path
import sys
import numpy as np
import onnx
import onnxruntime as ort
from onnx import TensorProto, helper
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import neurogolf_utils as ng
INPUT_SHAPE = [1, 10, 30, 30]
def build_sqrt_model(path: Path) -> None:
graph = helper.make_graph(
nodes=[
helper.make_node("Sqrt", ["input"], ["output"]),
],
name="sqrt_graph",
inputs=[helper.make_tensor_value_info("input", TensorProto.FLOAT, INPUT_SHAPE)],
outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, INPUT_SHAPE)],
)
model = helper.make_model(
graph,
opset_imports=[helper.make_operatorsetid("", 13)],
producer_name="expand_bug_repro",
)
onnx.save(model, path)
def build_expand_sqrt_model(path: Path) -> None:
shape_initializer = helper.make_tensor(
name="shape",
data_type=TensorProto.INT64,
dims=[1],
vals=[30],
)
graph = helper.make_graph(
nodes=[
helper.make_node("Expand", ["input", "shape"], ["temp"]),
helper.make_node("Sqrt", ["temp"], ["output"]),
],
name="expand_graph",
inputs=[helper.make_tensor_value_info("input", TensorProto.FLOAT, INPUT_SHAPE)],
outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, INPUT_SHAPE)],
initializer=[shape_initializer],
)
model = helper.make_model(
graph,
opset_imports=[helper.make_operatorsetid("", 13)],
producer_name="expand_bug_repro",
)
onnx.save(model, path)
def run_onnxruntime(model_path: Path, x: np.ndarray) -> np.ndarray:
session = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"])
return session.run(["output"], {"input": x})[0]
def main() -> None:
x = np.arange(np.prod(INPUT_SHAPE), dtype=np.float32).reshape(INPUT_SHAPE)
tmpdir_path = Path("/tmp")
sqrt_path = tmpdir_path / "sqrt.onnx"
expand_path = tmpdir_path / "expand.onnx"
build_sqrt_model(sqrt_path)
build_expand_sqrt_model(expand_path)
sqrt_out = run_onnxruntime(sqrt_path, x)
expand_out = run_onnxruntime(expand_path, x)
print("Sqrt == Expand:", np.array_equal(sqrt_out, expand_out))
print("Sqrt score:", ng.score_network(str(sqrt_path)))
print("Expand score:", ng.score_network(str(expand_path)))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment