Skip to content

Instantly share code, notes, and snippets.

@hashkanna
Created January 28, 2026 06:15
Show Gist options
  • Select an option

  • Save hashkanna/46d540b98b6dd2bb6829f1be0b553174 to your computer and use it in GitHub Desktop.

Select an option

Save hashkanna/46d540b98b6dd2bb6829f1be0b553174 to your computer and use it in GitHub Desktop.
Test Plan for /v1/score API (SGLang-JAX)
Test Plan for /v1/score API (SGLang-JAX)
Overview
The JAX version currently lacks dedicated tests for the /v1/score API, while the PyTorch version has comprehensive coverage (14 test methods + integration tests). This plan
aims to achieve feature parity and add JAX-specific tests.
---
Phase 1: Core Functionality Tests
Location: test/srt/test_score_api.py
Test Class: TestScoreAPI
1.1 Accuracy & Consistency Tests
test_score_consistency_with_hf
- Compare JAX scoring results against HuggingFace reference implementation
- Test cases:
- Default case: query="I pledge allegiance", items=["", " to"]
- Item-first case: query=" is a city", items=["Tokyo", "Japan"], item_first=True
- Validation: 1% relative tolerance (matching PyTorch)
- JAX-specific: Verify on-device logprob extraction produces correct results
test_score_numerical_stability (NEW - JAX specific)
- Test with bfloat16 vs float32 precision
- Verify numerical stability across different batch sizes
- Check that JAX sharding doesn't affect accuracy
1.2 Batch Processing Tests
test_score_batch_handling
- Test batch sizes: 1, 2, 4, 8, 16, 32
- Validate:
- Correct number of outputs
- Proper list structure: list[list[float]]
- All scores are floats in [0, 1]
- Scores sum to 1.0 (when apply_softmax=True)
test_score_large_batch
- Test with 20+ items (stress test)
- Verify memory efficiency
test_score_prefill_only_optimization (NEW - JAX specific)
- Verify requests use max_new_tokens=0 internally
- Confirm is_prefill_only=True flag is set
- Check that no decode pass occurs (performance optimization)
1.3 Input Format Tests
test_score_text_input
- Query: string, Items: list of strings
- Label token IDs: [1, 2, 3]
test_score_token_input
- Query: list[int], Items: list[list[int]]
- Pre-tokenized inputs
test_score_mixed_input
- Mix text and token inputs where allowed
test_score_single_item
- Edge case: 1 item only
test_score_empty_item
- Edge case: Empty string in items list
1.4 Feature Tests
test_score_apply_softmax
- Test apply_softmax=True (normalized probabilities)
- Test apply_softmax=False (raw logprobs)
- Verify difference between modes
test_score_item_first
- Test item_first=False: f"{query}{item}"
- Test item_first=True: f"{item}{query}"
- Verify results differ appropriately
test_score_different_label_tokens
- Test with 1, 2, 4, 8, 16 label token IDs
- Verify output dimensions match input
test_score_unicode_multilingual
- Test with Unicode characters
- Test with non-English languages
- Ensure tokenizer handles correctly
1.5 Error Handling Tests
test_score_invalid_token_ids
- Invalid token IDs (negative, > vocab_size)
- Should return error or handle gracefully
test_score_missing_parameters
- Missing label_token_ids
- Missing query or items
- Should raise ValidationError
test_score_empty_inputs
- Empty items list
- Empty label_token_ids
---
Phase 2: HTTP Endpoint Tests
Location: test/srt/openai_server/basic/test_openai_server.py
Test Class: TestOpenAIV1Score (Add to existing file)
2.1 HTTP Integration Tests
test_v1_score_text_input
def test_v1_score_text_input(self):
response = requests.post(
f"{self.base_url}/v1/score",
json={
"model": self.model,
"query": "The capital of France is",
"items": ["Paris", "London", "Berlin"],
"label_token_ids": [9454, 2753], # Yes/No tokens
"apply_softmax": True,
},
)
self.assertEqual(response.status_code, 200)
data = response.json()
self.assertEqual(data["object"], "scoring")
self.assertEqual(len(data["scores"]), 3) # 3 items
test_v1_score_token_input
- Pre-tokenized inputs via HTTP
test_v1_score_error_handling
- Test 400 errors for invalid inputs
- Test 422 for validation errors
test_v1_score_usage_info
- Verify usage field is populated correctly
- Check prompt_tokens, completion_tokens, total_tokens
2.2 OpenAI Client Tests
test_score_with_openai_client (if supported)
- Use OpenAI Python client library
- Verify compatibility
---
Phase 3: Protocol Validation Tests
Location: test/srt/openai_server/basic/test_protocol.py
Test Class: TestScoringProtocol (Add to existing file)
test_scoring_request_validation
def test_scoring_request_validation(self):
# Valid request
request = ScoringRequest(
model="test-model",
query="test query",
items=["item1", "item2"],
label_token_ids=[1, 2, 3],
apply_softmax=True,
item_first=False,
)
self.assertEqual(request.query, "test query")
# Missing required fields
with self.assertRaises(ValidationError):
ScoringRequest(model="test-model")
test_scoring_response_serialization
- Test JSON serialization/deserialization
- Verify exclude_none=True behavior
test_scoring_request_type_hints
- Test modern Python 3.10+ type hints work correctly
- str | list[int] | None vs Optional[Union[str, List[int]]]
---
Phase 4: JAX-Specific Tests
Location: test/srt/test_score_api_jax_features.py (NEW)
Test Class: TestScoreAPIJAXFeatures
4.1 On-Device Extraction Tests
test_score_on_device_extraction
- Verify logprobs.at[i, token_ids].get() pattern is used
- Check sharding is correct (NamedSharding)
- Validate no unnecessary host transfers
test_score_sharding_correctness
- Test with TPU/multi-device setup
- Verify scores are identical across sharding strategies
4.2 Performance Tests
test_score_prefill_only_performance
- Benchmark with/without prefill-only optimization
- Verify no decode phase occurs
- Measure latency improvement
test_score_memory_efficiency
- Verify only target token logprobs are extracted (not full vocab)
- Memory usage: O(N) vs O(vocab_size)
4.3 Integration with Radix Cache
test_score_with_prefix_caching
- Run multiple scoring requests with same prefix
- Verify radix cache is utilized
- Check cache_miss_count metrics
test_score_parallel_sampling_caching
- Test parallel sampling use case (tokenizer_manager.py:545-554)
- Verify prefix is cached once, reused for multiple samples
---
Phase 5: Benchmark & Utilities
Location: test/srt/bench_score.py (NEW)
Performance Benchmarking Tool
# Measure throughput
# - Requests per second
# - Latency (p50, p90, p99)
# - Token processing rate
# Test scenarios:
# - Varying batch sizes: 1, 4, 16, 64
# - Varying label token counts: 2, 8, 32
# - Different query lengths: 50, 200, 1000 tokens
---
Test Infrastructure Setup
File Structure
test/srt/
├── test_score_api.py # NEW: Core functionality tests
├── test_score_api_jax_features.py # NEW: JAX-specific tests
├── bench_score.py # NEW: Performance benchmarks
└── openai_server/basic/
├── test_openai_server.py # MODIFIED: Add TestOpenAIV1Score class
└── test_protocol.py # MODIFIED: Add TestScoringProtocol class
Test Utilities (in test_score_api.py)
def compute_hf_reference_scores(
query: str,
items: list[str],
label_token_ids: list[int],
model_name: str,
apply_softmax: bool = True,
item_first: bool = False,
) -> list[list[float]]:
"""Generate reference scores using HuggingFace."""
# Load HF model
# Compute logits for last token
# Extract target token logprobs
# Apply softmax if needed
# Return scores
pass
def compare_scores(
reference_scores: list[list[float]],
test_scores: list[list[float]],
tolerance: float = 0.01,
) -> None:
"""Compare scores with relative tolerance."""
# Assert lengths match
# Assert individual scores within tolerance
# Assert scores in [0, 1] range
# Assert scores sum to 1.0 (if softmax)
pass
---
Implementation Priority
Tier 1 - Critical (Implement first)
1. test_score_consistency_with_hf - Accuracy validation
2. test_score_batch_handling - Core functionality
3. test_score_text_input / test_score_token_input - Input formats
4. test_v1_score_text_input - HTTP endpoint
5. test_score_invalid_token_ids - Error handling
Tier 2 - Important (Implement second)
6. test_score_apply_softmax - Feature coverage
7. test_score_item_first - Feature coverage
8. test_score_prefill_only_optimization - JAX optimization
9. test_score_different_label_tokens - Edge cases
10. test_scoring_request_validation - Protocol validation
Tier 3 - Nice to Have (Implement third)
11. test_score_numerical_stability - JAX precision
12. test_score_with_prefix_caching - Integration
13. test_score_unicode_multilingual - I18n
14. Performance benchmarks
15. Sharding correctness tests
---
Test Configuration
Models to Test
- Primary: DEFAULT_MODEL_NAME_FOR_TEST (Qwen2.5-0.5B-Instruct)
- Secondary: Larger model for stress tests (if needed)
Test Parameters
- Label tokens: Use [9454, 2753] (Yes/No for Qwen models)
- Tolerance: 1% relative difference (matching PyTorch)
- Batch sizes: 1, 2, 4, 8, 16, 32
- Query lengths: Short (10-50), Medium (100-200), Long (500-1000) tokens
CI Integration
- Add to test suites in test/srt/run_suite.py
- Register for TPU CI runs
- Estimated time: 5-10 minutes
---
Expected Outcomes
1. Feature Parity: Match PyTorch test coverage (14+ tests)
2. JAX Advantages: Demonstrate JAX-specific optimizations
3. Confidence: Validate accuracy against HuggingFace reference
4. Regression Prevention: Catch bugs early in CI/CD
5. Documentation: Tests serve as usage examples
---
Success Criteria
- ✅ All Tier 1 tests passing
- ✅ Accuracy within 1% of HuggingFace reference
- ✅ HTTP endpoint tests passing
- ✅ Protocol validation tests passing
- ✅ Performance benchmarks showing prefill-only optimization works
- ✅ Tests integrated into CI pipeline
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment