Skip to content

Commit 84d7aee

Browse files
committed
code review
1 parent 4025aae commit 84d7aee

File tree

1 file changed

+44
-44
lines changed

1 file changed

+44
-44
lines changed

backend/graph.py

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
ChatPromptTemplate,
1919
PromptTemplate,
2020
)
21-
from langchain_core.runnables import RunnableConfig
2221
from langchain_core.retrievers import BaseRetriever
22+
from langchain_core.runnables import ConfigurableField, RunnableConfig
2323
from langchain_fireworks import ChatFireworks
2424
from langchain_google_genai import ChatGoogleGenerativeAI
2525
from langchain_openai import ChatOpenAI
@@ -28,7 +28,6 @@
2828
from backend.constants import WEAVIATE_DOCS_INDEX_NAME
2929
from backend.ingest import get_embeddings_model
3030

31-
3231
RESPONSE_TEMPLATE = """\
3332
You are an expert programmer and problem-solver, tasked with answering any question \
3433
about Langchain.
@@ -111,49 +110,53 @@ class AgentState(TypedDict):
111110
messages: Annotated[list[BaseMessage], add_messages]
112111

113112

114-
def get_model(model_name: str) -> LanguageModelLike:
115-
gpt_3_5 = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0, streaming=True)
116-
claude_3_haiku = ChatAnthropic(
117-
model="claude-3-haiku-20240307",
118-
temperature=0,
119-
max_tokens=4096,
120-
anthropic_api_key=os.environ.get("ANTHROPIC_API_KEY", "not_provided"),
121-
)
122-
fireworks_mixtral = ChatFireworks(
123-
model="accounts/fireworks/models/mixtral-8x7b-instruct",
124-
temperature=0,
125-
max_tokens=16384,
126-
fireworks_api_key=os.environ.get("FIREWORKS_API_KEY", "not_provided"),
127-
)
128-
gemini_pro = ChatGoogleGenerativeAI(
129-
model="gemini-pro",
130-
temperature=0,
131-
max_output_tokens=16384,
132-
convert_system_message_to_human=True,
133-
google_api_key=os.environ.get("GOOGLE_API_KEY", "not_provided"),
134-
)
135-
cohere_command = ChatCohere(
136-
model="command",
137-
temperature=0,
138-
cohere_api_key=os.environ.get("COHERE_API_KEY", "not_provided"),
139-
)
140-
llm: LanguageModelLike = {
141-
OPENAI_MODEL_KEY: gpt_3_5,
113+
gpt_3_5 = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0, streaming=True)
114+
claude_3_haiku = ChatAnthropic(
115+
model="claude-3-haiku-20240307",
116+
temperature=0,
117+
max_tokens=4096,
118+
anthropic_api_key=os.environ.get("ANTHROPIC_API_KEY", "not_provided"),
119+
)
120+
fireworks_mixtral = ChatFireworks(
121+
model="accounts/fireworks/models/mixtral-8x7b-instruct",
122+
temperature=0,
123+
max_tokens=16384,
124+
fireworks_api_key=os.environ.get("FIREWORKS_API_KEY", "not_provided"),
125+
)
126+
gemini_pro = ChatGoogleGenerativeAI(
127+
model="gemini-pro",
128+
temperature=0,
129+
max_output_tokens=16384,
130+
convert_system_message_to_human=True,
131+
google_api_key=os.environ.get("GOOGLE_API_KEY", "not_provided"),
132+
)
133+
cohere_command = ChatCohere(
134+
model="command",
135+
temperature=0,
136+
cohere_api_key=os.environ.get("COHERE_API_KEY", "not_provided"),
137+
)
138+
llm = gpt_3_5.configurable_alternatives(
139+
# This gives this field an id
140+
# When configuring the end runnable, we can then use this id to configure this field
141+
ConfigurableField(id="model_name"),
142+
default_key=OPENAI_MODEL_KEY,
143+
**{
142144
ANTHROPIC_MODEL_KEY: claude_3_haiku,
143145
FIREWORKS_MIXTRAL_MODEL_KEY: fireworks_mixtral,
144146
GOOGLE_MODEL_KEY: gemini_pro,
145147
COHERE_MODEL_KEY: cohere_command,
146-
}[model_name]
147-
llm = llm.with_fallbacks(
148-
[gpt_3_5, claude_3_haiku, fireworks_mixtral, gemini_pro, cohere_command]
149-
)
150-
return llm
148+
},
149+
).with_fallbacks(
150+
[gpt_3_5, claude_3_haiku, fireworks_mixtral, gemini_pro, cohere_command]
151+
)
151152

152153

153154
def get_retriever() -> BaseRetriever:
154155
weaviate_client = weaviate.Client(
155156
url=os.environ["WEAVIATE_URL"],
156-
auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY", "not_provided")),
157+
auth_client_secret=weaviate.AuthApiKey(
158+
api_key=os.environ.get("WEAVIATE_API_KEY", "not_provided")
159+
),
157160
)
158161
weaviate_client = Weaviate(
159162
client=weaviate_client,
@@ -182,10 +185,9 @@ def retrieve_documents(state: AgentState) -> AgentState:
182185
return {"query": query, "documents": relevant_documents, "messages": []}
183186

184187

185-
def retrieve_documents_with_chat_history(state: AgentState, config: RunnableConfig) -> AgentState:
186-
model_name = config.get("configurable", {}).get("model_name", OPENAI_MODEL_KEY)
188+
def retrieve_documents_with_chat_history(state: AgentState) -> AgentState:
187189
retriever = get_retriever()
188-
model = get_model(model_name).with_config(tags=["nostream"])
190+
model = llm.with_config(tags=["nostream"])
189191

190192
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(REPHRASE_TEMPLATE)
191193
condense_question_chain = (
@@ -245,14 +247,12 @@ def synthesize_response(
245247
}
246248

247249

248-
def synthesize_response_default(state: AgentState, config: RunnableConfig) -> AgentState:
249-
model_name = config.get("configurable", {}).get("model_name", OPENAI_MODEL_KEY)
250-
model = get_model(model_name)
251-
return synthesize_response(state, model, RESPONSE_TEMPLATE)
250+
def synthesize_response_default(state: AgentState) -> AgentState:
251+
return synthesize_response(state, llm, RESPONSE_TEMPLATE)
252252

253253

254254
def synthesize_response_cohere(state: AgentState) -> AgentState:
255-
model = get_model(COHERE_MODEL_KEY).bind(documents=state["documents"])
255+
model = llm.bind(documents=state["documents"])
256256
return synthesize_response(state, model, COHERE_RESPONSE_TEMPLATE)
257257

258258

0 commit comments

Comments
 (0)