diff --git a/server/text_generation_server/inference_engine/tgis_native.py b/server/text_generation_server/inference_engine/tgis_native.py index 86291815..620be4b1 100644 --- a/server/text_generation_server/inference_engine/tgis_native.py +++ b/server/text_generation_server/inference_engine/tgis_native.py @@ -18,6 +18,7 @@ TP_NONFLASH_TYPES = ["bloom", "t5", "gpt_neox"] TP_FLASH_TYPES = NONTP_FLASH_TYPES # All flash types currently support TP NONTP_NONFLASH_TYPES = ["bloom", "t5"] +PAGED_TYPES = ["llama", "gpt_bigcode"] class InferenceEngine(BaseInferenceEngine): @@ -52,6 +53,11 @@ def __init__( raise NotImplementedError( f"Flash attention currently only supported by the following model types: {NONTP_FLASH_TYPES}" ) + elif PAGED_ATTENTION: + if model_type not in PAGED_TYPES: + raise NotImplementedError( + f"Paged attention currently only supported by the following model types: {PAGED_TYPES}" + ) elif model_type not in NONTP_NONFLASH_TYPES: raise ValueError("tgis_native engine must be used with FLASH_ATTENTION, num_shards > 1 and/or BLOOM or T5")