1
1
import re
2
2
import uuid
3
+ from typing import cast
3
4
4
- from langchain_core .messages import AIMessage , ToolCall , ToolMessage
5
+ from langchain_core .messages import AIMessage , BaseMessage , ToolCall , ToolMessage
5
6
from langchain_core .tools import BaseTool , InjectedToolCallId , tool
6
7
from langgraph .prebuilt import InjectedState
7
- from langgraph .types import Command
8
+ from langgraph .types import Command , Send
8
9
from typing_extensions import Annotated
9
10
10
11
WHITESPACE_RE = re .compile (r"\s+" )
@@ -15,6 +16,43 @@ def _normalize_agent_name(agent_name: str) -> str:
15
16
return WHITESPACE_RE .sub ("_" , agent_name .strip ()).lower ()
16
17
17
18
19
+ def _remove_non_handoff_tool_calls (
20
+ messages : list [BaseMessage ], handoff_tool_name : str
21
+ ) -> list [BaseMessage ]:
22
+ """Remove tool calls that are not meant for the agent."""
23
+ last_ai_message = cast (AIMessage , messages [- 1 ])
24
+ # if the supervisor is calling multiple agents/tools in parallel,
25
+ # we need to remove tool calls that are not meant for this agent
26
+ # to ensure that the resulting message history is valid
27
+ if len (last_ai_message .tool_calls ) > 1 and any (
28
+ tool_call ["name" ] == handoff_tool_name for tool_call in last_ai_message .tool_calls
29
+ ):
30
+ content = last_ai_message .content
31
+ if isinstance (content , list ) and len (content ) > 1 and isinstance (content [0 ], dict ):
32
+ content = [
33
+ content_block
34
+ for content_block in content
35
+ if (
36
+ content_block ["type" ] == "tool_use"
37
+ and content_block ["name" ] == handoff_tool_name
38
+ )
39
+ or content_block ["type" ] != "tool_use"
40
+ ]
41
+
42
+ last_ai_message = AIMessage (
43
+ content = content ,
44
+ tool_calls = [
45
+ tool_call
46
+ for tool_call in last_ai_message .tool_calls
47
+ if tool_call ["name" ] == handoff_tool_name
48
+ ],
49
+ name = last_ai_message .name ,
50
+ id = str (uuid .uuid4 ()),
51
+ )
52
+
53
+ return messages [:- 1 ] + [last_ai_message ]
54
+
55
+
18
56
def create_handoff_tool (* , agent_name : str ) -> BaseTool :
19
57
"""Create a tool that can handoff control to the requested agent.
20
58
@@ -39,10 +77,17 @@ def handoff_to_agent(
39
77
name = tool_name ,
40
78
tool_call_id = tool_call_id ,
41
79
)
80
+ handoff_messages = _remove_non_handoff_tool_calls (state ["messages" ], tool_name ) + [
81
+ tool_message
82
+ ]
42
83
return Command (
43
- goto = agent_name ,
44
84
graph = Command .PARENT ,
45
- update = {"messages" : state ["messages" ] + [tool_message ]},
85
+ # NOTE: we are using Send here to allow the ToolNode in langgraph.prebuilt
86
+ # to handle parallel handoffs by combining all Send commands into a single command
87
+ goto = [Send (agent_name , {"messages" : handoff_messages })],
88
+ # we also propagate the update to make sure the handoff messages are applied
89
+ # to the parent graph's state
90
+ update = {"messages" : handoff_messages },
46
91
)
47
92
48
93
return handoff_to_agent
0 commit comments