@@ -72,8 +72,8 @@ def dummy_func(*args, **kwargs):
72
72
_is_cuda = is_cuda ()
73
73
74
74
if _is_hip :
75
- from aiter import ActivationType
76
- from aiter .fused_moe_bf16_asm import asm_moe , ck_moe_2stages , ck_moe_2stages_win4
75
+ from aiter import ActivationType , QuantType
76
+ from aiter .fused_moe_bf16_asm import asm_moe , ck_moe_2stages
77
77
from aiter .ops .shuffle import shuffle_weight
78
78
79
79
if not _is_cuda :
@@ -484,7 +484,7 @@ def create_weights(
484
484
if self .quant_config .is_checkpoint_fp8_serialized :
485
485
params_dtype = (
486
486
torch .uint32
487
- if get_bool_env_var ("USE_INT4_WEIGHT " )
487
+ if get_bool_env_var ("SGLANG_INT4_WEIGHT " )
488
488
else torch .float8_e4m3fn
489
489
)
490
490
tp_size = get_tensor_model_parallel_world_size ()
@@ -511,7 +511,7 @@ def create_weights(
511
511
)
512
512
513
513
# WEIGHTS
514
- if _is_hip and get_bool_env_var ("USE_INT4_WEIGHT " ):
514
+ if _is_hip and get_bool_env_var ("SGLANG_INT4_WEIGHT " ):
515
515
# INT4 MoE weight - INT32 packed
516
516
w13_weight = torch .nn .Parameter (
517
517
torch .empty (
@@ -585,7 +585,7 @@ def create_weights(
585
585
586
586
if (
587
587
_is_hip
588
- ): # and get_bool_env_var("CK_MOE "): TODO: add check back after triton kernel
588
+ ): # and get_bool_env_var("SGLANG_AITER_MOE "): TODO: add check back after triton kernel
589
589
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
590
590
w13_weight_scale1 = torch .nn .Parameter (
591
591
torch .ones (num_experts , 2 * intermediate_size , dtype = torch .float32 ),
@@ -612,7 +612,7 @@ def create_weights(
612
612
set_weight_attrs (w13_weight_scale , extra_weight_attrs )
613
613
set_weight_attrs (w2_weight_scale , extra_weight_attrs )
614
614
615
- if _is_hip and get_bool_env_var ("USE_INT4_WEIGHT " ):
615
+ if _is_hip and get_bool_env_var ("SGLANG_INT4_WEIGHT " ):
616
616
extra_weight_attrs .update (
617
617
{"quant_method" : FusedMoeWeightScaleSupported .CHANNEL .value }
618
618
)
@@ -644,7 +644,7 @@ def create_weights(
644
644
layer .w2_input_scale = None
645
645
646
646
def process_weights_after_loading (self , layer : Module ) -> None :
647
- if _is_hip and get_bool_env_var ("USE_INT4_WEIGHT " ):
647
+ if _is_hip and get_bool_env_var ("SGLANG_INT4_WEIGHT " ):
648
648
self .process_weights_hip_int4 (layer )
649
649
return
650
650
@@ -675,7 +675,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
675
675
)
676
676
layer .w2_input_scale = None
677
677
678
- if get_bool_env_var ("CK_MOE " ):
678
+ if get_bool_env_var ("SGLANG_AITER_MOE " ):
679
679
# Pre-shuffle weights
680
680
layer .w13_weight .data = shuffle_weight (
681
681
layer .w13_weight .contiguous (), (16 , 16 )
@@ -798,17 +798,15 @@ def process_weights_after_loading(self, layer: Module) -> None:
798
798
return
799
799
800
800
def process_weights_hip_int4 (self , layer : Module ):
801
- # TODO: and get_bool_env_var("CK_MOE "): add after triton kernel added
801
+ # TODO: and get_bool_env_var("SGLANG_AITER_MOE "): add after triton kernel added
802
802
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
803
803
# Weight Permutation
804
804
layer .w13_weight = torch .nn .Parameter (
805
- # permute_weight(layer.w13_weight.data),
806
805
shuffle_weight (layer .w13_weight .data , (16 , 16 )),
807
806
requires_grad = False ,
808
807
)
809
808
torch .cuda .empty_cache ()
810
809
layer .w2_weight = torch .nn .Parameter (
811
- # permute_weight(layer.w2_weight.data),
812
810
shuffle_weight (layer .w2_weight .data , (16 , 16 )),
813
811
requires_grad = False ,
814
812
)
@@ -847,23 +845,21 @@ def process_weights_hip_scale_padding(self, layer: Module):
847
845
padding_size , # Avoid circular import
848
846
)
849
847
850
- if get_bool_env_var ("CK_MOE " ):
848
+ if get_bool_env_var ("SGLANG_AITER_MOE " ):
851
849
layer .w13_weight = torch .nn .Parameter (
852
- # permute_weight(layer.w13_weight.data),
853
850
shuffle_weight (layer .w13_weight .data , (16 , 16 )),
854
851
requires_grad = False ,
855
852
)
856
853
torch .cuda .empty_cache ()
857
854
layer .w2_weight = torch .nn .Parameter (
858
- # permute_weight(layer.w2_weight.data),
859
855
shuffle_weight (layer .w2_weight .data , (16 , 16 )),
860
856
requires_grad = False ,
861
857
)
862
858
torch .cuda .empty_cache ()
863
- # ROCm (CK_MOE ): using column-wise scaling
859
+ # ROCm (SGLANG_AITER_MOE ): using column-wise scaling
864
860
layer .w13_weight_scale1 *= layer .w13_weight_scale .unsqueeze (- 1 )
865
861
layer .w2_weight_scale1 *= layer .w2_weight_scale .unsqueeze (- 1 )
866
- elif get_bool_env_var ("MOE_PADDING " ):
862
+ elif get_bool_env_var ("SGLANG_MOE_PADDING " ):
867
863
# If ROCm, apply weight padding (min. Mem channel contention) only if set
868
864
layer .w13_weight = torch .nn .Parameter (
869
865
F .pad (layer .w13_weight .data , (0 , padding_size ), "constant" , 0 ),
@@ -912,15 +908,16 @@ def apply(
912
908
)
913
909
914
910
if _is_hip :
915
- if get_bool_env_var ("USE_INT4_WEIGHT " ):
916
- # TODO: add triton kernel and add check get_bool_env_var("CK_MOE ")
911
+ if get_bool_env_var ("SGLANG_INT4_WEIGHT " ):
912
+ # TODO: add triton kernel and add check get_bool_env_var("SGLANG_AITER_MOE ")
917
913
assert not no_combine , f"{ no_combine = } is not supported."
918
- return ck_moe_2stages_win4 (
914
+ return ck_moe_2stages (
919
915
x ,
920
916
layer .w13_weight ,
921
917
layer .w2_weight ,
922
918
topk_weights ,
923
919
topk_ids ,
920
+ QuantType .per_Token ,
924
921
layer .w13_weight_scale1 ,
925
922
layer .w2_weight_scale1 ,
926
923
activation = (
@@ -930,13 +927,13 @@ def apply(
930
927
),
931
928
)
932
929
933
- if get_bool_env_var ("CK_MOE " ):
930
+ if get_bool_env_var ("SGLANG_AITER_MOE " ):
934
931
assert not no_combine , f"{ no_combine = } is not supported."
935
932
if self .block_quant :
936
- # TODO(CK_MOE ): FP8 block_quant only supports 'silu' for the time-being.
933
+ # TODO(SGLANG_AITER_MOE ): FP8 block_quant only supports 'silu' for the time-being.
937
934
assert (
938
935
activation == "silu"
939
- ), f"CK_MOE : FP8 bloack_quant { activation = } will be supported later, unset CK_MOE "
936
+ ), f"SGLANG_AITER_MOE : FP8 bloack_quant { activation = } will be supported later, unset SGLANG_AITER_MOE "
940
937
return asm_moe (
941
938
x ,
942
939
layer .w13_weight ,
@@ -955,6 +952,7 @@ def apply(
955
952
layer .w2_weight ,
956
953
topk_weights ,
957
954
topk_ids ,
955
+ QuantType .per_Token ,
958
956
layer .w13_weight_scale1 ,
959
957
layer .w2_weight_scale1 ,
960
958
activation = (
0 commit comments