Skip to content

Commit b855976

Browse files
authored
[Test] Add flashmla attention backend test (sgl-project#5587)
1 parent 56f6589 commit b855976

File tree

3 files changed

+68
-0
lines changed

3 files changed

+68
-0
lines changed

scripts/ci_install_dependency.sh

+3
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,6 @@ pip install cuda-python nvidia-cuda-nvrtc-cu12
3131
# For lmms_evals evaluating MMMU
3232
git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git
3333
pip install -e lmms-eval/
34+
35+
# Install FlashMLA for attention backend tests
36+
pip install git+https://github.com/deepseek-ai/FlashMLA.git

test/srt/run_suite.py

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class TestFile:
5151
TestFile("test_mla_int8_deepseek_v3.py", 389),
5252
TestFile("test_mla_flashinfer.py", 395),
5353
TestFile("test_mla_fp8.py", 153),
54+
TestFile("test_flash_mla_attention_backend.py", 300),
5455
TestFile("test_no_chunked_prefill.py", 108),
5556
TestFile("test_no_overlap_scheduler.py", 216),
5657
TestFile("test_openai_server.py", 149),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""
2+
Usage:
3+
python3 -m unittest test_flash_mla_attention_backend.TestFlashMLAAttnBackend.test_mmlu
4+
"""
5+
6+
import unittest
7+
from types import SimpleNamespace
8+
9+
from sglang.srt.utils import kill_process_tree
10+
from sglang.test.run_eval import run_eval
11+
from sglang.test.test_utils import (
12+
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
13+
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
14+
DEFAULT_URL_FOR_TEST,
15+
is_in_ci,
16+
popen_launch_server,
17+
run_bench_one_batch,
18+
)
19+
20+
21+
class TestFlashMLAAttnBackend(unittest.TestCase):
22+
def test_latency(self):
23+
output_throughput = run_bench_one_batch(
24+
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
25+
[
26+
"--attention-backend",
27+
"flashmla",
28+
"--enable-torch-compile",
29+
"--cuda-graph-max-bs",
30+
"16",
31+
"--trust-remote-code",
32+
],
33+
)
34+
35+
if is_in_ci():
36+
self.assertGreater(output_throughput, 153)
37+
38+
def test_mmlu(self):
39+
model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
40+
base_url = DEFAULT_URL_FOR_TEST
41+
process = popen_launch_server(
42+
model,
43+
base_url,
44+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
45+
other_args=["--attention-backend", "flashmla", "--trust-remote-code"],
46+
)
47+
48+
try:
49+
args = SimpleNamespace(
50+
base_url=base_url,
51+
model=model,
52+
eval_name="mmlu",
53+
num_examples=64,
54+
num_threads=32,
55+
)
56+
57+
metrics = run_eval(args)
58+
self.assertGreaterEqual(metrics["score"], 0.2)
59+
finally:
60+
kill_process_tree(process.pid)
61+
62+
63+
if __name__ == "__main__":
64+
unittest.main()

0 commit comments

Comments
 (0)