From 519947ec909a9bbf9ec3f878035f682448437577 Mon Sep 17 00:00:00 2001 From: Takeshi Yoshimura Date: Fri, 24 May 2024 03:45:00 +0000 Subject: [PATCH 1/3] use fastsafetensors Signed-off-by: Takeshi Yoshimura --- Dockerfile | 6 ++++++ .../inference_engine/tgis_native.py | 18 +++++++++++++++--- server/text_generation_server/utils/layers.py | 11 +++++++---- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/Dockerfile b/Dockerfile index 624f1bd5..3f19fc1d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -319,6 +319,12 @@ COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/loca # Install launcher COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher +# Install cufile.so, libnuma.so, and fastsafetensors +RUN dnf config-manager \ + --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \ + && dnf install -y libcufile-12-1 numactl-libs +RUN pip install -v fastsafetensors==0.1.0 --no-cache-dir + ENV PORT=3000 \ GRPC_PORT=8033 \ HOME=/home/tgis diff --git a/server/text_generation_server/inference_engine/tgis_native.py b/server/text_generation_server/inference_engine/tgis_native.py index 86291815..d5a09d37 100644 --- a/server/text_generation_server/inference_engine/tgis_native.py +++ b/server/text_generation_server/inference_engine/tgis_native.py @@ -14,6 +14,8 @@ from text_generation_server.utils.dist import initialize_torch_distributed from text_generation_server.utils.hub import local_weight_files +from fastsafetensors.connectors.tgis_weights import Weights as FastWeights + NONTP_FLASH_TYPES = ["RefinedWeb", "RefinedWebModel", "gpt_neox", "gpt_bigcode", "llama", "falcon"] TP_NONFLASH_TYPES = ["bloom", "t5", "gpt_neox"] TP_FLASH_TYPES = NONTP_FLASH_TYPES # All flash types currently support TP @@ -123,9 +125,19 @@ def __init__( if not filenames: raise ValueError("No safetensors weights found - required for tgis_native engine") - weights = Weights( - filenames, device=self.device, dtype=dtype, process_group=self.process_group, aliases=aliases - ) + use_fst = os.getenv("USE_FST") + if use_fst is not None and use_fst == "1": + nogds = os.getenv("FST_NOGDS") # disable GDS if FST_NOGDS==1 + max_threads = int(os.getenv("FST_THREADS", "16")) # number of copy threads at host CPU + bbuf_size_kb = int(os.getenv("FST_BBUF_SIZE_KB", "163840")) # size of bounce buffer at host memory for FST_NOGDS==1 + nogds = nogds is not None and nogds == "1" + weights = FastWeights( + filenames, device=self.device, dtype=dtype, pg=self.process_group, aliases=aliases, nogds=nogds, max_copier_threads=max_threads, bbuf_size_kb_total=bbuf_size_kb, + ) + else: + weights = Weights( + filenames, device=self.device, dtype=dtype, process_group=self.process_group, aliases=aliases + ) if quantize == "gptq": weights._set_gptq_params(model_config, model_path) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 312f4d5d..5081874f 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -305,11 +305,14 @@ def __init__(self, linear, process_group): @classmethod def load(cls, config, prefix: str, weights, bias: bool): weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) - if bias and weights.process_group.rank() == 0: - # Rank is only on the first rank process - bias = weights.get_tensor(f"{prefix}.bias") + if hasattr(weights, "push_tensor"): + bias = weights.push_tensor(f"{prefix}.bias", 0) if bias else None else: - bias = None + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None return cls( get_linear(weight, bias, config.quantize), process_group=weights.process_group, From ff0ebd5b574771cafac3b6dea932000ba57565dd Mon Sep 17 00:00:00 2001 From: Takeshi Yoshimura Date: Mon, 27 May 2024 07:16:48 +0000 Subject: [PATCH 2/3] add fastweight.py Signed-off-by: Takeshi Yoshimura --- Dockerfile | 4 +- .../inference_engine/tgis_native.py | 4 +- .../text_generation_server/utils/__init__.py | 1 + .../utils/fastweight.py | 83 +++++++++++++++++++ 4 files changed, 87 insertions(+), 5 deletions(-) create mode 100644 server/text_generation_server/utils/fastweight.py diff --git a/Dockerfile b/Dockerfile index 3f19fc1d..91869851 100644 --- a/Dockerfile +++ b/Dockerfile @@ -323,7 +323,7 @@ COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/ RUN dnf config-manager \ --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \ && dnf install -y libcufile-12-1 numactl-libs -RUN pip install -v fastsafetensors==0.1.0 --no-cache-dir +RUN pip install -v fastsafetensors==0.1.1 --no-cache-dir ENV PORT=3000 \ GRPC_PORT=8033 \ @@ -333,7 +333,7 @@ ENV PORT=3000 \ RUN chmod -R g+rwx ${HOME} # Temporary for dev -RUN chmod -R g+w ${SITE_PACKAGES}/text_generation_server /usr/src /usr/local/bin +RUN chmod -R g+w ${SITE_PACKAGES}/text_generation_server /usr/src /usr/local/bin ${SITE_PACKAGES}/fastsafetensors # Run as non-root user by default USER tgis diff --git a/server/text_generation_server/inference_engine/tgis_native.py b/server/text_generation_server/inference_engine/tgis_native.py index d5a09d37..11a43146 100644 --- a/server/text_generation_server/inference_engine/tgis_native.py +++ b/server/text_generation_server/inference_engine/tgis_native.py @@ -8,14 +8,12 @@ from transformers.models.auto.auto_factory import _BaseAutoModelClass from text_generation_server.models import FLASH_ATTENTION, PAGED_ATTENTION -from text_generation_server.utils import Weights +from text_generation_server.utils import Weights, FastWeights from text_generation_server.inference_engine import BaseInferenceEngine from text_generation_server.utils.dist import initialize_torch_distributed from text_generation_server.utils.hub import local_weight_files -from fastsafetensors.connectors.tgis_weights import Weights as FastWeights - NONTP_FLASH_TYPES = ["RefinedWeb", "RefinedWebModel", "gpt_neox", "gpt_bigcode", "llama", "falcon"] TP_NONFLASH_TYPES = ["bloom", "t5", "gpt_neox"] TP_FLASH_TYPES = NONTP_FLASH_TYPES # All flash types currently support TP diff --git a/server/text_generation_server/utils/__init__.py b/server/text_generation_server/utils/__init__.py index a4f4d28d..6b48e1c4 100644 --- a/server/text_generation_server/utils/__init__.py +++ b/server/text_generation_server/utils/__init__.py @@ -7,6 +7,7 @@ RANK, ) from text_generation_server.utils.weights import Weights +from text_generation_server.utils.fastweight import FastWeights from text_generation_server.utils.hub import ( get_model_path, local_weight_files, diff --git a/server/text_generation_server/utils/fastweight.py b/server/text_generation_server/utils/fastweight.py new file mode 100644 index 00000000..bdf2573a --- /dev/null +++ b/server/text_generation_server/utils/fastweight.py @@ -0,0 +1,83 @@ +# Copyright 2024 IBM Inc. All rights reserved +# SPDX-License-Identifier: Apache-2.0 + +import os +import torch +import torch.distributed as dist +from typing import List, Tuple, Optional, Dict +from fastsafetensors.loader import SafeTensorsFileLoader + +class FastWeights: + def __init__(self, filenames:List[str], + device: torch.device, + dtype: torch.dtype, + pg: dist.ProcessGroup, + debug_log: bool=False, + aliases: Optional[Dict[str, List[str]]] = None, + prefix: Optional[str] = None, + nogds: bool = False, + max_copier_threads: int = 16, # should be same as the number of physical CPUs on a node + bbuf_size_kb_total = 160 * 1024, # should be same as L2 cache size + ): + self._loader = SafeTensorsFileLoader(pg, device, bbuf_size_kb=bbuf_size_kb_total//pg.size(), max_threads=max_copier_threads, nogds=nogds, debug_log=debug_log) + rank_filenames: Dict[str, List[str]] = {rank: [] for rank in range(0, pg.size())} + max_copy_block_size = 1 + total_size = 0 + for idx, filename in enumerate(sorted(filenames, key=lambda x: os.path.basename(x))): + rank_filenames[idx % pg.size()].append(filename) + s = os.stat(filename) + total_size += s.st_size + if max_copy_block_size < s.st_size: + max_copy_block_size = s.st_size + self._loader.add_filenames(rank_filenames) + if len(filenames) < max_copier_threads: + max_copy_block_size = total_size // pg.size() // max_copier_threads + if max_copy_block_size % bbuf_size_kb_total*1024 > 0: + max_copy_block_size = max_copy_block_size - max_copy_block_size % (bbuf_size_kb_total*1024) + (bbuf_size_kb_total*1024) + self._fb = self._loader.copy_files_to_device(dtype, max_copy_block_size=max_copy_block_size) + self.device = device + self.dtype = dtype + if aliases is None: + aliases = {} + self.prefix = prefix + self.aliases = aliases + self.process_group = pg + + def close(self): + self._fb.close() + self._loader.close() + torch.cuda.empty_cache() + + def _get_alias(self, tensor_name: str)->str: + if self._fb.get_filename(tensor_name) is None: + for alias in self.aliases[tensor_name]: + if self._fb.get_filename(alias) is not None: + return alias + raise RuntimeError(f"weight {tensor_name} does not exist") + return tensor_name + + def get_shape(self, tensor_name: str)->torch.Size: + return self._fb.get_shape(self._get_alias(tensor_name)) + + def get_tensor(self, tensor_name: str)->torch.Tensor: + return self._fb.get_tensor(self._get_alias(tensor_name), device=self.device, dtype=self.dtype) + + def push_tensor(self, tensor_name: str, dst_rank: int)->torch.Tensor: + return self._fb.push_tensor(self._get_alias(tensor_name), dst_rank, device=self.device, dtype=self.dtype) + + def get_partial_sharded(self, tensor_name: str, dim: int)->torch.Tensor: + return self._fb.get_sharded(self._get_alias(tensor_name), dim, device=self.device, dtype=self.dtype) + + def get_sharded(self, tensor_name: str, dim: int=1)->torch.Tensor: + return self._fb.get_sharded(self._get_alias(tensor_name), dim, device=self.device, dtype=self.dtype) + + def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int)->torch.Tensor: + if quantize in ["gptq", "awq"]: + raise NotImplementedError("Quantization is not supported yet") + tensor_names = [self._get_alias(f"{prefix}.weight") for prefix in prefixes] + return self._fb.get_multi_cols(tensor_names, dim, device=self.device, dtype=self.dtype) + + def get_multi_weights_row(self, prefix: str, quantize: str)->torch.Tensor: + if quantize in ["gptq", "awq"]: + raise NotImplementedError("Quantization is not supported yet") + return self._fb.get_sharded(self._get_alias(f"{prefix}.weight"), 1, device=self.device, dtype=self.dtype) \ No newline at end of file From 1f295eb44d48ea22717e57db05965f8c7b7cb401 Mon Sep 17 00:00:00 2001 From: Takeshi Yoshimura Date: Mon, 27 May 2024 08:30:45 +0000 Subject: [PATCH 3/3] Add get_config and auto-configuration Signed-off-by: Takeshi Yoshimura --- .../inference_engine/tgis_native.py | 6 +-- .../utils/fastweight.py | 45 ++++++++++++++++--- 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/server/text_generation_server/inference_engine/tgis_native.py b/server/text_generation_server/inference_engine/tgis_native.py index 11a43146..fbd0ab02 100644 --- a/server/text_generation_server/inference_engine/tgis_native.py +++ b/server/text_generation_server/inference_engine/tgis_native.py @@ -125,12 +125,8 @@ def __init__( use_fst = os.getenv("USE_FST") if use_fst is not None and use_fst == "1": - nogds = os.getenv("FST_NOGDS") # disable GDS if FST_NOGDS==1 - max_threads = int(os.getenv("FST_THREADS", "16")) # number of copy threads at host CPU - bbuf_size_kb = int(os.getenv("FST_BBUF_SIZE_KB", "163840")) # size of bounce buffer at host memory for FST_NOGDS==1 - nogds = nogds is not None and nogds == "1" weights = FastWeights( - filenames, device=self.device, dtype=dtype, pg=self.process_group, aliases=aliases, nogds=nogds, max_copier_threads=max_threads, bbuf_size_kb_total=bbuf_size_kb, + filenames, device=self.device, dtype=dtype, pg=self.process_group, aliases=aliases, ) else: weights = Weights( diff --git a/server/text_generation_server/utils/fastweight.py b/server/text_generation_server/utils/fastweight.py index bdf2573a..0686a1c9 100644 --- a/server/text_generation_server/utils/fastweight.py +++ b/server/text_generation_server/utils/fastweight.py @@ -2,10 +2,42 @@ # SPDX-License-Identifier: Apache-2.0 import os +import glob import torch import torch.distributed as dist -from typing import List, Tuple, Optional, Dict -from fastsafetensors.loader import SafeTensorsFileLoader +from typing import List, Optional, Dict, Tuple + +def get_config(device_index: int) -> Tuple[bool, int, int]: + auto_config = os.getenv("FST_CONFIG", "auto") + nogds = os.getenv("FST_NOGDS") # disable GDS if FST_NOGDS==1 + nogds = nogds is not None and nogds == "1" + max_copier_threads = int(os.getenv("FST_THREADS", "16")) # number of copy threads at host CPU + bbuf_size_kb_total = int(os.getenv("FST_BBUF_SIZE_KB", "163840")) # size of bounce buffer at host memory for FST_NOGDS==1 + if auto_config == "auto": + nogds = not os.path.exists("/run/udev") # udev directory is required for GDS + from fastsafetensors.common import get_device_numa_node + node = get_device_numa_node(device_index) + total_l2_size = 0 + phys_cpus = {} + failed = False + for cpudir in glob.glob(f"/sys/devices/system/node/node{node}/cpu[0-9]*"): + try: + with open(f"{cpudir}/cache/index2/size") as f: # L2 cache size for a cpu + size_str = f.read().strip() + if size_str[-1] != "K": + raise Exception(f"cannot parse {cpudir}/cache/index2/size") + total_l2_size += int(size_str[:-1]) + with open(f"{cpudir}/topology/core_id") as f: # physical core ID + phys_cpus[f.read().strip()] = True + except Exception as e: + failed = True + print(f"Failed to auto-configure fastsafetensors. reason: {e}") + break + if not failed and total_l2_size > 0: + bbuf_size_kb_total = total_l2_size + if not failed and len(phys_cpus) > 0: + max_copier_threads = len(phys_cpus) + return (nogds, max_copier_threads, bbuf_size_kb_total) class FastWeights: def __init__(self, filenames:List[str], @@ -15,10 +47,9 @@ def __init__(self, filenames:List[str], debug_log: bool=False, aliases: Optional[Dict[str, List[str]]] = None, prefix: Optional[str] = None, - nogds: bool = False, - max_copier_threads: int = 16, # should be same as the number of physical CPUs on a node - bbuf_size_kb_total = 160 * 1024, # should be same as L2 cache size ): + from fastsafetensors.loader import SafeTensorsFileLoader + (nogds, max_copier_threads, bbuf_size_kb_total) = get_config(device.index) self._loader = SafeTensorsFileLoader(pg, device, bbuf_size_kb=bbuf_size_kb_total//pg.size(), max_threads=max_copier_threads, nogds=nogds, debug_log=debug_log) rank_filenames: Dict[str, List[str]] = {rank: [] for rank in range(0, pg.size())} max_copy_block_size = 1 @@ -34,6 +65,10 @@ def __init__(self, filenames:List[str], max_copy_block_size = total_size // pg.size() // max_copier_threads if max_copy_block_size % bbuf_size_kb_total*1024 > 0: max_copy_block_size = max_copy_block_size - max_copy_block_size % (bbuf_size_kb_total*1024) + (bbuf_size_kb_total*1024) + msg = f"Fastsafetensors configuration: GDS={not nogds}, maximum number of file copy threads={max_copier_threads}, copy block size={max_copy_block_size}B" + if nogds: + msg += f", total bounce buffer size={bbuf_size_kb_total * 1024}B" + print(msg) self._fb = self._loader.copy_files_to_device(dtype, max_copy_block_size=max_copy_block_size) self.device = device self.dtype = dtype