Skip to content

feat(prebuilt): add preModelHook and postModelHook #1212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 172 additions & 48 deletions libs/langgraph/src/prebuilt/react_agent_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
RunnableToolLike,
RunnableSequence,
RunnableBinding,
type RunnableLike,
} from "@langchain/core/runnables";
import { DynamicTool, StructuredToolInterface } from "@langchain/core/tools";
import {
Expand All @@ -30,7 +31,7 @@ import { z } from "zod";

import {
StateGraph,
CompiledStateGraph,
type CompiledStateGraph,
AnnotationRoot,
} from "../graph/index.js";
import { MessagesAnnotation } from "../graph/messages_annotation.js";
Expand Down Expand Up @@ -327,6 +328,27 @@ export const createReactAgentAnnotation = <
structuredResponse: Annotation<T>,
});

type WithStateGraphNodes<K extends string, Graph> = Graph extends StateGraph<
infer SD,
infer S,
infer U,
infer N,
infer I,
infer O,
infer C
>
? StateGraph<SD, S, U, N | K, I, O, C>
: never;

const PreHookAnnotation = Annotation.Root({
llmInputMessages: Annotation<BaseMessage[], Messages>({
reducer: messagesStateReducer,
default: () => [],
}),
});

type PreHookAnnotation = typeof PreHookAnnotation;

export type CreateReactAgentParams<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
A extends AnnotationRoot<any> = AnnotationRoot<any>,
Expand Down Expand Up @@ -407,6 +429,26 @@ export type CreateReactAgentParams<
Example: `"How can I help you"` -> `"<name>agent_name</name><content>How can I help you?</content>"`
*/
includeAgentName?: "inline" | undefined;

/**
* An optional node to add before the `agent` node (i.e., the node that calls the LLM).
* Useful for managing long message histories (e.g., message trimming, summarization, etc.).
*/
preModelHook?: RunnableLike<
A["State"] & PreHookAnnotation["State"],
A["Update"] & PreHookAnnotation["Update"],
LangGraphRunnableConfig
>;

/**
* An optional node to add after the `agent` node (i.e., the node that calls the LLM).
* Useful for implementing human-in-the-loop, guardrails, validation, or other post-processing.
*/
postModelHook?: RunnableLike<
A["State"],
A["Update"],
LangGraphRunnableConfig
>;
};

