Skip to content

Commit dfb038a

Browse files
fix(checkpoint-mongodb): apply filters correctly in list method
fixes #581
1 parent 06b546a commit dfb038a

File tree

12 files changed

+563
-59
lines changed

12 files changed

+563
-59
lines changed

libs/checkpoint-mongodb/package.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
"@langchain/scripts": ">=0.1.3 <0.2.0",
4444
"@swc/core": "^1.3.90",
4545
"@swc/jest": "^0.2.29",
46+
"@testcontainers/mongodb": "^10.13.2",
4647
"@tsconfig/recommended": "^1.0.3",
4748
"@types/better-sqlite3": "^7.6.9",
4849
"@types/uuid": "^10",
@@ -62,6 +63,7 @@
6263
"prettier": "^2.8.3",
6364
"release-it": "^17.6.0",
6465
"rollup": "^4.23.0",
66+
"testcontainers": "^10.13.2",
6567
"ts-jest": "^29.1.0",
6668
"tsx": "^4.7.0",
6769
"typescript": "^4.9.5 || ^5.4.5"

libs/checkpoint-mongodb/src/index.ts

Lines changed: 146 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,21 @@ import {
99
type PendingWrite,
1010
type CheckpointMetadata,
1111
CheckpointPendingWrite,
12+
validCheckpointMetadataKeys,
1213
} from "@langchain/langgraph-checkpoint";
14+
import { applyMigrations, needsMigration } from "./migrations/index.js";
15+
16+
export * from "./migrations/index.js";
17+
18+
// increment this whenever the structure of the database changes in a way that would require a migration
19+
const CURRENT_SCHEMA_VERSION = 1;
1320

1421
export type MongoDBSaverParams = {
1522
client: MongoClient;
1623
dbName?: string;
1724
checkpointCollectionName?: string;
1825
checkpointWritesCollectionName?: string;
26+
schemaVersionCollectionName?: string;
1927
};
2028

2129
/**
@@ -26,16 +34,21 @@ export class MongoDBSaver extends BaseCheckpointSaver {
2634

2735
protected db: MongoDatabase;
2836

37+
private setupPromise: Promise<void> | undefined;
38+
2939
checkpointCollectionName = "checkpoints";
3040

3141
checkpointWritesCollectionName = "checkpoint_writes";
3242

43+
schemaVersionCollectionName = "schema_version";
44+
3345
constructor(
3446
{
3547
client,
3648
dbName,
3749
checkpointCollectionName,
3850
checkpointWritesCollectionName,
51+
schemaVersionCollectionName,
3952
}: MongoDBSaverParams,
4053
serde?: SerializerProtocol
4154
) {
@@ -46,6 +59,118 @@ export class MongoDBSaver extends BaseCheckpointSaver {
4659
checkpointCollectionName ?? this.checkpointCollectionName;
4760
this.checkpointWritesCollectionName =
4861
checkpointWritesCollectionName ?? this.checkpointWritesCollectionName;
62+
this.schemaVersionCollectionName =
63+
schemaVersionCollectionName ?? this.schemaVersionCollectionName;
64+
}
65+
66+
/**
67+
* Runs async setup tasks if they haven't been run yet.
68+
*/
69+
async setup(): Promise<void> {
70+
if (this.setupPromise) {
71+
return this.setupPromise;
72+
}
73+
this.setupPromise = this.initializeSchemaVersion();
74+
return this.setupPromise;
75+
}
76+
77+
private async isDatabaseEmpty(): Promise<boolean> {
78+
const results = await Promise.all(
79+
[this.checkpointCollectionName, this.checkpointWritesCollectionName].map(
80+
async (collectionName) => {
81+
const collection = this.db.collection(collectionName);
82+
// set a limit of 1 to stop scanning if any documents are found
83+
const count = await collection.countDocuments({}, { limit: 1 });
84+
return count === 0;
85+
}
86+
)
87+
);
88+
89+
return results.every((result) => result);
90+
}
91+
92+
private async initializeSchemaVersion(): Promise<void> {
93+
const schemaVersionCollection = this.db.collection(
94+
this.schemaVersionCollectionName
95+
);
96+
97+
// empty database, no migrations needed - just set the schema version and move on
98+
if (await this.isDatabaseEmpty()) {
99+
const schemaVersionCollection = this.db.collection(
100+
this.schemaVersionCollectionName
101+
);
102+
103+
const versionDoc = await schemaVersionCollection.findOne({});
104+
if (!versionDoc) {
105+
await schemaVersionCollection.insertOne({
106+
version: CURRENT_SCHEMA_VERSION,
107+
});
108+
}
109+
} else {
110+
// non-empty database, check if migrations are needed
111+
const dbNeedsMigration = await needsMigration({
112+
client: this.client,
113+
dbName: this.db.databaseName,
114+
checkpointCollectionName: this.checkpointCollectionName,
115+
checkpointWritesCollectionName: this.checkpointWritesCollectionName,
116+
schemaVersionCollectionName: this.schemaVersionCollectionName,
117+
serializer: this.serde,
118+
currentSchemaVersion: CURRENT_SCHEMA_VERSION,
119+
});
120+
121+
if (dbNeedsMigration) {
122+
throw new Error(
123+
`Database needs migration. Call the migrate() method to migrate the database.`
124+
);
125+
}
126+
127+
// always defined if dbNeedsMigration is false
128+
const versionDoc = (await schemaVersionCollection.findOne({}))!;
129+
130+
if (versionDoc.version == null) {
131+
throw new Error(
132+
`BUG: Database schema version is corrupt. Manual intervention required.`
133+
);
134+
}
135+
136+
if (versionDoc.version > CURRENT_SCHEMA_VERSION) {
137+
throw new Error(
138+
`Database created with newer version of checkpoint-mongodb. This version supports schema version ` +
139+
`${CURRENT_SCHEMA_VERSION} but the database was created with schema version ${versionDoc.version}.`
140+
);
141+
}
142+
143+
if (versionDoc.version < CURRENT_SCHEMA_VERSION) {
144+
throw new Error(
145+
`BUG: Schema version ${versionDoc.version} is outdated (should be >= ${CURRENT_SCHEMA_VERSION}), but no ` +
146+
`migration wants to execute.`
147+
);
148+
}
149+
}
150+
}
151+
152+
async migrate() {
153+
if (
154+
await needsMigration({
155+
client: this.client,
156+
dbName: this.db.databaseName,
157+
checkpointCollectionName: this.checkpointCollectionName,
158+
checkpointWritesCollectionName: this.checkpointWritesCollectionName,
159+
schemaVersionCollectionName: this.schemaVersionCollectionName,
160+
serializer: this.serde,
161+
currentSchemaVersion: CURRENT_SCHEMA_VERSION,
162+
})
163+
) {
164+
await applyMigrations({
165+
client: this.client,
166+
dbName: this.db.databaseName,
167+
checkpointCollectionName: this.checkpointCollectionName,
168+
checkpointWritesCollectionName: this.checkpointWritesCollectionName,
169+
schemaVersionCollectionName: this.schemaVersionCollectionName,
170+
serializer: this.serde,
171+
currentSchemaVersion: CURRENT_SCHEMA_VERSION,
172+
});
173+
}
49174
}
50175

