Skip to content

Commit d364b9b

Browse files
authored
ROCm: update AITER (sgl-project#5816)
1 parent 849c83a commit d364b9b

File tree

7 files changed

+48
-52
lines changed

7 files changed

+48
-52
lines changed

.github/workflows/pr-test-amd.yml

+6-6
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ jobs:
3838
else
3939
DEVICE_FLAG="--device /dev/dri"
4040
fi
41-
docker pull lmsysorg/sglang:v0.4.5.post3-rocm630
41+
docker pull ghcr.io/saienduri/sglang-aiter-v0.1.1:428
4242
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
4343
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
4444
--cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \
4545
-w /sglang-checkout --name ci_sglang \
46-
lmsysorg/sglang:v0.4.5.post3-rocm630
46+
ghcr.io/saienduri/sglang-aiter-v0.1.1:428
4747
4848
- name: Install dependencies
4949
run: |
@@ -82,12 +82,12 @@ jobs:
8282
else
8383
DEVICE_FLAG="--device /dev/dri"
8484
fi
85-
docker pull lmsysorg/sglang:v0.4.5.post3-rocm630
85+
docker pull ghcr.io/saienduri/sglang-aiter-v0.1.1:428
8686
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
8787
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
8888
--cap-add=SYS_PTRACE -e HF_TOKEN=${{ secrets.AMD_HF_TOKEN }} --security-opt seccomp=unconfined \
8989
-w /sglang-checkout --name ci_sglang \
90-
lmsysorg/sglang:v0.4.5.post3-rocm630
90+
ghcr.io/saienduri/sglang-aiter-v0.1.1:428
9191
9292
- name: Install dependencies
9393
run: |
@@ -120,12 +120,12 @@ jobs:
120120
else
121121
DEVICE_FLAG="--device /dev/dri"
122122
fi
123-
docker pull lmsysorg/sglang:v0.4.5.post3-rocm630
123+
docker pull ghcr.io/saienduri/sglang-aiter-v0.1.1:428
124124
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
125125
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
126126
--cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \
127127
-w /sglang-checkout --name ci_sglang \
128-
lmsysorg/sglang:v0.4.5.post3-rocm630
128+
ghcr.io/saienduri/sglang-aiter-v0.1.1:428
129129
130130
- name: Install dependencies
131131
run: |

3rdparty/amd/tuning/benchmark_moe_rocm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
get_config_file_name,
1616
)
1717

18-
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
18+
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
1919

2020

2121
def main(model, tp_size, dtype: str, batches):

docker/Dockerfile.rocm

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ ARG TRITON_COMMIT="improve_fa_decode_3.0.0"
1818

1919

2020
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
21-
ARG AITER_COMMIT="testx"
21+
ARG AITER_COMMIT="v0.1.1"
2222

2323
RUN git clone ${SGL_REPO} \
2424
&& cd sglang \
@@ -74,7 +74,7 @@ ENV SGLANG_SET_CPU_AFFINITY=1
7474
ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1
7575
ENV NCCL_MIN_NCHANNELS=112
7676

77-
ENV MOE_PADDING=1
77+
ENV SGLANG_MOE_PADDING=1
7878
ENV VLLM_FP8_PADDING=1
7979
ENV VLLM_FP8_ACT_PADDING=1
8080
ENV VLLM_FP8_WEIGHT_PADDING=1

python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545

4646

4747
logger = logging.getLogger(__name__)
48-
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
48+
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
4949
enable_moe_align_block_size_triton = bool(
5050
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
5151
)
@@ -1327,7 +1327,7 @@ def fused_experts_impl(
13271327
if (
13281328
not (use_fp8_w8a8 or use_int8_w8a8)
13291329
or block_shape is not None
1330-
or (_is_hip and get_bool_env_var("CK_MOE"))
1330+
or (_is_hip and get_bool_env_var("SGLANG_AITER_MOE"))
13311331
):
13321332
padded_size = 0
13331333

python/sglang/srt/layers/moe/fused_moe_triton/layer.py

+15-17
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
QuantizationConfig,
1919
QuantizeMethodBase,
2020
)
21-
from sglang.srt.utils import get_bool_env_var, is_hip, permute_weight, set_weight_attrs
21+
from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs
2222

2323
if torch.cuda.is_available():
2424
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -30,7 +30,9 @@
3030
_is_hip = is_hip()
3131

3232
if _is_hip:
33-
from aiter import ck_moe
33+
from aiter import ActivationType
34+
from aiter.fused_moe_bf16_asm import ck_moe_2stages
35+
from aiter.ops.shuffle import shuffle_weight
3436

3537
logger = logging.getLogger(__name__)
3638

