Skip to content

Commit a16b79a

Browse files
committed
Adds test of invoke and interrupts via graph interface. Checks issues/128 is not an issue.
1 parent 59a08e8 commit a16b79a

File tree

3 files changed

+256
-0
lines changed

3 files changed

+256
-0
lines changed

libs/langgraph-checkpoint-mongodb/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dev = [
2121
"langchain-core>=0.3.55",
2222
"langchain-ollama>=0.2.2",
2323
"langchain-openai>=0.2.14",
24+
"langgraph>=0.3.23",
2425
"langgraph-checkpoint>=2.0.9",
2526
"pytest-asyncio>=0.21.1",
2627
"pytest>=7.2.1",
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""Follows https://langchain-ai.github.io/langgraph/how-tos/human_in_the_loop/time-travel"""
2+
3+
import os
4+
from typing import TypedDict
5+
6+
import pytest
7+
8+
from langgraph.checkpoint.base import BaseCheckpointSaver
9+
from langgraph.checkpoint.memory import InMemorySaver
10+
from langgraph.checkpoint.mongodb import MongoDBSaver
11+
from langgraph.graph import END, StateGraph
12+
from langgraph.graph.graph import CompiledGraph
13+
14+
# --- Configuration ---
15+
MONGODB_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017")
16+
DB_NAME = os.environ.get("DB_NAME", "langgraph-test")
17+
CHECKPOINT_CLXN_NAME = "interrupts_checkpoints"
18+
WRITES_CLXN_NAME = "interrupts_writes"
19+
20+
21+
@pytest.fixture(scope="function")
22+
def checkpointer_memory():
23+
yield InMemorySaver()
24+
25+
26+
@pytest.fixture(scope="function")
27+
def checkpointer_mongodb():
28+
with MongoDBSaver.from_conn_string(
29+
MONGODB_URI,
30+
db_name=DB_NAME,
31+
checkpoint_collection_name=CHECKPOINT_CLXN_NAME,
32+
writes_collection_name=WRITES_CLXN_NAME,
33+
) as checkpointer:
34+
checkpointer.checkpoint_collection.delete_many({})
35+
checkpointer.writes_collection.delete_many({})
36+
yield checkpointer
37+
checkpointer.checkpoint_collection.drop()
38+
checkpointer.writes_collection.drop()
39+
40+
41+
ALL_CHECKPOINTERS_SYNC = [
42+
"checkpointer_memory",
43+
"checkpointer_mongodb",
44+
]
45+
46+
47+
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
48+
def test(request: pytest.FixtureRequest, checkpointer_name: str):
49+
checkpointer: BaseCheckpointSaver = request.getfixturevalue(checkpointer_name)
50+
assert isinstance(checkpointer, BaseCheckpointSaver)
51+
52+
# --- State Definition ---
53+
class State(TypedDict):
54+
value: int
55+
step: int
56+
57+
# --- Node Definitions ---
58+
def node_inc(state: State) -> State:
59+
"""Increments value and step by 1"""
60+
current_step = state.get("step", 0)
61+
return {"value": state["value"] + 1, "step": current_step + 1}
62+
63+
def node_double(state: State) -> State:
64+
"""Doubles value and increments step by 1"""
65+
current_step = state.get("step", 0)
66+
return {"value": state["value"] * 2, "step": current_step + 1}
67+
68+
# --- Graph Construction ---
69+
builder = StateGraph(State)
70+
builder.add_node("increment", node_inc)
71+
builder.add_node("double", node_double)
72+
builder.set_entry_point("increment")
73+
builder.add_edge("increment", "double")
74+
builder.add_edge("double", END)
75+
76+
# --- Compile Graph (with Interruption) ---
77+
# Using sync for simplicity in this demo
78+
graph: CompiledGraph = builder.compile(
79+
checkpointer=checkpointer, interrupt_after=["increment"]
80+
)
81+
82+
# --- Configure ---
83+
config = {"configurable": {"thread_id": "thread_#1"}}
84+
initial_input = {"value": 10, "step": 0}
85+
86+
# --- 1st invoke, with Interruption
87+
interrupted_state = graph.invoke(initial_input, config=config)
88+
assert interrupted_state == {"value": 10 + 1, "step": 1}
89+
state_history = list(graph.get_state_history(config))
90+
assert len(state_history) == 3
91+
# The states are returned in reverse chronological order.
92+
assert state_history[0].next == ("double",)
93+
94+
# --- 2nd invoke, with input=None, and original config ==> continues from point of interruption
95+
final_state = graph.invoke(None, config=config)
96+
assert final_state == {"value": (10 + 1) * 2, "step": 2}
97+
state_history = list(graph.get_state_history(config))
98+
assert len(state_history) == 4
99+
assert state_history[0].next == ()
100+
assert state_history[-1].next == ("__start__",)
101+
102+
# --- 3rd invoke, but with an input ===> the CompiledGraph is restarted.
103+
new_input = {"value": 100, "step": -100}
104+
third_state = graph.invoke(new_input, config=config)
105+
assert third_state == {"value": 101, "step": -99}
106+
107+
# The entire state history is preserved however
108+
state_history = list(graph.get_state_history(config))
109+
assert len(state_history) == 7
110+
assert state_history[0].next == ("double",)
111+
assert state_history[2].next == ("__start__",)
112+
113+
# --- Upstate state and continue from interrupt
114+
updated_state = {"value": 1000, "step": 1000}
115+
updated_config = graph.update_state(config, updated_state)
116+
final_state = graph.invoke(input=None, config=updated_config)
117+
assert final_state == {"value": 2000, "step": 1001}

0 commit comments

Comments
 (0)