Skip to content

Commit 32f68b8

Browse files
INTPYTHON-619 Confirm invoke behaves as expected when invoked after an interrupt. (#141)
1 parent 170264e commit 32f68b8

File tree

4 files changed

+1717
-1458
lines changed

4 files changed

+1717
-1458
lines changed

libs/langgraph-checkpoint-mongodb/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dev = [
2121
"langchain-core>=0.3.55",
2222
"langchain-ollama>=0.2.2",
2323
"langchain-openai>=0.2.14",
24+
"langgraph>=0.3.23",
2425
"langgraph-checkpoint>=2.0.9",
2526
"pytest-asyncio>=0.21.1",
2627
"pytest>=7.2.1",
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""Follows https://langchain-ai.github.io/langgraph/how-tos/human_in_the_loop/time-travel"""
2+
3+
import os
4+
from collections.abc import Generator
5+
from typing import TypedDict
6+
7+
import pytest
8+
from langchain_core.runnables import RunnableConfig
9+
10+
from langgraph.checkpoint.base import BaseCheckpointSaver
11+
from langgraph.checkpoint.memory import InMemorySaver
12+
from langgraph.checkpoint.mongodb import MongoDBSaver
13+
from langgraph.graph import END, StateGraph
14+
from langgraph.graph.graph import CompiledGraph
15+
16+
# --- Configuration ---
17+
MONGODB_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017")
18+
DB_NAME = os.environ.get("DB_NAME", "langgraph-test")
19+
CHECKPOINT_CLXN_NAME = "interrupts_checkpoints"
20+
WRITES_CLXN_NAME = "interrupts_writes"
21+
22+
23+
@pytest.fixture(scope="function")
24+
def checkpointer_memory() -> Generator[InMemorySaver, None, None]:
25+
yield InMemorySaver()
26+
27+
28+
@pytest.fixture(scope="function")
29+
def checkpointer_mongodb() -> Generator[MongoDBSaver, None, None]:
30+
with MongoDBSaver.from_conn_string(
31+
MONGODB_URI,
32+
db_name=DB_NAME,
33+
checkpoint_collection_name=CHECKPOINT_CLXN_NAME,
34+
writes_collection_name=WRITES_CLXN_NAME,
35+
) as checkpointer:
36+
checkpointer.checkpoint_collection.delete_many({})
37+
checkpointer.writes_collection.delete_many({})
38+
yield checkpointer
39+
checkpointer.checkpoint_collection.drop()
40+
checkpointer.writes_collection.drop()
41+
42+
43+
ALL_CHECKPOINTERS_SYNC = [
44+
"checkpointer_memory",
45+
"checkpointer_mongodb",
46+
]
47+
48+
49+
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
50+
def test(request: pytest.FixtureRequest, checkpointer_name: str) -> None:
51+
checkpointer: BaseCheckpointSaver = request.getfixturevalue(checkpointer_name)
52+
assert isinstance(checkpointer, BaseCheckpointSaver)
53+
54+
# --- State Definition ---
55+
class State(TypedDict):
56+
value: int
57+
step: int
58+
59+
# --- Node Definitions ---
60+
def node_inc(state: State) -> State:
61+
"""Increments value and step by 1"""
62+
current_step = state.get("step", 0)
63+
return {"value": state["value"] + 1, "step": current_step + 1}
64+
65+
def node_double(state: State) -> State:
66+
"""Doubles value and increments step by 1"""
67+
current_step = state.get("step", 0)
68+
return {"value": state["value"] * 2, "step": current_step + 1}
69+
70+
# --- Graph Construction ---
71+
builder = StateGraph(State)
72+
builder.add_node("increment", node_inc)
73+
builder.add_node("double", node_double)
74+
builder.set_entry_point("increment")
75+
builder.add_edge("increment", "double")
76+
builder.add_edge("double", END)
77+
78+
# --- Compile Graph (with Interruption) ---
79+
# Using sync for simplicity in this demo
80+
graph: CompiledGraph = builder.compile(
81+
checkpointer=checkpointer, interrupt_after=["increment"]
82+
)
83+
84+
# --- Configure ---
85+
config: RunnableConfig = {"configurable": {"thread_id": "thread_#1"}}
86+
initial_input = {"value": 10, "step": 0}
87+
88+
# --- 1st invoke, with Interruption
89+
interrupted_state = graph.invoke(initial_input, config=config)
90+
assert interrupted_state == {"value": 10 + 1, "step": 1}
91+
state_history = list(graph.get_state_history(config))
92+
assert len(state_history) == 3
93+
# The states are returned in reverse chronological order.
94+
assert state_history[0].next == ("double",)
95+
96+
# --- 2nd invoke, with input=None, and original config ==> continues from point of interruption
97+
final_state = graph.invoke(None, config=config)
98+
assert final_state == {"value": (10 + 1) * 2, "step": 2}
99+
state_history = list(graph.get_state_history(config))
100+
assert len(state_history) == 4
101+
assert state_history[0].next == ()
102+
assert state_history[-1].next == ("__start__",)
103+
104+
# --- 3rd invoke, but with an input ===> the CompiledGraph is restarted.
105+
new_input = {"value": 100, "step": -100}
106+
third_state = graph.invoke(new_input, config=config)
107+
assert third_state == {"value": 101, "step": -99}
108+
109+
# The entire state history is preserved however
110+
state_history = list(graph.get_state_history(config))
111+
assert len(state_history) == 7
112+
assert state_history[0].next == ("double",)
113+
assert state_history[2].next == ("__start__",)
114+
115+
# --- Upstate state and continue from interrupt
116+
updated_state = {"value": 1000, "step": 1000}
117+
updated_config = graph.update_state(config, updated_state)
118+
final_state = graph.invoke(input=None, config=updated_config)
119+
assert final_state == {"value": 2000, "step": 1001}

0 commit comments

Comments
 (0)