Created
May 8, 2026 21:28
-
-
Save bigs/0cd42247989a9835c6c1a2457e5a6c01 to your computer and use it in GitHub Desktop.
ZAYA vLLM patches: streaming tool parser fix and explicit head_dim support
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/tests/tool_parsers/test_zaya_tool_parser.py b/tests/tool_parsers/test_zaya_tool_parser.py | |
| new file mode 100644 | |
| index 000000000..6a7f9a942 | |
| --- /dev/null | |
| +++ b/tests/tool_parsers/test_zaya_tool_parser.py | |
| @@ -0,0 +1,104 @@ | |
| +# SPDX-License-Identifier: Apache-2.0 | |
| +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | |
| + | |
| +import json | |
| + | |
| +from vllm.entrypoints.openai.chat_completion.protocol import ( | |
| + ChatCompletionRequest, | |
| + ChatCompletionToolsParam, | |
| +) | |
| +from vllm.tool_parsers.zaya_tool_parser import ZayaXMLToolParser | |
| + | |
| + | |
| +def sample_tools(): | |
| + return [ | |
| + ChatCompletionToolsParam( | |
| + type="function", | |
| + function={ | |
| + "name": "get_current_weather", | |
| + "description": "Get the current weather", | |
| + "parameters": { | |
| + "type": "object", | |
| + "properties": { | |
| + "city": {"type": "string", "description": "The city name"}, | |
| + }, | |
| + "required": ["city"], | |
| + }, | |
| + }, | |
| + ), | |
| + ] | |
| + | |
| + | |
| +def test_zaya_streaming_tracks_tool_call_state_and_empty_final_delta(): | |
| + parser = ZayaXMLToolParser(tokenizer=None) # type: ignore[arg-type] | |
| + request = ChatCompletionRequest(model="zaya", messages=[], tools=sample_tools()) | |
| + chunks = [ | |
| + "<zyphra_tool_call>", | |
| + "<function=get_current_weather>", | |
| + "<parameter=city>", | |
| + "Paris", | |
| + "</parameter>", | |
| + "</function>", | |
| + "</zyphra_tool_call>", | |
| + ] | |
| + | |
| + previous_text = "" | |
| + previous_token_ids: list[int] = [] | |
| + tool_state = {"name": None, "arguments": ""} | |
| + | |
| + for idx, delta_text in enumerate(chunks): | |
| + delta_token_ids = [idx] | |
| + current_text = previous_text + delta_text | |
| + current_token_ids = previous_token_ids + delta_token_ids | |
| + | |
| + delta = parser.extract_tool_calls_streaming( | |
| + previous_text=previous_text, | |
| + current_text=current_text, | |
| + delta_text=delta_text, | |
| + previous_token_ids=previous_token_ids, | |
| + current_token_ids=current_token_ids, | |
| + delta_token_ids=delta_token_ids, | |
| + request=request, | |
| + ) | |
| + | |
| + if delta and delta.tool_calls: | |
| + for tool_call in delta.tool_calls: | |
| + if tool_call.function and tool_call.function.name: | |
| + tool_state["name"] = tool_call.function.name | |
| + if tool_call.function and tool_call.function.arguments is not None: | |
| + tool_state["arguments"] += tool_call.function.arguments | |
| + | |
| + previous_text = current_text | |
| + previous_token_ids = current_token_ids | |
| + | |
| + assert tool_state["name"] == "get_current_weather" | |
| + assert json.loads(tool_state["arguments"]) == {"city": "Paris"} | |
| + assert parser.prev_tool_call_arr == [ | |
| + { | |
| + "name": "get_current_weather", | |
| + "arguments": '{"city": "Paris"}', | |
| + } | |
| + ] | |
| + assert parser.streamed_args_for_tool == ['{"city": "Paris"}'] | |
| + | |
| + final_delta = parser.extract_tool_calls_streaming( | |
| + previous_text=previous_text, | |
| + current_text=previous_text, | |
| + delta_text="", | |
| + previous_token_ids=previous_token_ids, | |
| + current_token_ids=previous_token_ids, | |
| + delta_token_ids=[999], | |
| + request=request, | |
| + ) | |
| + | |
| + assert final_delta is not None | |
| + assert final_delta.content is None | |
| + assert final_delta.tool_calls is not None | |
| + assert final_delta.tool_calls[0].function is not None | |
| + assert final_delta.tool_calls[0].function.arguments == "" | |
| + | |
| + | |
| +def test_zaya_streaming_opts_out_of_final_argument_repair(): | |
| + parser = ZayaXMLToolParser(tokenizer=None) # type: ignore[arg-type] | |
| + | |
| + assert not parser.parser_should_check_for_unstreamed_tool_arg_tokens() | |
| diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py | |
| index adcd488a0..0ea8fd0e0 100644 | |
| --- a/vllm/entrypoints/openai/chat_completion/serving.py | |
| +++ b/vllm/entrypoints/openai/chat_completion/serving.py | |
| @@ -1899,6 +1899,13 @@ class OpenAIServingChat(OpenAIServing): | |
| is a tool call with arguments. | |
| """ | |
| + if self.tool_parser: | |
| + should_check = ( | |
| + self.tool_parser.parser_should_check_for_unstreamed_tool_arg_tokens() | |
| + ) | |
| + if not should_check: | |
| + return False | |
| + | |
| return bool( | |
| # if there is a delta message that includes tool calls which | |
| # include a function that has arguments | |
| diff --git a/vllm/tool_parsers/abstract_tool_parser.py b/vllm/tool_parsers/abstract_tool_parser.py | |
| index 75cffd329..ff00bb0fa 100644 | |
| --- a/vllm/tool_parsers/abstract_tool_parser.py | |
| +++ b/vllm/tool_parsers/abstract_tool_parser.py | |
| @@ -118,6 +118,14 @@ class ToolParser: | |
| "AbstractToolParser.extract_tool_calls_streaming has not been implemented!" | |
| ) | |
| + @staticmethod | |
| + def parser_should_check_for_unstreamed_tool_arg_tokens() -> bool: | |
| + """ | |
| + Whether serving should run its generic final-chunk recovery for argument | |
| + text that may have been parsed but not streamed yet. | |
| + """ | |
| + return True | |
| + | |
| class ToolParserManager: | |
| """ | |
| diff --git a/vllm/tool_parsers/zaya_tool_parser.py b/vllm/tool_parsers/zaya_tool_parser.py | |
| index 5818db7d5..442ae9924 100644 | |
| --- a/vllm/tool_parsers/zaya_tool_parser.py | |
| +++ b/vllm/tool_parsers/zaya_tool_parser.py | |
| @@ -1067,6 +1067,10 @@ class ZayaXMLToolParser(ToolParser): | |
| super().__init__(tokenizer) | |
| self.parser = StreamingXMLToolCallParser() | |
| + # Keep vLLM's streaming serving layer in sync with emitted tool deltas. | |
| + self.prev_tool_call_arr: list[dict] = [] | |
| + self.streamed_args_for_tool: list[str] = [] | |
| + | |
| logger.info("vLLM Successfully import tool parser %s !", | |
| self.__class__.__name__) | |
| @@ -1116,6 +1120,8 @@ class ZayaXMLToolParser(ToolParser): | |
| ) -> Union[DeltaMessage, None]: | |
| if not previous_text: | |
| self.parser.reset_streaming_state() | |
| + self.prev_tool_call_arr = [] | |
| + self.streamed_args_for_tool = [] | |
| if request: | |
| self.parser.set_tools(request.tools) | |
| @@ -1128,7 +1134,6 @@ class ZayaXMLToolParser(ToolParser): | |
| self.parser.tool_call_start_token) - current_text.count( | |
| self.parser.tool_call_end_token) | |
| if open_calls == 0 and self.parser.tool_call_index > 0: | |
| - # If current_call_id is None, use last_completed_call_id | |
| call_id = self.parser.current_call_id or \ | |
| self.parser.last_completed_call_id | |
| return DeltaMessage(tool_calls=[ | |
| @@ -1140,4 +1145,42 @@ class ZayaXMLToolParser(ToolParser): | |
| ) | |
| ]) | |
| - return self.parser.parse_single_streaming_chunks(delta_text) | |
| + result = self.parser.parse_single_streaming_chunks(delta_text) | |
| + if result and result.tool_calls: | |
| + for tool_call in result.tool_calls: | |
| + if not tool_call.function: | |
| + continue | |
| + | |
| + tool_index = ( | |
| + tool_call.index if tool_call.index is not None | |
| + else len(self.prev_tool_call_arr) - 1 | |
| + ) | |
| + | |
| + while len(self.prev_tool_call_arr) <= tool_index: | |
| + self.prev_tool_call_arr.append({"name": "", "arguments": ""}) | |
| + while len(self.streamed_args_for_tool) <= tool_index: | |
| + self.streamed_args_for_tool.append("") | |
| + | |
| + if tool_call.function.name: | |
| + self.prev_tool_call_arr[tool_index]["name"] = ( | |
| + tool_call.function.name | |
| + ) | |
| + | |
| + if tool_call.function.arguments is not None: | |
| + self.prev_tool_call_arr[tool_index]["arguments"] += ( | |
| + tool_call.function.arguments | |
| + ) | |
| + self.streamed_args_for_tool[tool_index] += ( | |
| + tool_call.function.arguments | |
| + ) | |
| + | |
| + return result | |
| + | |
| + @staticmethod | |
| + def parser_should_check_for_unstreamed_tool_arg_tokens() -> bool: | |
| + """ | |
| + Zaya XML streams its argument JSON directly as deltas, so the generic | |
| + final-chunk partial JSON repair would re-serialize already streamed | |
| + arguments and can duplicate/escape the tail. | |
| + """ | |
| + return False |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/vllm/model_executor/layers/mamba/cca.py b/vllm/model_executor/layers/mamba/cca.py | |
| index 8d0c3ce26..28b122431 100644 | |
| --- a/vllm/model_executor/layers/mamba/cca.py | |
| +++ b/vllm/model_executor/layers/mamba/cca.py | |
| @@ -64,8 +64,10 @@ class CCA(MambaBase, CustomOp): | |
| self.num_q_heads = int(cca_num_q_heads) | |
| self.num_heads = int(cca_num_heads) | |
| - # Geometry | |
| - self.head_dim = self.hidden_size // self.num_heads | |
| + # Geometry. Newer ZAYA configs carry an explicit head_dim; use it when | |
| + # present instead of deriving from the CCA head count. | |
| + self.head_dim = int(getattr(config, "head_dim", | |
| + self.hidden_size // self.num_heads)) | |
| self.latent_k_dim = self.num_k_heads * self.head_dim | |
| self.latent_q_dim = self.num_q_heads * self.head_dim | |
| self.sqrt_head_dim = np.sqrt(self.head_dim) | |
| diff --git a/vllm/model_executor/models/zaya.py b/vllm/model_executor/models/zaya.py | |
| index e86cf381f..5158a27a9 100644 | |
| --- a/vllm/model_executor/models/zaya.py | |
| +++ b/vllm/model_executor/models/zaya.py | |
| @@ -119,7 +119,8 @@ class ZayaAttention(nn.Module): | |
| self.cca_num_heads = config.num_attention_heads | |
| self.cca_time0 = config.cca_time0 | |
| self.cca_time1 = config.cca_time1 | |
| - self.head_dim = self.hidden_size // self.cca_num_heads | |
| + self.head_dim = getattr(config, "head_dim", | |
| + self.hidden_size // self.cca_num_heads) | |
| self.scale = self.head_dim**-0.5 | |
| self.qkv = CCA( | |
| @@ -702,7 +703,8 @@ class ZayaForCausalLM(nn.Module, HasInnerState, IsHybrid): | |
| conv_kernel_size = hf_config.cca_time0 | |
| num_k_heads = hf_config.num_query_groups | |
| num_q_heads = hf_config.cca_num_q_heads | |
| - head_dim = hf_config.hidden_size // hf_config.num_attention_heads | |
| + head_dim = getattr(hf_config, "head_dim", | |
| + hf_config.hidden_size // hf_config.num_attention_heads) | |
| hidden_size = hf_config.hidden_size | |
| return MambaStateShapeCalculator.cca_state_shape( | |
| diff --git a/vllm/transformers_utils/configs/zaya.py b/vllm/transformers_utils/configs/zaya.py | |
| index ff392811c..2a9577e67 100644 | |
| --- a/vllm/transformers_utils/configs/zaya.py | |
| +++ b/vllm/transformers_utils/configs/zaya.py | |
| @@ -20,6 +20,7 @@ class ZayaConfig(PretrainedConfig): | |
| num_hidden_layers=80, | |
| num_experts=16, | |
| num_attention_heads=16, | |
| + head_dim=None, | |
| activation_func='swiglu', | |
| max_position_embeddings=4096, | |
| norm_epsilon=1e-05, | |
| @@ -67,8 +68,11 @@ class ZayaConfig(PretrainedConfig): | |
| self.num_hidden_layers = num_hidden_layers | |
| self.num_experts = num_experts | |
| self.num_attention_heads = num_attention_heads | |
| - assert self.hidden_size % self.num_attention_heads == 0 | |
| - self.kv_channels = self.hidden_size // self.num_attention_heads | |
| + if head_dim is None: | |
| + assert self.hidden_size % self.num_attention_heads == 0 | |
| + head_dim = self.hidden_size // self.num_attention_heads | |
| + self.head_dim = head_dim | |
| + self.kv_channels = head_dim | |
| self.num_key_value_heads = num_key_value_heads | |
| self.activation_func = activation_func | |
| self.max_position_embeddings = max_position_embeddings |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment