Created
January 28, 2026 06:15
-
-
Save hashkanna/46d540b98b6dd2bb6829f1be0b553174 to your computer and use it in GitHub Desktop.
Test Plan for /v1/score API (SGLang-JAX)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Test Plan for /v1/score API (SGLang-JAX) | |
| Overview | |
| The JAX version currently lacks dedicated tests for the /v1/score API, while the PyTorch version has comprehensive coverage (14 test methods + integration tests). This plan | |
| aims to achieve feature parity and add JAX-specific tests. | |
| --- | |
| Phase 1: Core Functionality Tests | |
| Location: test/srt/test_score_api.py | |
| Test Class: TestScoreAPI | |
| 1.1 Accuracy & Consistency Tests | |
| test_score_consistency_with_hf | |
| - Compare JAX scoring results against HuggingFace reference implementation | |
| - Test cases: | |
| - Default case: query="I pledge allegiance", items=["", " to"] | |
| - Item-first case: query=" is a city", items=["Tokyo", "Japan"], item_first=True | |
| - Validation: 1% relative tolerance (matching PyTorch) | |
| - JAX-specific: Verify on-device logprob extraction produces correct results | |
| test_score_numerical_stability (NEW - JAX specific) | |
| - Test with bfloat16 vs float32 precision | |
| - Verify numerical stability across different batch sizes | |
| - Check that JAX sharding doesn't affect accuracy | |
| 1.2 Batch Processing Tests | |
| test_score_batch_handling | |
| - Test batch sizes: 1, 2, 4, 8, 16, 32 | |
| - Validate: | |
| - Correct number of outputs | |
| - Proper list structure: list[list[float]] | |
| - All scores are floats in [0, 1] | |
| - Scores sum to 1.0 (when apply_softmax=True) | |
| test_score_large_batch | |
| - Test with 20+ items (stress test) | |
| - Verify memory efficiency | |
| test_score_prefill_only_optimization (NEW - JAX specific) | |
| - Verify requests use max_new_tokens=0 internally | |
| - Confirm is_prefill_only=True flag is set | |
| - Check that no decode pass occurs (performance optimization) | |
| 1.3 Input Format Tests | |
| test_score_text_input | |
| - Query: string, Items: list of strings | |
| - Label token IDs: [1, 2, 3] | |
| test_score_token_input | |
| - Query: list[int], Items: list[list[int]] | |
| - Pre-tokenized inputs | |
| test_score_mixed_input | |
| - Mix text and token inputs where allowed | |
| test_score_single_item | |
| - Edge case: 1 item only | |
| test_score_empty_item | |
| - Edge case: Empty string in items list | |
| 1.4 Feature Tests | |
| test_score_apply_softmax | |
| - Test apply_softmax=True (normalized probabilities) | |
| - Test apply_softmax=False (raw logprobs) | |
| - Verify difference between modes | |
| test_score_item_first | |
| - Test item_first=False: f"{query}{item}" | |
| - Test item_first=True: f"{item}{query}" | |
| - Verify results differ appropriately | |
| test_score_different_label_tokens | |
| - Test with 1, 2, 4, 8, 16 label token IDs | |
| - Verify output dimensions match input | |
| test_score_unicode_multilingual | |
| - Test with Unicode characters | |
| - Test with non-English languages | |
| - Ensure tokenizer handles correctly | |
| 1.5 Error Handling Tests | |
| test_score_invalid_token_ids | |
| - Invalid token IDs (negative, > vocab_size) | |
| - Should return error or handle gracefully | |
| test_score_missing_parameters | |
| - Missing label_token_ids | |
| - Missing query or items | |
| - Should raise ValidationError | |
| test_score_empty_inputs | |
| - Empty items list | |
| - Empty label_token_ids | |
| --- | |
| Phase 2: HTTP Endpoint Tests | |
| Location: test/srt/openai_server/basic/test_openai_server.py | |
| Test Class: TestOpenAIV1Score (Add to existing file) | |
| 2.1 HTTP Integration Tests | |
| test_v1_score_text_input | |
| def test_v1_score_text_input(self): | |
| response = requests.post( | |
| f"{self.base_url}/v1/score", | |
| json={ | |
| "model": self.model, | |
| "query": "The capital of France is", | |
| "items": ["Paris", "London", "Berlin"], | |
| "label_token_ids": [9454, 2753], # Yes/No tokens | |
| "apply_softmax": True, | |
| }, | |
| ) | |
| self.assertEqual(response.status_code, 200) | |
| data = response.json() | |
| self.assertEqual(data["object"], "scoring") | |
| self.assertEqual(len(data["scores"]), 3) # 3 items | |
| test_v1_score_token_input | |
| - Pre-tokenized inputs via HTTP | |
| test_v1_score_error_handling | |
| - Test 400 errors for invalid inputs | |
| - Test 422 for validation errors | |
| test_v1_score_usage_info | |
| - Verify usage field is populated correctly | |
| - Check prompt_tokens, completion_tokens, total_tokens | |
| 2.2 OpenAI Client Tests | |
| test_score_with_openai_client (if supported) | |
| - Use OpenAI Python client library | |
| - Verify compatibility | |
| --- | |
| Phase 3: Protocol Validation Tests | |
| Location: test/srt/openai_server/basic/test_protocol.py | |
| Test Class: TestScoringProtocol (Add to existing file) | |
| test_scoring_request_validation | |
| def test_scoring_request_validation(self): | |
| # Valid request | |
| request = ScoringRequest( | |
| model="test-model", | |
| query="test query", | |
| items=["item1", "item2"], | |
| label_token_ids=[1, 2, 3], | |
| apply_softmax=True, | |
| item_first=False, | |
| ) | |
| self.assertEqual(request.query, "test query") | |
| # Missing required fields | |
| with self.assertRaises(ValidationError): | |
| ScoringRequest(model="test-model") | |
| test_scoring_response_serialization | |
| - Test JSON serialization/deserialization | |
| - Verify exclude_none=True behavior | |
| test_scoring_request_type_hints | |
| - Test modern Python 3.10+ type hints work correctly | |
| - str | list[int] | None vs Optional[Union[str, List[int]]] | |
| --- | |
| Phase 4: JAX-Specific Tests | |
| Location: test/srt/test_score_api_jax_features.py (NEW) | |
| Test Class: TestScoreAPIJAXFeatures | |
| 4.1 On-Device Extraction Tests | |
| test_score_on_device_extraction | |
| - Verify logprobs.at[i, token_ids].get() pattern is used | |
| - Check sharding is correct (NamedSharding) | |
| - Validate no unnecessary host transfers | |
| test_score_sharding_correctness | |
| - Test with TPU/multi-device setup | |
| - Verify scores are identical across sharding strategies | |
| 4.2 Performance Tests | |
| test_score_prefill_only_performance | |
| - Benchmark with/without prefill-only optimization | |
| - Verify no decode phase occurs | |
| - Measure latency improvement | |
| test_score_memory_efficiency | |
| - Verify only target token logprobs are extracted (not full vocab) | |
| - Memory usage: O(N) vs O(vocab_size) | |
| 4.3 Integration with Radix Cache | |
| test_score_with_prefix_caching | |
| - Run multiple scoring requests with same prefix | |
| - Verify radix cache is utilized | |
| - Check cache_miss_count metrics | |
| test_score_parallel_sampling_caching | |
| - Test parallel sampling use case (tokenizer_manager.py:545-554) | |
| - Verify prefix is cached once, reused for multiple samples | |
| --- | |
| Phase 5: Benchmark & Utilities | |
| Location: test/srt/bench_score.py (NEW) | |
| Performance Benchmarking Tool | |
| # Measure throughput | |
| # - Requests per second | |
| # - Latency (p50, p90, p99) | |
| # - Token processing rate | |
| # Test scenarios: | |
| # - Varying batch sizes: 1, 4, 16, 64 | |
| # - Varying label token counts: 2, 8, 32 | |
| # - Different query lengths: 50, 200, 1000 tokens | |
| --- | |
| Test Infrastructure Setup | |
| File Structure | |
| test/srt/ | |
| ├── test_score_api.py # NEW: Core functionality tests | |
| ├── test_score_api_jax_features.py # NEW: JAX-specific tests | |
| ├── bench_score.py # NEW: Performance benchmarks | |
| └── openai_server/basic/ | |
| ├── test_openai_server.py # MODIFIED: Add TestOpenAIV1Score class | |
| └── test_protocol.py # MODIFIED: Add TestScoringProtocol class | |
| Test Utilities (in test_score_api.py) | |
| def compute_hf_reference_scores( | |
| query: str, | |
| items: list[str], | |
| label_token_ids: list[int], | |
| model_name: str, | |
| apply_softmax: bool = True, | |
| item_first: bool = False, | |
| ) -> list[list[float]]: | |
| """Generate reference scores using HuggingFace.""" | |
| # Load HF model | |
| # Compute logits for last token | |
| # Extract target token logprobs | |
| # Apply softmax if needed | |
| # Return scores | |
| pass | |
| def compare_scores( | |
| reference_scores: list[list[float]], | |
| test_scores: list[list[float]], | |
| tolerance: float = 0.01, | |
| ) -> None: | |
| """Compare scores with relative tolerance.""" | |
| # Assert lengths match | |
| # Assert individual scores within tolerance | |
| # Assert scores in [0, 1] range | |
| # Assert scores sum to 1.0 (if softmax) | |
| pass | |
| --- | |
| Implementation Priority | |
| Tier 1 - Critical (Implement first) | |
| 1. test_score_consistency_with_hf - Accuracy validation | |
| 2. test_score_batch_handling - Core functionality | |
| 3. test_score_text_input / test_score_token_input - Input formats | |
| 4. test_v1_score_text_input - HTTP endpoint | |
| 5. test_score_invalid_token_ids - Error handling | |
| Tier 2 - Important (Implement second) | |
| 6. test_score_apply_softmax - Feature coverage | |
| 7. test_score_item_first - Feature coverage | |
| 8. test_score_prefill_only_optimization - JAX optimization | |
| 9. test_score_different_label_tokens - Edge cases | |
| 10. test_scoring_request_validation - Protocol validation | |
| Tier 3 - Nice to Have (Implement third) | |
| 11. test_score_numerical_stability - JAX precision | |
| 12. test_score_with_prefix_caching - Integration | |
| 13. test_score_unicode_multilingual - I18n | |
| 14. Performance benchmarks | |
| 15. Sharding correctness tests | |
| --- | |
| Test Configuration | |
| Models to Test | |
| - Primary: DEFAULT_MODEL_NAME_FOR_TEST (Qwen2.5-0.5B-Instruct) | |
| - Secondary: Larger model for stress tests (if needed) | |
| Test Parameters | |
| - Label tokens: Use [9454, 2753] (Yes/No for Qwen models) | |
| - Tolerance: 1% relative difference (matching PyTorch) | |
| - Batch sizes: 1, 2, 4, 8, 16, 32 | |
| - Query lengths: Short (10-50), Medium (100-200), Long (500-1000) tokens | |
| CI Integration | |
| - Add to test suites in test/srt/run_suite.py | |
| - Register for TPU CI runs | |
| - Estimated time: 5-10 minutes | |
| --- | |
| Expected Outcomes | |
| 1. Feature Parity: Match PyTorch test coverage (14+ tests) | |
| 2. JAX Advantages: Demonstrate JAX-specific optimizations | |
| 3. Confidence: Validate accuracy against HuggingFace reference | |
| 4. Regression Prevention: Catch bugs early in CI/CD | |
| 5. Documentation: Tests serve as usage examples | |
| --- | |
| Success Criteria | |
| - ✅ All Tier 1 tests passing | |
| - ✅ Accuracy within 1% of HuggingFace reference | |
| - ✅ HTTP endpoint tests passing | |
| - ✅ Protocol validation tests passing | |
| - ✅ Performance benchmarks showing prefill-only optimization works | |
| - ✅ Tests integrated into CI pipeline |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment