Skip to content

Commit 1849e81

Browse files
committed
feat(langgraph): add typedNode utility
1 parent 2763e52 commit 1849e81

File tree

3 files changed

+151
-10
lines changed

3 files changed

+151
-10
lines changed

libs/langgraph/src/graph/graph.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,11 +167,11 @@ export type NodeSpec<RunInput, RunOutput> = {
167167
ends?: string[];
168168
};
169169

170-
export type AddNodeOptions = {
170+
export type AddNodeOptions<Nodes extends string = string> = {
171171
metadata?: Record<string, unknown>;
172172
// eslint-disable-next-line @typescript-eslint/no-explicit-any
173173
subgraphs?: Pregel<any, any>[];
174-
ends?: string[];
174+
ends?: Nodes[];
175175
};
176176

177177
export class Graph<

libs/langgraph/src/graph/state.ts

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ export type StateGraphNodeSpec<RunInput, RunOutput> = NodeSpec<
8282
retryPolicy?: RetryPolicy;
8383
};
8484

85-
export type StateGraphAddNodeOptions = {
85+
export type StateGraphAddNodeOptions<Nodes extends string = string> = {
8686
retryPolicy?: RetryPolicy;
8787
// TODO: Fix generic typing for annotations
8888
// eslint-disable-next-line @typescript-eslint/no-explicit-any
8989
input?: AnnotationRoot<any> | AnyZodObject;
90-
} & AddNodeOptions;
90+
} & AddNodeOptions<Nodes>;
9191

9292
export type StateGraphArgsWithStateSchema<
9393
SD extends StateDefinition,
@@ -237,7 +237,8 @@ export class StateGraph<
237237
fields: SD extends StateDefinition
238238
? StateGraphArgsWithInputOutputSchemas<SD, ToStateDefinition<O>>
239239
: never,
240-
configSchema?: C | AnnotationRoot<ToStateDefinition<C>>
240+
configSchema?: C | AnnotationRoot<ToStateDefinition<C>>,
241+
options?: { nodes?: N[] }
241242
);
242243

243244
constructor(
@@ -252,14 +253,16 @@ export class StateGraph<
252253
ToStateDefinition<O>
253254
>
254255
: StateGraphArgs<S>,
255-
configSchema?: C | AnnotationRoot<ToStateDefinition<C>>
256+
configSchema?: C | AnnotationRoot<ToStateDefinition<C>>,
257+
options?: { nodes?: N[] }
256258
);
257259

258260
constructor(
259261
fields: SD extends AnyZodObject
260262
? SD | ZodStateGraphArgsWithStateSchema<SD, I, O>
261263
: never,
262-
configSchema?: C | AnnotationRoot<ToStateDefinition<C>>
264+
configSchema?: C | AnnotationRoot<ToStateDefinition<C>>,
265+
options?: { nodes?: N[] }
263266
);
264267

