|
18 | 18 | ChatPromptTemplate,
|
19 | 19 | PromptTemplate,
|
20 | 20 | )
|
21 |
| -from langchain_core.runnables import RunnableConfig |
22 | 21 | from langchain_core.retrievers import BaseRetriever
|
| 22 | +from langchain_core.runnables import ConfigurableField, RunnableConfig |
23 | 23 | from langchain_fireworks import ChatFireworks
|
24 | 24 | from langchain_google_genai import ChatGoogleGenerativeAI
|
25 | 25 | from langchain_openai import ChatOpenAI
|
|
28 | 28 | from backend.constants import WEAVIATE_DOCS_INDEX_NAME
|
29 | 29 | from backend.ingest import get_embeddings_model
|
30 | 30 |
|
31 |
| - |
32 | 31 | RESPONSE_TEMPLATE = """\
|
33 | 32 | You are an expert programmer and problem-solver, tasked with answering any question \
|
34 | 33 | about Langchain.
|
@@ -111,49 +110,53 @@ class AgentState(TypedDict):
|
111 | 110 | messages: Annotated[list[BaseMessage], add_messages]
|
112 | 111 |
|
113 | 112 |
|
114 |
| -def get_model(model_name: str) -> LanguageModelLike: |
115 |
| - gpt_3_5 = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0, streaming=True) |
116 |
| - claude_3_haiku = ChatAnthropic( |
117 |
| - model="claude-3-haiku-20240307", |
118 |
| - temperature=0, |
119 |
| - max_tokens=4096, |
120 |
| - anthropic_api_key=os.environ.get("ANTHROPIC_API_KEY", "not_provided"), |
121 |
| - ) |
122 |
| - fireworks_mixtral = ChatFireworks( |
123 |
| - model="accounts/fireworks/models/mixtral-8x7b-instruct", |
124 |
| - temperature=0, |
125 |
| - max_tokens=16384, |
126 |
| - fireworks_api_key=os.environ.get("FIREWORKS_API_KEY", "not_provided"), |
127 |
| - ) |
128 |
| - gemini_pro = ChatGoogleGenerativeAI( |
129 |
| - model="gemini-pro", |
130 |
| - temperature=0, |
131 |
| - max_output_tokens=16384, |
132 |
| - convert_system_message_to_human=True, |
133 |
| - google_api_key=os.environ.get("GOOGLE_API_KEY", "not_provided"), |
134 |
| - ) |
135 |
| - cohere_command = ChatCohere( |
136 |
| - model="command", |
137 |
| - temperature=0, |
138 |
| - cohere_api_key=os.environ.get("COHERE_API_KEY", "not_provided"), |
139 |
| - ) |
140 |
| - llm: LanguageModelLike = { |
141 |
| - OPENAI_MODEL_KEY: gpt_3_5, |
| 113 | +gpt_3_5 = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0, streaming=True) |
| 114 | +claude_3_haiku = ChatAnthropic( |
| 115 | + model="claude-3-haiku-20240307", |
| 116 | + temperature=0, |
| 117 | + max_tokens=4096, |
| 118 | + anthropic_api_key=os.environ.get("ANTHROPIC_API_KEY", "not_provided"), |
| 119 | +) |
| 120 | +fireworks_mixtral = ChatFireworks( |
| 121 | + model="accounts/fireworks/models/mixtral-8x7b-instruct", |
| 122 | + temperature=0, |
| 123 | + max_tokens=16384, |
| 124 | + fireworks_api_key=os.environ.get("FIREWORKS_API_KEY", "not_provided"), |
| 125 | +) |
| 126 | +gemini_pro = ChatGoogleGenerativeAI( |
| 127 | + model="gemini-pro", |
| 128 | + temperature=0, |
| 129 | + max_output_tokens=16384, |
| 130 | + convert_system_message_to_human=True, |
| 131 | + google_api_key=os.environ.get("GOOGLE_API_KEY", "not_provided"), |
| 132 | +) |
| 133 | +cohere_command = ChatCohere( |
| 134 | + model="command", |
| 135 | + temperature=0, |
| 136 | + cohere_api_key=os.environ.get("COHERE_API_KEY", "not_provided"), |
| 137 | +) |
| 138 | +llm = gpt_3_5.configurable_alternatives( |
| 139 | + # This gives this field an id |
| 140 | + # When configuring the end runnable, we can then use this id to configure this field |
| 141 | + ConfigurableField(id="model_name"), |
| 142 | + default_key=OPENAI_MODEL_KEY, |
| 143 | + **{ |
142 | 144 | ANTHROPIC_MODEL_KEY: claude_3_haiku,
|
143 | 145 | FIREWORKS_MIXTRAL_MODEL_KEY: fireworks_mixtral,
|
144 | 146 | GOOGLE_MODEL_KEY: gemini_pro,
|
145 | 147 | COHERE_MODEL_KEY: cohere_command,
|
146 |
| - }[model_name] |
147 |
| - llm = llm.with_fallbacks( |
148 |
| - [gpt_3_5, claude_3_haiku, fireworks_mixtral, gemini_pro, cohere_command] |
149 |
| - ) |
150 |
| - return llm |
| 148 | + }, |
| 149 | +).with_fallbacks( |
| 150 | + [gpt_3_5, claude_3_haiku, fireworks_mixtral, gemini_pro, cohere_command] |
| 151 | +) |
151 | 152 |
|
152 | 153 |
|
153 | 154 | def get_retriever() -> BaseRetriever:
|
154 | 155 | weaviate_client = weaviate.Client(
|
155 | 156 | url=os.environ["WEAVIATE_URL"],
|
156 |
| - auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY", "not_provided")), |
| 157 | + auth_client_secret=weaviate.AuthApiKey( |
| 158 | + api_key=os.environ.get("WEAVIATE_API_KEY", "not_provided") |
| 159 | + ), |
157 | 160 | )
|
158 | 161 | weaviate_client = Weaviate(
|
159 | 162 | client=weaviate_client,
|
@@ -182,10 +185,9 @@ def retrieve_documents(state: AgentState) -> AgentState:
|
182 | 185 | return {"query": query, "documents": relevant_documents, "messages": []}
|
183 | 186 |
|
184 | 187 |
|
185 |
| -def retrieve_documents_with_chat_history(state: AgentState, config: RunnableConfig) -> AgentState: |
186 |
| - model_name = config.get("configurable", {}).get("model_name", OPENAI_MODEL_KEY) |
| 188 | +def retrieve_documents_with_chat_history(state: AgentState) -> AgentState: |
187 | 189 | retriever = get_retriever()
|
188 |
| - model = get_model(model_name).with_config(tags=["nostream"]) |
| 190 | + model = llm.with_config(tags=["nostream"]) |
189 | 191 |
|
190 | 192 | CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(REPHRASE_TEMPLATE)
|
191 | 193 | condense_question_chain = (
|
@@ -245,14 +247,12 @@ def synthesize_response(
|
245 | 247 | }
|
246 | 248 |
|
247 | 249 |
|
248 |
| -def synthesize_response_default(state: AgentState, config: RunnableConfig) -> AgentState: |
249 |
| - model_name = config.get("configurable", {}).get("model_name", OPENAI_MODEL_KEY) |
250 |
| - model = get_model(model_name) |
251 |
| - return synthesize_response(state, model, RESPONSE_TEMPLATE) |
| 250 | +def synthesize_response_default(state: AgentState) -> AgentState: |
| 251 | + return synthesize_response(state, llm, RESPONSE_TEMPLATE) |
252 | 252 |
|
253 | 253 |
|
254 | 254 | def synthesize_response_cohere(state: AgentState) -> AgentState:
|
255 |
| - model = get_model(COHERE_MODEL_KEY).bind(documents=state["documents"]) |
| 255 | + model = llm.bind(documents=state["documents"]) |
256 | 256 | return synthesize_response(state, model, COHERE_RESPONSE_TEMPLATE)
|
257 | 257 |
|
258 | 258 |
|
|
0 commit comments