Skip to content

Commit 6f5d5e8

Browse files
committed
change: Redis vectorstore -> Qdrant vectorstore
1 parent 5b2d56f commit 6f5d5e8

19 files changed

+820
-310
lines changed

.env-sample

+5-1
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,8 @@ GOOGLE_TRANSLATE_API_KEY ="OPTIONAL_FOR_TRANSTLATION"
3434
GOOGLE_TRANSLATE_OAUTH_ID="OPTIONAL_FOR_TRANSTLATION"
3535
GOOGLE_TRANSLATE_OAUTH_SECRET="OPTIONAL_FOR_TRANSTLATION"
3636
RAPIDAPI_KEY="OPTIONAL_FOR_TRANSLATION"
37-
CUSTOM_TRANSLATE_URL="OPTIONAL_FOR_TRANSLATION"
37+
CUSTOM_TRANSLATE_URL="OPTIONAL_FOR_TRANSLATION"
38+
SUMMARIZE_FOR_CHAT=True
39+
SUMMARIZATION_THRESHOLD=512
40+
EMBEDDING_TOKEN_CHUNK_SIZE=512
41+
EMBEDDING_TOKEN_CHUNK_OVERLAP=128

app/common/config.py

+36-23
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from os import environ
66
from pathlib import Path
77
from re import Pattern, compile
8+
from typing import Optional
89
from aiohttp import ClientTimeout
910
from dotenv import load_dotenv
1011
from urllib import parse
@@ -35,7 +36,6 @@ def __call__(cls, *args, **kwargs):
3536
MAX_API_KEY: int = 3
3637
MAX_API_WHITELIST: int = 10
3738
BASE_DIR: Path = Path(__file__).parents[2]
38-
EMBEDDING_VECTOR_DIMENSION: int = 1536
3939

4040
# MySQL Variables
4141
MYSQL_ROOT_PASSWORD: str = environ["MYSQL_ROOT_PASSWORD"]
@@ -57,24 +57,29 @@ def __call__(cls, *args, **kwargs):
5757

5858

5959
# Optional Service Variables
60+
EMBEDDING_VECTOR_DIMENSION: int = 1536
61+
EMBEDDING_TOKEN_CHUNK_SIZE: int = int(environ.get("EMBEDDING_TOKEN_CHUNK_SIZE", 512))
62+
EMBEDDING_TOKEN_CHUNK_OVERLAP: int = int(environ.get("EMBEDDING_TOKEN_CHUNK_OVERLAP", 128))
63+
SUMMARIZE_FOR_CHAT: bool = environ.get("SUMMARIZE_FOR_CHAT", "True").lower() == "true"
64+
SUMMARIZATION_THRESHOLD: int = int(environ.get("SUMMARIZATION_THRESHOLD", 512))
6065
DEFAULT_LLM_MODEL: str = environ.get("DEFAULT_LLM_MODEL", "gpt_3_5_turbo")
61-
OPENAI_API_KEY: str | None = environ.get("OPENAI_API_KEY")
62-
RAPID_API_KEY: str | None = environ.get("RAPID_API_KEY")
63-
GOOGLE_TRANSLATE_API_KEY: str | None = environ.get("GOOGLE_TRANSLATE_API_KEY")
64-
PAPAGO_CLIENT_ID: str | None = environ.get("PAPAGO_CLIENT_ID")
65-
PAPAGO_CLIENT_SECRET: str | None = environ.get("PAPAGO_CLIENT_SECRET")
66-
CUSTOM_TRANSLATE_URL: str | None = environ.get("CUSTOM_TRANSLATE_URL")
67-
AWS_ACCESS_KEY: str | None = environ.get("AWS_ACCESS_KEY")
68-
AWS_SECRET_KEY: str | None = environ.get("AWS_SECRET_KEY")
69-
AWS_AUTHORIZED_EMAIL: str | None = environ.get("AWS_AUTHORIZED_EMAIL")
70-
SAMPLE_JWT_TOKEN: str | None = environ.get("SAMPLE_JWT_TOKEN")
71-
SAMPLE_ACCESS_KEY: str | None = environ.get("SAMPLE_ACCESS_KEY")
72-
SAMPLE_SECRET_KEY: str | None = environ.get("SAMPLE_SECRET_KEY")
73-
KAKAO_RESTAPI_TOKEN: str | None = environ.get("KAKAO_RESTAPI_TOKEN")
74-
WEATHERBIT_API_KEY: str | None = environ.get("WEATHERBIT_API_KEY")
75-
KAKAO_IMAGE_URL: str | None = (
76-
"http://k.kakaocdn.net/dn/wwWjr/btrYVhCnZDF/2bgXDJth2LyIajIjILhLK0/kakaolink40_original.png"
77-
)
66+
OPENAI_API_KEY: Optional[str] = environ.get("OPENAI_API_KEY")
67+
RAPID_API_KEY: Optional[str] = environ.get("RAPID_API_KEY")
68+
GOOGLE_TRANSLATE_API_KEY: Optional[str] = environ.get("GOOGLE_TRANSLATE_API_KEY")
69+
PAPAGO_CLIENT_ID: Optional[str] = environ.get("PAPAGO_CLIENT_ID")
70+
PAPAGO_CLIENT_SECRET: Optional[str] = environ.get("PAPAGO_CLIENT_SECRET")
71+
CUSTOM_TRANSLATE_URL: Optional[str] = environ.get("CUSTOM_TRANSLATE_URL")
72+
AWS_ACCESS_KEY: Optional[str] = environ.get("AWS_ACCESS_KEY")
73+
AWS_SECRET_KEY: Optional[str] = environ.get("AWS_SECRET_KEY")
74+
AWS_AUTHORIZED_EMAIL: Optional[str] = environ.get("AWS_AUTHORIZED_EMAIL")
75+
SAMPLE_JWT_TOKEN: Optional[str] = environ.get("SAMPLE_JWT_TOKEN")
76+
SAMPLE_ACCESS_KEY: Optional[str] = environ.get("SAMPLE_ACCESS_KEY")
77+
SAMPLE_SECRET_KEY: Optional[str] = environ.get("SAMPLE_SECRET_KEY")
78+
KAKAO_RESTAPI_TOKEN: Optional[str] = environ.get("KAKAO_RESTAPI_TOKEN")
79+
WEATHERBIT_API_KEY: Optional[str] = environ.get("WEATHERBIT_API_KEY")
80+
KAKAO_IMAGE_URL: Optional[
81+
str
82+
] = "http://k.kakaocdn.net/dn/wwWjr/btrYVhCnZDF/2bgXDJth2LyIajIjILhLK0/kakaolink40_original.png"
7883

