From e536db6bc8efd188172c7614878f8bd5669c8f02 Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Fri, 23 May 2025 20:38:27 +0200 Subject: [PATCH 1/6] feat(prebuilt): add `preModelHook` and `postModelHook` --- .../src/prebuilt/react_agent_executor.ts | 224 ++++++++++++++---- libs/langgraph/src/tests/prebuilt.test.ts | 218 +++++++++++++++++ 2 files changed, 397 insertions(+), 45 deletions(-) diff --git a/libs/langgraph/src/prebuilt/react_agent_executor.ts b/libs/langgraph/src/prebuilt/react_agent_executor.ts index b6315c536..45805ede4 100644 --- a/libs/langgraph/src/prebuilt/react_agent_executor.ts +++ b/libs/langgraph/src/prebuilt/react_agent_executor.ts @@ -19,6 +19,7 @@ import { RunnableToolLike, RunnableSequence, RunnableBinding, + type RunnableLike, } from "@langchain/core/runnables"; import { DynamicTool, StructuredToolInterface } from "@langchain/core/tools"; import { @@ -30,7 +31,7 @@ import { z } from "zod"; import { StateGraph, - CompiledStateGraph, + type CompiledStateGraph, AnnotationRoot, } from "../graph/index.js"; import { MessagesAnnotation } from "../graph/messages_annotation.js"; @@ -327,6 +328,27 @@ export const createReactAgentAnnotation = < structuredResponse: Annotation, }); +type WithStateGraphNodes = Graph extends StateGraph< + infer SD, + infer S, + infer U, + infer N, + infer I, + infer O, + infer C +> + ? StateGraph + : never; + +const PreHookAnnotation = Annotation.Root({ + llmInputMessages: Annotation({ + reducer: messagesStateReducer, + default: () => [], + }), +}); + +type PreHookAnnotation = typeof PreHookAnnotation; + export type CreateReactAgentParams< // eslint-disable-next-line @typescript-eslint/no-explicit-any A extends AnnotationRoot = AnnotationRoot, @@ -407,6 +429,26 @@ export type CreateReactAgentParams< Example: `"How can I help you"` -> `"agent_nameHow can I help you?"` */ includeAgentName?: "inline" | undefined; + + /** + * An optional node to add before the `agent` node (i.e., the node that calls the LLM). + * Useful for managing long message histories (e.g., message trimming, summarization, etc.). + */ + preModelHook?: RunnableLike< + A["State"] & PreHookAnnotation["State"], + A["Update"] & PreHookAnnotation["Update"], + LangGraphRunnableConfig + >; + + /** + * An optional node to add after the `agent` node (i.e., the node that calls the LLM). + * Useful for implementing human-in-the-loop, guardrails, validation, or other post-processing. + */ + postModelHook?: RunnableLike< + A["State"], + A["Update"], + LangGraphRunnableConfig + >; }; /** @@ -483,6 +525,8 @@ export function createReactAgent< interruptAfter, store, responseFormat, + preModelHook, + postModelHook, name, includeAgentName, } = params; @@ -537,18 +581,15 @@ export function createReactAgent< .map((tool) => tool.name) ); - const shouldContinue = (state: AgentState) => { - const { messages } = state; - const lastMessage = messages[messages.length - 1]; - if ( - isAIMessage(lastMessage) && - (!lastMessage.tool_calls || lastMessage.tool_calls.length === 0) - ) { - return responseFormat != null ? "generate_structured_response" : END; - } else { - return "continue"; + function getModelInputState( + state: AgentState & PreHookAnnotation["State"] + ): Omit, "llmInputMessages"> { + const { messages, llmInputMessages, ...rest } = state; + if (llmInputMessages != null && llmInputMessages.length > 0) { + return { messages: llmInputMessages, ...rest }; } - }; + return { messages, ...rest }; + } const generateStructuredResponse = async ( state: AgentState, @@ -583,14 +624,17 @@ export function createReactAgent< }; const callModel = async ( - state: AgentState, + state: AgentState & PreHookAnnotation["State"], config?: RunnableConfig ) => { // NOTE: we're dynamically creating the model runnable here // to ensure that we can validate ConfigurableModel properly const modelRunnable = await getModelRunnable(llm); // TODO: Auto-promote streaming. - const response = (await modelRunnable.invoke(state, config)) as BaseMessage; + const response = (await modelRunnable.invoke( + getModelInputState(state), + config + )) as BaseMessage; // add agent name to the AIMessage // TODO: figure out if we can avoid mutating the message directly response.name = name; @@ -598,51 +642,141 @@ export function createReactAgent< return { messages: [response] }; }; - const workflow = new StateGraph( - stateSchema ?? createReactAgentAnnotation() - ) - .addNode("agent", callModel) - .addNode("tools", toolNode) - .addEdge(START, "agent"); + const schema = + stateSchema ?? createReactAgentAnnotation(); + + const workflow = new StateGraph(schema).addNode("tools", toolNode); + + const allNodeWorkflows = workflow as WithStateGraphNodes< + | "pre_model_hook" + | "post_model_hook" + | "generate_structured_response" + | "agent", + typeof workflow + >; + + let entrypoint: "agent" | "pre_model_hook" = "agent"; + let inputSchema: AnnotationRoot<(typeof schema)["spec"]> | undefined; + if (preModelHook != null) { + allNodeWorkflows + .addNode("pre_model_hook", preModelHook) + .addEdge("pre_model_hook", "agent"); + entrypoint = "pre_model_hook"; + + inputSchema = new AnnotationRoot({ + ...schema.spec, + ...PreHookAnnotation.spec, + }); + } else { + entrypoint = "agent"; + } + + allNodeWorkflows + .addNode("agent", callModel, { input: inputSchema }) + .addEdge(START, entrypoint); + + if (postModelHook != null) { + allNodeWorkflows + .addNode("post_model_hook", postModelHook) + .addEdge("agent", "post_model_hook") + .addConditionalEdges( + "post_model_hook", + (state) => { + const { messages } = state; + const lastMessage = messages[messages.length - 1]; + + if (isAIMessage(lastMessage) && lastMessage.tool_calls?.length) { + return "tools" as const; + } + + if (isToolMessage(lastMessage)) { + return "entrypoint" as const; + } + + if (responseFormat != null) { + return "generate_structured_response" as const; + } + + return END; + }, + { + tools: "tools", + entrypoint, + generate_structured_response: + responseFormat != null ? "generate_structured_response" : END, + [END]: END, + } + ); + } if (responseFormat !== undefined) { workflow .addNode("generate_structured_response", generateStructuredResponse) - .addEdge("generate_structured_response", END) - .addConditionalEdges("agent", shouldContinue, { - continue: "tools", - [END]: END, - generate_structured_response: "generate_structured_response", - }); - } else { - workflow.addConditionalEdges("agent", shouldContinue, { - continue: "tools", - [END]: END, - }); + .addEdge("generate_structured_response", END); } - const routeToolResponses = (state: AgentState) => { - // Check the last consecutive tool calls - for (let i = state.messages.length - 1; i >= 0; i -= 1) { - const message = state.messages[i]; - if (!isToolMessage(message)) { - break; - } - // Check if this tool is configured to return directly - if (message.name !== undefined && shouldReturnDirect.has(message.name)) { + allNodeWorkflows.addConditionalEdges( + "agent", + (state) => { + const { messages } = state; + const lastMessage = messages[messages.length - 1]; + + // if there's no function call, we finish + if (!isAIMessage(lastMessage) || !lastMessage.tool_calls?.length) { + if (postModelHook != null) { + return "post_model_hook" as const; + } + + if (responseFormat != null) { + return "generate_structured_response" as const; + } + return END; } + + // there are function calls, we continue + if (postModelHook != null) { + return "post_model_hook" as const; + } + + return "tools" as const; + }, + { + tools: "tools", + post_model_hook: postModelHook != null ? "post_model_hook" : END, + generate_structured_response: + responseFormat != null ? "generate_structured_response" : END, + [END]: END, } - return "agent"; - }; + ); if (shouldReturnDirect.size > 0) { - workflow.addConditionalEdges("tools", routeToolResponses, ["agent", END]); + allNodeWorkflows.addConditionalEdges( + "tools", + (state) => { + // Check the last consecutive tool calls + for (let i = state.messages.length - 1; i >= 0; i -= 1) { + const message = state.messages[i]; + if (!isToolMessage(message)) break; + + // Check if this tool is configured to return directly + if ( + message.name !== undefined && + shouldReturnDirect.has(message.name) + ) { + return END; + } + } + + return "agent" as const; + }, + ["agent", END] + ); } else { - workflow.addEdge("tools", "agent"); + allNodeWorkflows.addEdge("tools", "agent"); } - return workflow.compile({ + return allNodeWorkflows.compile({ checkpointer: checkpointer ?? checkpointSaver, interruptBefore, interruptAfter, diff --git a/libs/langgraph/src/tests/prebuilt.test.ts b/libs/langgraph/src/tests/prebuilt.test.ts index 8e4428e97..49a50a2bb 100644 --- a/libs/langgraph/src/tests/prebuilt.test.ts +++ b/libs/langgraph/src/tests/prebuilt.test.ts @@ -37,6 +37,7 @@ import { interrupt, MemorySaver, messagesStateReducer, + REMOVE_ALL_MESSAGES, Send, StateGraph, } from "../index.js"; @@ -44,6 +45,7 @@ import { MessagesAnnotation, MessagesZodState, } from "../graph/messages_annotation.js"; +import { gatherIterator } from "../utils.js"; // Tracing slows down the tests beforeAll(() => { @@ -1049,6 +1051,222 @@ describe("createReactAgent with ToolNode", () => { }); }); +describe("createReactAgent with hooks", () => { + it("preModelHook", async () => { + const llm = new FakeToolCallingChatModel({ + responses: [new AIMessage({ id: "0", content: "Hello!" })], + }); + const llmSpy = vi.spyOn(llm, "_generate"); + + // Test `llm_input_messages` + let agent = createReactAgent({ + llm, + tools: [], + preModelHook: () => ({ + llmInputMessages: [ + new HumanMessage({ id: "human", content: "pre-hook" }), + ], + }), + }); + + expect("pre_model_hook" in agent.nodes).toBe(true); + expect(await agent.invoke({ messages: [new HumanMessage("hi?")] })).toEqual( + { + messages: [ + new _AnyIdHumanMessage("hi?"), + new AIMessage({ id: "0", content: "Hello!" }), + ], + } + ); + + expect(llmSpy).toHaveBeenCalledWith( + [new HumanMessage({ id: "human", content: "pre-hook" })], + expect.anything(), + undefined + ); + + // Test `messages` + agent = createReactAgent({ + llm, + tools: [], + preModelHook: () => ({ + messages: [ + new RemoveMessage({ id: REMOVE_ALL_MESSAGES }), + new HumanMessage("Hello!"), + ], + }), + }); + + expect("pre_model_hook" in agent.nodes).toBe(true); + expect(await agent.invoke({ messages: [new HumanMessage("hi?")] })).toEqual( + { + messages: [ + new _AnyIdHumanMessage("Hello!"), + new AIMessage({ id: "0", content: "Hello!" }), + ], + } + ); + }); + + it("postModelHook", async () => { + const FlagAnnotation = Annotation.Root({ + ...MessagesAnnotation.spec, + flag: Annotation, + }); + + const llm = new FakeToolCallingChatModel({ + responses: [new AIMessage({ id: "1", content: "hi?" })], + }); + + const agent = createReactAgent({ + llm, + tools: [], + postModelHook: () => ({ flag: true }), + stateSchema: FlagAnnotation, + }); + + expect("post_model_hook" in agent.nodes).toBe(true); + expect( + await agent.invoke({ + messages: [new HumanMessage("hi?")], + flag: false, + }) + ).toMatchObject({ flag: true }); + + expect( + await gatherIterator( + agent.stream({ + messages: [new HumanMessage("hi?")], + flag: false, + }) + ) + ).toMatchObject([ + { + agent: { + messages: [new AIMessage({ id: "1", content: "hi?" })], + }, + }, + { post_model_hook: { flag: true } }, + ]); + }); + + it("postModelHook + structured response", async () => { + const weatherResponseSchema = z.object({ + temperature: z.number().describe("The temperature in fahrenheit"), + }); + + const FlagAnnotation = Annotation.Root({ + ...MessagesAnnotation.spec, + flag: Annotation, + structuredResponse: Annotation>, + }); + + const llm = new FakeToolCallingChatModel({ + responses: [ + new AIMessage({ + id: "1", + content: "What's the weather?", + tool_calls: [ + { + name: "get_weather", + args: {}, + id: "1", + type: "tool_call", + }, + ], + }), + new AIMessage({ id: "3", content: "The weather is nice" }), + ], + structuredResponse: { temperature: 75 }, + }); + + const getWeather = tool(async () => "The weather is sunny and 75°F.", { + name: "get_weather", + description: "Get the weather", + schema: z.object({}), + }); + + const agent = createReactAgent({ + llm, + tools: [getWeather], + responseFormat: weatherResponseSchema, + postModelHook: () => ({ flag: true }), + stateSchema: FlagAnnotation, + }); + + expect("post_model_hook" in agent.nodes).toBe(true); + expect("generate_structured_response" in agent.nodes).toBe(true); + + const response = await agent.invoke({ + messages: [new HumanMessage({ id: "0", content: "What's the weather?" })], + flag: false, + }); + + expect(response).toMatchObject({ + flag: true, + structuredResponse: { temperature: 75 }, + }); + + expect( + await gatherIterator( + agent.stream({ + messages: [ + new HumanMessage({ id: "0", content: "What's the weather?" }), + ], + flag: false, + }) + ) + ).toEqual([ + { + agent: { + messages: [ + new AIMessage({ + content: "What's the weather?", + id: "1", + tool_calls: [ + { + name: "get_weather", + args: {}, + id: "1", + type: "tool_call", + }, + ], + }), + ], + }, + }, + { post_model_hook: { flag: true } }, + { + tools: { + messages: [ + new _AnyIdToolMessage({ + content: "The weather is sunny and 75°F.", + name: "get_weather", + tool_call_id: "1", + }), + ], + }, + }, + { + agent: { + messages: [ + new AIMessage({ + content: "The weather is nice", + id: "3", + }), + ], + }, + }, + { post_model_hook: { flag: true } }, + { + generate_structured_response: { + structuredResponse: { temperature: 75 }, + }, + }, + ]); + }); +}); + describe("ToolNode", () => { it("Should support graceful error handling", async () => { const toolNode = new ToolNode([new SearchAPI()]); From 38a97bfb0e69b422a7076e0cc210f142427b950b Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Fri, 23 May 2025 23:16:29 +0200 Subject: [PATCH 2/6] Fix tests --- .../src/prebuilt/react_agent_executor.ts | 24 ++++++++++++------- libs/langgraph/src/tests/diagrams.test.ts | 2 +- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/libs/langgraph/src/prebuilt/react_agent_executor.ts b/libs/langgraph/src/prebuilt/react_agent_executor.ts index 45805ede4..99edcbfd5 100644 --- a/libs/langgraph/src/prebuilt/react_agent_executor.ts +++ b/libs/langgraph/src/prebuilt/react_agent_executor.ts @@ -655,6 +655,12 @@ export function createReactAgent< typeof workflow >; + const conditionalMap = (map: Record) => { + return Object.fromEntries( + Object.entries(map).filter(([_, v]) => v != null) as [string, T][] + ); + }; + let entrypoint: "agent" | "pre_model_hook" = "agent"; let inputSchema: AnnotationRoot<(typeof schema)["spec"]> | undefined; if (preModelHook != null) { @@ -699,13 +705,13 @@ export function createReactAgent< return END; }, - { + conditionalMap({ tools: "tools", entrypoint, generate_structured_response: - responseFormat != null ? "generate_structured_response" : END, + responseFormat != null ? "generate_structured_response" : null, [END]: END, - } + }) ); } @@ -739,15 +745,15 @@ export function createReactAgent< return "post_model_hook" as const; } - return "tools" as const; + return "continue" as const; }, - { - tools: "tools", - post_model_hook: postModelHook != null ? "post_model_hook" : END, + conditionalMap({ + continue: "tools", + post_model_hook: postModelHook != null ? "post_model_hook" : null, generate_structured_response: - responseFormat != null ? "generate_structured_response" : END, + responseFormat != null ? "generate_structured_response" : null, [END]: END, - } + }) ); if (shouldReturnDirect.size > 0) { diff --git a/libs/langgraph/src/tests/diagrams.test.ts b/libs/langgraph/src/tests/diagrams.test.ts index a713b1f48..fd2acbbfd 100644 --- a/libs/langgraph/src/tests/diagrams.test.ts +++ b/libs/langgraph/src/tests/diagrams.test.ts @@ -16,8 +16,8 @@ test("prebuilt agent", async () => { expect(mermaid).toEqual(`%%{init: {'flowchart': {'curve': 'linear'}}}%% graph TD; \t__start__([

