Skip to content

Commit b115ad0

Browse files
committed
feat(langgraph): introduce strict typing for .stream() methods
Supersedes #512
1 parent e0b65cc commit b115ad0

File tree

10 files changed

+321
-23
lines changed

10 files changed

+321
-23
lines changed

libs/langgraph/.eslintrc.cjs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ module.exports = {
3939
"import/extensions": [2, "ignorePackages"],
4040
"import/no-extraneous-dependencies": [
4141
"error",
42-
{ devDependencies: ["**/*.test.ts"] },
42+
{ devDependencies: ["**/*.test.ts", "**/*.test-d.ts"] },
4343
],
4444
"import/no-unresolved": 0,
4545
"import/prefer-default-export": 0,

libs/langgraph/src/func/index.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,10 @@ export interface EntrypointFunction {
146146
},
147147
Record<string, unknown>,
148148
InputT,
149-
EntrypointReturnT<OutputT>
149+
EntrypointReturnT<OutputT>,
150+
// Because the update type is an return type union of tasks + entrypoint,
151+
// thus we can't type it properly.
152+
any // eslint-disable-line @typescript-eslint/no-explicit-any
150153
>;
151154

152155
/**

libs/langgraph/src/pregel/index.ts

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ import {
7272
SingleChannelSubscriptionOptions,
7373
MultipleChannelSubscriptionOptions,
7474
GetStateOptions,
75+
type StreamOutputMap,
7576
} from "./types.js";
7677
import {
7778
GraphRecursionError,
@@ -118,10 +119,6 @@ import {
118119
type WriteValue = Runnable | RunnableFunc<unknown, unknown> | unknown;
119120
type StreamEventsOptions = Parameters<Runnable["streamEvents"]>[2];
120121

121-
function isString(value: unknown): value is string {
122-
return typeof value === "string";
123-
}
124-
125122
/**
126123
* Utility class for working with channels in the Pregel system.
127124
* Provides static methods for subscribing to channels and writing to them.
@@ -201,7 +198,7 @@ export class Channel {
201198

202199
let channelMappingOrArray: string[] | Record<string, string>;
203200

204-
if (isString(channels)) {
201+
if (typeof channels === "string") {
205202
if (key) {
206203
channelMappingOrArray = { [key]: channels };
207204
} else {
@@ -349,7 +346,9 @@ export class Pregel<
349346
// eslint-disable-next-line @typescript-eslint/no-explicit-any
350347
ConfigurableFieldType extends Record<string, any> = StrRecord<string, any>,
351348
InputType = PregelInputType,
352-
OutputType = PregelOutputType
349+
OutputType = PregelOutputType,
350+
StreamUpdatesType = InputType,
351+
StreamValuesType = OutputType
353352
>
354353
extends Runnable<
355354
InputType | Command | null,
@@ -1645,10 +1644,31 @@ export class Pregel<
16451644
* @param options - Configuration options for streaming
16461645
* @returns An async iterable stream of graph state updates
16471646
*/
1648-
override async stream(
1647+
// @ts-expect-error Return type of `stream()` differs from `invoke()`, which is expected.
1648+
override async stream<
1649+
TStreamMode extends StreamMode | StreamMode[] | undefined,
1650+
TSubgraphs extends boolean
1651+
>(
16491652
input: InputType | Command | null,
1650-
options?: Partial<PregelOptions<Nodes, Channels, ConfigurableFieldType>>
1651-
): Promise<IterableReadableStream<PregelOutputType>> {
1653+
options?: Partial<
1654+
PregelOptions<
1655+
Nodes,
1656+
Channels,
1657+
ConfigurableFieldType,
1658+
TStreamMode,
1659+
TSubgraphs
1660+
>
1661+
>
1662+
): Promise<
1663+
IterableReadableStream<
1664+
StreamOutputMap<
1665+
TStreamMode,
1666+
TSubgraphs,
1667+
StreamUpdatesType,
1668+
StreamValuesType
1669+
>
1670+
>
1671+
> {
16521672
// The ensureConfig method called internally defaults recursionLimit to 25 if not
16531673
// passed directly in `options`.
16541674
// There is currently no way in _streamIterator to determine whether this was
@@ -1665,7 +1685,14 @@ export class Pregel<
16651685
};
16661686

16671687
return new IterableReadableStreamWithAbortSignal(
1668-
await super.stream(input, config),
1688+
(await super.stream(input, config)) as IterableReadableStream<
1689+
StreamOutputMap<
1690+
TStreamMode,
1691+
TSubgraphs,
1692+
StreamUpdatesType,
1693+
StreamValuesType
1694+
>
1695+
>,
16691696
abortController
16701697
);
16711698
}
@@ -2025,7 +2052,7 @@ export class Pregel<
20252052
chunks.push(chunk);
20262053
}
20272054
if (streamMode === "values") {
2028-
return chunks[chunks.length - 1];
2055+
return chunks[chunks.length - 1] as OutputType;
20292056
}
20302057
return chunks as OutputType;
20312058
}

libs/langgraph/src/pregel/types.ts

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import type {
99
} from "@langchain/langgraph-checkpoint";
1010
import { Graph as DrawableGraph } from "@langchain/core/runnables/graph";
1111
import { IterableReadableStream } from "@langchain/core/utils/stream";
12+
import type { BaseMessage } from "@langchain/core/messages";
1213
import type { BaseChannel } from "../channels/base.js";
1314
import type { PregelNode } from "./read.js";
1415
import { RetryPolicy } from "./utils/index.js";
@@ -27,6 +28,68 @@ export type PregelInputType = any;
2728
// eslint-disable-next-line @typescript-eslint/no-explicit-any
2829
export type PregelOutputType = any;
2930

