Skip to content

Commit 4b403f4

Browse files
committed
lint
1 parent 60a3315 commit 4b403f4

File tree

2 files changed

+43
-32
lines changed

2 files changed

+43
-32
lines changed

backend/graph.py

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,31 @@
22
from typing import Annotated, Literal, Sequence, TypedDict
33

44
import weaviate
5+
from langchain_anthropic import ChatAnthropic
6+
from langchain_cohere import ChatCohere
7+
from langchain_community.vectorstores import Weaviate
58
from langchain_core.documents import Document
69
from langchain_core.language_models import LanguageModelLike
7-
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, convert_to_messages
10+
from langchain_core.messages import (
11+
AIMessage,
12+
BaseMessage,
13+
HumanMessage,
14+
convert_to_messages,
15+
)
816
from langchain_core.output_parsers import StrOutputParser
917
from langchain_core.prompts import (
1018
ChatPromptTemplate,
1119
MessagesPlaceholder,
1220
PromptTemplate,
1321
)
1422
from langchain_core.retrievers import BaseRetriever
15-
from langchain_openai import ChatOpenAI
16-
from langchain_cohere import ChatCohere
17-
from langchain_anthropic import ChatAnthropic
1823
from langchain_fireworks import ChatFireworks
1924
from langchain_google_genai import ChatGoogleGenerativeAI
20-
from langchain_community.vectorstores import Weaviate
21-
from langgraph.graph import StateGraph, END, add_messages
25+
from langchain_openai import ChatOpenAI
26+
from langgraph.graph import END, StateGraph, add_messages
2227

23-
from backend.ingest import get_embeddings_model
2428
from backend.constants import WEAVIATE_DOCS_INDEX_NAME
25-
29+
from backend.ingest import get_embeddings_model
2630

2731
WEAVIATE_URL = os.environ["WEAVIATE_URL"]
2832
WEAVIATE_API_KEY = os.environ["WEAVIATE_API_KEY"]
@@ -143,7 +147,9 @@ def get_model(model_name: str) -> LanguageModelLike:
143147
GOOGLE_MODEL_KEY: gemini_pro,
144148
COHERE_MODEL_KEY: cohere_command,
145149
}[model_name]
146-
llm = llm.with_fallbacks([gpt_3_5, claude_3_haiku, fireworks_mixtral, gemini_pro, cohere_command])
150+
llm = llm.with_fallbacks(
151+
[gpt_3_5, claude_3_haiku, fireworks_mixtral, gemini_pro, cohere_command]
152+
)
147153
return llm
148154

149155

@@ -176,11 +182,7 @@ def retrieve_documents(state: AgentState):
176182
messages = convert_to_messages(state["messages"])
177183
query = messages[-1].content
178184
relevant_documents = retriever.get_relevant_documents(query)
179-
return {
180-
"query": query,
181-
"documents": relevant_documents,
182-
"messages": []
183-
}
185+
return {"query": query, "documents": relevant_documents, "messages": []}
184186

185187

186188
def retrieve_documents_with_chat_history(state: AgentState, config):
@@ -189,22 +191,24 @@ def retrieve_documents_with_chat_history(state: AgentState, config):
189191
model = get_model(model_name).with_config(tags=["nostream"])
190192

191193
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(REPHRASE_TEMPLATE)
192-
condense_question_chain = (CONDENSE_QUESTION_PROMPT | model | StrOutputParser()).with_config(
194+
condense_question_chain = (
195+
CONDENSE_QUESTION_PROMPT | model | StrOutputParser()
196+
).with_config(
193197
run_name="CondenseQuestion",
194198
)
195199

196200
messages = convert_to_messages(state["messages"])
197201
query = messages[-1].content
198202
retriever_with_condensed_question = condense_question_chain | retriever
199-
relevant_documents = retriever_with_condensed_question.invoke({"question": query, "chat_history": get_chat_history(messages)})
200-
return {
201-
"query": query,
202-
"documents": relevant_documents,
203-
"messages": []
204-
}
203+
relevant_documents = retriever_with_condensed_question.invoke(
204+
{"question": query, "chat_history": get_chat_history(messages)}
205+
)
206+
return {"query": query, "documents": relevant_documents, "messages": []}
205207

206208

207-
def route_to_retriever(state: AgentState) -> Literal["retriever", "retriever_with_chat_history"]:
209+
def route_to_retriever(
210+
state: AgentState,
211+
) -> Literal["retriever", "retriever_with_chat_history"]:
208212
if len(state["messages"]) == 1:
209213
return "retriever"
210214
else:
@@ -219,7 +223,9 @@ def get_chat_history(messages: Sequence[BaseMessage]) -> Sequence[BaseMessage]:
219223
return chat_history
220224

221225

222-
def synthesize_response(state: AgentState, model: LanguageModelLike, prompt_template: str):
226+
def synthesize_response(
227+
state: AgentState, model: LanguageModelLike, prompt_template: str
228+
):
223229
prompt = ChatPromptTemplate.from_messages(
224230
[
225231
("system", prompt_template),
@@ -228,11 +234,13 @@ def synthesize_response(state: AgentState, model: LanguageModelLike, prompt_temp
228234
]
229235
)
230236
response_synthesizer = prompt | model
231-
synthesized_response = response_synthesizer.invoke({
232-
"question": state["query"],
233-
"context": format_docs(state["documents"]),
234-
"chat_history": get_chat_history(convert_to_messages(state["messages"]))
235-
})
237+
synthesized_response = response_synthesizer.invoke(
238+
{
239+
"question": state["query"],
240+
"context": format_docs(state["documents"]),
241+
"chat_history": get_chat_history(convert_to_messages(state["messages"])),
242+
}
243+
)
236244
return {
237245
**state,
238246
"messages": [synthesized_response],
@@ -250,7 +258,9 @@ def synthesize_response_cohere(state: AgentState):
250258
return synthesize_response(state, model, COHERE_RESPONSE_TEMPLATE)
251259

252260

253-
def route_to_response_synthesizer(state: AgentState, config) -> Literal["response_synthesizer", "response_synthesizer_cohere"]:
261+
def route_to_response_synthesizer(
262+
state: AgentState, config
263+
) -> Literal["response_synthesizer", "response_synthesizer_cohere"]:
254264
model_name = config.get("configurable", {}).get("model_name", OPENAI_MODEL_KEY)
255265
if model_name == COHERE_MODEL_KEY:
256266
return "response_synthesizer_cohere"
@@ -271,10 +281,12 @@ def route_to_response_synthesizer(state: AgentState, config) -> Literal["respons
271281

272282
# connect retrievers and response synthesizers
273283
workflow.add_conditional_edges("retriever", route_to_response_synthesizer)
274-
workflow.add_conditional_edges("retriever_with_chat_history", route_to_response_synthesizer)
284+
workflow.add_conditional_edges(
285+
"retriever_with_chat_history", route_to_response_synthesizer
286+
)
275287

276288
# connect synthesizers to terminal node
277289
workflow.add_edge("response_synthesizer", END)
278290
workflow.add_edge("response_synthesizer_cohere", END)
279291

280-
graph = workflow.compile()
292+
graph = workflow.compile()

backend/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import langsmith
88
from fastapi import FastAPI
99
from fastapi.middleware.cors import CORSMiddleware
10-
from langserve import add_routes
1110
from langsmith import Client
1211
from pydantic import BaseModel
1312

0 commit comments

Comments
 (0)