Skip to content

Commit a3e4e9b

Browse files
authored
Better PD initialization (sgl-project#5751)
1 parent 6d4d3bc commit a3e4e9b

File tree

5 files changed

+141
-25
lines changed

5 files changed

+141
-25
lines changed

python/sglang/srt/disaggregation/mini_lb.py

+74-23
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,45 @@
33
"""
44

55
import asyncio
6+
import dataclasses
7+
import logging
68
import random
79
import urllib
810
from itertools import chain
9-
from typing import List
11+
from typing import List, Optional
1012

1113
import aiohttp
1214
import orjson
1315
import uvicorn
1416
from fastapi import FastAPI, HTTPException
1517
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
1618

19+
from sglang.srt.disaggregation.utils import PDRegistryRequest
1720

21+
22+
def setup_logger():
23+
logger = logging.getLogger("pdlb")
24+
logger.setLevel(logging.INFO)
25+
26+
formatter = logging.Formatter(
27+
"[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s",
28+
datefmt="%Y-%m-%d %H:%M:%S",
29+
)
30+
31+
handler = logging.StreamHandler()
32+
handler.setFormatter(formatter)
33+
logger.addHandler(handler)
34+
35+
return logger
36+
37+
38+
logger = setup_logger()
39+
40+
41+
@dataclasses.dataclass
1842
class PrefillConfig:
19-
def __init__(self, url: str, bootstrap_port: int):
20-
self.url = url
21-
self.bootstrap_port = bootstrap_port
43+
url: str
44+
bootstrap_port: Optional[int] = None
2245

2346

2447
class MiniLoadBalancer:
@@ -28,6 +51,10 @@ def __init__(self, prefill_configs: List[PrefillConfig], decode_servers: List[st
2851
self.decode_servers = decode_servers
2952

3053
def select_pair(self):
54+
# TODO: return some message instead of panic
55+
assert len(self.prefill_configs) > 0, "No prefill servers available"
56+
assert len(self.decode_servers) > 0, "No decode servers available"
57+
3158
prefill_config = random.choice(self.prefill_configs)
3259
decode_server = random.choice(self.decode_servers)
3360
return prefill_config.url, prefill_config.bootstrap_port, decode_server
@@ -47,7 +74,7 @@ async def generate(
4774
session.post(f"{decode_server}/{endpoint}", json=modified_request),
4875
]
4976
# Wait for both responses to complete. Prefill should end first.
50-
prefill_response, decode_response = await asyncio.gather(*tasks)
77+
_, decode_response = await asyncio.gather(*tasks)
5178

5279
return ORJSONResponse(
5380
content=await decode_response.json(),
@@ -268,6 +295,32 @@ async def get_models():
268295
raise HTTPException(status_code=500, detail=str(e))
269296

270297

298+
@app.post("/register")
299+
async def register(obj: PDRegistryRequest):
300+
if obj.mode == "prefill":
301+
load_balancer.prefill_configs.append(
302+
PrefillConfig(obj.registry_url, obj.bootstrap_port)
303+
)
304+
logger.info(
305+
f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}"
306+
)
307+
elif obj.mode == "decode":
308+
load_balancer.decode_servers.append(obj.registry_url)
309+
logger.info(f"Registered decode server: {obj.registry_url}")
310+
else:
311+
raise HTTPException(
312+
status_code=400,
313+
detail="Invalid mode. Must be either PREFILL or DECODE.",
314+
)
315+
316+
logger.info(
317+
f"#Prefill servers: {len(load_balancer.prefill_configs)}, "
318+
f"#Decode servers: {len(load_balancer.decode_servers)}"
319+
)
320+
321+
return Response(status_code=200)
322+
323+
271324
def run(prefill_configs, decode_addrs, host, port):
272325
global load_balancer
273326
load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
@@ -279,15 +332,16 @@ def run(prefill_configs, decode_addrs, host, port):
279332

280333
parser = argparse.ArgumentParser(description="Mini Load Balancer Server")
281334
parser.add_argument(
282-
"--prefill", required=True, help="Comma-separated URLs for prefill servers"
335+
"--prefill", type=str, default=[], nargs="+", help="URLs for prefill servers"
283336
)
284337
parser.add_argument(
285-
"--prefill-bootstrap-ports",
286-
help="Comma-separated bootstrap ports for prefill servers",
287-
default="8998",
338+
"--decode", type=str, default=[], nargs="+", help="URLs for decode servers"
288339
)
289340
parser.add_argument(
290-
"--decode", required=True, help="Comma-separated URLs for decode servers"
341+
"--prefill-bootstrap-ports",
342+
type=int,
343+
nargs="+",
344+
help="Bootstrap ports for prefill servers",
291345
)
292346
parser.add_argument(
293347
"--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)"
@@ -297,22 +351,19 @@ def run(prefill_configs, decode_addrs, host, port):
297351
)
298352
args = parser.parse_args()
299353

300-
prefill_urls = args.prefill.split(",")
301-
bootstrap_ports = [int(p) for p in args.prefill_bootstrap_ports.split(",")]
302-
303-
if len(bootstrap_ports) == 1:
304-
bootstrap_ports = bootstrap_ports * len(prefill_urls)
354+
bootstrap_ports = args.prefill_bootstrap_ports
355+
if bootstrap_ports is None:
356+
bootstrap_ports = [None] * len(args.prefill)
357+
elif len(bootstrap_ports) == 1:
358+
bootstrap_ports = bootstrap_ports * len(args.prefill)
305359
else:
306-
if len(bootstrap_ports) != len(prefill_urls):
360+
if len(bootstrap_ports) != len(args.prefill):
307361
raise ValueError(
308362
"Number of prefill URLs must match number of bootstrap ports"
309363
)
310-
exit(1)
311-
312-
prefill_configs = []
313-
for url, port in zip(prefill_urls, bootstrap_ports):
314-
prefill_configs.append(PrefillConfig(url, port))
315364

316-
decode_addrs = args.decode.split(",")
365+
prefill_configs = [
366+
PrefillConfig(url, port) for url, port in zip(args.prefill, bootstrap_ports)
367+
]
317368

318-
run(prefill_configs, decode_addrs, args.host, args.port)
369+
run(prefill_configs, args.decode, args.host, args.port)

python/sglang/srt/disaggregation/utils.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
from __future__ import annotations
22

3+
import dataclasses
4+
import warnings
35
from collections import deque
46
from enum import Enum
5-
from typing import List
7+
from typing import List, Optional
68

79
import numpy as np
10+
import requests
811
import torch
912
import torch.distributed as dist
1013

14+
from sglang.srt.utils import get_ip
15+
1116

1217
class DisaggregationMode(Enum):
1318
NULL = "null"
@@ -119,3 +124,41 @@ def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
119124
def kv_to_page_num(num_kv_indices: int, page_size: int):
120125
# ceil(num_kv_indices / page_size)
121126
return (num_kv_indices + page_size - 1) // page_size
127+
128+
129+
@dataclasses.dataclass
130+
class PDRegistryRequest:
131+
"""A request to register a machine itself to the LB."""
132+
133+
mode: str
134+
registry_url: str
135+
bootstrap_port: Optional[int] = None
136+
137+
def __post_init__(self):
138+
if self.mode == "prefill" and self.bootstrap_port is None:
139+
raise ValueError("Bootstrap port must be set in PREFILL mode.")
140+
elif self.mode == "decode" and self.bootstrap_port is not None:
141+
raise ValueError("Bootstrap port must not be set in DECODE mode.")
142+
elif self.mode not in ["prefill", "decode"]:
143+
raise ValueError(
144+
f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
145+
)
146+
147+
148+
def register_disaggregation_server(
149+
mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
150+
):
151+
boostrap_port = bootstrap_port if mode == "prefill" else None
152+
registry_request = PDRegistryRequest(
153+
mode=mode,
154+
registry_url=f"http://{get_ip()}:{server_port}",
155+
bootstrap_port=boostrap_port,
156+
)
157+
res = requests.post(
158+
f"{pdlb_url}/register",
159+
json=dataclasses.asdict(registry_request),
160+
)
161+
if res.status_code != 200:
162+
warnings.warn(
163+
f"Failed to register disaggregation server: {res.status_code} {res.text}"
164+
)

python/sglang/srt/entrypoints/http_server.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@
4242
from fastapi.middleware.cors import CORSMiddleware
4343
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
4444

45-
from sglang.srt.disaggregation.utils import FakeBootstrapHost
45+
from sglang.srt.disaggregation.utils import (
46+
FakeBootstrapHost,
47+
register_disaggregation_server,
48+
)
4649
from sglang.srt.entrypoints.engine import _launch_subprocesses
4750
from sglang.srt.function_call_parser import FunctionCallParser
4851
from sglang.srt.managers.io_struct import (
@@ -871,5 +874,13 @@ def _wait_and_warmup(
871874
if server_args.debug_tensor_dump_input_file:
872875
kill_process_tree(os.getpid())
873876

877+
if server_args.pdlb_url is not None:
878+
register_disaggregation_server(
879+
server_args.disaggregation_mode,
880+
server_args.port,
881+
server_args.disaggregation_bootstrap_port,
882+
server_args.pdlb_url,
883+
)
884+
874885
if launch_callback is not None:
875886
launch_callback()

python/sglang/srt/managers/scheduler.py

+4
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,10 @@ def handle_generate_request(
925925
)
926926
custom_logit_processor = None
927927

928+
if recv_req.bootstrap_port is None:
929+
# Use default bootstrap port
930+
recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port
931+
928932
req = Req(
929933
recv_req.rid,
930934
recv_req.input_text,

python/sglang/srt/server_args.py

+7
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ class ServerArgs:
198198
disaggregation_bootstrap_port: int = 8998
199199
disaggregation_transfer_backend: str = "mooncake"
200200
disaggregation_ib_device: Optional[str] = None
201+
pdlb_url: Optional[str] = None
201202

202203
def __post_init__(self):
203204
# Expert parallelism
@@ -1254,6 +1255,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
12541255
"or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
12551256
"Default is None, which triggers automatic device detection when mooncake backend is enabled.",
12561257
)
1258+
parser.add_argument(
1259+
"--pdlb-url",
1260+
type=str,
1261+
default=None,
1262+
help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
1263+
)
12571264

12581265
@classmethod
12591266
def from_cli_args(cls, args: argparse.Namespace):

0 commit comments

Comments
 (0)