7984
"""
8085
400 Bad Request
@@ -113,6 +118,10 @@ class Config(metaclass=SingletonMetaClass):
113118
redis_port: int = REDIS_PORT
114119
redis_database: int = REDIS_DATABASE
115120
redis_password: str = REDIS_PASSWORD
121+
qdrant_host: str = "vectorstore"
122+
qdrant_port: int = 6333
123+
qdrant_grpc_port: int = 6334
124+
shared_vectorestore_name: str = "SharedCollection"
116125
trusted_hosts: list[str] = field(default_factory=lambda: ["*"])
117126
allowed_sites: list[str] = field(default_factory=lambda: ["*"])
118127

@@ -121,6 +130,7 @@ def __post_init__(self):
121130
self.port = 8001
122131
self.mysql_host = "localhost"
123132
self.redis_host = "localhost"
133+
self.qdrant_host = "localhost"
124134
self.mysql_root_url = self.database_url_format.format(
125135
dialect="mysql",
126136
driver="pymysql",
@@ -149,7 +159,7 @@ def __post_init__(self):
149159

150160
@staticmethod
151161
def get(
152-
option: str | None = None,
162+
option: Optional[str] = None,
153163
) -> LocalConfig | ProdConfig | TestConfig:
154164
if environ.get("PYTEST_RUNNING") is not None:
155165
return TestConfig()
@@ -202,15 +212,16 @@ class TestConfig(Config):
202212
mysql_database: str = MYSQL_TEST_DATABASE
203213
mysql_host: str = "localhost"
204214
redis_host: str = "localhost"
215+
qdrant_host: str = "localhost"
205216
port: int = 8001
206217

207218

208219
@dataclass
209220
class LoggingConfig:
210221
logger_level: int = logging.DEBUG
211222
console_log_level: int = logging.INFO
212-
file_log_level: int | None = logging.DEBUG
213-
file_log_name: str | None = "./logs/debug.log"
223+
file_log_level: Optional[int] = logging.DEBUG
224+
file_log_name: Optional[str] = "./logs/debug.log"
214225
logging_format: str = "[%(asctime)s] %(name)s:%(levelname)s - %(message)s"
215226

216227

@@ -223,8 +234,10 @@ class ChatConfig:
223234
api_regex_pattern: Pattern = compile(r"data:\s*({.+?})\n\n")
224235
extra_token_margin: int = 512 # number of tokens to remove when tokens exceed token limit
225236
continue_message: str = "...[CONTINUED]" # message to append when tokens exceed token limit
226-
summarize_for_chat: bool = True # whether to summarize chat messages
227-
summarization_threshold: int = 512 # token threshold for summarization. if message tokens exceed this, summarize
237+
summarize_for_chat: bool = SUMMARIZE_FOR_CHAT # whether to summarize chat messages
238+
summarization_threshold: int = (
239+
SUMMARIZATION_THRESHOLD # token threshold for summarization. if message tokens exceed this, summarize
240+
)
228241
summarization_openai_model: str = "gpt-3.5-turbo"
229242
summarization_token_limit: int = 2048 # token limit for summarization
230243
summarization_token_overlap: int = 100 # number of tokens to overlap between chunks

app/database/connection.py

+25-26
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from collections.abc import Iterable
33
from typing import Any, AsyncGenerator, Callable, Optional, Type
44

5-
from langchain.embeddings.base import Embeddings
6-
from redis.asyncio import Redis as AsyncRedisType
5+
from qdrant_client import QdrantClient
6+
from redis.asyncio import Redis, from_url
77
from sqlalchemy import Delete, Result, ScalarResult, Select, TextClause, Update, create_engine, text
88
from sqlalchemy.engine.base import Connection, Engine
99
from sqlalchemy.ext.asyncio import (
@@ -16,9 +16,9 @@
1616
from sqlalchemy_utils import create_database, database_exists
1717

1818
from app.common.config import Config, SingletonMetaClass, logging_config
19-
from app.shared import Shared
2019
from app.errors.api_exceptions import Responses_500
21-
from app.utils.langchain.redis_vectorstore import Redis as RedisVectorStore
20+
from app.shared import Shared
21+
from app.utils.langchain.qdrant_vectorstore import Qdrant
2222
from app.utils.logger import CustomLogger
2323

2424
from . import Base, DeclarativeMeta
@@ -338,50 +338,49 @@ async def scalars__one_or_none(
338338
return (await self.run_in_session(self._scalars)(session, stmt=stmt)).one_or_none()
339339

340340

341-
class RedisFactory(metaclass=SingletonMetaClass):
341+
class CacheFactory(metaclass=SingletonMetaClass):
342342
def __init__(self):
343-
self._vectorstore: RedisVectorStore | None = None
343+
self._vectorstore: Optional[Qdrant] = None
344344
self.is_test_mode: bool = False
345345
self.is_initiated: bool = False
346346

347347
def start(
348348
self,
349349
config: Config,
350-
content_key: str = "content",
351-
metadata_key: str = "metadata",
352-
vector_key: str = "content_vector",
353350
) -> None:
354351
if self.is_initiated:
355352
return
356353
self.is_test_mode = True if config.test_mode else False
357-
embeddings: Embeddings = Shared().openai_embeddings
358-
self._vectorstore = RedisVectorStore( # type: ignore
359-
redis_url=config.redis_url,
360-
embedding_function=embeddings.embed_query,
361-
content_key=content_key,
362-
metadata_key=metadata_key,
363-
vector_key=vector_key,
364-
is_async=True,
354+
self._redis = from_url(url=config.redis_url)
355+
self._vectorstore = Qdrant(
356+
client=QdrantClient(
357+
host=config.qdrant_host,
358+
port=config.qdrant_port,
359+
grpc_port=config.qdrant_grpc_port,
360+
prefer_grpc=True,
361+
),
362+
collection_name=config.shared_vectorestore_name,
363+
embeddings=Shared().openai_embeddings,
365364
)
366365
self.is_initiated = True
367366

368367
async def close(self) -> None:
369-
if self._vectorstore is not None:
370-
assert isinstance(self._vectorstore.client, AsyncRedisType)
371-
await self._vectorstore.client.close()
368+
if self._redis is not None:
369+
assert isinstance(self._redis, Redis)
370+
await self._redis.close()
372371
self.is_initiated = False
373372

374373
@property
375-
def redis(self) -> AsyncRedisType:
374+
def redis(self) -> Redis:
376375
try:
377-
assert self._vectorstore is not None
378-
assert isinstance(self._vectorstore.client, AsyncRedisType)
376+
assert self._redis is not None
377+
assert isinstance(self._redis, Redis)
379378
except AssertionError:
380379
raise Responses_500.cache_not_initialized
381-
return self._vectorstore.client
380+
return self._redis
382381

383382
@property
384-
def vectorstore(self) -> RedisVectorStore:
383+
def vectorstore(self) -> Qdrant:
385384
try:
386385
assert self._vectorstore is not None
387386
except AssertionError:
@@ -390,4 +389,4 @@ def vectorstore(self) -> RedisVectorStore:
390389

391390

392391
db: SQLAlchemy = SQLAlchemy()
393-
cache: RedisFactory = RedisFactory()
392+
cache: CacheFactory = CacheFactory()

app/models/llms.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,10 @@ class LLMModels(Enum):
190190
token_margin=8,
191191
tokenizer=LlamaTokenizer("timdettmers/guanaco-65b-merged"), # timdettmers/guanaco-13b
192192
model_path="./llama_models/ggml/guanaco-13B.ggmlv3.q5_1.bin",
193+
description=DESCRIPTION_TMPL2,
193194
user_chat_roles=UserChatRoles(
194-
user="Instruction",
195-
ai="Response",
195+
user="Human",
196+
ai="Assistant",
196197
system="System",
197198
),
198199
)

app/utils/chat/buffer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
from dataclasses import dataclass, field
3-
from typing import Any, Awaitable, Callable
3+
from typing import Any, Awaitable, Callable, Optional
44

55
from fastapi import WebSocket
66

@@ -45,6 +45,7 @@ class BufferedUserContext:
4545
queue: asyncio.Queue = field(default_factory=asyncio.Queue)
4646
done: asyncio.Event = field(default_factory=asyncio.Event)
4747
task_list: list[asyncio.Task[Any]] = field(default_factory=list) # =
48+
last_user_message: Optional[str] = None
4849
_sorted_ctxts: ContextList = field(init=False)
4950
_current_ctxt: UserChatContext = field(init=False)
5051

app/utils/chat/chat_commands.py

+25-13
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
from fastapi import WebSocket
1010
from fastapi.concurrency import run_in_threadpool
1111

12-
from app.common.constants import CODEX_PROMPT, QUERY_TMPL1, REDEX_PROMPT, CHAT_TURN_TMPL1
12+
from app.common.config import config
13+
from app.common.constants import CHAT_TURN_TMPL1, CODEX_PROMPT, QUERY_TMPL1, REDEX_PROMPT
1314
from app.database.schemas.auth import UserStatus
1415
from app.errors.api_exceptions import InternalServerError
1516
from app.models.chat_models import ChatRoles, LLMModels, MessageHistory, UserChatContext
1617
from app.shared import Shared
18+
from app.utils.api.translate import Translator
1719
from app.utils.chat.buffer import BufferedUserContext
1820
from app.utils.chat.cache_manager import CacheManager
1921
from app.utils.chat.message_handler import MessageHandler
@@ -580,19 +582,27 @@ async def query(query: str, /, buffer: BufferedUserContext, **kwargs) -> Tuple[s
580582
return query, ResponseType.REPEAT_COMMAND
581583

582584
k: int = 3
585+
if kwargs.get("translate", False):
586+
query = await Translator.translate(text=query, src_lang="ko", trg_lang="en")
587+
await SendToWebsocket.message(
588+
websocket=buffer.websocket,
589+
msg=f"## 번역된 질문\n\n{query}\n\n## 생성된 답변\n\n",
590+
chat_room_id=buffer.current_chat_room_id,
591+
finish=False,
592+
model_name=buffer.current_user_chat_context.llm_model.value.name,
593+
)
583594
found_text_and_score: list[
584-
list[Tuple[Document, float]]
585-
] = await VectorStoreManager.asimilarity_search_multiple_index_with_score(
586-
queries=[query], index_names=[buffer.user_id, ""], k=k
595+
Tuple[Document, float]
596+
] = await VectorStoreManager.asimilarity_search_multiple_collections_with_score(
597+
query=query, collection_names=[buffer.user_id, config.shared_vectorestore_name], k=k
587598
) # lower score is the better!
588-
print(found_text_and_score)
589599

590-
if len(found_text_and_score[0]) > 0:
591-
found_text: str = "\n\n".join([document.page_content for document, _ in found_text_and_score[0]])
600+
if len(found_text_and_score) > 0:
601+
found_text: str = "\n\n".join([document.page_content for document, _ in found_text_and_score])
592602
context_and_query: str = QUERY_TMPL1.format(question=query, context=found_text)
593603
await MessageHandler.user(
594604
msg=context_and_query,
595-
translate=kwargs.get("translate", False),
605+
translate=False,
596606
buffer=buffer,
597607
)
598608
await MessageHandler.ai(
@@ -621,15 +631,15 @@ async def query(query: str, /, buffer: BufferedUserContext, **kwargs) -> Tuple[s
621631
async def embed(text_to_embed: str, /, buffer: BufferedUserContext) -> str:
622632
"""Embed the text and save its vectors in the redis vectorstore.\n
623633
/embed <text_to_embed>"""
624-
await VectorStoreManager.create_documents(text=text_to_embed, index_name=buffer.user_id)
634+
await VectorStoreManager.create_documents(text=text_to_embed, collection_name=buffer.user_id)
625635
return "Embedding successful!"
626636

627637
@staticmethod
628638
@CommandResponse.send_message_and_stop
629639
async def share(text_to_embed: str, /) -> str:
630640
"""Embed the text and save its vectors in the redis vectorstore. This index is shared for everyone.\n
631641
/share <text_to_embed>"""
632-
await VectorStoreManager.create_documents(text=text_to_embed, index_name="")
642+
await VectorStoreManager.create_documents(text=text_to_embed, collection_name=config.shared_vectorestore_name)
633643
return "Embedding successful! This data will be shared for everyone."
634644

635645
@staticmethod
@@ -638,10 +648,12 @@ async def drop(buffer: BufferedUserContext) -> str:
638648
"""Drop the index from the redis vectorstore.\n
639649
/drop"""
640650
dropped_index: list[str] = []
641-
if await VectorStoreManager.drop_index(index_name=buffer.user_id):
651+
if await VectorStoreManager.delete_collection(collection_name=buffer.user_id):
642652
dropped_index.append(buffer.user_id)
643-
if buffer.user.status is UserStatus.admin and await VectorStoreManager.drop_index(index_name=""):
644-
dropped_index.append("shared")
653+
if buffer.user.status is UserStatus.admin and await VectorStoreManager.delete_collection(
654+
collection_name=config.shared_vectorestore_name,
655+
):
656+
dropped_index.append(config.shared_vectorestore_name)
645657
if not dropped_index:
646658
return "No index dropped."
647659
return f"Index dropped: {', '.join(dropped_index)}"

app/utils/chat/llama_cpp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def get_generator() -> Iterator[Any]:
155155
else:
156156
real_max_tokens = max_tokens
157157
if real_max_tokens <= 0:
158-
raise ChatLengthException()
158+
raise ChatLengthException(msg=content_buffer)
159159
return llm_client.create_completion( # type: ignore
160160
prompt=prompt,
161161
suffix=llm.suffix,

app/utils/chat/stream_manager.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ async def harvest_done_tasks(buffer: BufferedUserContext) -> None:
8787
api_logger.exception(f"Some error occurred while running update tasks: {result}")
8888
except Exception as e:
8989
api_logger.exception(f"Unexpected error occurred while running update tasks: {e}")
90-
91-
buffer.task_list = [task for task in buffer.task_list if task not in harvested_tasks]
90+
finally:
91+
buffer.task_list = [task for task in buffer.task_list if task not in harvested_tasks]
9292

9393

9494
class ChatStreamManager:
@@ -173,7 +173,7 @@ async def _websocket_receiver(buffer: BufferedUserContext) -> None:
173173
elif received_bytes is not None:
174174
await buffer.queue.put(
175175
await VectorStoreManager.embed_file_to_vectorstore(
176-
file=received_bytes, filename=filename, index_name=buffer.current_user_chat_context.user_id
176+
file=received_bytes, filename=filename, collection_name=buffer.current_user_chat_context.user_id
177177
)
178178
)
179179

@@ -213,6 +213,7 @@ async def _websocket_sender(cls, buffer: BufferedUserContext) -> None:
213213
buffer=buffer,
214214
)
215215
else:
216+
buffer.last_user_message = item.msg
216217
await MessageHandler.user(
217218
msg=item.msg,
218219
translate=item.translate,

0 commit comments

Comments
 (0)