@@ -102,14 +104,14 @@ def create_weights(
102104
set_weight_attrs(w2_weight, extra_weight_attrs)
103105

104106
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
105-
if _is_hip and get_bool_env_var("CK_MOE"):
107+
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
106108
layer.w13_weight = torch.nn.Parameter(
107-
permute_weight(layer.w13_weight.data),
109+
shuffle_weight(layer.w13_weight.data, (16, 16)),
108110
requires_grad=False,
109111
)
110112
torch.cuda.empty_cache()
111113
layer.w2_weight = torch.nn.Parameter(
112-
permute_weight(layer.w2_weight.data),
114+
shuffle_weight(layer.w2_weight.data, (16, 16)),
113115
requires_grad=False,
114116
)
115117
torch.cuda.empty_cache()
@@ -182,21 +184,17 @@ def forward_cuda(
182184
routed_scaling_factor=routed_scaling_factor,
183185
)
184186

185-
if _is_hip and get_bool_env_var("CK_MOE"):
187+
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
186188
assert not no_combine, "unsupported"
187-
return ck_moe(
189+
return ck_moe_2stages(
188190
x,
189191
layer.w13_weight,
190192
layer.w2_weight,
191193
topk_weights,
192194
topk_ids,
193-
None,
194-
None,
195-
None,
196-
None,
197-
32,
198-
None,
199-
activation,
195+
activation=(
196+
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
197+
),
200198
)
201199
else:
202200
return fused_experts(
@@ -527,7 +525,7 @@ def weight_loader(
527525
# Case input scale: input_scale loading is only supported for fp8
528526
if "input_scale" in weight_name:
529527
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
530-
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
528+
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
531529
loaded_weight = loaded_weight * 2.0
532530

533531
# this is needed for compressed-tensors only
@@ -569,7 +567,7 @@ def weight_loader(
569567
quant_method = getattr(param, "quant_method", None)
570568
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
571569
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
572-
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
570+
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
573571
loaded_weight = loaded_weight * 0.5
574572

575573
self._load_per_channel_weight_scale(
@@ -592,7 +590,7 @@ def weight_loader(
592590
)
593591
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
594592
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
595-
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
593+
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
596594
loaded_weight = loaded_weight * 2.0
597595

598596
self._load_per_tensor_weight_scale(

python/sglang/srt/layers/quantization/fp8.py

+20-22
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def dummy_func(*args, **kwargs):
7272
_is_cuda = is_cuda()
7373

7474
if _is_hip:
75-
from aiter import ActivationType
76-
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
75+
from aiter import ActivationType, QuantType
76+
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
7777
from aiter.ops.shuffle import shuffle_weight
7878

7979
if not _is_cuda:
@@ -484,7 +484,7 @@ def create_weights(
484484
if self.quant_config.is_checkpoint_fp8_serialized:
485485
params_dtype = (
486486
torch.uint32
487-
if get_bool_env_var("USE_INT4_WEIGHT")
487+
if get_bool_env_var("SGLANG_INT4_WEIGHT")
488488
else torch.float8_e4m3fn
489489
)
490490
tp_size = get_tensor_model_parallel_world_size()
@@ -511,7 +511,7 @@ def create_weights(
511511
)
512512

513513
# WEIGHTS
514-
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
514+
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
515515
# INT4 MoE weight - INT32 packed
516516
w13_weight = torch.nn.Parameter(
517517
torch.empty(
@@ -585,7 +585,7 @@ def create_weights(
585585

586586
if (
587587
_is_hip
588-
): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel
588+
): # and get_bool_env_var("SGLANG_AITER_MOE"): TODO: add check back after triton kernel
589589
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
590590
w13_weight_scale1 = torch.nn.Parameter(
591591
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
@@ -612,7 +612,7 @@ def create_weights(
612612
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
613613
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
614614

615-
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
615+
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
616616
extra_weight_attrs.update(
617617
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
618618
)
@@ -644,7 +644,7 @@ def create_weights(
644644
layer.w2_input_scale = None
645645

646646
def process_weights_after_loading(self, layer: Module) -> None:
647-
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
647+
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
648648
self.process_weights_hip_int4(layer)
649649
return
650650

@@ -675,7 +675,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
675675
)
676676
layer.w2_input_scale = None
677677

678-
if get_bool_env_var("CK_MOE"):
678+
if get_bool_env_var("SGLANG_AITER_MOE"):
679679
# Pre-shuffle weights
680680
layer.w13_weight.data = shuffle_weight(
681681
layer.w13_weight.contiguous(), (16, 16)
@@ -798,17 +798,15 @@ def process_weights_after_loading(self, layer: Module) -> None:
798798
return
799799

800800
def process_weights_hip_int4(self, layer: Module):
801-
# TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added
801+
# TODO: and get_bool_env_var("SGLANG_AITER_MOE"): add after triton kernel added
802802
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
803803
# Weight Permutation
804804
layer.w13_weight = torch.nn.Parameter(
805-
# permute_weight(layer.w13_weight.data),
806805
shuffle_weight(layer.w13_weight.data, (16, 16)),
807806
requires_grad=False,
808807
)
809808
torch.cuda.empty_cache()
810809
layer.w2_weight = torch.nn.Parameter(
811-
# permute_weight(layer.w2_weight.data),
812810
shuffle_weight(layer.w2_weight.data, (16, 16)),
813811
requires_grad=False,
814812
)
@@ -847,23 +845,21 @@ def process_weights_hip_scale_padding(self, layer: Module):
847845
padding_size, # Avoid circular import
848846
)
849847

850-
if get_bool_env_var("CK_MOE"):
848+
if get_bool_env_var("SGLANG_AITER_MOE"):
851849
layer.w13_weight = torch.nn.Parameter(
852-
# permute_weight(layer.w13_weight.data),
853850
shuffle_weight(layer.w13_weight.data, (16, 16)),
854851
requires_grad=False,
855852
)
856853
torch.cuda.empty_cache()
857854
layer.w2_weight = torch.nn.Parameter(
858-
# permute_weight(layer.w2_weight.data),
859855
shuffle_weight(layer.w2_weight.data, (16, 16)),
860856
requires_grad=False,
861857
)
862858
torch.cuda.empty_cache()
863-
# ROCm (CK_MOE): using column-wise scaling
859+
# ROCm (SGLANG_AITER_MOE): using column-wise scaling
864860
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
865861
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
866-
elif get_bool_env_var("MOE_PADDING"):
862+
elif get_bool_env_var("SGLANG_MOE_PADDING"):
867863
# If ROCm, apply weight padding (min. Mem channel contention) only if set
868864
layer.w13_weight = torch.nn.Parameter(
869865
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
@@ -912,15 +908,16 @@ def apply(
912908
)
913909

914910
if _is_hip:
915-
if get_bool_env_var("USE_INT4_WEIGHT"):
916-
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
911+
if get_bool_env_var("SGLANG_INT4_WEIGHT"):
912+
# TODO: add triton kernel and add check get_bool_env_var("SGLANG_AITER_MOE")
917913
assert not no_combine, f"{no_combine=} is not supported."
918-
return ck_moe_2stages_win4(
914+
return ck_moe_2stages(
919915
x,
920916
layer.w13_weight,
921917
layer.w2_weight,
922918
topk_weights,
923919
topk_ids,
920+
QuantType.per_Token,
924921
layer.w13_weight_scale1,
925922
layer.w2_weight_scale1,
926923
activation=(
@@ -930,13 +927,13 @@ def apply(
930927
),
931928
)
932929

933-
if get_bool_env_var("CK_MOE"):
930+
if get_bool_env_var("SGLANG_AITER_MOE"):
934931
assert not no_combine, f"{no_combine=} is not supported."
935932
if self.block_quant:
936-
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
933+
# TODO(SGLANG_AITER_MOE): FP8 block_quant only supports 'silu' for the time-being.
937934
assert (
938935
activation == "silu"
939-
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
936+
), f"SGLANG_AITER_MOE: FP8 bloack_quant {activation=} will be supported later, unset SGLANG_AITER_MOE"
940937
return asm_moe(
941938
x,
942939
layer.w13_weight,
@@ -955,6 +952,7 @@ def apply(
955952
layer.w2_weight,
956953
topk_weights,
957954
topk_ids,
955+
QuantType.per_Token,
958956
layer.w13_weight_scale1,
959957
layer.w2_weight_scale1,
960958
activation=(

python/sglang/srt/layers/quantization/fp8_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
_is_hip = is_hip()
3232
_is_cuda = is_cuda()
3333

34-
if _is_hip and get_bool_env_var("CK_MOE"):
34+
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
3535
from aiter import gemm_a8w8_blockscale
3636

3737
if _is_cuda:
@@ -132,7 +132,7 @@ def apply_w8a8_block_fp8_linear(
132132
output = fp8_blockwise_scaled_mm(
133133
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
134134
)
135-
elif _is_hip and get_bool_env_var("CK_MOE"):
135+
elif _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
136136
q_input, x_scale = per_token_group_quant_fp8(
137137
input_2d, block_size[1], column_major_scales=False
138138
)

0 commit comments

Comments
 (0)