Skip to content

Commit 849c83a

Browse files
authored
[CI] test chunked prefill more (sgl-project#5798)
1 parent d73ddeb commit 849c83a

15 files changed

+212
-97
lines changed

.github/workflows/pr-test.yml

+1
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ jobs:
123123
timeout-minutes: 10
124124
run: |
125125
cd test/srt
126+
python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_bs1_small
126127
python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_bs1_default
127128
128129
- name: Benchmark online latency

docs/backend/server_arguments.md

+10-20
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,21 @@ Please consult the documentation below and [server_args.py](https://github.com/s
5454

5555
| Arguments | Description | Defaults |
5656
|----------|-------------|---------|
57-
| `model_path` | Path to the model that will be served. | None |
58-
| `tokenizer_path` | Defaults to the `model_path`. | None |
57+
| `model_path` | The path of the model weights. This can be a local folder or a Hugging Face repo ID. | None |
58+
| `tokenizer_path` | The path of the tokenizer. Defaults to the `model_path`. | None |
5959
| `tokenizer_mode` | See [different mode](https://huggingface.co/docs/transformers/en/main_classes/tokenizer). | `auto` |
60-
| `load_format` | The format the weights are loaded in. | `auto` |
61-
| `trust_remote_code` | If `true`, will use locally cached config files, otherwise use remote configs in HuggingFace. | `False` |
62-
| `dtype` | Dtype used for the model. | `bfloat16` |
63-
| `kv_cache_dtype` | Dtype of the kv cache. | `dtype` |
64-
| `context_length` | The number of tokens our model can process *including the input*. Note that extending the default might lead to strange behavior. | None |
60+
| `load_format` | The format of the model weights to load. | `auto` |
61+
| `trust_remote_code` | Whether or not to allow for custom models defined on the Hub in their own modeling files. | `False` |
62+
| `dtype` | Dtype used for the model. | `auto` |
63+
| `kv_cache_dtype` | Dtype of the kv cache. | `auto` |
64+
| `context_length` | The model's maximum context length. Defaults to None (will use the value from the model's config.json instead). Note that extending the default might lead to strange behavior. | None |
6565
| `device` | The device we put the model. | None |
66-
| `chat_template` | The chat template to use. See [multi-modal templates](https://docs.sglang.ai/backend/openai_api_vision.ipynb#Chat-Template). **Make sure the correct `chat_template` is passed, or performance degradation may occur!!!!** | None |
66+
| `device` | The device we put the model. | None |
67+
| `served_model_name` | Override the model name returned by the v1/models endpoint in OpenAI API server.| None |
6768
| `is_embedding` | Set to `true` to perform [embedding](./openai_api_embeddings.ipynb) / [encode](https://docs.sglang.ai/backend/native_api#Encode-(embedding-model)) and [reward](https://docs.sglang.ai/backend/native_api#Classify-(reward-model)) tasks. | `False` |
6869
| `revision` | Adjust if a specific version of the model should be used. | None |
6970
| `skip_tokenizer_init` | Set to `true` to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. See [example](https://github.com/sgl-project/sglang/blob/main/examples/runtime/token_in_token_out/). | `False` |
70-
| `json_model_override_args` | Override model config with the provided JSON. | `"{}"` |
71+
| `json_model_override_args` | A dictionary in JSON string format used to override default model configurations. | `"{}"` |
7172
| `disable_fast_image_processor` | Adopt base image processor instead of fast image processor (which is by default). See [details](https://huggingface.co/docs/transformers/main/en/main_classes/image_processor#image-processor). | `False` |
7273

7374
## Serving: HTTP & API
@@ -188,17 +189,6 @@ Please consult the documentation below and [server_args.py](https://github.com/s
188189
| `speculative_eagle_topk` | The number of top candidates we keep for verification at each step for [Eagle](https://arxiv.org/html/2406.16858v1). | None |
189190
| `speculative_token_map` | Optional, the path to the high frequency token list of [FR-Spec](https://arxiv.org/html/2502.14856v1), used for accelerating [Eagle](https://arxiv.org/html/2406.16858v1). | None |
190191

191-
## Double Sparsity
192-
193-
| Arguments | Description | Defaults |
194-
|----------|-------------|---------|
195-
| `enable_double_sparsity` | Enables [double sparsity](https://arxiv.org/html/2408.07092v2) which increases throughput. | `False` |
196-
| `ds_channel_config_path` | The double sparsity config. See [a guide on how to generate the config for your model](https://github.com/andy-yang-1/DoubleSparse/tree/main/config). | None |
197-
| `ds_heavy_channel_num` | Number of channel indices to keep for each layer. | `32` |
198-
| `ds_heavy_token_num` | Number of tokens used for attention during decode. Skip sparse decoding if `min_seq_len` in batch is less than this number. | `256` |
199-
| `ds_heavy_channel_type` | The type of heavy channels. Options are `q`, `k` or `qk`. | `qk` |
200-
| `ds_sparse_decode_threshold` | Don't apply sparse decoding if `max_seq_len` in batch < this threshold. | `4096` |
201-
202192
## Debug options
203193

204194
*Note: We recommend to stay with the defaults and only use these options for debugging for best possible performance.*

python/sglang/srt/model_executor/model_runner.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -975,7 +975,7 @@ def init_cuda_graphs(self):
975975
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
976976
logger.info(
977977
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
978-
f"avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
978+
f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
979979
)
980980

981981
def apply_torch_tp(self):

python/sglang/srt/server_args.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
426426
parser.add_argument(
427427
"--skip-tokenizer-init",
428428
action="store_true",
429-
help="If set, skip init tokenizer and pass input_ids in generate request",
429+
help="If set, skip init tokenizer and pass input_ids in generate request.",
430430
)
431431
parser.add_argument(
432432
"--enable-tokenizer-batch-encode",
@@ -565,6 +565,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
565565
"name, a tag name, or a commit id. If unspecified, will use "
566566
"the default version.",
567567
)
568+
568569
# Memory and scheduling
569570
parser.add_argument(
570571
"--mem-fraction-static",

python/sglang/test/send_one.py

+84-28
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,56 @@
66
"""
77

88
import argparse
9+
import dataclasses
910
import json
1011

1112
import requests
1213

1314

15+
@dataclasses.dataclass
16+
class BenchArgs:
17+
host: str = "localhost"
18+
port: int = 30000
19+
batch_size: int = 1
20+
temperature: float = 0.0
21+
max_new_tokens: int = 512
22+
frequency_penalty: float = 0.0
23+
presence_penalty: float = 0.0
24+
json: bool = False
25+
return_logprob: bool = False
26+
prompt: str = (
27+
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:"
28+
)
29+
image: bool = False
30+
stream: bool = False
31+
32+
@staticmethod
33+
def add_cli_args(parser: argparse.ArgumentParser):
34+
parser.add_argument("--host", type=str, default=BenchArgs.host)
35+
parser.add_argument("--port", type=int, default=BenchArgs.port)
36+
parser.add_argument("--batch-size", type=int, default=BenchArgs.batch_size)
37+
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
38+
parser.add_argument(
39+
"--max-new-tokens", type=int, default=BenchArgs.max_new_tokens
40+
)
41+
parser.add_argument(
42+
"--frequency-penalty", type=float, default=BenchArgs.frequency_penalty
43+
)
44+
parser.add_argument(
45+
"--presence-penalty", type=float, default=BenchArgs.presence_penalty
46+
)
47+
parser.add_argument("--json", action="store_true")
48+
parser.add_argument("--return-logprob", action="store_true")
49+
parser.add_argument("--prompt", type=str, default=BenchArgs.prompt)
50+
parser.add_argument("--image", action="store_true")
51+
parser.add_argument("--stream", action="store_true")
52+
53+
@classmethod
54+
def from_cli_args(cls, args: argparse.Namespace):
55+
attrs = [attr.name for attr in dataclasses.fields(cls)]
56+
return cls(**{attr: getattr(args, attr) for attr in attrs})
57+
58+
1459
def send_one_prompt(args):
1560
if args.image:
1661
args.prompt = (
@@ -20,20 +65,42 @@ def send_one_prompt(args):
2065
else:
2166
image_data = None
2267

23-
response = requests.post(
24-
"http://localhost:30000/generate",
25-
json={
26-
"text": args.prompt,
27-
"image_data": image_data,
28-
"sampling_params": {
29-
"temperature": args.temperature,
30-
"max_new_tokens": args.max_new_tokens,
31-
"frequency_penalty": args.frequency_penalty,
32-
"presence_penalty": args.presence_penalty,
33-
},
34-
"return_logprob": args.return_logprob,
35-
"stream": args.stream,
68+
prompt = args.prompt
69+
70+
if args.json:
71+
prompt = (
72+
"Human: What is the capital of France and how is that city like. "
73+
"Give me 3 trivial information about that city. "
74+
"Write in a format of json.\nAssistant:"
75+
)
76+
json_schema = "$$ANY$$"
77+
json_schema = (
78+
'{"type": "object", "properties": {"population": {"type": "integer"}}}'
79+
)
80+
else:
81+
json_schema = None
82+
83+
if args.batch_size > 1:
84+
prompt = [prompt] * args.batch_size
85+
86+
json_data = {
87+
"text": prompt,
88+
"image_data": image_data,
89+
"sampling_params": {
90+
"temperature": args.temperature,
91+
"max_new_tokens": args.max_new_tokens,
92+
"frequency_penalty": args.frequency_penalty,
93+
"presence_penalty": args.presence_penalty,
94+
"json_schema": json_schema,
95+
"stop": ["Question", "Assistant:", "<|separator|>", "<|eos|>"],
3696
},
97+
"return_logprob": args.return_logprob,
98+
"stream": args.stream,
99+
}
100+
101+
response = requests.post(
102+
f"http://{args.host}:{args.port}/generate",
103+
json=json_data,
37104
stream=args.stream,
38105
)
39106

@@ -47,6 +114,9 @@ def send_one_prompt(args):
47114
else:
48115
ret = response.json()
49116

117+
if args.batch_size > 1:
118+
ret = ret[0]
119+
50120
latency = ret["meta_info"]["e2e_latency"]
51121

52122
if "spec_verify_ct" in ret["meta_info"]:
@@ -68,21 +138,7 @@ def send_one_prompt(args):
68138

69139
if __name__ == "__main__":
70140
parser = argparse.ArgumentParser()
71-
parser.add_argument("--temperature", type=float, default=0.0)
72-
parser.add_argument("--max-new-tokens", type=int, default=512)
73-
parser.add_argument("--frequency-penalty", type=float, default=0.0)
74-
parser.add_argument("--presence-penalty", type=float, default=0.0)
75-
parser.add_argument("--return-logprob", action="store_true")
76-
parser.add_argument(
77-
"--prompt",
78-
type=str,
79-
default="Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:",
80-
)
81-
parser.add_argument(
82-
"--image",
83-
action="store_true",
84-
)
85-
parser.add_argument("--stream", action="store_true")
141+
BenchArgs.add_cli_args(parser)
86142
args = parser.parse_args()
87143

88144
send_one_prompt(args)

python/sglang/test/test_utils.py

+38
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,44 @@ def run_bench_one_batch(model, other_args):
732732
return output_throughput
733733

734734

735+
def run_bench_offline_throughput(model, other_args):
736+
command = [
737+
"python3",
738+
"-m",
739+
"sglang.bench_offline_throughput",
740+
"--num-prompts",
741+
"1",
742+
"--dataset-name",
743+
"random",
744+
"--random-input-len",
745+
"256",
746+
"--random-output-len",
747+
"256",
748+
"--model-path",
749+
model,
750+
*[str(x) for x in other_args],
751+
]
752+
753+
print(f"{command=}")
754+
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
755+
756+
try:
757+
stdout, stderr = process.communicate()
758+
output = stdout.decode()
759+
error = stderr.decode()
760+
print(f"Output: {output}", flush=True)
761+
print(f"Error: {error}", flush=True)
762+
763+
output_throughput = -1
764+
for line in output.split("\n"):
765+
if "Last generation throughput (tok/s):" in line:
766+
output_throughput = float(line.split(":")[-1])
767+
finally:
768+
kill_process_tree(process.pid)
769+
770+
return output_throughput
771+
772+
735773
def lcs(X, Y):
736774
m = len(X)
737775
n = len(Y)

test/srt/models/test_dummy_grok_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_dummy_grok_1(self):
2626
)
2727

2828
if is_in_ci():
29-
assert output_throughput > 0, f"{output_throughput=}"
29+
self.assertGreater(output_throughput, 0)
3030

3131

3232
if __name__ == "__main__":

test/srt/models/test_vlm_models.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def run_mmmu_eval(
6464
model = "openai_compatible"
6565
tp = 1
6666
tasks = "mmmu_val"
67-
batch_size = 1
67+
batch_size = 2
6868
log_suffix = "openai_compatible"
6969
os.makedirs(output_path, exist_ok=True)
7070

@@ -125,6 +125,9 @@ def test_vlm_mmmu_benchmark(self):
125125
"--chat-template",
126126
model.chat_template,
127127
"--trust-remote-code",
128+
"--cuda-graph-max-bs",
129+
"32",
130+
"--enable-multimodal",
128131
"--mem-fraction-static",
129132
str(self.parsed_args.mem_fraction_static), # Use class variable
130133
],
@@ -171,7 +174,7 @@ def test_vlm_mmmu_benchmark(self):
171174
"--mem-fraction-static",
172175
type=float,
173176
help="Static memory fraction for the model",
174-
default=0.6,
177+
default=0.8,
175178
)
176179

177180
# Parse args intended for unittest

test/srt/test_bench_one_batch.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,28 @@
33
from sglang.test.test_utils import (
44
DEFAULT_MODEL_NAME_FOR_TEST,
55
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
6+
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
67
CustomTestCase,
78
is_in_ci,
9+
run_bench_offline_throughput,
810
run_bench_one_batch,
911
write_github_step_summary,
1012
)
1113

14+
# We use `run_bench_offline_throughput`` instead of `run_bench_one_batch` for most cases
15+
# because `run_bench_offline_throughput`` has overlap scheduler.
16+
1217

1318
class TestBenchOneBatch(CustomTestCase):
14-
def test_bs1_default(self):
19+
20+
def test_bs1_small(self):
1521
output_throughput = run_bench_one_batch(
22+
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, ["--cuda-graph-max-bs", "2"]
23+
)
24+
self.assertGreater(output_throughput, 50)
25+
26+
def test_bs1_default(self):
27+
output_throughput = run_bench_offline_throughput(
1628
DEFAULT_MODEL_NAME_FOR_TEST, ["--cuda-graph-max-bs", "2"]
1729
)
1830

@@ -24,26 +36,26 @@ def test_bs1_default(self):
2436
self.assertGreater(output_throughput, 135)
2537

2638
def test_moe_tp2_bs1(self):
27-
output_throughput = run_bench_one_batch(
39+
output_throughput = run_bench_offline_throughput(
2840
DEFAULT_MOE_MODEL_NAME_FOR_TEST, ["--tp", "2", "--cuda-graph-max-bs", "2"]
2941
)
3042

3143
if is_in_ci():
3244
write_github_step_summary(
33-
f"### test_moe_tp2_bs1\n"
45+
f"### test_moe_tp2_bs1 (Mixtral-8x7B)\n"
3446
f"output_throughput: {output_throughput:.2f} token/s\n"
3547
)
3648
self.assertGreater(output_throughput, 125)
3749

3850
def test_torch_compile_tp2_bs1(self):
39-
output_throughput = run_bench_one_batch(
51+
output_throughput = run_bench_offline_throughput(
4052
DEFAULT_MODEL_NAME_FOR_TEST,
4153
["--tp", "2", "--enable-torch-compile", "--cuda-graph-max-bs", "2"],
4254
)
4355

4456
if is_in_ci():
4557
write_github_step_summary(
46-
f"### test_torch_compile_tp2_bs1\n"
58+
f"### test_torch_compile_tp2_bs1 (Mixtral-8x7B)\n"
4759
f"output_throughput: {output_throughput:.2f} token/s\n"
4860
)
4961
self.assertGreater(output_throughput, 220)

0 commit comments

Comments
 (0)