Skip to content

Commit 9b42933

Browse files
committed
llama-cpp API server added
1 parent 5393c85 commit 9b42933

38 files changed

+1413
-394
lines changed

.dockerignore

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
.env
2+
llama_models/*

.gitignore

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,8 @@ PRIVATE_*
44
venv/
55
*.pyc
66
*.log
7-
*.bin
7+
llama_models/ggml/*
8+
llama_models/gptq/*
9+
!llama_models/ggml/llama_cpp_models_here.txt
10+
!llama_models/gptq/gptq_models_here.txt
811
deprecated_*

app/common/app_settings.py

+89-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
from concurrent.futures import ProcessPoolExecutor
2+
from concurrent.futures.process import BrokenProcessPool
3+
from threading import Event
4+
from threading import Thread
5+
from time import sleep
6+
from urllib import parse
7+
8+
import requests
19
from fastapi import Depends, FastAPI
210
from fastapi.staticfiles import StaticFiles
311
from starlette.middleware import Middleware
@@ -15,14 +23,65 @@
1523
from app.dependencies import USER_DEPENDENCY, api_service_dependency
1624
from app.middlewares.token_validator import access_control
1725
from app.middlewares.trusted_hosts import TrustedHostMiddleware
18-
from app.routers import auth, index, services, users, user_services, websocket
26+
from app.routers import auth, index, services, user_services, users, websocket
1927
from app.shared import Shared
2028
from app.utils.chat.managers.cache import CacheManager
2129
from app.utils.js_initializer import js_url_initializer
2230
from app.utils.logger import api_logger
2331
from app.viewmodels.admin import ApiKeyAdminView, UserAdminView
2432

2533

34+
def check_health(url: str) -> bool:
35+
try:
36+
schema = parse.urlparse(url).scheme
37+
netloc = parse.urlparse(url).netloc
38+
if requests.get(f"{schema}://{netloc}/health").status_code != 200:
39+
return False
40+
return True
41+
except Exception:
42+
return False
43+
44+
45+
def start_llama_cpp_server():
46+
from app.start_llama_cpp_server import run
47+
48+
api_logger.critical("Starting Llama CPP server")
49+
try:
50+
Shared().process_pool_executor.submit(
51+
run,
52+
terminate_event=Shared().process_terminate_signal,
53+
)
54+
except BrokenProcessPool as e:
55+
api_logger.exception(f"Broken Llama CPP server: {e}")
56+
Shared().process_pool_executor.shutdown(wait=False)
57+
Shared().process_pool_executor = ProcessPoolExecutor()
58+
start_llama_cpp_server()
59+
except Exception as e:
60+
api_logger.exception(f"Failed to start Llama CPP server: {e}")
61+
62+
63+
def shutdown_llama_cpp_server():
64+
api_logger.critical("Shutting down Llama CPP server")
65+
Shared().process_terminate_signal.set()
66+
67+
68+
def monitor_llama_cpp_server(config: Config, terminate_signal: Event) -> None:
69+
while not terminate_signal.is_set():
70+
sleep(0.5)
71+
if config.llama_cpp_api_url:
72+
if not check_health(config.llama_cpp_api_url):
73+
if config.is_llama_cpp_booting or terminate_signal.is_set():
74+
continue
75+
api_logger.error("Llama CPP server is not available")
76+
config.llama_cpp_available = False
77+
config.is_llama_cpp_booting = True
78+
start_llama_cpp_server()
79+
else:
80+
config.is_llama_cpp_booting = False
81+
config.llama_cpp_available = True
82+
shutdown_llama_cpp_server()
83+
84+
2685
def create_app(config: Config) -> FastAPI:
2786
# Initialize app & db & js
2887
new_app = FastAPI(
@@ -132,11 +191,38 @@ async def startup():
132191
except ImportError:
133192
api_logger.critical("uvloop not installed!")
134193

194+
if config.llama_cpp_api_url:
195+
# Start Llama CPP server monitoring
196+
api_logger.critical("Llama CPP server monitoring started!")
197+
Shared().thread = Thread(
198+
target=monitor_llama_cpp_server,
199+
args=(config, Shared().thread_terminate_signal),
200+
)
201+
Shared().thread.start()
202+
135203
@new_app.on_event("shutdown")
136204
async def shutdown():
137205
# await CacheManager.delete_user(f"testaccount@{HOST_MAIN}")
138-
Shared().process_manager.shutdown()
139-
Shared().process_pool_executor.shutdown()
206+
Shared().thread_terminate_signal.set()
207+
Shared().process_terminate_signal.set()
208+
209+
process_manager = Shared()._process_manager
210+
if process_manager is not None:
211+
process_manager.shutdown()
212+
213+
process_pool_executor = Shared()._process_pool_executor
214+
if process_pool_executor is not None:
215+
process_pool_executor.shutdown(wait=False)
216+
217+
process = Shared()._process
218+
if process is not None:
219+
process.terminate()
220+
process.join()
221+
222+
thread = Shared()._thread
223+
if thread is not None:
224+
thread.join()
225+
140226
await db.close()
141227
await cache.close()
142228
api_logger.critical("DB & CACHE connection closed!")

app/common/config.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
from pathlib import Path
77
from re import Pattern, compile
88
from typing import Optional
9+
from urllib import parse
10+
11+
import requests
912
from aiohttp import ClientTimeout
1013
from dotenv import load_dotenv
11-
from urllib import parse
1214

1315
load_dotenv()
1416

@@ -141,8 +143,11 @@ class Config(metaclass=SingletonMetaClass):
141143
shared_vectorestore_name: str = QDRANT_COLLECTION
142144
trusted_hosts: list[str] = field(default_factory=lambda: ["*"])
143145
allowed_sites: list[str] = field(default_factory=lambda: ["*"])
146+
llama_cpp_api_url: Optional[str] = "http://localhost:8002/v1/completions"
144147

145148
def __post_init__(self):
149+
self.llama_cpp_available: bool = self.llama_cpp_api_url is not None
150+
self.is_llama_cpp_booting: bool = False
146151
if not DOCKER_MODE:
147152
self.port = 8001
148153
self.mysql_host = "localhost"
@@ -248,7 +253,12 @@ class ChatConfig:
248253
timeout: ClientTimeout = ClientTimeout(sock_connect=30.0, sock_read=20.0)
249254
read_timeout: float = 30.0 # wait for this time before timeout
250255
wait_for_reconnect: float = 3.0 # wait for this time before reconnecting
251-
api_regex_pattern: Pattern = compile(r"data:\s*({.+?})\n\n")
256+
api_regex_pattern_openai: Pattern = compile(
257+
r"data:\s*({.+?})\n\n"
258+
) # regex pattern to extract json from openai api response
259+
api_regex_pattern_llama_cpp: Pattern = compile(
260+
r"data:\s*({.+?})\r\n\r\n"
261+
) # regex pattern to extract json from llama cpp api response
252262
extra_token_margin: int = (
253263
512 # number of tokens to remove when tokens exceed token limit
254264
)

app/common/constants.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class QueryTemplates:
1111
"\n---\n"
1212
"{context}"
1313
"\n---\n"
14-
"Answer the question in detail: {question}\n"
14+
"Answer the question in as much detail as possible: {question}\n"
1515
),
1616
input_variables=["context", "question"],
1717
template_format="f-string",
@@ -23,7 +23,7 @@ class QueryTemplates:
2323
"{context}"
2424
"\n---\n"
2525
"Given the context information and not prior knowledge, "
26-
"answer the question in detail: {question}\n"
26+
"answer the question in as much detail as possible:: {question}\n"
2727
),
2828
input_variables=["context", "question"],
2929
template_format="f-string",

app/contents/llama_api.png

7.29 KB
Loading

app/database/schemas/auth.py

+11-16
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,12 @@
1111
Mapped,
1212
mapped_column,
1313
)
14+
15+
from app.viewmodels.status import ApiKeyStatus, UserStatus
1416
from .. import Base
1517
from . import Mixin
1618

1719

18-
class UserStatus(str, enum.Enum):
19-
admin = "admin"
20-
active = "active"
21-
deleted = "deleted"
22-
blocked = "blocked"
23-
24-
25-
class ApiKeyStatus(str, enum.Enum):
26-
active = "active"
27-
stopped = "stopped"
28-
deleted = "deleted"
29-
30-
3120
class Users(Base, Mixin):
3221
__tablename__ = "users"
3322
status: Mapped[str] = mapped_column(Enum(UserStatus), default=UserStatus.active)
@@ -37,7 +26,9 @@ class Users(Base, Mixin):
3726
phone_number: Mapped[str | None] = mapped_column(String(length=20))
3827
profile_img: Mapped[str | None] = mapped_column(String(length=100))
3928
marketing_agree: Mapped[bool] = mapped_column(Boolean, default=True)
40-
api_keys: Mapped["ApiKeys"] = relationship(back_populates="users", cascade="all, delete-orphan", lazy=True)
29+
api_keys: Mapped["ApiKeys"] = relationship(
30+
back_populates="users", cascade="all, delete-orphan", lazy=True
31+
)
4132
# chat_rooms: Mapped["ChatRooms"] = relationship(back_populates="users", cascade="all, delete-orphan", lazy=True)
4233
# chat_messages: Mapped["ChatMessages"] = relationship(
4334
# back_populates="users", cascade="all, delete-orphan", lazy=True
@@ -56,12 +47,16 @@ class ApiKeys(Base, Mixin):
5647
is_whitelisted: Mapped[bool] = mapped_column(default=False)
5748
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"))
5849
users: Mapped["Users"] = relationship(back_populates="api_keys")
59-
whitelists: Mapped["ApiWhiteLists"] = relationship(backref="api_keys", cascade="all, delete-orphan")
50+
whitelists: Mapped["ApiWhiteLists"] = relationship(
51+
backref="api_keys", cascade="all, delete-orphan"
52+
)
6053

6154

6255
class ApiWhiteLists(Base, Mixin):
6356
__tablename__ = "api_whitelists"
64-
api_key_id: Mapped[int] = mapped_column(Integer, ForeignKey("api_keys.id", ondelete="CASCADE"))
57+
api_key_id: Mapped[int] = mapped_column(
58+
Integer, ForeignKey("api_keys.id", ondelete="CASCADE")
59+
)
6560
ip_address: Mapped[str] = mapped_column(String(length=64))
6661

6762

app/models/base_models.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from pydantic import Field
77
from pydantic.main import BaseModel
88

9-
from app.database.schemas.auth import UserStatus
109
from app.utils.date_utils import UTC
10+
from app.viewmodels.status import UserStatus
1111

1212
JSON_TYPES = Union[int, float, str, bool, dict, list, None]
1313

@@ -135,7 +135,7 @@ class Config:
135135
orm_mode = True
136136

137137

138-
class OpenAIChatMessage(BaseModel):
138+
class APIChatMessage(BaseModel):
139139
role: str
140140
content: str
141141

@@ -146,10 +146,10 @@ class Config:
146146
class MessageHistory(BaseModel):
147147
role: str
148148
content: str
149-
tokens: int
150-
actual_role: str
149+
tokens: int = 0
151150
timestamp: int = Field(default_factory=UTC.timestamp)
152151
uuid: str = Field(default_factory=lambda: uuid4().hex)
152+
actual_role: Optional[str] = None
153153
model_name: Optional[str] = None
154154
summarized: Optional[str] = None
155155
summarized_tokens: Optional[int] = None

app/models/llm_tokenizers.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
from abc import ABC, abstractmethod
2-
from typing import Type
2+
from typing import TYPE_CHECKING, Type
33

4-
from tiktoken import Encoding, encoding_for_model
5-
from transformers.models.llama import LlamaTokenizer as _LlamaTokenizer
6-
7-
from app.utils.chat.llama_cpp import LlamaTokenizerAdapter
84

95
from app.utils.logger import api_logger
10-
from app.shared import Shared
6+
7+
if TYPE_CHECKING:
8+
from tiktoken import Encoding
9+
from app.utils.chat.text_generations._llama_cpp import LlamaTokenizerAdapter
1110

1211

1312
class BaseTokenizer(ABC):
@@ -39,11 +38,16 @@ def split_text_on_tokens(
3938
chunk_ids = input_ids[start_idx:cur_idx]
4039
return splits
4140

41+
def get_chunk_of(self, text: str, tokens: int) -> str:
42+
"""Split incoming text and return chunks."""
43+
input_ids = self.encode(text)
44+
return self.decode(input_ids[: min(tokens, len(input_ids))])
45+
4246

4347
class OpenAITokenizer(BaseTokenizer):
4448
def __init__(self, model_name: str):
4549
self.model_name = model_name
46-
self._tokenizer: Encoding | None = None
50+
self._tokenizer: "Encoding" | None = None
4751

4852
def encode(self, message: str, /) -> list[int]:
4953
return self.tokenizer.encode(message)
@@ -59,7 +63,9 @@ def vocab_size(self) -> int:
5963
return self.tokenizer.n_vocab
6064

6165
@property
62-
def tokenizer(self) -> Encoding:
66+
def tokenizer(self) -> "Encoding":
67+
from tiktoken import encoding_for_model
68+
6369
if self._tokenizer is None:
6470
print("Loading tokenizer: ", self.model_name)
6571
self._tokenizer = encoding_for_model(self.model_name)
@@ -69,7 +75,7 @@ def tokenizer(self) -> Encoding:
6975
class LlamaTokenizer(BaseTokenizer):
7076
def __init__(self, model_name: str):
7177
self.model_name = model_name
72-
self._tokenizer: Encoding | None = None
78+
self._tokenizer: "Encoding" | None = None
7379

7480
def encode(self, message: str, /) -> list[int]:
7581
return self.tokenizer.encode(message)
@@ -85,7 +91,9 @@ def vocab_size(self) -> int:
8591
return self.tokenizer.n_vocab
8692

8793
@property
88-
def tokenizer(self) -> Encoding:
94+
def tokenizer(self) -> "Encoding":
95+
from transformers.models.llama import LlamaTokenizer as _LlamaTokenizer
96+
8997
if self._tokenizer is None:
9098
split_str = self.model_name.split("/")
9199

@@ -118,6 +126,7 @@ def __init__(self, llama_cpp_model_name: str):
118126
def encode(self, message: str, /) -> list[int]:
119127
from app.models.llms import LLMModels
120128
from app.models.llms import LlamaCppModel
129+
from app.shared import Shared
121130

122131
llama_cpp_model = LLMModels.find_model_by_name(self.llama_cpp_model_name)
123132
assert isinstance(llama_cpp_model, LlamaCppModel), type(llama_cpp_model)
@@ -135,5 +144,7 @@ def tokens_of(self, message: str) -> int:
135144
return len(self.encode(message))
136145

137146
@property
138-
def tokenizer(self) -> Type[LlamaTokenizerAdapter]:
147+
def tokenizer(self) -> Type["LlamaTokenizerAdapter"]:
148+
from app.utils.chat.text_generations._llama_cpp import LlamaTokenizerAdapter
149+
139150
return LlamaTokenizerAdapter

0 commit comments

Comments
 (0)