Skip to content

Commit 5009942

Browse files
authored
genai[patch]: support image parts in tool responses (#921)
1 parent 0d40a4f commit 5009942

File tree

3 files changed

+50
-10
lines changed

3 files changed

+50
-10
lines changed

libs/genai/langchain_google_genai/chat_models.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,23 @@ def _is_lc_content_block(part: dict) -> bool:
247247
return "type" in part
248248

249249

250+
def _is_openai_image_block(block: dict) -> bool:
251+
"""Check if the block contains image data in OpenAI Chat Completions format."""
252+
if block.get("type") == "image_url":
253+
if (
254+
(set(block.keys()) <= {"type", "image_url", "detail"})
255+
and (image_url := block.get("image_url"))
256+
and isinstance(image_url, dict)
257+
):
258+
url = image_url.get("url")
259+
if isinstance(url, str):
260+
return True
261+
else:
262+
return False
263+
264+
return False
265+
266+
250267
def _convert_to_parts(
251268
raw_content: Union[str, Sequence[Union[str, dict]]],
252269
) -> List[Part]:
@@ -334,14 +351,28 @@ def _convert_to_parts(
334351
return parts
335352

336353

337-
def _convert_tool_message_to_part(
354+
def _convert_tool_message_to_parts(
338355
message: ToolMessage | FunctionMessage, name: Optional[str] = None
339-
) -> Part:
356+
) -> list[Part]:
340357
"""Converts a tool or function message to a google part."""
341358
# Legacy agent stores tool name in message.additional_kwargs instead of message.name
342359
name = message.name or name or message.additional_kwargs.get("name")
343360
response: Any
344-
if not isinstance(message.content, str):
361+
parts: list[Part] = []
362+
if isinstance(message.content, list):
363+
media_blocks = []
364+
other_blocks = []
365+
for block in message.content:
366+
if isinstance(block, dict) and (
367+
is_data_content_block(block) or _is_openai_image_block(block)
368+
):
369+
media_blocks.append(block)
370+
else:
371+
other_blocks.append(block)
372+
parts.extend(_convert_to_parts(media_blocks))
373+
response = other_blocks
374+
375+
elif not isinstance(message.content, str):
345376
response = message.content
346377
else:
347378
try:
@@ -356,7 +387,8 @@ def _convert_tool_message_to_part(
356387
),
357388
)
358389
)
359-
return part
390+
parts.append(part)
391+
return parts
360392

361393

362394
def _get_ai_message_tool_messages_parts(
@@ -374,8 +406,10 @@ def _get_ai_message_tool_messages_parts(
374406
break
375407
if message.tool_call_id in tool_calls_ids:
376408
tool_call = tool_calls_ids[message.tool_call_id]
377-
part = _convert_tool_message_to_part(message, name=tool_call.get("name"))
378-
parts.append(part)
409+
message_parts = _convert_tool_message_to_parts(
410+
message, name=tool_call.get("name")
411+
)
412+
parts.extend(message_parts)
379413
# remove the id from the dict, so that we do not iterate over it again
380414
tool_calls_ids.pop(message.tool_call_id)
381415
return parts
@@ -442,7 +476,7 @@ def _parse_chat_history(
442476
system_instruction = None
443477
elif isinstance(message, FunctionMessage):
444478
role = "user"
445-
parts = [_convert_tool_message_to_part(message)]
479+
parts = _convert_tool_message_to_parts(message)
446480
else:
447481
raise ValueError(
448482
f"Unexpected message with type {type(message)} at the position {i}."

libs/genai/tests/integration_tests/test_standard.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ def supports_image_inputs(self) -> bool:
3333
def supports_image_urls(self) -> bool:
3434
return True
3535

36+
@property
37+
def supports_image_tool_message(self) -> bool:
38+
return True
39+
3640
@property
3741
def supports_pdf_inputs(self) -> bool:
3842
return True

libs/genai/tests/unit_tests/test_chat_models.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
from langchain_google_genai.chat_models import (
3232
ChatGoogleGenerativeAI,
33-
_convert_tool_message_to_part,
33+
_convert_tool_message_to_parts,
3434
_parse_chat_history,
3535
_parse_response_candidate,
3636
)
@@ -710,10 +710,12 @@ def test_serialize() -> None:
710710
),
711711
],
712712
)
713-
def test__convert_tool_message_to_part__sets_tool_name(
713+
def test__convert_tool_message_to_parts__sets_tool_name(
714714
tool_message: ToolMessage,
715715
) -> None:
716-
part = _convert_tool_message_to_part(tool_message)
716+
parts = _convert_tool_message_to_parts(tool_message)
717+
assert len(parts) == 1
718+
part = parts[0]
717719
assert part.function_response.name == "tool_name"
718720
assert part.function_response.response == {"output": "test_content"}
719721

0 commit comments

Comments
 (0)