import onnx_ir as ir import onnx_ir.passes.common import onnxscript m = ir.load("perch_v2_opt3.onnx") for node in m.graph: if node.op_type == "MatMul": print(node) if node.inputs[0].producer().op_type == "Reshape": # Skip the reshape input = node.inputs[0].producer().inputs[0] node.replace_input_with(0, input) for usage in node.outputs[0].uses(): if usage.node.op_type == "Reshape": reshape_usages = list(usage.node.outputs[0].uses()) # Keep the last Reshape if reshape_usages[0].node.op_type == "ReduceMax": shape = ir.val( "reshape_shape", const_value=ir.tensor([-1, 16, 4, 14795, 4]) ) m.graph.initializers.add(shape) usage.node.replace_input_with(1, shape) continue reshape_node = usage.node output = reshape_node.outputs[0] output.replace_all_uses_with(node.outputs[0]) # Remove Expand if node.op_type == "Expand": print(node) input = node.inputs[0] output = node.outputs[0] output.replace_all_uses_with(input) # Clean up any unused nodes onnx_ir.passes.common.RemoveUnusedNodesPass()(m) # Clear all intermediate shapes and re-infer shapes for node in m.graph: for output in node.outputs: if output.is_graph_output(): continue output.shape = None m.graph.inputs[0].shape = ir.Shape(["batch", *m.graph.inputs[0].shape[1:]]) for output in m.graph.outputs: output.shape = ir.Shape(["batch", *output.shape[1:]]) onnxscript.optimizer.optimize( m, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024 ) onnx_ir.passes.common.ClearMetadataAndDocStringPass()(m) # Rename outputs and match the tflite model m.graph.outputs[0].name = "spatial_embedding" m.graph.outputs[1].name = "embedding" m.graph.outputs[2].name = "spectrogram" m.graph.outputs[3].name = "label" out_0 = m.graph.outputs[0] out_1 = m.graph.outputs[1] m.graph.outputs[1] = out_0 m.graph.outputs[0] = out_1 ir.save(m, "perch_v2_opt4.onnx")