3
3
"""
4
4
5
5
import asyncio
6
+ import dataclasses
7
+ import logging
6
8
import random
7
9
import urllib
8
10
from itertools import chain
9
- from typing import List
11
+ from typing import List , Optional
10
12
11
13
import aiohttp
12
14
import orjson
13
15
import uvicorn
14
16
from fastapi import FastAPI , HTTPException
15
17
from fastapi .responses import ORJSONResponse , Response , StreamingResponse
16
18
19
+ from sglang .srt .disaggregation .utils import PDRegistryRequest
17
20
21
+
22
+ def setup_logger ():
23
+ logger = logging .getLogger ("pdlb" )
24
+ logger .setLevel (logging .INFO )
25
+
26
+ formatter = logging .Formatter (
27
+ "[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s" ,
28
+ datefmt = "%Y-%m-%d %H:%M:%S" ,
29
+ )
30
+
31
+ handler = logging .StreamHandler ()
32
+ handler .setFormatter (formatter )
33
+ logger .addHandler (handler )
34
+
35
+ return logger
36
+
37
+
38
+ logger = setup_logger ()
39
+
40
+
41
+ @dataclasses .dataclass
18
42
class PrefillConfig :
19
- def __init__ (self , url : str , bootstrap_port : int ):
20
- self .url = url
21
- self .bootstrap_port = bootstrap_port
43
+ url : str
44
+ bootstrap_port : Optional [int ] = None
22
45
23
46
24
47
class MiniLoadBalancer :
@@ -28,6 +51,10 @@ def __init__(self, prefill_configs: List[PrefillConfig], decode_servers: List[st
28
51
self .decode_servers = decode_servers
29
52
30
53
def select_pair (self ):
54
+ # TODO: return some message instead of panic
55
+ assert len (self .prefill_configs ) > 0 , "No prefill servers available"
56
+ assert len (self .decode_servers ) > 0 , "No decode servers available"
57
+
31
58
prefill_config = random .choice (self .prefill_configs )
32
59
decode_server = random .choice (self .decode_servers )
33
60
return prefill_config .url , prefill_config .bootstrap_port , decode_server
@@ -47,7 +74,7 @@ async def generate(
47
74
session .post (f"{ decode_server } /{ endpoint } " , json = modified_request ),
48
75
]
49
76
# Wait for both responses to complete. Prefill should end first.
50
- prefill_response , decode_response = await asyncio .gather (* tasks )
77
+ _ , decode_response = await asyncio .gather (* tasks )
51
78
52
79
return ORJSONResponse (
53
80
content = await decode_response .json (),
@@ -268,6 +295,32 @@ async def get_models():
268
295
raise HTTPException (status_code = 500 , detail = str (e ))
269
296
270
297
298
+ @app .post ("/register" )
299
+ async def register (obj : PDRegistryRequest ):
300
+ if obj .mode == "prefill" :
301
+ load_balancer .prefill_configs .append (
302
+ PrefillConfig (obj .registry_url , obj .bootstrap_port )
303
+ )
304
+ logger .info (
305
+ f"Registered prefill server: { obj .registry_url } with bootstrap port: { obj .bootstrap_port } "
306
+ )
307
+ elif obj .mode == "decode" :
308
+ load_balancer .decode_servers .append (obj .registry_url )
309
+ logger .info (f"Registered decode server: { obj .registry_url } " )
310
+ else :
311
+ raise HTTPException (
312
+ status_code = 400 ,
313
+ detail = "Invalid mode. Must be either PREFILL or DECODE." ,
314
+ )
315
+
316
+ logger .info (
317
+ f"#Prefill servers: { len (load_balancer .prefill_configs )} , "
318
+ f"#Decode servers: { len (load_balancer .decode_servers )} "
319
+ )
320
+
321
+ return Response (status_code = 200 )
322
+
323
+
271
324
def run (prefill_configs , decode_addrs , host , port ):
272
325
global load_balancer
273
326
load_balancer = MiniLoadBalancer (prefill_configs , decode_addrs )
@@ -279,15 +332,16 @@ def run(prefill_configs, decode_addrs, host, port):
279
332
280
333
parser = argparse .ArgumentParser (description = "Mini Load Balancer Server" )
281
334
parser .add_argument (
282
- "--prefill" , required = True , help = "Comma-separated URLs for prefill servers"
335
+ "--prefill" , type = str , default = [], nargs = "+" , help = " URLs for prefill servers"
283
336
)
284
337
parser .add_argument (
285
- "--prefill-bootstrap-ports" ,
286
- help = "Comma-separated bootstrap ports for prefill servers" ,
287
- default = "8998" ,
338
+ "--decode" , type = str , default = [], nargs = "+" , help = "URLs for decode servers"
288
339
)
289
340
parser .add_argument (
290
- "--decode" , required = True , help = "Comma-separated URLs for decode servers"
341
+ "--prefill-bootstrap-ports" ,
342
+ type = int ,
343
+ nargs = "+" ,
344
+ help = "Bootstrap ports for prefill servers" ,
291
345
)
292
346
parser .add_argument (
293
347
"--host" , default = "0.0.0.0" , help = "Host to bind the server (default: 0.0.0.0)"
@@ -297,22 +351,19 @@ def run(prefill_configs, decode_addrs, host, port):
297
351
)
298
352
args = parser .parse_args ()
299
353
300
- prefill_urls = args .prefill . split ( "," )
301
- bootstrap_ports = [ int ( p ) for p in args . prefill_bootstrap_ports . split ( "," )]
302
-
303
- if len (bootstrap_ports ) == 1 :
304
- bootstrap_ports = bootstrap_ports * len (prefill_urls )
354
+ bootstrap_ports = args .prefill_bootstrap_ports
355
+ if bootstrap_ports is None :
356
+ bootstrap_ports = [ None ] * len ( args . prefill )
357
+ elif len (bootstrap_ports ) == 1 :
358
+ bootstrap_ports = bootstrap_ports * len (args . prefill )
305
359
else :
306
- if len (bootstrap_ports ) != len (prefill_urls ):
360
+ if len (bootstrap_ports ) != len (args . prefill ):
307
361
raise ValueError (
308
362
"Number of prefill URLs must match number of bootstrap ports"
309
363
)
310
- exit (1 )
311
-
312
- prefill_configs = []
313
- for url , port in zip (prefill_urls , bootstrap_ports ):
314
- prefill_configs .append (PrefillConfig (url , port ))
315
364
316
- decode_addrs = args .decode .split ("," )
365
+ prefill_configs = [
366
+ PrefillConfig (url , port ) for url , port in zip (args .prefill , bootstrap_ports )
367
+ ]
317
368
318
- run (prefill_configs , decode_addrs , args .host , args .port )
369
+ run (prefill_configs , args . decode , args .host , args .port )
0 commit comments