31+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
32+
type StreamMessageOutput = [BaseMessage, Record<string, any>];
33+
34+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
35+
type StreamCustomOutput = any;
36+
37+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
38+
type StreamDebugOutput = Record<string, any>;
39+
40+
type DefaultStreamMode = "updates";
41+
42+
export type StreamOutputMap<
43+
TStreamMode extends StreamMode | StreamMode[] | undefined,
44+
TStreamSubgraphs extends boolean,
45+
StreamUpdates,
46+
StreamValues
47+
> = (
48+
undefined extends TStreamMode
49+
? []
50+
: StreamMode | StreamMode[] extends TStreamMode
51+
? TStreamMode extends StreamMode[]
52+
? TStreamMode[number]
53+
: TStreamMode
54+
: TStreamMode extends StreamMode[]
55+
? TStreamMode[number]
56+
: []
57+
) extends infer Multiple extends StreamMode
58+
? [TStreamSubgraphs] extends [true]
59+
? {
60+
values: [string[], "values", StreamValues];
61+
updates: [string[], "updates", Record<string, StreamUpdates>];
62+
messages: [string[], "messages", StreamMessageOutput];
63+
custom: [string[], "custom", StreamCustomOutput];
64+
debug: [string[], "debug", StreamDebugOutput];
65+
}[Multiple]
66+
: {
67+
values: ["values", StreamValues];
68+
updates: ["updates", Record<string, StreamUpdates>];
69+
messages: ["messages", StreamMessageOutput];
70+
custom: ["custom", StreamCustomOutput];
71+
debug: ["debug", StreamDebugOutput];
72+
}[Multiple]
73+
: (
74+
undefined extends TStreamMode ? DefaultStreamMode : TStreamMode
75+
) extends infer Single extends StreamMode
76+
? [TStreamSubgraphs] extends [true]
77+
? {
78+
values: [string[], StreamValues];
79+
updates: [string[], Record<string, StreamUpdates>];
80+
messages: [string[], StreamMessageOutput];
81+
custom: [string[], StreamCustomOutput];
82+
debug: [string[], StreamDebugOutput];
83+
}[Single]
84+
: {
85+
values: StreamValues;
86+
updates: Record<string, StreamUpdates>;
87+
messages: StreamMessageOutput;
88+
custom: StreamCustomOutput;
89+
debug: StreamDebugOutput;
90+
}[Single]
91+
: never;
92+
3093
/**
3194
* Configuration options for executing a Pregel graph.
3295
* These options control how the graph executes, what data is streamed, and how interrupts are handled.
@@ -39,7 +102,12 @@ export interface PregelOptions<
39102
Nodes extends StrRecord<string, PregelNode>,
40103
Channels extends StrRecord<string, BaseChannel | ManagedValueSpec>,
41104
// eslint-disable-next-line @typescript-eslint/no-explicit-any
42-
ConfigurableFieldType extends Record<string, any> = Record<string, any>
105+
ConfigurableFieldType extends Record<string, any> = Record<string, any>,
106+
TStreamMode extends StreamMode | StreamMode[] | undefined =
107+
| StreamMode
108+
| StreamMode[]
109+
| undefined,
110+
TSubgraphs extends boolean = boolean
43111
> extends RunnableConfig<ConfigurableFieldType> {
44112
/**
45113
* Controls what information is streamed during graph execution.
@@ -63,7 +131,7 @@ export interface PregelOptions<
63131
*
64132
* @default ["values"]
65133
*/
66-
streamMode?: StreamMode | StreamMode[];
134+
streamMode?: TStreamMode;
67135

68136
/**
69137
* Specifies which channel keys to retrieve from the checkpoint when resuming execution.
@@ -140,7 +208,7 @@ export interface PregelOptions<
140208
*
141209
* @default false
142210
*/
143-
subgraphs?: boolean;
211+
subgraphs?: TSubgraphs;
144212

145213
/**
146214
* A shared value store that allows you to store and retrieve state across

libs/langgraph/src/tests/prebuilt.int.test.ts

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,7 @@ describe("createReactAgent", () => {
208208
});
209209

210210
const stream = await reactAgent.stream(
211-
{
212-
messages: [new HumanMessage("What's the weather like in SF?")],
213-
},
211+
{ messages: [new HumanMessage("What's the weather like in SF?")] },
214212
{ configurable: { thread_id: "foo" }, streamMode: "values" }
215213
);
216214
const fullResponse = [];
@@ -226,7 +224,7 @@ describe("createReactAgent", () => {
226224

227225
const lastMessage = endState.messages[endState.messages.length - 1];
228226
expect(lastMessage._getType()).toBe("ai");
229-
expect(lastMessage.content.toLowerCase()).toContain("not too cold");
227+
expect(lastMessage.text.toLowerCase()).toContain("not too cold");
230228
const stream2 = await reactAgent.stream(
231229
{
232230
messages: [new HumanMessage("What about NYC?")],
@@ -245,6 +243,6 @@ describe("createReactAgent", () => {
245243

246244
const lastMessage2 = endState.messages[endState.messages.length - 1];
247245
expect(lastMessage2._getType()).toBe("ai");
248-
expect(lastMessage2.content.toLowerCase()).toContain("not too cold");
246+
expect(lastMessage2.text.toLowerCase()).toContain("not too cold");
249247
});
250248
});

0 commit comments

Comments
 (0)