/**
Expand Down Expand Up @@ -483,6 +525,8 @@ export function createReactAgent<
interruptAfter,
store,
responseFormat,
preModelHook,
postModelHook,
name,
includeAgentName,
} = params;
Expand Down Expand Up @@ -537,18 +581,15 @@ export function createReactAgent<
.map((tool) => tool.name)
);

const shouldContinue = (state: AgentState<StructuredResponseFormat>) => {
const { messages } = state;
const lastMessage = messages[messages.length - 1];
if (
isAIMessage(lastMessage) &&
(!lastMessage.tool_calls || lastMessage.tool_calls.length === 0)
) {
return responseFormat != null ? "generate_structured_response" : END;
} else {
return "continue";
function getModelInputState(
state: AgentState<StructuredResponseFormat> & PreHookAnnotation["State"]
): Omit<AgentState<StructuredResponseFormat>, "llmInputMessages"> {
const { messages, llmInputMessages, ...rest } = state;
if (llmInputMessages != null && llmInputMessages.length > 0) {
return { messages: llmInputMessages, ...rest };
}
};
return { messages, ...rest };
}

const generateStructuredResponse = async (
state: AgentState<StructuredResponseFormat>,
Expand Down Expand Up @@ -583,66 +624,149 @@ export function createReactAgent<
};

const callModel = async (
state: AgentState<StructuredResponseFormat>,
state: AgentState<StructuredResponseFormat> & PreHookAnnotation["State"],
config?: RunnableConfig
) => {
// NOTE: we're dynamically creating the model runnable here
// to ensure that we can validate ConfigurableModel properly
const modelRunnable = await getModelRunnable(llm);
// TODO: Auto-promote streaming.
const response = (await modelRunnable.invoke(state, config)) as BaseMessage;
const response = (await modelRunnable.invoke(
getModelInputState(state),
config
)) as BaseMessage;
// add agent name to the AIMessage
// TODO: figure out if we can avoid mutating the message directly
response.name = name;
response.lc_kwargs.name = name;
return { messages: [response] };
};

const workflow = new StateGraph(
stateSchema ?? createReactAgentAnnotation<StructuredResponseFormat>()
)
.addNode("agent", callModel)
.addNode("tools", toolNode)
.addEdge(START, "agent");
const schema =
stateSchema ?? createReactAgentAnnotation<StructuredResponseFormat>();

const workflow = new StateGraph(schema).addNode("tools", toolNode);

const allNodeWorkflows = workflow as WithStateGraphNodes<
| "pre_model_hook"
| "post_model_hook"
| "generate_structured_response"
| "agent",
typeof workflow
>;

const conditionalMap = <T extends string>(map: Record<string, T | null>) => {
return Object.fromEntries(
Object.entries(map).filter(([_, v]) => v != null) as [string, T][]
);
};

let entrypoint: "agent" | "pre_model_hook" = "agent";
let inputSchema: AnnotationRoot<(typeof schema)["spec"]> | undefined;
if (preModelHook != null) {
allNodeWorkflows
.addNode("pre_model_hook", preModelHook)
.addEdge("pre_model_hook", "agent");
entrypoint = "pre_model_hook";

inputSchema = Annotation.Root({
...schema.spec,
...PreHookAnnotation.spec,
});
} else {
entrypoint = "agent";
}

allNodeWorkflows
.addNode("agent", callModel, { input: inputSchema })
.addEdge(START, entrypoint);

if (postModelHook != null) {
allNodeWorkflows
.addNode("post_model_hook", postModelHook)
.addEdge("agent", "post_model_hook")
.addConditionalEdges(
"post_model_hook",
(state) => {
const { messages } = state;
const lastMessage = messages[messages.length - 1];

if (isAIMessage(lastMessage) && lastMessage.tool_calls?.length) {
return "tools";
}

if (isToolMessage(lastMessage)) return entrypoint;
if (responseFormat != null) return "generate_structured_response";
return END;
},
conditionalMap({
tools: "tools",
[entrypoint]: entrypoint,
generate_structured_response:
responseFormat != null ? "generate_structured_response" : null,
[END]: responseFormat != null ? null : END,
})
);
}

if (responseFormat !== undefined) {
workflow
.addNode("generate_structured_response", generateStructuredResponse)
.addEdge("generate_structured_response", END)
.addConditionalEdges("agent", shouldContinue, {
continue: "tools",
[END]: END,
generate_structured_response: "generate_structured_response",
});
} else {
workflow.addConditionalEdges("agent", shouldContinue, {
continue: "tools",
[END]: END,
});
.addEdge("generate_structured_response", END);
}

const routeToolResponses = (state: AgentState<StructuredResponseFormat>) => {
// Check the last consecutive tool calls
for (let i = state.messages.length - 1; i >= 0; i -= 1) {
const message = state.messages[i];
if (!isToolMessage(message)) {
break;
}
// Check if this tool is configured to return directly
if (message.name !== undefined && shouldReturnDirect.has(message.name)) {
return END;
}
}
return "agent";
};
if (postModelHook == null) {
allNodeWorkflows.addConditionalEdges(
"agent",
(state) => {
const { messages } = state;
const lastMessage = messages[messages.length - 1];

// if there's no function call, we finish
if (!isAIMessage(lastMessage) || !lastMessage.tool_calls?.length) {
if (responseFormat != null) return "generate_structured_response";
return END;
}

// there are function calls, we continue
return "tools";
},
conditionalMap({
tools: "tools",
generate_structured_response:
responseFormat != null ? "generate_structured_response" : null,
[END]: responseFormat != null ? null : END,
})
);
}

if (shouldReturnDirect.size > 0) {
workflow.addConditionalEdges("tools", routeToolResponses, ["agent", END]);
allNodeWorkflows.addConditionalEdges(
"tools",
(state) => {
// Check the last consecutive tool calls
for (let i = state.messages.length - 1; i >= 0; i -= 1) {
const message = state.messages[i];
if (!isToolMessage(message)) break;

// Check if this tool is configured to return directly
if (
message.name !== undefined &&
shouldReturnDirect.has(message.name)
) {
return END;
}
}

return entrypoint;
},
conditionalMap({ [entrypoint]: entrypoint, [END]: END })
);
} else {
workflow.addEdge("tools", "agent");
allNodeWorkflows.addEdge("tools", entrypoint);
}

return workflow.compile({
return allNodeWorkflows.compile({
checkpointer: checkpointer ?? checkpointSaver,
interruptBefore,
interruptAfter,
Expand Down
4 changes: 2 additions & 2 deletions libs/langgraph/src/tests/diagrams.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ test("prebuilt agent", async () => {
expect(mermaid).toEqual(`%%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD;
\t__start__([<p>__start__</p>]):::first
\tagent(agent)
\ttools(tools)
\tagent(agent)
\t__end__([<p>__end__</p>]):::last
\t__start__ --> agent;
\ttools --> agent;
\tagent -. &nbsp;continue&nbsp; .-> tools;
\tagent -.-> tools;
\tagent -.-> __end__;
\tclassDef default fill:#f2f0ff,line-height:1.2;
\tclassDef first fill-opacity:0;
Expand Down
Loading