diff --git a/backend/retrieval_graph/graph.py b/backend/retrieval_graph/graph.py index 3ace8208..b65e4425 100644 --- a/backend/retrieval_graph/graph.py +++ b/backend/retrieval_graph/graph.py @@ -38,13 +38,16 @@ async def analyze_and_route_query( return {"router": state.router} configuration = AgentConfiguration.from_runnable_config(config) - model = load_chat_model(configuration.query_model) + structured_output_kwargs = ( + {"method": "function_calling"} if "openai" in configuration.query_model else {} + ) + model = load_chat_model(configuration.query_model).with_structured_output( + Router, **structured_output_kwargs + ) messages = [ {"role": "system", "content": configuration.router_system_prompt} ] + state.messages - response = cast( - Router, await model.with_structured_output(Router).ainvoke(messages) - ) + response = cast(Router, await model.ainvoke(messages)) return {"router": response} @@ -140,7 +143,12 @@ class Plan(TypedDict): steps: list[str] configuration = AgentConfiguration.from_runnable_config(config) - model = load_chat_model(configuration.query_model).with_structured_output(Plan) + structured_output_kwargs = ( + {"method": "function_calling"} if "openai" in configuration.query_model else {} + ) + model = load_chat_model(configuration.query_model).with_structured_output( + Plan, **structured_output_kwargs + ) messages = [ {"role": "system", "content": configuration.research_plan_system_prompt} ] + state.messages diff --git a/backend/retrieval_graph/researcher_graph/graph.py b/backend/retrieval_graph/researcher_graph/graph.py index 28ecba5f..799e5b1e 100644 --- a/backend/retrieval_graph/researcher_graph/graph.py +++ b/backend/retrieval_graph/researcher_graph/graph.py @@ -37,7 +37,12 @@ class Response(TypedDict): queries: list[str] configuration = AgentConfiguration.from_runnable_config(config) - model = load_chat_model(configuration.query_model).with_structured_output(Response) + structured_output_kwargs = ( + {"method": "function_calling"} if "openai" in configuration.query_model else {} + ) + model = load_chat_model(configuration.query_model).with_structured_output( + Response, **structured_output_kwargs + ) messages = [ {"role": "system", "content": configuration.generate_queries_system_prompt}, {"role": "human", "content": state.question},