Skip to content

Commit 95d43aa

Browse files
authored
fix(langgraph): fix default input Zod schema not obtaining langgraph metadata (#1232)
2 parents e0b65cc + ac3d270 commit 95d43aa

File tree

5 files changed

+93
-8
lines changed

5 files changed

+93
-8
lines changed

libs/langgraph/src/graph/messages_annotation.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ export const MessagesZodState = z.object({
9393
schema: z.custom<Messages>(),
9494
fn: messagesStateReducer,
9595
},
96+
jsonSchemaExtra: {
97+
langgraph_type: "messages",
98+
},
9699
default: () => [],
97100
}),
98101
});

libs/langgraph/src/graph/state.ts

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ type NodeAction<S, U, C extends SDZod> = RunnableLike<
127127
LangGraphRunnableConfig<StateType<ToStateDefinition<C>>>
128128
>;
129129

130+
const PartialStateSchema = Symbol.for("langgraph.state.partial");
131+
type PartialStateSchema = typeof PartialStateSchema;
132+
130133
/**
131134
* A graph whose nodes communicate by reading and writing to a shared state.
132135
* Each node takes a defined `State` as input and returns a `Partial<State>`.
@@ -213,7 +216,7 @@ export class StateGraph<
213216
_inputDefinition: I;
214217

215218
/** @internal */
216-
_inputRuntimeDefinition: AnyZodObject | undefined;
219+
_inputRuntimeDefinition: AnyZodObject | PartialStateSchema | undefined;
217220

218221
/** @internal */
219222
_outputDefinition: O;
@@ -292,7 +295,7 @@ export class StateGraph<
292295
this._schemaRuntimeDefinition = fields.state;
293296

294297
this._inputDefinition = inputDef as I;
295-
this._inputRuntimeDefinition = fields.input ?? fields.state.partial();
298+
this._inputRuntimeDefinition = fields.input ?? PartialStateSchema;
296299

297300
this._outputDefinition = outputDef as O;
298301
this._outputRuntimeDefinition = fields.output ?? fields.state;
@@ -303,7 +306,7 @@ export class StateGraph<
303306
this._schemaRuntimeDefinition = fields;
304307

305308
this._inputDefinition = stateDef as I;
306-
this._inputRuntimeDefinition = fields.partial();
309+
this._inputRuntimeDefinition = PartialStateSchema;
307310