__start__

]):::first -\tagent(agent) \ttools(tools) +\tagent(agent) \t__end__([

__end__

]):::last \t__start__ --> agent; \ttools --> agent; From f3631ee357b42cb8692757f3730c9399964df7a8 Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Fri, 23 May 2025 23:22:47 +0200 Subject: [PATCH 3/6] Don't create AnnotationRoot directly --- libs/langgraph/src/prebuilt/react_agent_executor.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langgraph/src/prebuilt/react_agent_executor.ts b/libs/langgraph/src/prebuilt/react_agent_executor.ts index 99edcbfd5..7122eb61c 100644 --- a/libs/langgraph/src/prebuilt/react_agent_executor.ts +++ b/libs/langgraph/src/prebuilt/react_agent_executor.ts @@ -669,7 +669,7 @@ export function createReactAgent< .addEdge("pre_model_hook", "agent"); entrypoint = "pre_model_hook"; - inputSchema = new AnnotationRoot({ + inputSchema = Annotation.Root({ ...schema.spec, ...PreHookAnnotation.spec, }); From fa6ad84ff6f9756bdbd1f1801c4fd46be5af995a Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Fri, 30 May 2025 20:10:26 +0200 Subject: [PATCH 4/6] Add mermaid structure tests for each possible combination --- libs/langgraph/src/tests/prebuilt.test.ts | 250 ++++++++++++++++++++++ libs/langgraph/src/tests/utils.ts | 15 ++ 2 files changed, 265 insertions(+) diff --git a/libs/langgraph/src/tests/prebuilt.test.ts b/libs/langgraph/src/tests/prebuilt.test.ts index 49a50a2bb..bbdf881d8 100644 --- a/libs/langgraph/src/tests/prebuilt.test.ts +++ b/libs/langgraph/src/tests/prebuilt.test.ts @@ -22,6 +22,7 @@ import { _AnyIdToolMessage, FakeConfigurableModel, FakeToolCallingChatModel, + getReadableMermaid, MemorySaverAssertImmutable, } from "./utils.js"; import { ToolNode, createReactAgent } from "../prebuilt/index.js"; @@ -1265,6 +1266,255 @@ describe("createReactAgent with hooks", () => { }, ]); }); + + it.each([ + [ + { + name: "no tools", + graph: createReactAgent({ + llm: new FakeToolCallingChatModel({}), + tools: [], + }), + structure: [ + "__start__ --> agent", + "agent -.-> __end__", + "agent -.[continue].-> tools", + "tools --> agent", + ], + }, + ], + [ + { + name: "tools", + graph: createReactAgent({ + llm: new FakeToolCallingChatModel({}), + tools: [ + tool(() => "The weather is sunny and 75°F.", { + name: "get_weather", + description: "Get the weather", + schema: z.object({}), + }), + ], + }), + structure: [ + "__start__ --> agent", + "agent -.-> __end__", + "agent -.[continue].-> tools", + "tools --> agent", + ], + }, + ], + + [ + { + name: "pre model hook + tools", + graph: createReactAgent({ + llm: new FakeToolCallingChatModel({}), + tools: [ + tool(() => "The weather is sunny and 75°F.", { + name: "get_weather", + description: "Get the weather", + schema: z.object({}), + }), + ], + preModelHook: () => ({ messages: [] }), + }), + structure: [ + "__start__ --> pre_model_hook", + "agent -.-> __end__", + "agent -.[continue].-> tools", + "pre_model_hook --> agent", + "tools --> pre_model_hook", + ], + }, + ], + + [ + { + name: "tools + post model hook", + graph: createReactAgent({ + llm: new FakeToolCallingChatModel({}), + tools: [ + tool(() => "The weather is sunny and 75°F.", { + name: "get_weather", + description: "Get the weather", + schema: z.object({}), + }), + ], + postModelHook: () => ({ flag: true }), + stateSchema: Annotation.Root({ + ...MessagesAnnotation.spec, + flag: Annotation, + }), + }), + structure: [ + "__start__ --> agent", + "agent --> post_model_hook", + "tools --> agent", + "post_model_hook -.-> tools", + "post_model_hook -.[entrypoint].-> agent", + "post_model_hook -.-> __end__", + ], + }, + ], + + [ + { + name: "tools + response format", + graph: createReactAgent({ + llm: new FakeToolCallingChatModel({}), + tools: [ + tool(() => "The weather is sunny and 75°F.", { + name: "get_weather", + description: "Get the weather", + schema: z.object({}), + }), + ], + responseFormat: z.object({ + temperature: z.number().describe("The temperature in fahrenheit"), + }), + }), + structure: [ + "__start__ --> agent", + "generate_structured_response --> __end__", + "tools --> agent", + "agent -.[continue].-> tools", + "agent -.-> generate_structured_response", + ], + }, + ], + + [ + { + name: "pre model hook + tools + response format", + graph: createReactAgent({ + llm: new FakeToolCallingChatModel({}), + tools: [ + tool(() => "The weather is sunny and 75°F.", { + name: "get_weather", + description: "Get the weather", + schema: z.object({}), + }), + ], + preModelHook: () => ({ messages: [] }), + responseFormat: z.object({ + temperature: z.number().describe("The temperature in fahrenheit"), + }), + }), + structure: [ + "__start__ --> pre_model_hook", + "pre_model_hook --> agent", + "generate_structured_response --> __end__", + "tools --> pre_model_hook", + "post_model_hook -.-> tools", + "post_model_hook -.[entrypoint].-> pre_model_hook", + "post_model_hook -.-> generate_structured_response", + ], + }, + ], + + [ + { + name: "tools + post model hook + response format", + graph: createReactAgent({ + llm: new FakeToolCallingChatModel({}), + tools: [ + tool(() => "The weather is sunny and 75°F.", { + name: "get_weather", + description: "Get the weather", + schema: z.object({}), + }), + ], + responseFormat: z.object({ + temperature: z.number().describe("The temperature in fahrenheit"), + }), + postModelHook: () => ({ flag: true }), + stateSchema: Annotation.Root({ + ...MessagesAnnotation.spec, + flag: Annotation, + }), + }), + structure: [ + "__start__ --> agent", + "agent --> post_model_hook", + "generate_structured_response --> __end__", + "tools --> agent", + "post_model_hook -.-> tools", + "post_model_hook -.[entrypoint].-> agent", + "post_model_hook -.-> generate_structured_response", + ], + }, + ], + + [ + { + name: "pre model hook + tools + post model hook", + graph: createReactAgent({ + llm: new FakeToolCallingChatModel({}), + tools: [ + tool(() => "The weather is sunny and 75°F.", { + name: "get_weather", + description: "Get the weather", + schema: z.object({}), + }), + ], + postModelHook: () => ({ flag: true }), + stateSchema: Annotation.Root({ + ...MessagesAnnotation.spec, + flag: Annotation, + }), + }), + structure: [ + "__start__ --> pre_model_hook", + "pre_model_hook --> agent", + "agent --> post_model_hook", + "tools --> pre_model_hook", + "post_model_hook -.-> tools", + "post_model_hook -.[entrypoint].-> pre_model_hook", + "post_model_hook -.-> __end__", + ], + }, + ], + + [ + { + name: "pre model hook + tools + post model hook + response format", + graph: createReactAgent({ + llm: new FakeToolCallingChatModel({}), + tools: [ + tool(() => "The weather is sunny and 75°F.", { + name: "get_weather", + description: "Get the weather", + schema: z.object({}), + }), + ], + responseFormat: z.object({ + temperature: z.number().describe("The temperature in fahrenheit"), + }), + preModelHook: () => ({ messages: [] }), + postModelHook: () => ({ flag: true }), + stateSchema: Annotation.Root({ + ...MessagesAnnotation.spec, + flag: Annotation, + }), + }), + structure: [ + "__start__ --> pre_model_hook", + "pre_model_hook --> agent", + "agent --> post_model_hook", + "generate_structured_response --> __end__", + "tools --> pre_model_hook", + "post_model_hook -.-> tools", + "post_model_hook -.[entrypoint].-> pre_model_hook", + "post_model_hook -.-> generate_structured_response", + ], + }, + ], + ])("graph structure $name", async ({ name, graph, structure }) => { + expect(getReadableMermaid(await graph.getGraphAsync()).sort()).toEqual( + structure.sort() + ); + }); }); describe("ToolNode", () => { diff --git a/libs/langgraph/src/tests/utils.ts b/libs/langgraph/src/tests/utils.ts index 335aee5a3..4f18a6f6a 100644 --- a/libs/langgraph/src/tests/utils.ts +++ b/libs/langgraph/src/tests/utils.ts @@ -4,6 +4,7 @@ import assert from "node:assert"; import { expect, it } from "vitest"; import { v4 as uuidv4 } from "uuid"; import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; +import { Graph as DrawableGraph } from "@langchain/core/runnables/graph"; import { BaseChatModel, BaseChatModelParams, @@ -670,3 +671,17 @@ export async function dumpDebugStream< console.log(`final state: ${JSON.stringify(graphState.values, null, 2)}`); return invokeReturnValue as ReturnType; } + +export function getReadableMermaid(graph: DrawableGraph) { + const mermaid = graph.drawMermaid({ withStyles: false }); + return mermaid + .replace(/\s* (.*) \s*/g, "[$1]") + .split("\n") + .slice(1) + .map((i) => { + const res = i.trim(); + if (res.endsWith(";")) return res.slice(0, -1); + return res; + }) + .filter(Boolean); +} From dd84df01e0b3511305e8c62110401af8b05b0019 Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Fri, 30 May 2025 20:10:37 +0200 Subject: [PATCH 5/6] Lint --- libs/langgraph/src/tests/prebuilt.test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langgraph/src/tests/prebuilt.test.ts b/libs/langgraph/src/tests/prebuilt.test.ts index bbdf881d8..6b684e169 100644 --- a/libs/langgraph/src/tests/prebuilt.test.ts +++ b/libs/langgraph/src/tests/prebuilt.test.ts @@ -1510,7 +1510,7 @@ describe("createReactAgent with hooks", () => { ], }, ], - ])("graph structure $name", async ({ name, graph, structure }) => { + ])("graph structure $name", async ({ graph, structure }) => { expect(getReadableMermaid(await graph.getGraphAsync()).sort()).toEqual( structure.sort() ); From 71dee1cf0ccb6b8f68a5181b17838dbab6bd55b7 Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Fri, 30 May 2025 22:44:48 +0200 Subject: [PATCH 6/6] Fix edges --- .../src/prebuilt/react_agent_executor.ts | 74 ++++++++----------- libs/langgraph/src/tests/diagrams.test.ts | 2 +- libs/langgraph/src/tests/prebuilt.test.ts | 36 ++++----- 3 files changed, 48 insertions(+), 64 deletions(-) diff --git a/libs/langgraph/src/prebuilt/react_agent_executor.ts b/libs/langgraph/src/prebuilt/react_agent_executor.ts index 7122eb61c..f22ad47fd 100644 --- a/libs/langgraph/src/prebuilt/react_agent_executor.ts +++ b/libs/langgraph/src/prebuilt/react_agent_executor.ts @@ -692,25 +692,19 @@ export function createReactAgent< const lastMessage = messages[messages.length - 1]; if (isAIMessage(lastMessage) && lastMessage.tool_calls?.length) { - return "tools" as const; - } - - if (isToolMessage(lastMessage)) { - return "entrypoint" as const; - } - - if (responseFormat != null) { - return "generate_structured_response" as const; + return "tools"; } + if (isToolMessage(lastMessage)) return entrypoint; + if (responseFormat != null) return "generate_structured_response"; return END; }, conditionalMap({ tools: "tools", - entrypoint, + [entrypoint]: entrypoint, generate_structured_response: responseFormat != null ? "generate_structured_response" : null, - [END]: END, + [END]: responseFormat != null ? null : END, }) ); } @@ -721,40 +715,30 @@ export function createReactAgent< .addEdge("generate_structured_response", END); } - allNodeWorkflows.addConditionalEdges( - "agent", - (state) => { - const { messages } = state; - const lastMessage = messages[messages.length - 1]; - - // if there's no function call, we finish - if (!isAIMessage(lastMessage) || !lastMessage.tool_calls?.length) { - if (postModelHook != null) { - return "post_model_hook" as const; - } + if (postModelHook == null) { + allNodeWorkflows.addConditionalEdges( + "agent", + (state) => { + const { messages } = state; + const lastMessage = messages[messages.length - 1]; - if (responseFormat != null) { - return "generate_structured_response" as const; + // if there's no function call, we finish + if (!isAIMessage(lastMessage) || !lastMessage.tool_calls?.length) { + if (responseFormat != null) return "generate_structured_response"; + return END; } - return END; - } - - // there are function calls, we continue - if (postModelHook != null) { - return "post_model_hook" as const; - } - - return "continue" as const; - }, - conditionalMap({ - continue: "tools", - post_model_hook: postModelHook != null ? "post_model_hook" : null, - generate_structured_response: - responseFormat != null ? "generate_structured_response" : null, - [END]: END, - }) - ); + // there are function calls, we continue + return "tools"; + }, + conditionalMap({ + tools: "tools", + generate_structured_response: + responseFormat != null ? "generate_structured_response" : null, + [END]: responseFormat != null ? null : END, + }) + ); + } if (shouldReturnDirect.size > 0) { allNodeWorkflows.addConditionalEdges( @@ -774,12 +758,12 @@ export function createReactAgent< } } - return "agent" as const; + return entrypoint; }, - ["agent", END] + conditionalMap({ [entrypoint]: entrypoint, [END]: END }) ); } else { - allNodeWorkflows.addEdge("tools", "agent"); + allNodeWorkflows.addEdge("tools", entrypoint); } return allNodeWorkflows.compile({ diff --git a/libs/langgraph/src/tests/diagrams.test.ts b/libs/langgraph/src/tests/diagrams.test.ts index fd2acbbfd..730d622f5 100644 --- a/libs/langgraph/src/tests/diagrams.test.ts +++ b/libs/langgraph/src/tests/diagrams.test.ts @@ -21,7 +21,7 @@ graph TD; \t__end__([

__end__

]):::last \t__start__ --> agent; \ttools --> agent; -\tagent -.  continue  .-> tools; +\tagent -.-> tools; \tagent -.-> __end__; \tclassDef default fill:#f2f0ff,line-height:1.2; \tclassDef first fill-opacity:0; diff --git a/libs/langgraph/src/tests/prebuilt.test.ts b/libs/langgraph/src/tests/prebuilt.test.ts index 6b684e169..e56b06f88 100644 --- a/libs/langgraph/src/tests/prebuilt.test.ts +++ b/libs/langgraph/src/tests/prebuilt.test.ts @@ -1278,7 +1278,7 @@ describe("createReactAgent with hooks", () => { structure: [ "__start__ --> agent", "agent -.-> __end__", - "agent -.[continue].-> tools", + "agent -.-> tools", "tools --> agent", ], }, @@ -1299,7 +1299,7 @@ describe("createReactAgent with hooks", () => { structure: [ "__start__ --> agent", "agent -.-> __end__", - "agent -.[continue].-> tools", + "agent -.-> tools", "tools --> agent", ], }, @@ -1307,7 +1307,7 @@ describe("createReactAgent with hooks", () => { [ { - name: "pre model hook + tools", + name: "pre + tools", graph: createReactAgent({ llm: new FakeToolCallingChatModel({}), tools: [ @@ -1322,7 +1322,7 @@ describe("createReactAgent with hooks", () => { structure: [ "__start__ --> pre_model_hook", "agent -.-> __end__", - "agent -.[continue].-> tools", + "agent -.-> tools", "pre_model_hook --> agent", "tools --> pre_model_hook", ], @@ -1331,7 +1331,7 @@ describe("createReactAgent with hooks", () => { [ { - name: "tools + post model hook", + name: "tools + post", graph: createReactAgent({ llm: new FakeToolCallingChatModel({}), tools: [ @@ -1352,7 +1352,7 @@ describe("createReactAgent with hooks", () => { "agent --> post_model_hook", "tools --> agent", "post_model_hook -.-> tools", - "post_model_hook -.[entrypoint].-> agent", + "post_model_hook -.-> agent", "post_model_hook -.-> __end__", ], }, @@ -1378,7 +1378,7 @@ describe("createReactAgent with hooks", () => { "__start__ --> agent", "generate_structured_response --> __end__", "tools --> agent", - "agent -.[continue].-> tools", + "agent -.-> tools", "agent -.-> generate_structured_response", ], }, @@ -1386,7 +1386,7 @@ describe("createReactAgent with hooks", () => { [ { - name: "pre model hook + tools + response format", + name: "pre + tools + response format", graph: createReactAgent({ llm: new FakeToolCallingChatModel({}), tools: [ @@ -1404,18 +1404,17 @@ describe("createReactAgent with hooks", () => { structure: [ "__start__ --> pre_model_hook", "pre_model_hook --> agent", + "agent -.-> tools", + "agent -.-> generate_structured_response", "generate_structured_response --> __end__", "tools --> pre_model_hook", - "post_model_hook -.-> tools", - "post_model_hook -.[entrypoint].-> pre_model_hook", - "post_model_hook -.-> generate_structured_response", ], }, ], [ { - name: "tools + post model hook + response format", + name: "tools + post + response format", graph: createReactAgent({ llm: new FakeToolCallingChatModel({}), tools: [ @@ -1440,7 +1439,7 @@ describe("createReactAgent with hooks", () => { "generate_structured_response --> __end__", "tools --> agent", "post_model_hook -.-> tools", - "post_model_hook -.[entrypoint].-> agent", + "post_model_hook -.-> agent", "post_model_hook -.-> generate_structured_response", ], }, @@ -1448,7 +1447,7 @@ describe("createReactAgent with hooks", () => { [ { - name: "pre model hook + tools + post model hook", + name: "pre + tools + post", graph: createReactAgent({ llm: new FakeToolCallingChatModel({}), tools: [ @@ -1458,6 +1457,7 @@ describe("createReactAgent with hooks", () => { schema: z.object({}), }), ], + preModelHook: () => ({ messages: [] }), postModelHook: () => ({ flag: true }), stateSchema: Annotation.Root({ ...MessagesAnnotation.spec, @@ -1470,7 +1470,7 @@ describe("createReactAgent with hooks", () => { "agent --> post_model_hook", "tools --> pre_model_hook", "post_model_hook -.-> tools", - "post_model_hook -.[entrypoint].-> pre_model_hook", + "post_model_hook -.-> pre_model_hook", "post_model_hook -.-> __end__", ], }, @@ -1478,7 +1478,7 @@ describe("createReactAgent with hooks", () => { [ { - name: "pre model hook + tools + post model hook + response format", + name: "pre + tools + post + response format", graph: createReactAgent({ llm: new FakeToolCallingChatModel({}), tools: [ @@ -1505,12 +1505,12 @@ describe("createReactAgent with hooks", () => { "generate_structured_response --> __end__", "tools --> pre_model_hook", "post_model_hook -.-> tools", - "post_model_hook -.[entrypoint].-> pre_model_hook", + "post_model_hook -.-> pre_model_hook", "post_model_hook -.-> generate_structured_response", ], }, ], - ])("graph structure $name", async ({ graph, structure }) => { + ])("mermaid $name", async ({ graph, structure }) => { expect(getReadableMermaid(await graph.getGraphAsync()).sort()).toEqual( structure.sort() );