Skip to content

Instantly share code, notes, and snippets.

@shinh
Created April 22, 2026 04:02
Show Gist options
  • Select an option

  • Save shinh/a958fff943b98e1b65464048483a9494 to your computer and use it in GitHub Desktop.

Select an option

Save shinh/a958fff943b98e1b65464048483a9494 to your computer and use it in GitHub Desktop.
Another bug in onnx-tool
#!/usr/bin/env python3
"""Reproduce scoring/runtime behavior for Sqrt vs Einsum(Sqrt) with a [1,1,1,1] tensor."""
from __future__ import annotations
from pathlib import Path
import sys
import numpy as np
import onnx
import onnxruntime as ort
from onnx import TensorProto, helper
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import neurogolf_utils as ng
INPUT_SHAPE = [2, 3, 4, 5]
ONES_SHAPE = [1, 1, 1, 1]
def build_sqrt_model(path: Path) -> None:
graph = helper.make_graph(
nodes=[
helper.make_node("Sqrt", ["input"], ["output"]),
],
name="sqrt_graph",
inputs=[helper.make_tensor_value_info("input", TensorProto.FLOAT, INPUT_SHAPE)],
outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, INPUT_SHAPE)],
)
model = helper.make_model(
graph,
opset_imports=[helper.make_operatorsetid("", 13)],
producer_name="einsum_bug_repro",
)
onnx.save(model, path)
def build_einsum_sqrt_model(path: Path) -> None:
ones_initializer = helper.make_tensor(
name="ones",
data_type=TensorProto.FLOAT,
dims=ONES_SHAPE,
vals=np.ones(np.prod(ONES_SHAPE), dtype=np.float32),
)
graph = helper.make_graph(
nodes=[
helper.make_node(
"Einsum",
["input", "ones"],
["temp"],
equation="ijkl,ijkl->ijkl",
),
helper.make_node("Sqrt", ["temp"], ["output"]),
],
name="einsum_graph",
inputs=[helper.make_tensor_value_info("input", TensorProto.FLOAT, INPUT_SHAPE)],
outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, INPUT_SHAPE)],
initializer=[ones_initializer],
)
model = helper.make_model(
graph,
opset_imports=[helper.make_operatorsetid("", 13)],
producer_name="einsum_bug_repro",
)
onnx.save(model, path)
def run_onnxruntime(model_path: Path, x: np.ndarray) -> np.ndarray:
session = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"])
return session.run(["output"], {"input": x})[0]
def main() -> None:
x = np.arange(1, np.prod(INPUT_SHAPE) + 1, dtype=np.float32).reshape(INPUT_SHAPE)
tmpdir_path = Path("/tmp")
sqrt_path = tmpdir_path / "sqrt_einsum_base.onnx"
einsum_path = tmpdir_path / "sqrt_with_einsum.onnx"
build_sqrt_model(sqrt_path)
build_einsum_sqrt_model(einsum_path)
sqrt_out = run_onnxruntime(sqrt_path, x)
einsum_out = run_onnxruntime(einsum_path, x)
print("Sqrt == Einsum(Sqrt):", np.array_equal(sqrt_out, einsum_out))
print("Sqrt score:", ng.score_network(str(sqrt_path)))
print("Einsum(Sqrt) score:", ng.score_network(str(einsum_path)))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment