Skip to content

Commit 8db2816

Browse files
committed
feat: separate parent classes for topic, service and actions tasks
1 parent 19f918f commit 8db2816

File tree

3 files changed

+811
-422
lines changed

3 files changed

+811
-422
lines changed

src/rai_bench/rai_bench/tool_calling_agent_bench/agent_tasks_interfaces.py

+317-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import logging
1616
from abc import ABC, abstractmethod
17-
from typing import Any, List, Literal
17+
from typing import Any, Dict, List, Literal, Sequence
1818

1919
from langchain_core.messages import AIMessage, ToolCall
2020
from langchain_core.runnables.config import DEFAULT_RECURSION_LIMIT
@@ -274,3 +274,319 @@ def _is_ai_message_requesting_get_ros2_topics_and_types(
274274
):
275275
return False
276276
return True
277+
278+
def _is_ai_message_requesting_get_ros2_services_and_types(
279+
self, ai_message: AIMessage
280+
) -> bool:
281+
"""Helper method to check if the given AIMessage is calling the exactly one tool that gets ROS2 service names and types correctly.
282+
283+
Parameters
284+
----------
285+
ai_message : AIMessage
286+
The AIMessage to check
287+
288+
Returns
289+
-------
290+
bool
291+
True if the ai_message is requesting get_ros2_service_names_and_types correctly, False otherwise
292+
"""
293+
if not self._check_tool_calls_num_in_ai_message(ai_message, expected_num=1):
294+
return False
295+
296+
tool_call: ToolCall = ai_message.tool_calls[0]
297+
if not self._check_tool_call(
298+
tool_call=tool_call,
299+
expected_name="get_ros2_services_names_and_types",
300+
expected_args={},
301+
):
302+
return False
303+
return True
304+
305+
def _is_ai_message_requesting_get_ros2_actions_and_types(
306+
self, ai_message: AIMessage
307+
) -> bool:
308+
"""Helper method to check if the given AIMessage is calling the exactly one tool that gets ROS2 actions names and types correctly.
309+
310+
Parameters
311+
----------
312+
ai_message : AIMessage
313+
The AIMessage to check
314+
315+
Returns
316+
-------
317+
bool
318+
True if the ai_message is requesting get_ros2_actions_names_and_types correctly, False otherwise
319+
"""
320+
if not self._check_tool_calls_num_in_ai_message(ai_message, expected_num=1):
321+
return False
322+
323+
tool_call: ToolCall = ai_message.tool_calls[0]
324+
if not self._check_tool_call(
325+
tool_call=tool_call,
326+
expected_name="get_ros2_actions_names_and_types",
327+
expected_args={},
328+
):
329+
return False
330+
return True
331+
332+
333+
class CustomInterfacesTopicTask(ROS2ToolCallingAgentTask, ABC):
334+
TOPICS_AND_TYPES: Dict[str, str] = {
335+
# sample topics
336+
"/attached_collision_object": "moveit_msgs/msg/AttachedCollisionObject",
337+
"/camera_image_color": "sensor_msgs/msg/Image",
338+
"/camera_image_depth": "sensor_msgs/msg/Image",
339+
"/clock": "rosgraph_msgs/msg/Clock",
340+
"/collision_object": "moveit_msgs/msg/CollisionObject",
341+
"/color_camera_info": "sensor_msgs/msg/CameraInfo",
342+
"/color_camera_info5": "sensor_msgs/msg/CameraInfo",
343+
"/depth_camera_info5": "sensor_msgs/msg/CameraInfo",
344+
"/depth_image5": "sensor_msgs/msg/Image",
345+
# custom topics
346+
"/to_human": "rai_interfaces/msg/HRIMessage",
347+
"/send_audio": "rai_interfaces/msg/AudioMessage",
348+
"/send_detections": "rai_interfaces/msg/RAIDetectionArray",
349+
}
350+
topic_strings = [
351+
f"topic: {topic}\ntype: {msg_type}\n"
352+
for topic, msg_type in TOPICS_AND_TYPES.items()
353+
]
354+
355+
def __init__(self, logger: loggers_type | None = None) -> None:
356+
super().__init__(logger=logger)
357+
358+
# self.expected_message_type = TOPICS_AND_TYPES[self.expected_topic]
359+
360+
# def get_system_prompt(self) -> str:
361+
# return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT
362+
363+
@property
364+
@abstractmethod
365+
def expected_topic(self) -> str:
366+
pass
367+
368+
@property
369+
@abstractmethod
370+
def expected_message(self) -> Dict[str, Any]:
371+
pass
372+
373+
@property
374+
def expected_message_type(self) -> str:
375+
return self.TOPICS_AND_TYPES[self.expected_topic]
376+
377+
def verify_tool_calls(self, response: dict[str, Any]):
378+
"""It is expected that the agent will request:
379+
1. The tool that retrieves the topics names and types to recognize what type of message to_human topic has
380+
2. The tool that retrieves interfaces to check HRIMessage type
381+
3. The tool to publish message with proper topic, message type and content
382+
383+
Parameters
384+
----------
385+
response : dict[str, Any]
386+
The response from the agent
387+
"""
388+
messages = response["messages"]
389+
ai_messages: Sequence[AIMessage] = [
390+
message for message in messages if isinstance(message, AIMessage)
391+
]
392+
self.logger.debug(ai_messages)
393+
if len(ai_messages) != 4:
394+
self.log_error(
395+
msg=f"Expected exactly 4 AI messages, but got {len(ai_messages)}."
396+
)
397+
if ai_messages:
398+
if not self._is_ai_message_requesting_get_ros2_topics_and_types(
399+
ai_messages[0]
400+
):
401+
self.log_error(
402+
msg="First AI message did not request ROS2 topics and types correctly."
403+
)
404+
if len(ai_messages) > 1:
405+
if self._check_tool_calls_num_in_ai_message(ai_messages[1], expected_num=1):
406+
self._check_tool_call(
407+
tool_call=ai_messages[1].tool_calls[0],
408+
expected_name="get_ros2_message_interface",
409+
expected_args={"msg_type": self.expected_message_type},
410+
)
411+
412+
if len(ai_messages) > 2:
413+
if self._check_tool_calls_num_in_ai_message(ai_messages[2], expected_num=1):
414+
self._check_tool_call(
415+
tool_call=ai_messages[2].tool_calls[0],
416+
expected_name="publish_ros2_message",
417+
expected_args={
418+
"topic": self.expected_topic,
419+
"message": self.expected_message,
420+
"message_type": self.expected_message_type,
421+
},
422+
)
423+
if not self.result.errors:
424+
self.result.success = True
425+
426+
427+
class CustomInterfacesServiceTask(ROS2ToolCallingAgentTask):
428+
SERVICES_AND_TYPES = {
429+
# sample interfaces
430+
"/load_map": "moveit_msgs/srv/LoadMap",
431+
"/query_planner_interface": "moveit_msgs/srv/QueryPlannerInterfaces",
432+
# custom interfaces
433+
"/manipulator_move_to": "rai_interfaces/srv/ManipulatorMoveTo",
434+
"/grounded_sam_segment": "rai_interfaces/srv/RAIGroundedSam",
435+
"/grounding_dino_classify": "rai_interfaces/srv/RAIGroundingDino",
436+
"/get_log_digest": "rai_interfaces/srv/StringList",
437+
"/rai_whoami_documentation_service": "rai_interfaces/srv/VectorStoreRetrieval",
438+
"rai/whatisee/get": "rai_interfaces/srv/WhatISee",
439+
}
440+
service_strings = [
441+
f"service: {service}\ntype: {msg_type}\n"
442+
for service, msg_type in SERVICES_AND_TYPES.items()
443+
]
444+
445+
def __init__(self, logger: loggers_type | None = None) -> None:
446+
super().__init__(logger=logger)
447+
448+
@property
449+
@abstractmethod
450+
def expected_service(self) -> str:
451+
pass
452+
453+
@property
454+
@abstractmethod
455+
def expected_message(self) -> Dict[str, Any]:
456+
pass
457+
458+
@property
459+
def expected_service_type(self) -> str:
460+
return self.SERVICES_AND_TYPES[self.expected_service]
461+
462+
def verify_tool_calls(self, response: dict[str, Any]):
463+
"""It is expected that the agent will request:
464+
1. The tool that retrieves the topics names and types to recognize what type of message to_human topic has
465+
2. The tool that retrieves interfaces to check HRIMessage type
466+
3. The tool to publish message with proper topic, message type and content
467+
468+
Parameters
469+
----------
470+
response : dict[str, Any]
471+
The response from the agent
472+
"""
473+
messages = response["messages"]
474+
ai_messages: Sequence[AIMessage] = [
475+
message for message in messages if isinstance(message, AIMessage)
476+
]
477+
self.logger.debug(ai_messages)
478+
if len(ai_messages) != 4:
479+
self.log_error(
480+
msg=f"Expected exactly 4 AI messages, but got {len(ai_messages)}."
481+
)
482+
if ai_messages:
483+
if not self._is_ai_message_requesting_get_ros2_services_and_types(
484+
ai_messages[0]
485+
):
486+
self.log_error(
487+
msg="First AI message did not request ROS2 topics and types correctly."
488+
)
489+
if len(ai_messages) > 1:
490+
if self._check_tool_calls_num_in_ai_message(ai_messages[1], expected_num=1):
491+
self._check_tool_call(
492+
tool_call=ai_messages[1].tool_calls[0],
493+
expected_name="get_ros2_message_interface",
494+
expected_args={"msg_type": self.expected_service_type},
495+
)
496+
497+
if len(ai_messages) > 2:
498+
if self._check_tool_calls_num_in_ai_message(ai_messages[2], expected_num=1):
499+
self._check_tool_call(
500+
tool_call=ai_messages[2].tool_calls[0],
501+
expected_name="call_ros2_service",
502+
expected_args={
503+
"topic": self.expected_service,
504+
"message": self.expected_message,
505+
"message_type": self.expected_service_type,
506+
},
507+
)
508+
if not self.result.errors:
509+
self.result.success = True
510+
511+
512+
class CustomInterfacesActionTask(ROS2ToolCallingAgentTask):
513+
ACTIONS_AND_TYPES = {
514+
# custom actions
515+
"/perform_task": "rai_interfaces/action/Task",
516+
# some sample actions
517+
# "/execute_trajectory": "moveit_msgs/action/ExecuteTrajectory",
518+
# "/move_action": "moveit_msgs/action/MoveGroup",
519+
# "/follow_joint_trajectory": "control_msgs/action/FollowJointTrajectory",
520+
# "/gripper_cmd": "control_msgs/action/GripperCommand",
521+
}
522+
523+
action_strings = [
524+
f"action: {action}\ntype: {msg_type}\n"
525+
for action, msg_type in ACTIONS_AND_TYPES.items()
526+
]
527+
528+
def __init__(self, logger: loggers_type | None = None) -> None:
529+
super().__init__(logger=logger)
530+
531+
@property
532+
@abstractmethod
533+
def expected_action(self) -> str:
534+
pass
535+
536+
@property
537+
@abstractmethod
538+
def expected_message(self) -> Dict[str, Any]:
539+
pass
540+
541+
@property
542+
def expected_action_type(self) -> str:
543+
return self.ACTIONS_AND_TYPES[self.expected_action]
544+
545+
def verify_tool_calls(self, response: dict[str, Any]):
546+
"""It is expected that the agent will request:
547+
1. The tool that retrieves the topics names and types to recognize what type of message to_human topic has
548+
2. The tool that retrieves interfaces to check HRIMessage type
549+
3. The tool to publish message with proper topic, message type and content
550+
551+
Parameters
552+
----------
553+
response : dict[str, Any]
554+
The response from the agent
555+
"""
556+
messages = response["messages"]
557+
ai_messages: Sequence[AIMessage] = [
558+
message for message in messages if isinstance(message, AIMessage)
559+
]
560+
self.logger.debug(ai_messages)
561+
if len(ai_messages) != 4:
562+
self.log_error(
563+
msg=f"Expected exactly 4 AI messages, but got {len(ai_messages)}."
564+
)
565+
if ai_messages:
566+
if not self._is_ai_message_requesting_get_ros2_actions_and_types(
567+
ai_messages[0]
568+
):
569+
self.log_error(
570+
msg="First AI message did not request ROS2 topics and types correctly."
571+
)
572+
if len(ai_messages) > 1:
573+
if self._check_tool_calls_num_in_ai_message(ai_messages[1], expected_num=1):
574+
self._check_tool_call(
575+
tool_call=ai_messages[1].tool_calls[0],
576+
expected_name="get_ros2_message_interface",
577+
expected_args={"msg_type": self.expected_action_type},
578+
)
579+
580+
if len(ai_messages) > 2:
581+
if self._check_tool_calls_num_in_ai_message(ai_messages[2], expected_num=1):
582+
self._check_tool_call(
583+
tool_call=ai_messages[2].tool_calls[0],
584+
expected_name="start_ros2_action",
585+
expected_args={
586+
"topic": self.expected_action,
587+
"message": self.expected_message,
588+
"message_type": self.expected_action_type,
589+
},
590+
)
591+
if not self.result.errors:
592+
self.result.success = True

0 commit comments

Comments
 (0)