51176
/**
@@ -55,6 +180,8 @@ export class MongoDBSaver extends BaseCheckpointSaver {
55180
* for the given thread ID is retrieved.
56181
*/
57182
async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
183+
await this.setup();
184+
58185
const {
59186
thread_id,
60187
checkpoint_ns = "",
@@ -109,10 +236,7 @@ export class MongoDBSaver extends BaseCheckpointSaver {
109236
config: { configurable: configurableValues },
110237
checkpoint,
111238
pendingWrites,
112-
metadata: (await this.serde.loadsTyped(
113-
doc.type,
114-
doc.metadata.value()
115-
)) as CheckpointMetadata,
239+
metadata: doc.metadata as CheckpointMetadata,
116240
parentConfig:
117241
doc.parent_checkpoint_id != null
118242
? {
@@ -135,6 +259,8 @@ export class MongoDBSaver extends BaseCheckpointSaver {
135259
config: RunnableConfig,
136260
options?: CheckpointListOptions
137261
): AsyncGenerator<CheckpointTuple> {
262+
await this.setup();
263+
138264
const { limit, before, filter } = options ?? {};
139265
const query: Record<string, unknown> = {};
140266

@@ -150,9 +276,16 @@ export class MongoDBSaver extends BaseCheckpointSaver {
150276
}
151277

152278
if (filter) {
153-
Object.entries(filter).forEach(([key, value]) => {
154-
query[`metadata.${key}`] = value;
155-
});
279+
Object.entries(filter)
280+
.filter(
281+
([key, value]) =>
282+
validCheckpointMetadataKeys.includes(
283+
key as keyof CheckpointMetadata
284+
) && value !== undefined
285+
)
286+
.forEach(([key, value]) => {
287+
query[`metadata.${key}`] = value;
288+
});
156289
}
157290

158291
if (before) {
@@ -173,10 +306,7 @@ export class MongoDBSaver extends BaseCheckpointSaver {
173306
doc.type,
174307
doc.checkpoint.value()
175308
)) as Checkpoint;
176-
const metadata = (await this.serde.loadsTyped(
177-
doc.type,
178-
doc.metadata.value()
179-
)) as CheckpointMetadata;
309+
const metadata = doc.metadata as CheckpointMetadata;
180310

181311
yield {
182312
config: {
@@ -210,6 +340,8 @@ export class MongoDBSaver extends BaseCheckpointSaver {
210340
checkpoint: Checkpoint,
211341
metadata: CheckpointMetadata
212342
): Promise<RunnableConfig> {
343+
await this.setup();
344+
213345
const thread_id = config.configurable?.thread_id;
214346
const checkpoint_ns = config.configurable?.checkpoint_ns ?? "";
215347
const checkpoint_id = checkpoint.id;
@@ -220,15 +352,11 @@ export class MongoDBSaver extends BaseCheckpointSaver {
220352
}
221353
const [checkpointType, serializedCheckpoint] =
222354
this.serde.dumpsTyped(checkpoint);
223-
const [metadataType, serializedMetadata] = this.serde.dumpsTyped(metadata);
224-
if (checkpointType !== metadataType) {
225-
throw new Error("Mismatched checkpoint and metadata types.");
226-
}
227355
const doc = {
228356
parent_checkpoint_id: config.configurable?.checkpoint_id,
229357
type: checkpointType,
230358
checkpoint: serializedCheckpoint,
231-
metadata: serializedMetadata,
359+
metadata,
232360
};
233361
const upsertQuery = {
234362
thread_id,
@@ -259,6 +387,8 @@ export class MongoDBSaver extends BaseCheckpointSaver {
259387
writes: PendingWrite[],
260388
taskId: string
261389
): Promise<void> {
390+
await this.setup();
391+
262392
const thread_id = config.configurable?.thread_id;
263393
const checkpoint_ns = config.configurable?.checkpoint_ns;
264394
const checkpoint_id = config.configurable?.checkpoint_id;
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import { Binary, ObjectId, Collection, Document, WithId } from "mongodb";
2+
import { CheckpointMetadata } from "@langchain/langgraph-checkpoint";
3+
import { Migration, MigrationParams } from "./base.js";
4+
5+
const BULK_WRITE_SIZE = 100;
6+
7+
interface OldCheckpointDocument {
8+
parent_checkpoint_id: string | undefined;
9+
type: string;
10+
checkpoint: Binary;
11+
metadata: Binary;
12+
thread_id: string;
13+
checkpoint_ns: string | undefined;
14+
checkpoint_id: string;
15+
}
16+
17+
interface NewCheckpointDocument {
18+
parent_checkpoint_id: string | undefined;
19+
type: string;
20+
checkpoint: Binary;
21+
metadata: CheckpointMetadata;
22+
thread_id: string;
23+
checkpoint_ns: string | undefined;
24+
checkpoint_id: string;
25+
}
26+
27+
export class Migration1ObjectMetadata extends Migration {
28+
version = 1;
29+
30+
constructor(params: MigrationParams) {
31+
super(params);
32+
}
33+
34+
override async apply() {
35+
const db = this.client.db(this.dbName);
36+
const checkpointCollection = db.collection(this.checkpointCollectionName);
37+
const schemaVersionCollection = db.collection(
38+
this.schemaVersionCollectionName
39+
);
40+
41+
// Fetch all documents from the checkpoints collection
42+
const cursor = checkpointCollection.find({});
43+
44+
let updateBatch: {
45+
id: string;
46+
newDoc: NewCheckpointDocument;
47+
}[] = [];
48+
49+
for await (const doc of cursor) {
50+
// already migrated
51+
if (!(doc.metadata._bsontype && doc.metadata._bsontype === "Binary")) {
52+
continue;
53+
}
54+
55+
const oldDoc = doc as WithId<OldCheckpointDocument>;
56+
57+
const metadata: CheckpointMetadata = await this.serializer.loadsTyped(
58+
oldDoc.type,
59+
oldDoc.metadata.value()
60+
);
61+
62+
const newDoc: NewCheckpointDocument = {
63+
...oldDoc,
64+
metadata,
65+
};
66+
67+
updateBatch.push({
68+
id: doc._id.toString(),
69+
newDoc,
70+
});
71+
72+
if (updateBatch.length >= BULK_WRITE_SIZE) {
73+
await this.flushBatch(updateBatch, checkpointCollection);
74+
updateBatch = [];
75+
}
76+
}
77+
78+
if (updateBatch.length > 0) {
79+
await this.flushBatch(updateBatch, checkpointCollection);
80+
}
81+
82+
// Update schema version to 1
83+
await schemaVersionCollection.updateOne(
84+
{},
85+
{ $set: { version: 1 } },
86+
{ upsert: true }
87+
);
88+
}
89+
90+
private async flushBatch(
91+
updateBatch: {
92+
id: string;
93+
newDoc: NewCheckpointDocument;
94+
}[],
95+
checkpointCollection: Collection<Document>
96+
) {
97+
if (updateBatch.length === 0) {
98+
throw new Error("No updates to apply");
99+
}
100+
101+
const bulkOps = updateBatch.map(({ id, newDoc: newCheckpoint }) => ({
102+
updateOne: {
103+
filter: { _id: new ObjectId(id) },
104+
update: { $set: newCheckpoint },
105+
},
106+
}));
107+
108+
await checkpointCollection.bulkWrite(bulkOps);
109+
}
110+
}

0 commit comments

Comments
 (0)