# transformers==4.52.4 # pytorch nightly from __future__ import annotations import ast import logging import typing from pathlib import Path import onnx_ir as ir import onnxscript.rewriter.ort_fusions import torch import torch.onnx.testing from onnx_ir.passes import PassResult from onnx_ir.passes.common import ClearMetadataAndDocStringPass from transformers import AutoModel, AutoTokenizer logger = logging.getLogger(__name__) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def sdpa_attention_forward( module: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor | None, dropout: float = 0.0, scaling: float | None = None, is_causal: bool | None = None, **kwargs, ) -> tuple[torch.Tensor, None]: if hasattr(module, "num_key_value_groups"): key = repeat_kv(key, module.num_key_value_groups) value = repeat_kv(value, module.num_key_value_groups) causal_mask = attention_mask query = query.contiguous() key = key.contiguous() value = value.contiguous() attn_output = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=causal_mask, dropout_p=dropout, scale=scaling, ) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None # Path sdpa attention for transformers import transformers.integrations.sdpa_attention transformers.integrations.sdpa_attention.sdpa_attention_forward = sdpa_attention_forward def _get_scoped_prefix(name_scopes: list[str]) -> str: # Remove common prefixes between consecutive scopes processed_scopes = [] for i, scope in enumerate(name_scopes): if i == 0: processed_scopes.append(scope) else: prev_scope = name_scopes[i - 1] processed_scopes.append(scope.removeprefix(prev_scope).lstrip(".")) return "/".join(processed_scopes) class AssignNamesPass(ir.passes.InPlacePass): def call(self, model: ir.Model) -> PassResult: modified = False for node in model.graph.all_nodes(): if "pkg.torch.onnx.name_scopes" in node.metadata_props: name_scopes = typing.cast( "list[str]", ast.literal_eval(node.metadata_props["pkg.torch.onnx.name_scopes"]), ) name_scopes.pop() # Remove self name prefix = _get_scoped_prefix(name_scopes) # Rename node if prefix: node.name = f"{prefix}/{node.name}" modified = True # Rename outputs for output in node.outputs: if ( not output.is_graph_output() and output.name is not None and output.name != "" ): if prefix: scoped_name = f"{prefix}/{output.name}" logger.debug("Renaming %r to %r", output.name, scoped_name) output.name = scoped_name modified = True return PassResult(model, modified) class Qwen3EmbeddingONNXExporter: def __init__(self, model_id="Qwen/Qwen3-Embedding-0.6B"): self.model_id = model_id self.model = None self.tokenizer = None def load_model(self): """Load the Qwen3 model and tokenizer""" print(f"Loading {self.model_id}...") self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True) self.model = AutoModel.from_pretrained(self.model_id, trust_remote_code=True) self.model.eval() print("Model loaded successfully!") def create_dummy_inputs(self, batch_size=2, seq_length=128): """Create dummy inputs for ONNX export""" dummy_text = ["This is a sample text for ONNX export"] * batch_size inputs = self.tokenizer( dummy_text, return_tensors="pt", padding="max_length", truncation=True, max_length=seq_length, ) return inputs def export_to_onnx(self, output_dir="./qwen3-onnx"): """Export model to ONNX format""" output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) # Save tokenizer and config print("Saving tokenizer and config...") self.tokenizer.save_pretrained(output_dir) # Create dummy inputs # NOTE(justinchuby): Batch size must be greater than 1 to be captured as dynamic dummy_inputs = self.create_dummy_inputs() input_ids = dummy_inputs["input_ids"] attention_mask = dummy_inputs["attention_mask"] # Export to ONNX print("Exporting to ONNX...") # Wrap the model WITHOUT pooling - TEI will handle pooling class ModelWrapper(torch.nn.Module): def __init__(self, model): super().__init__() self.model = model def forward(self, input_ids, attention_mask): outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) if hasattr(outputs, "last_hidden_state"): return outputs.last_hidden_state else: return outputs[0] wrapped_model = ModelWrapper(self.model) wrapped_model.eval() onnx_program = torch.onnx.export( wrapped_model, (input_ids, attention_mask), input_names=["input_ids", "attention_mask"], output_names=["last_hidden_state"], dynamic_shapes={ "input_ids": {0: "batch", 1: "seq"}, "attention_mask": {0: "batch", 1: "seq"}, }, opset_version=21, ) AssignNamesPass()(onnx_program.model) onnx_program.save(output_path / "model_pre_fusion.onnx") torch.onnx.testing.assert_onnx_program(onnx_program, atol=1e-4, rtol=1e-4) # Optimize for ORT model, fusion = onnxscript.rewriter.ort_fusions.optimize_for_ort(onnx_program.model) print(fusion) # For production, remove metadata: result = ClearMetadataAndDocStringPass()(model) onnx_program.model = result.model onnx_program.save(output_path / "model.onnx") torch.onnx.testing.assert_onnx_program(onnx_program, atol=1e-3, rtol=1e-3) def main(): exporter = Qwen3EmbeddingONNXExporter() exporter.load_model() exporter.export_to_onnx(output_dir="./qwen3-onnx-1028") if __name__ == "__main__": main()