265268
constructor(
@@ -277,7 +280,8 @@ export class StateGraph<
277280
>
278281
| StateGraphArgsWithInputOutputSchemas<SD, ToStateDefinition<O>>
279282
: StateGraphArgs<S>,
280-
configSchema?: C | AnnotationRoot<ToStateDefinition<C>>
283+
configSchema?: C | AnnotationRoot<ToStateDefinition<C>>,
284+
_options?: { nodes?: N[] }
281285
) {
282286
super();
283287

@@ -434,7 +438,12 @@ export class StateGraph<
434438
isMultipleNodes(args) // eslint-disable-line no-nested-ternary
435439
? Array.isArray(args[0])
436440
? args[0]
437-
: Object.entries(args[0])
441+
: Object.entries(args[0]).map(([key, action]) => [
442+
key,
443+
action,
444+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
445+
(action as any)[Symbol.for("langgraph.state.node")] ?? undefined,
446+
])
438447
: [[args[0], args[1], args[2]]]
439448
) as [
440449
K,
@@ -574,7 +583,12 @@ export class StateGraph<
574583
): StateGraph<SD, S, U, N | K, I, O, C> {
575584
const parsedNodes = Array.isArray(nodes)
576585
? nodes
577-
: (Object.entries(nodes) as [K, NodeAction<S, U, C>][]);
586+
: (Object.entries(nodes).map(([key, action]) => [
587+
key,
588+
action,
589+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
590+
(action as any)[Symbol.for("langgraph.state.node")] ?? undefined,
591+
]) as [K, NodeAction<S, U, C>, StateGraphAddNodeOptions | undefined][]);
578592

579593
if (parsedNodes.length === 0) {
580594
throw new Error("Sequence requires at least one node.");
@@ -1091,3 +1105,41 @@ function _getControlBranch() {
10911105
path: CONTROL_BRANCH_PATH,
10921106
});
10931107
}
1108+
1109+
type TypedNodeAction<SD extends StateDefinition, Nodes extends string> = (
1110+
state: StateType<SD>,
1111+
config: LangGraphRunnableConfig
1112+
) => UpdateType<SD> | Command<unknown, UpdateType<SD>, Nodes>;
1113+
1114+
export function typedNode<SD extends SDZod, Nodes extends string>(
1115+
_state: SD extends StateDefinition ? AnnotationRoot<SD> : never,
1116+
_options?: { nodes?: Nodes[] }
1117+
): (
1118+
func: TypedNodeAction<ToStateDefinition<SD>, Nodes>,
1119+
options?: StateGraphAddNodeOptions<Nodes>
1120+
) => TypedNodeAction<ToStateDefinition<SD>, Nodes>;
1121+
1122+
export function typedNode<SD extends SDZod, Nodes extends string>(
1123+
_state: SD extends AnyZodObject ? SD : never,
1124+
_options?: { nodes?: Nodes[] }
1125+
): (
1126+
func: TypedNodeAction<ToStateDefinition<SD>, Nodes>,
1127+
options?: StateGraphAddNodeOptions<Nodes>
1128+
) => TypedNodeAction<ToStateDefinition<SD>, Nodes>;
1129+
1130+
export function typedNode<SD extends SDZod, Nodes extends string>(
1131+
_state: SD extends AnyZodObject
1132+
? SD
1133+
: SD extends StateDefinition
1134+
? AnnotationRoot<SD>
1135+
: never,
1136+
_options?: { nodes?: Nodes[] }
1137+
) {
1138+
return (
1139+
func: TypedNodeAction<ToStateDefinition<SD>, Nodes>,
1140+
options?: StateGraphAddNodeOptions<Nodes>
1141+
) => {
1142+
Object.assign(func, { [Symbol.for("langgraph.state.node")]: options });
1143+
return func;
1144+
};
1145+
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import { z } from "zod";
2+
import { Command } from "../constants.js";
3+
import { Annotation } from "../graph/annotation.js";
4+
import {
5+
MessagesAnnotation,
6+
MessagesZodState,
7+
} from "../graph/messages_annotation.js";
8+
import { StateGraph, typedNode } from "../graph/state.js";
9+
import { _AnyIdHumanMessage } from "./utils.js";
10+
11+
it("Annotation.Root", async () => {
12+
const StateAnnotation = Annotation.Root({
13+
messages: MessagesAnnotation.spec.messages,
14+
foo: Annotation<string>,
15+
});
16+
17+
const node = typedNode(StateAnnotation, {
18+
nodes: ["nodeA", "nodeB", "nodeC"],
19+
});
20+
21+
const nodeA = node(
22+
(state) => {
23+
const goto = state.foo === "foo" ? "nodeB" : "nodeC";
24+
return new Command({
25+
update: { messages: [{ type: "user", content: "a" }], foo: "a" },
26+
goto,
27+
});
28+
},
29+
{ ends: ["nodeB", "nodeC"] }
30+
);
31+
32+
const nodeB = node(() => {
33+
return new Command({
34+
goto: "nodeC",
35+
update: { foo: "123" },
36+
});
37+
});
38+
const nodeC = node((state) => ({ foo: `${state.foo}|c` }));
39+
40+
const graph = new StateGraph(StateAnnotation)
41+
.addNode({ nodeA, nodeB, nodeC })
42+
.addEdge("__start__", "nodeA")
43+
.compile();
44+
45+
expect(await graph.invoke({ foo: "foo" })).toEqual({
46+
messages: [new _AnyIdHumanMessage("a")],
47+
foo: "123|c",
48+
});
49+
});
50+
51+
it("Zod", async () => {
52+
const StateAnnotation = MessagesZodState.extend({
53+
foo: z.string(),
54+
});
55+
56+
const node = typedNode(StateAnnotation, {
57+
nodes: ["nodeA", "nodeB", "nodeC"],
58+
});
59+
60+
const nodeA = node(
61+
(state) => {
62+
const goto = state.foo === "foo" ? "nodeB" : "nodeC";
63+
return new Command({
64+
update: { messages: [{ type: "user", content: "a" }], foo: "a" },
65+
goto,
66+
});
67+
},
68+
{ ends: ["nodeB", "nodeC"] }
69+
);
70+
71+
const nodeB = node(() => {
72+
return new Command({
73+
goto: "nodeC",
74+
update: { foo: "123" },
75+
});
76+
});
77+
78+
const nodeC = node((state) => ({ foo: `${state.foo}|c` }));
79+
80+
const graph = new StateGraph(StateAnnotation)
81+
.addNode({ nodeA, nodeB, nodeC })
82+
.addEdge("__start__", "nodeA")
83+
.compile();
84+
85+
expect(await graph.invoke({ foo: "foo" })).toEqual({
86+
messages: [new _AnyIdHumanMessage("a")],
87+
foo: "123|c",
88+
});
89+
});

0 commit comments

Comments
 (0)