308311
this._outputDefinition = stateDef as O;
309312
this._outputRuntimeDefinition = fields;
@@ -944,7 +947,11 @@ export class CompiledStateGraph<
944947
protected async _validateInput(
945948
input: UpdateType<ToStateDefinition<I>>
946949
): Promise<UpdateType<ToStateDefinition<I>>> {
947-
const inputSchema = this.builder._inputRuntimeDefinition;
950+
let inputSchema = this.builder._inputRuntimeDefinition;
951+
if (inputSchema === PartialStateSchema) {
952+
inputSchema = this.builder._schemaRuntimeDefinition?.partial();
953+
}
954+
948955
if (isCommand(input)) {
949956
const parsedInput = input;
950957
if (input.update && isAnyZodObject(inputSchema))

libs/langgraph/src/graph/zod/schema.ts

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ import { getMeta } from "./state.js";
55
const TYPE_CACHE: Record<string, WeakMap<z.AnyZodObject, z.AnyZodObject>> = {};
66
const DESCRIPTION_PREFIX = "lg:";
77

8+
const PartialStateSchema = Symbol.for("langgraph.state.partial");
9+
type PartialStateSchema = typeof PartialStateSchema;
10+
811
function applyPlugin(
912
schema: z.AnyZodObject,
1013
actions: {
@@ -61,7 +64,7 @@ function applyPlugin(
6164
interface GraphWithZodLike {
6265
builder: {
6366
_schemaRuntimeDefinition: z.AnyZodObject | undefined;
64-
_inputRuntimeDefinition: z.AnyZodObject | undefined;
67+
_inputRuntimeDefinition: z.AnyZodObject | PartialStateSchema | undefined;
6568
_outputRuntimeDefinition: z.AnyZodObject | undefined;
6669
_configRuntimeSchema: z.AnyZodObject | undefined;
6770
};
@@ -124,7 +127,7 @@ export function getStateTypeSchema(graph: unknown): JsonSchema | undefined {
124127
if (!isGraphWithZodLike(graph)) return undefined;
125128
const schemaDef = graph.builder._schemaRuntimeDefinition;
126129
if (!schemaDef) return undefined;
127-
return toJsonSchema(schemaDef);
130+
return toJsonSchema(applyPlugin(schemaDef, { jsonSchemaExtra: true }));
128131
}
129132

130133
/**
@@ -153,7 +156,12 @@ export function getUpdateTypeSchema(graph: unknown): JsonSchema | undefined {
153156
*/
154157
export function getInputTypeSchema(graph: unknown): JsonSchema | undefined {
155158
if (!isGraphWithZodLike(graph)) return undefined;
156-
const schemaDef = graph.builder._inputRuntimeDefinition;
159+
let schemaDef = graph.builder._inputRuntimeDefinition;
160+
if (schemaDef === PartialStateSchema) {
161+
// No need to pass `.partial()` here, that's being done by `applyPlugin`
162+
schemaDef = graph.builder._schemaRuntimeDefinition;
163+
}
164+
157165
if (!schemaDef) return undefined;
158166
return toJsonSchema(
159167
applyPlugin(schemaDef, {

libs/langgraph/src/graph/zod/state.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ const META_MAP = new WeakMap<z.ZodType, Meta<any, any>>();
99
export interface Meta<ValueType, UpdateType = ValueType> {
1010
jsonSchemaExtra?: {
1111
langgraph_nodes?: string[];
12-
langgraph_type?: "prompt";
12+
langgraph_type?: "prompt" | "messages";
1313

1414
[key: string]: unknown;
1515
};

libs/langgraph/src/tests/zod_state.test.ts

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@ import { describe, it, expect } from "vitest";
22
import { z } from "zod";
33
import { StateGraph } from "../graph/state.js";
44
import { END, START } from "../constants.js";
5+
import { _AnyIdAIMessage, _AnyIdHumanMessage } from "./utils.js";
6+
import {
7+
getOutputTypeSchema,
8+
getInputTypeSchema,
9+
getUpdateTypeSchema,
10+
getStateTypeSchema,
11+
} from "../graph/zod/schema.js";
12+
import { MessagesZodState } from "../graph/messages_annotation.js";
513

614
describe("StateGraph with Zod schemas", () => {
715
it("should accept Zod schema as input in addNode", async () => {
@@ -77,4 +85,63 @@ describe("StateGraph with Zod schemas", () => {
7785
count: 1,
7886
});
7987
});
88+
89+
it("should accept Zod messages schema & return tagged JSON schema", async () => {
90+
const schema = MessagesZodState.extend({ count: z.number() });
91+
92+
const graph = new StateGraph(schema)
93+
.addNode("agent", () => ({
94+
messages: [{ type: "ai", content: "agent" }],
95+
}))
96+
.addNode("tool", () => ({
97+
messages: [{ type: "ai", content: "tool" }],
98+
}))
99+
.addEdge("__start__", "agent")
100+
.addEdge("agent", "tool")
101+
.compile();
102+
103+
expect(
104+
await graph.invoke({
105+
messages: [{ type: "human", content: "hello" }],
106+
})
107+
).toMatchObject({
108+
messages: [
109+
new _AnyIdHumanMessage("hello"),
110+
new _AnyIdAIMessage("agent"),
111+
new _AnyIdAIMessage("tool"),
112+
],
113+
});
114+
115+
expect.soft(getStateTypeSchema(graph)).toMatchObject({
116+
$schema: "http://json-schema.org/draft-07/schema#",
117+
properties: {
118+
messages: { langgraph_type: "messages" },
119+
count: { type: "number" },
120+
},
121+
});
122+
123+
expect.soft(getUpdateTypeSchema(graph)).toMatchObject({
124+
$schema: "http://json-schema.org/draft-07/schema#",
125+
properties: {
126+
messages: { langgraph_type: "messages" },
127+
count: { type: "number" },
128+
},
129+
});
130+
131+
expect.soft(getInputTypeSchema(graph)).toMatchObject({
132+
$schema: "http://json-schema.org/draft-07/schema#",
133+
properties: {
134+
messages: { langgraph_type: "messages" },
135+
count: { type: "number" },
136+
},
137+
});
138+
139+
expect.soft(getOutputTypeSchema(graph)).toMatchObject({
140+
$schema: "http://json-schema.org/draft-07/schema#",
141+
properties: {
142+
messages: { langgraph_type: "messages" },
143+
count: { type: "number" },
144+
},
145+
});
146+
});
80147
});

0 commit comments

Comments
 (0)