2
2
from typing import Annotated , Literal , Sequence , TypedDict
3
3
4
4
import weaviate
5
+ from langchain_anthropic import ChatAnthropic
6
+ from langchain_cohere import ChatCohere
7
+ from langchain_community .vectorstores import Weaviate
5
8
from langchain_core .documents import Document
6
9
from langchain_core .language_models import LanguageModelLike
7
- from langchain_core .messages import BaseMessage , AIMessage , HumanMessage , convert_to_messages
10
+ from langchain_core .messages import (
11
+ AIMessage ,
12
+ BaseMessage ,
13
+ HumanMessage ,
14
+ convert_to_messages ,
15
+ )
8
16
from langchain_core .output_parsers import StrOutputParser
9
17
from langchain_core .prompts import (
10
18
ChatPromptTemplate ,
11
19
MessagesPlaceholder ,
12
20
PromptTemplate ,
13
21
)
14
22
from langchain_core .retrievers import BaseRetriever
15
- from langchain_openai import ChatOpenAI
16
- from langchain_cohere import ChatCohere
17
- from langchain_anthropic import ChatAnthropic
18
23
from langchain_fireworks import ChatFireworks
19
24
from langchain_google_genai import ChatGoogleGenerativeAI
20
- from langchain_community . vectorstores import Weaviate
21
- from langgraph .graph import StateGraph , END , add_messages
25
+ from langchain_openai import ChatOpenAI
26
+ from langgraph .graph import END , StateGraph , add_messages
22
27
23
- from backend .ingest import get_embeddings_model
24
28
from backend .constants import WEAVIATE_DOCS_INDEX_NAME
25
-
29
+ from backend . ingest import get_embeddings_model
26
30
27
31
WEAVIATE_URL = os .environ ["WEAVIATE_URL" ]
28
32
WEAVIATE_API_KEY = os .environ ["WEAVIATE_API_KEY" ]
@@ -143,7 +147,9 @@ def get_model(model_name: str) -> LanguageModelLike:
143
147
GOOGLE_MODEL_KEY : gemini_pro ,
144
148
COHERE_MODEL_KEY : cohere_command ,
145
149
}[model_name ]
146
- llm = llm .with_fallbacks ([gpt_3_5 , claude_3_haiku , fireworks_mixtral , gemini_pro , cohere_command ])
150
+ llm = llm .with_fallbacks (
151
+ [gpt_3_5 , claude_3_haiku , fireworks_mixtral , gemini_pro , cohere_command ]
152
+ )
147
153
return llm
148
154
149
155
@@ -176,11 +182,7 @@ def retrieve_documents(state: AgentState):
176
182
messages = convert_to_messages (state ["messages" ])
177
183
query = messages [- 1 ].content
178
184
relevant_documents = retriever .get_relevant_documents (query )
179
- return {
180
- "query" : query ,
181
- "documents" : relevant_documents ,
182
- "messages" : []
183
- }
185
+ return {"query" : query , "documents" : relevant_documents , "messages" : []}
184
186
185
187
186
188
def retrieve_documents_with_chat_history (state : AgentState , config ):
@@ -189,22 +191,24 @@ def retrieve_documents_with_chat_history(state: AgentState, config):
189
191
model = get_model (model_name ).with_config (tags = ["nostream" ])
190
192
191
193
CONDENSE_QUESTION_PROMPT = PromptTemplate .from_template (REPHRASE_TEMPLATE )
192
- condense_question_chain = (CONDENSE_QUESTION_PROMPT | model | StrOutputParser ()).with_config (
194
+ condense_question_chain = (
195
+ CONDENSE_QUESTION_PROMPT | model | StrOutputParser ()
196
+ ).with_config (
193
197
run_name = "CondenseQuestion" ,
194
198
)
195
199
196
200
messages = convert_to_messages (state ["messages" ])
197
201
query = messages [- 1 ].content
198
202
retriever_with_condensed_question = condense_question_chain | retriever
199
- relevant_documents = retriever_with_condensed_question .invoke ({"question" : query , "chat_history" : get_chat_history (messages )})
200
- return {
201
- "query" : query ,
202
- "documents" : relevant_documents ,
203
- "messages" : []
204
- }
203
+ relevant_documents = retriever_with_condensed_question .invoke (
204
+ {"question" : query , "chat_history" : get_chat_history (messages )}
205
+ )
206
+ return {"query" : query , "documents" : relevant_documents , "messages" : []}
205
207
206
208
207
- def route_to_retriever (state : AgentState ) -> Literal ["retriever" , "retriever_with_chat_history" ]:
209
+ def route_to_retriever (
210
+ state : AgentState ,
211
+ ) -> Literal ["retriever" , "retriever_with_chat_history" ]:
208
212
if len (state ["messages" ]) == 1 :
209
213
return "retriever"
210
214
else :
@@ -219,7 +223,9 @@ def get_chat_history(messages: Sequence[BaseMessage]) -> Sequence[BaseMessage]:
219
223
return chat_history
220
224
221
225
222
- def synthesize_response (state : AgentState , model : LanguageModelLike , prompt_template : str ):
226
+ def synthesize_response (
227
+ state : AgentState , model : LanguageModelLike , prompt_template : str
228
+ ):
223
229
prompt = ChatPromptTemplate .from_messages (
224
230
[
225
231
("system" , prompt_template ),
@@ -228,11 +234,13 @@ def synthesize_response(state: AgentState, model: LanguageModelLike, prompt_temp
228
234
]
229
235
)
230
236
response_synthesizer = prompt | model
231
- synthesized_response = response_synthesizer .invoke ({
232
- "question" : state ["query" ],
233
- "context" : format_docs (state ["documents" ]),
234
- "chat_history" : get_chat_history (convert_to_messages (state ["messages" ]))
235
- })
237
+ synthesized_response = response_synthesizer .invoke (
238
+ {
239
+ "question" : state ["query" ],
240
+ "context" : format_docs (state ["documents" ]),
241
+ "chat_history" : get_chat_history (convert_to_messages (state ["messages" ])),
242
+ }
243
+ )
236
244
return {
237
245
** state ,
238
246
"messages" : [synthesized_response ],
@@ -250,7 +258,9 @@ def synthesize_response_cohere(state: AgentState):
250
258
return synthesize_response (state , model , COHERE_RESPONSE_TEMPLATE )
251
259
252
260
253
- def route_to_response_synthesizer (state : AgentState , config ) -> Literal ["response_synthesizer" , "response_synthesizer_cohere" ]:
261
+ def route_to_response_synthesizer (
262
+ state : AgentState , config
263
+ ) -> Literal ["response_synthesizer" , "response_synthesizer_cohere" ]:
254
264
model_name = config .get ("configurable" , {}).get ("model_name" , OPENAI_MODEL_KEY )
255
265
if model_name == COHERE_MODEL_KEY :
256
266
return "response_synthesizer_cohere"
@@ -271,10 +281,12 @@ def route_to_response_synthesizer(state: AgentState, config) -> Literal["respons
271
281
272
282
# connect retrievers and response synthesizers
273
283
workflow .add_conditional_edges ("retriever" , route_to_response_synthesizer )
274
- workflow .add_conditional_edges ("retriever_with_chat_history" , route_to_response_synthesizer )
284
+ workflow .add_conditional_edges (
285
+ "retriever_with_chat_history" , route_to_response_synthesizer
286
+ )
275
287
276
288
# connect synthesizers to terminal node
277
289
workflow .add_edge ("response_synthesizer" , END )
278
290
workflow .add_edge ("response_synthesizer_cohere" , END )
279
291
280
- graph = workflow .compile ()
292
+ graph = workflow .compile ()
0 commit comments