Skip to content

Commit 78dc2ac

Browse files
author
Vadym Barda
authored
add support for parallel handoffs (#85)
* filter out tool calls not meant for individual agents * update `langgraph-prebuilt` to handle combining multiple `Send` + Command.PARENT commands
1 parent f96e88e commit 78dc2ac

File tree

4 files changed

+56
-35
lines changed

4 files changed

+56
-35
lines changed

langgraph_supervisor/handoff.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import re
22
import uuid
3+
from typing import cast
34

4-
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
5+
from langchain_core.messages import AIMessage, BaseMessage, ToolCall, ToolMessage
56
from langchain_core.tools import BaseTool, InjectedToolCallId, tool
67
from langgraph.prebuilt import InjectedState
7-
from langgraph.types import Command
8+
from langgraph.types import Command, Send
89
from typing_extensions import Annotated
910

1011
WHITESPACE_RE = re.compile(r"\s+")
@@ -15,6 +16,43 @@ def _normalize_agent_name(agent_name: str) -> str:
1516
return WHITESPACE_RE.sub("_", agent_name.strip()).lower()
1617

1718

19+
def _remove_non_handoff_tool_calls(
20+
messages: list[BaseMessage], handoff_tool_name: str
21+
) -> list[BaseMessage]:
22+
"""Remove tool calls that are not meant for the agent."""
23+
last_ai_message = cast(AIMessage, messages[-1])
24+
# if the supervisor is calling multiple agents/tools in parallel,
25+
# we need to remove tool calls that are not meant for this agent
26+
# to ensure that the resulting message history is valid
27+
if len(last_ai_message.tool_calls) > 1 and any(
28+
tool_call["name"] == handoff_tool_name for tool_call in last_ai_message.tool_calls
29+
):
30+
content = last_ai_message.content
31+
if isinstance(content, list) and len(content) > 1 and isinstance(content[0], dict):
32+
content = [
33+
content_block
34+
for content_block in content
35+
if (
36+
content_block["type"] == "tool_use"
37+
and content_block["name"] == handoff_tool_name
38+
)
39+
or content_block["type"] != "tool_use"
40+
]
41+
42+
last_ai_message = AIMessage(
43+
content=content,
44+
tool_calls=[
45+
tool_call
46+
for tool_call in last_ai_message.tool_calls
47+
if tool_call["name"] == handoff_tool_name
48+
],
49+
name=last_ai_message.name,
50+
id=str(uuid.uuid4()),
51+
)
52+
53+
return messages[:-1] + [last_ai_message]
54+
55+
1856
def create_handoff_tool(*, agent_name: str) -> BaseTool:
1957
"""Create a tool that can handoff control to the requested agent.
2058
@@ -39,10 +77,17 @@ def handoff_to_agent(
3977
name=tool_name,
4078
tool_call_id=tool_call_id,
4179
)
80+
handoff_messages = _remove_non_handoff_tool_calls(state["messages"], tool_name) + [
81+
tool_message
82+
]
4283
return Command(
43-
goto=agent_name,
4484
graph=Command.PARENT,
45-
update={"messages": state["messages"] + [tool_message]},
85+
# NOTE: we are using Send here to allow the ToolNode in langgraph.prebuilt
86+
# to handle parallel handoffs by combining all Send commands into a single command
87+
goto=[Send(agent_name, {"messages": handoff_messages})],
88+
# we also propagate the update to make sure the handoff messages are applied
89+
# to the parent graph's state
90+
update={"messages": handoff_messages},
4691
)
4792

4893
return handoff_to_agent

langgraph_supervisor/supervisor.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
import inspect
21
from typing import Any, Callable, Literal, Optional, Type, Union
32

4-
from langchain_core.language_models import BaseChatModel, LanguageModelLike
3+
from langchain_core.language_models import LanguageModelLike
54
from langchain_core.tools import BaseTool
65
from langgraph.graph import END, START, StateGraph
76
from langgraph.prebuilt.chat_agent_executor import (
@@ -28,25 +27,6 @@
2827
"""
2928

3029

31-
MODELS_NO_PARALLEL_TOOL_CALLS = {"o3-mini"}
32-
33-
34-
def _supports_disable_parallel_tool_calls(model: LanguageModelLike) -> bool:
35-
if not isinstance(model, BaseChatModel):
36-
return False
37-
38-
if hasattr(model, "model_name") and model.model_name in MODELS_NO_PARALLEL_TOOL_CALLS:
39-
return False
40-
41-
if not hasattr(model, "bind_tools"):
42-
return False
43-
44-
if "parallel_tool_calls" not in inspect.signature(model.bind_tools).parameters:
45-
return False
46-
47-
return True
48-
49-
5030
def _make_call_agent(
5131
agent: Pregel,
5232
output_mode: OutputMode,
@@ -168,11 +148,7 @@ def create_supervisor(
168148

169149
handoff_tools = [create_handoff_tool(agent_name=agent.name) for agent in agents]
170150
all_tools = (tools or []) + handoff_tools
171-
172-
if _supports_disable_parallel_tool_calls(model):
173-
model = model.bind_tools(all_tools, parallel_tool_calls=False)
174-
else:
175-
model = model.bind_tools(all_tools)
151+
model = model.bind_tools(all_tools)
176152

177153
if include_agent_name:
178154
model = with_agent_name(model, include_agent_name)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ readme = "README.md"
1313
requires-python = ">=3.10"
1414
dependencies = [
1515
"langgraph>=0.3.5,<0.4.0",
16-
"langgraph-prebuilt>=0.1.2,<0.2.0"
16+
"langgraph-prebuilt>=0.1.6,<0.2.0"
1717
]
1818

1919
[dependency-groups]

uv.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)