Skip to content

use fastsafetensors #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,12 @@ COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/loca
# Install launcher
COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher

# Install cufile.so, libnuma.so, and fastsafetensors
RUN dnf config-manager \
--add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \
&& dnf install -y libcufile-12-1 numactl-libs
RUN pip install -v fastsafetensors==0.1.1 --no-cache-dir

ENV PORT=3000 \
GRPC_PORT=8033 \
HOME=/home/tgis
Expand All @@ -327,7 +333,7 @@ ENV PORT=3000 \
RUN chmod -R g+rwx ${HOME}

# Temporary for dev
RUN chmod -R g+w ${SITE_PACKAGES}/text_generation_server /usr/src /usr/local/bin
RUN chmod -R g+w ${SITE_PACKAGES}/text_generation_server /usr/src /usr/local/bin ${SITE_PACKAGES}/fastsafetensors

# Run as non-root user by default
USER tgis
Expand Down
14 changes: 10 additions & 4 deletions server/text_generation_server/inference_engine/tgis_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from text_generation_server.models import FLASH_ATTENTION, PAGED_ATTENTION
from text_generation_server.utils import Weights
from text_generation_server.utils import Weights, FastWeights

from text_generation_server.inference_engine import BaseInferenceEngine
from text_generation_server.utils.dist import initialize_torch_distributed
Expand Down Expand Up @@ -123,9 +123,15 @@ def __init__(
if not filenames:
raise ValueError("No safetensors weights found - required for tgis_native engine")

weights = Weights(
filenames, device=self.device, dtype=dtype, process_group=self.process_group, aliases=aliases
)
use_fst = os.getenv("USE_FST")
if use_fst is not None and use_fst == "1":
weights = FastWeights(
filenames, device=self.device, dtype=dtype, pg=self.process_group, aliases=aliases,
)
else:
weights = Weights(
filenames, device=self.device, dtype=dtype, process_group=self.process_group, aliases=aliases
)

if quantize == "gptq":
weights._set_gptq_params(model_config, model_path)
Expand Down
1 change: 1 addition & 0 deletions server/text_generation_server/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
RANK,
)
from text_generation_server.utils.weights import Weights
from text_generation_server.utils.fastweight import FastWeights
from text_generation_server.utils.hub import (
get_model_path,
local_weight_files,
Expand Down
118 changes: 118 additions & 0 deletions server/text_generation_server/utils/fastweight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright 2024 IBM Inc. All rights reserved
# SPDX-License-Identifier: Apache-2.0

import os
import glob
import torch
import torch.distributed as dist
from typing import List, Optional, Dict, Tuple

def get_config(device_index: int) -> Tuple[bool, int, int]:
auto_config = os.getenv("FST_CONFIG", "auto")
nogds = os.getenv("FST_NOGDS") # disable GDS if FST_NOGDS==1
nogds = nogds is not None and nogds == "1"
max_copier_threads = int(os.getenv("FST_THREADS", "16")) # number of copy threads at host CPU
bbuf_size_kb_total = int(os.getenv("FST_BBUF_SIZE_KB", "163840")) # size of bounce buffer at host memory for FST_NOGDS==1
if auto_config == "auto":
nogds = not os.path.exists("/run/udev") # udev directory is required for GDS
from fastsafetensors.common import get_device_numa_node
node = get_device_numa_node(device_index)
total_l2_size = 0
phys_cpus = {}
failed = False
for cpudir in glob.glob(f"/sys/devices/system/node/node{node}/cpu[0-9]*"):
try:
with open(f"{cpudir}/cache/index2/size") as f: # L2 cache size for a cpu
size_str = f.read().strip()
if size_str[-1] != "K":
raise Exception(f"cannot parse {cpudir}/cache/index2/size")
total_l2_size += int(size_str[:-1])
with open(f"{cpudir}/topology/core_id") as f: # physical core ID
phys_cpus[f.read().strip()] = True
except Exception as e:
failed = True
print(f"Failed to auto-configure fastsafetensors. reason: {e}")
break
if not failed and total_l2_size > 0:
bbuf_size_kb_total = total_l2_size
if not failed and len(phys_cpus) > 0:
max_copier_threads = len(phys_cpus)
return (nogds, max_copier_threads, bbuf_size_kb_total)

class FastWeights:
def __init__(self, filenames:List[str],
device: torch.device,
dtype: torch.dtype,
pg: dist.ProcessGroup,
debug_log: bool=False,
aliases: Optional[Dict[str, List[str]]] = None,
prefix: Optional[str] = None,
):
from fastsafetensors.loader import SafeTensorsFileLoader
(nogds, max_copier_threads, bbuf_size_kb_total) = get_config(device.index)
self._loader = SafeTensorsFileLoader(pg, device, bbuf_size_kb=bbuf_size_kb_total//pg.size(), max_threads=max_copier_threads, nogds=nogds, debug_log=debug_log)
rank_filenames: Dict[str, List[str]] = {rank: [] for rank in range(0, pg.size())}
max_copy_block_size = 1
total_size = 0
for idx, filename in enumerate(sorted(filenames, key=lambda x: os.path.basename(x))):
rank_filenames[idx % pg.size()].append(filename)
s = os.stat(filename)
total_size += s.st_size
if max_copy_block_size < s.st_size:
max_copy_block_size = s.st_size
self._loader.add_filenames(rank_filenames)
if len(filenames) < max_copier_threads:
max_copy_block_size = total_size // pg.size() // max_copier_threads
if max_copy_block_size % bbuf_size_kb_total*1024 > 0:
max_copy_block_size = max_copy_block_size - max_copy_block_size % (bbuf_size_kb_total*1024) + (bbuf_size_kb_total*1024)
msg = f"Fastsafetensors configuration: GDS={not nogds}, maximum number of file copy threads={max_copier_threads}, copy block size={max_copy_block_size}B"
if nogds:
msg += f", total bounce buffer size={bbuf_size_kb_total * 1024}B"
print(msg)
self._fb = self._loader.copy_files_to_device(dtype, max_copy_block_size=max_copy_block_size)
self.device = device
self.dtype = dtype
if aliases is None:
aliases = {}
self.prefix = prefix
self.aliases = aliases
self.process_group = pg

def close(self):
self._fb.close()
self._loader.close()
torch.cuda.empty_cache()

def _get_alias(self, tensor_name: str)->str:
if self._fb.get_filename(tensor_name) is None:
for alias in self.aliases[tensor_name]:
if self._fb.get_filename(alias) is not None:
return alias
raise RuntimeError(f"weight {tensor_name} does not exist")
return tensor_name

def get_shape(self, tensor_name: str)->torch.Size:
return self._fb.get_shape(self._get_alias(tensor_name))

def get_tensor(self, tensor_name: str)->torch.Tensor:
return self._fb.get_tensor(self._get_alias(tensor_name), device=self.device, dtype=self.dtype)

def push_tensor(self, tensor_name: str, dst_rank: int)->torch.Tensor:
return self._fb.push_tensor(self._get_alias(tensor_name), dst_rank, device=self.device, dtype=self.dtype)

def get_partial_sharded(self, tensor_name: str, dim: int)->torch.Tensor:
return self._fb.get_sharded(self._get_alias(tensor_name), dim, device=self.device, dtype=self.dtype)

def get_sharded(self, tensor_name: str, dim: int=1)->torch.Tensor:
return self._fb.get_sharded(self._get_alias(tensor_name), dim, device=self.device, dtype=self.dtype)

def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int)->torch.Tensor:
if quantize in ["gptq", "awq"]:
raise NotImplementedError("Quantization is not supported yet")
tensor_names = [self._get_alias(f"{prefix}.weight") for prefix in prefixes]
return self._fb.get_multi_cols(tensor_names, dim, device=self.device, dtype=self.dtype)

def get_multi_weights_row(self, prefix: str, quantize: str)->torch.Tensor:
if quantize in ["gptq", "awq"]:
raise NotImplementedError("Quantization is not supported yet")
return self._fb.get_sharded(self._get_alias(f"{prefix}.weight"), 1, device=self.device, dtype=self.dtype)
11 changes: 7 additions & 4 deletions server/text_generation_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,14 @@ def __init__(self, linear, process_group):
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
if hasattr(weights, "push_tensor"):
bias = weights.push_tensor(f"{prefix}.bias", 0) if bias else None
else:
bias = None
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return cls(
get_linear(weight, bias, config.quantize),
process_group=weights.process_group,
Expand Down