1
1
from fastapi import WebSocket , WebSocketDisconnect
2
-
2
+ from fastapi .concurrency import run_in_threadpool
3
+ from pydantic import ValidationError
3
4
from app .errors .gpt_exceptions import GptOtherException , GptTextGenerationException , GptTooMuchTokenException
5
+ from app .utils .chatgpt .chatgpt_buffer import BufferedUserContext
4
6
from app .utils .chatgpt .chatgpt_cache_manager import ChatGptCacheManager
5
7
from app .utils .chatgpt .chatgpt_commands import (
8
+ ChatGptCommands ,
9
+ command_handler ,
6
10
create_new_chat_room ,
7
11
get_contexts_sorted_from_recent_to_past ,
8
- command_handler ,
9
12
)
13
+ from app .utils .chatgpt .chatgpt_fileloader import read_bytes_to_text
10
14
from app .utils .chatgpt .chatgpt_message_manager import MessageManager
15
+ from app .utils .chatgpt .chatgpt_vectorstore_manager import VectorStoreManager
11
16
from app .utils .chatgpt .chatgpt_websocket_manager import HandleMessage , SendToWebsocket
12
17
from app .utils .logger import api_logger
13
- from app .utils .chatgpt .chatgpt_buffer import BufferedUserContext
14
18
from app .viewmodels .base_models import MessageFromWebsocket , MessageToWebsocket
15
19
from app .viewmodels .gpt_models import GptRoles
16
20
@@ -33,8 +37,20 @@ async def begin_chat(
33
37
34
38
while True : # loop until connection is closed
35
39
try :
36
- # receive message from websocket
37
- received : MessageFromWebsocket = MessageFromWebsocket .parse_raw (await websocket .receive_text ())
40
+ rcvd : dict = await websocket .receive_json ()
41
+ assert isinstance (rcvd , dict )
42
+ if "filename" in rcvd :
43
+ text : str = await run_in_threadpool (
44
+ read_bytes_to_text , await websocket .receive_bytes (), rcvd ["filename" ]
45
+ )
46
+ docs : list [str ] = await VectorStoreManager .create_documents (text )
47
+ await SendToWebsocket .message (
48
+ websocket = websocket ,
49
+ msg = f"Successfully embedded documents. You uploaded file begins with...\n \n ```{ docs [0 ][:50 ]} ```..." ,
50
+ chat_room_id = buffer .current_chat_room_id ,
51
+ )
52
+ continue
53
+ received : MessageFromWebsocket = MessageFromWebsocket (** rcvd )
38
54
39
55
if received .chat_room_id != buffer .current_chat_room_id : # change chat room
40
56
index : int | None = buffer .find_index_of_chatroom (received .chat_room_id )
@@ -76,6 +92,20 @@ async def begin_chat(
76
92
77
93
except WebSocketDisconnect :
78
94
raise WebSocketDisconnect (code = 1000 , reason = "client disconnected" )
95
+ except (AssertionError , ValidationError ):
96
+ await SendToWebsocket .message (
97
+ websocket = websocket ,
98
+ msg = "Invalid message. Message is not in the correct format, maybe frontend - backend version mismatch?" ,
99
+ chat_room_id = buffer .current_chat_room_id ,
100
+ )
101
+ continue
102
+ except ValueError :
103
+ await SendToWebsocket .message (
104
+ websocket = websocket ,
105
+ msg = "Invalid file type." ,
106
+ chat_room_id = buffer .current_chat_room_id ,
107
+ )
108
+ continue
79
109
except GptTextGenerationException :
80
110
await MessageManager .rpop_message_history_safely (
81
111
user_gpt_context = buffer .current_user_gpt_context , role = GptRoles .USER
0 commit comments