From ef7eb8821f14fbf27b3202b17e7f6e1495d66273 Mon Sep 17 00:00:00 2001 From: Annanya Date: Mon, 14 Apr 2025 11:52:47 -0400 Subject: [PATCH] [Refactor] Added a new spec for PagedKVCache --- python/tvm/relax/frontend/nn/spec.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/nn/spec.py b/python/tvm/relax/frontend/nn/spec.py index 54928ce07b80..c3ad630ad650 100644 --- a/python/tvm/relax/frontend/nn/spec.py +++ b/python/tvm/relax/frontend/nn/spec.py @@ -17,6 +17,7 @@ """Compilation specifications, for example, dynamic shape inputs.""" import inspect import typing +from .llm.kv_cache import PagedKVCache as _PagedKVCache if typing.TYPE_CHECKING: from .core import Module as nn_module_class @@ -64,6 +65,14 @@ def __repr__(self) -> str: return "object" +class PagedKVCache(Object): # pylint: disable=too-few-public-methods + """A specialized spec that extends Object for a paged key-value cache.""" + + def __init__(self) -> None: + # Pass the type of this class to the Object constructor. + super().__init__(object_type=_PagedKVCache) + + class Tuple: # pylint: disable=too-few-public-methods """A tuple input or a list input""" @@ -92,7 +101,7 @@ class MethodSpec: param_mode: str # "plain", "packed", "none" effect_mode: str # "plain", "packed", "none" - def __init__( # pylint: disable=too-many-arguments + def __init__( # pylint: disable=too-many-positional-arguments,too-many-arguments self, method: typing.Callable, arg_names: typing.List[str],