Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save bigs/0cd42247989a9835c6c1a2457e5a6c01 to your computer and use it in GitHub Desktop.

Select an option

Save bigs/0cd42247989a9835c6c1a2457e5a6c01 to your computer and use it in GitHub Desktop.
ZAYA vLLM patches: streaming tool parser fix and explicit head_dim support
diff --git a/tests/tool_parsers/test_zaya_tool_parser.py b/tests/tool_parsers/test_zaya_tool_parser.py
new file mode 100644
index 000000000..6a7f9a942
--- /dev/null
+++ b/tests/tool_parsers/test_zaya_tool_parser.py
@@ -0,0 +1,104 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import json
+
+from vllm.entrypoints.openai.chat_completion.protocol import (
+ ChatCompletionRequest,
+ ChatCompletionToolsParam,
+)
+from vllm.tool_parsers.zaya_tool_parser import ZayaXMLToolParser
+
+
+def sample_tools():
+ return [
+ ChatCompletionToolsParam(
+ type="function",
+ function={
+ "name": "get_current_weather",
+ "description": "Get the current weather",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "city": {"type": "string", "description": "The city name"},
+ },
+ "required": ["city"],
+ },
+ },
+ ),
+ ]
+
+
+def test_zaya_streaming_tracks_tool_call_state_and_empty_final_delta():
+ parser = ZayaXMLToolParser(tokenizer=None) # type: ignore[arg-type]
+ request = ChatCompletionRequest(model="zaya", messages=[], tools=sample_tools())
+ chunks = [
+ "<zyphra_tool_call>",
+ "<function=get_current_weather>",
+ "<parameter=city>",
+ "Paris",
+ "</parameter>",
+ "</function>",
+ "</zyphra_tool_call>",
+ ]
+
+ previous_text = ""
+ previous_token_ids: list[int] = []
+ tool_state = {"name": None, "arguments": ""}
+
+ for idx, delta_text in enumerate(chunks):
+ delta_token_ids = [idx]
+ current_text = previous_text + delta_text
+ current_token_ids = previous_token_ids + delta_token_ids
+
+ delta = parser.extract_tool_calls_streaming(
+ previous_text=previous_text,
+ current_text=current_text,
+ delta_text=delta_text,
+ previous_token_ids=previous_token_ids,
+ current_token_ids=current_token_ids,
+ delta_token_ids=delta_token_ids,
+ request=request,
+ )
+
+ if delta and delta.tool_calls:
+ for tool_call in delta.tool_calls:
+ if tool_call.function and tool_call.function.name:
+ tool_state["name"] = tool_call.function.name
+ if tool_call.function and tool_call.function.arguments is not None:
+ tool_state["arguments"] += tool_call.function.arguments
+
+ previous_text = current_text
+ previous_token_ids = current_token_ids
+
+ assert tool_state["name"] == "get_current_weather"
+ assert json.loads(tool_state["arguments"]) == {"city": "Paris"}
+ assert parser.prev_tool_call_arr == [
+ {
+ "name": "get_current_weather",
+ "arguments": '{"city": "Paris"}',
+ }
+ ]
+ assert parser.streamed_args_for_tool == ['{"city": "Paris"}']
+
+ final_delta = parser.extract_tool_calls_streaming(
+ previous_text=previous_text,
+ current_text=previous_text,
+ delta_text="",
+ previous_token_ids=previous_token_ids,
+ current_token_ids=previous_token_ids,
+ delta_token_ids=[999],
+ request=request,
+ )
+
+ assert final_delta is not None
+ assert final_delta.content is None
+ assert final_delta.tool_calls is not None
+ assert final_delta.tool_calls[0].function is not None
+ assert final_delta.tool_calls[0].function.arguments == ""
+
+
+def test_zaya_streaming_opts_out_of_final_argument_repair():
+ parser = ZayaXMLToolParser(tokenizer=None) # type: ignore[arg-type]
+
+ assert not parser.parser_should_check_for_unstreamed_tool_arg_tokens()
diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py
index adcd488a0..0ea8fd0e0 100644
--- a/vllm/entrypoints/openai/chat_completion/serving.py
+++ b/vllm/entrypoints/openai/chat_completion/serving.py
@@ -1899,6 +1899,13 @@ class OpenAIServingChat(OpenAIServing):
is a tool call with arguments.
"""
+ if self.tool_parser:
+ should_check = (
+ self.tool_parser.parser_should_check_for_unstreamed_tool_arg_tokens()
+ )
+ if not should_check:
+ return False
+
return bool(
# if there is a delta message that includes tool calls which
# include a function that has arguments
diff --git a/vllm/tool_parsers/abstract_tool_parser.py b/vllm/tool_parsers/abstract_tool_parser.py
index 75cffd329..ff00bb0fa 100644
--- a/vllm/tool_parsers/abstract_tool_parser.py
+++ b/vllm/tool_parsers/abstract_tool_parser.py
@@ -118,6 +118,14 @@ class ToolParser:
"AbstractToolParser.extract_tool_calls_streaming has not been implemented!"
)
+ @staticmethod
+ def parser_should_check_for_unstreamed_tool_arg_tokens() -> bool:
+ """
+ Whether serving should run its generic final-chunk recovery for argument
+ text that may have been parsed but not streamed yet.
+ """
+ return True
+
class ToolParserManager:
"""
diff --git a/vllm/tool_parsers/zaya_tool_parser.py b/vllm/tool_parsers/zaya_tool_parser.py
index 5818db7d5..442ae9924 100644
--- a/vllm/tool_parsers/zaya_tool_parser.py
+++ b/vllm/tool_parsers/zaya_tool_parser.py
@@ -1067,6 +1067,10 @@ class ZayaXMLToolParser(ToolParser):
super().__init__(tokenizer)
self.parser = StreamingXMLToolCallParser()
+ # Keep vLLM's streaming serving layer in sync with emitted tool deltas.
+ self.prev_tool_call_arr: list[dict] = []
+ self.streamed_args_for_tool: list[str] = []
+
logger.info("vLLM Successfully import tool parser %s !",
self.__class__.__name__)
@@ -1116,6 +1120,8 @@ class ZayaXMLToolParser(ToolParser):
) -> Union[DeltaMessage, None]:
if not previous_text:
self.parser.reset_streaming_state()
+ self.prev_tool_call_arr = []
+ self.streamed_args_for_tool = []
if request:
self.parser.set_tools(request.tools)
@@ -1128,7 +1134,6 @@ class ZayaXMLToolParser(ToolParser):
self.parser.tool_call_start_token) - current_text.count(
self.parser.tool_call_end_token)
if open_calls == 0 and self.parser.tool_call_index > 0:
- # If current_call_id is None, use last_completed_call_id
call_id = self.parser.current_call_id or \
self.parser.last_completed_call_id
return DeltaMessage(tool_calls=[
@@ -1140,4 +1145,42 @@ class ZayaXMLToolParser(ToolParser):
)
])
- return self.parser.parse_single_streaming_chunks(delta_text)
+ result = self.parser.parse_single_streaming_chunks(delta_text)
+ if result and result.tool_calls:
+ for tool_call in result.tool_calls:
+ if not tool_call.function:
+ continue
+
+ tool_index = (
+ tool_call.index if tool_call.index is not None
+ else len(self.prev_tool_call_arr) - 1
+ )
+
+ while len(self.prev_tool_call_arr) <= tool_index:
+ self.prev_tool_call_arr.append({"name": "", "arguments": ""})
+ while len(self.streamed_args_for_tool) <= tool_index:
+ self.streamed_args_for_tool.append("")
+
+ if tool_call.function.name:
+ self.prev_tool_call_arr[tool_index]["name"] = (
+ tool_call.function.name
+ )
+
+ if tool_call.function.arguments is not None:
+ self.prev_tool_call_arr[tool_index]["arguments"] += (
+ tool_call.function.arguments
+ )
+ self.streamed_args_for_tool[tool_index] += (
+ tool_call.function.arguments
+ )
+
+ return result
+
+ @staticmethod
+ def parser_should_check_for_unstreamed_tool_arg_tokens() -> bool:
+ """
+ Zaya XML streams its argument JSON directly as deltas, so the generic
+ final-chunk partial JSON repair would re-serialize already streamed
+ arguments and can duplicate/escape the tail.
+ """
+ return False
diff --git a/vllm/model_executor/layers/mamba/cca.py b/vllm/model_executor/layers/mamba/cca.py
index 8d0c3ce26..28b122431 100644
--- a/vllm/model_executor/layers/mamba/cca.py
+++ b/vllm/model_executor/layers/mamba/cca.py
@@ -64,8 +64,10 @@ class CCA(MambaBase, CustomOp):
self.num_q_heads = int(cca_num_q_heads)
self.num_heads = int(cca_num_heads)
- # Geometry
- self.head_dim = self.hidden_size // self.num_heads
+ # Geometry. Newer ZAYA configs carry an explicit head_dim; use it when
+ # present instead of deriving from the CCA head count.
+ self.head_dim = int(getattr(config, "head_dim",
+ self.hidden_size // self.num_heads))
self.latent_k_dim = self.num_k_heads * self.head_dim
self.latent_q_dim = self.num_q_heads * self.head_dim
self.sqrt_head_dim = np.sqrt(self.head_dim)
diff --git a/vllm/model_executor/models/zaya.py b/vllm/model_executor/models/zaya.py
index e86cf381f..5158a27a9 100644
--- a/vllm/model_executor/models/zaya.py
+++ b/vllm/model_executor/models/zaya.py
@@ -119,7 +119,8 @@ class ZayaAttention(nn.Module):
self.cca_num_heads = config.num_attention_heads
self.cca_time0 = config.cca_time0
self.cca_time1 = config.cca_time1
- self.head_dim = self.hidden_size // self.cca_num_heads
+ self.head_dim = getattr(config, "head_dim",
+ self.hidden_size // self.cca_num_heads)
self.scale = self.head_dim**-0.5
self.qkv = CCA(
@@ -702,7 +703,8 @@ class ZayaForCausalLM(nn.Module, HasInnerState, IsHybrid):
conv_kernel_size = hf_config.cca_time0
num_k_heads = hf_config.num_query_groups
num_q_heads = hf_config.cca_num_q_heads
- head_dim = hf_config.hidden_size // hf_config.num_attention_heads
+ head_dim = getattr(hf_config, "head_dim",
+ hf_config.hidden_size // hf_config.num_attention_heads)
hidden_size = hf_config.hidden_size
return MambaStateShapeCalculator.cca_state_shape(
diff --git a/vllm/transformers_utils/configs/zaya.py b/vllm/transformers_utils/configs/zaya.py
index ff392811c..2a9577e67 100644
--- a/vllm/transformers_utils/configs/zaya.py
+++ b/vllm/transformers_utils/configs/zaya.py
@@ -20,6 +20,7 @@ class ZayaConfig(PretrainedConfig):
num_hidden_layers=80,
num_experts=16,
num_attention_heads=16,
+ head_dim=None,
activation_func='swiglu',
max_position_embeddings=4096,
norm_epsilon=1e-05,
@@ -67,8 +68,11 @@ class ZayaConfig(PretrainedConfig):
self.num_hidden_layers = num_hidden_layers
self.num_experts = num_experts
self.num_attention_heads = num_attention_heads
- assert self.hidden_size % self.num_attention_heads == 0
- self.kv_channels = self.hidden_size // self.num_attention_heads
+ if head_dim is None:
+ assert self.hidden_size % self.num_attention_heads == 0
+ head_dim = self.hidden_size // self.num_attention_heads
+ self.head_dim = head_dim
+ self.kv_channels = head_dim
self.num_key_value_heads = num_key_value_heads
self.activation_func = activation_func
self.max_position_embeddings = max_position_embeddings
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment