|
29 | 29 | from rai_bench.tool_calling_agent_bench.mocked_tools import (
|
30 | 30 | MockGetObjectPositionsTool,
|
31 | 31 | MockGetROS2ImageTool,
|
| 32 | + MockGetROS2MessageInterfaceTool, |
32 | 33 | MockGetROS2TopicsNamesAndTypesTool,
|
33 | 34 | MockMoveToPointTool,
|
| 35 | + MockPublishROS2MessageTool, |
34 | 36 | MockReceiveROS2MessageTool,
|
35 | 37 | )
|
36 | 38 |
|
@@ -1512,3 +1514,157 @@ def _matches_sequence(
|
1512 | 1514 | any(call["name"] == e["name"] and call["args"] == e["args"] for call in it)
|
1513 | 1515 | for e in expected_tool_calls_seq
|
1514 | 1516 | )
|
| 1517 | + |
| 1518 | + |
| 1519 | +class PublishROS2CustomMessageTask(ROS2ToolCallingAgentTask): |
| 1520 | + complexity = "easy" |
| 1521 | + |
| 1522 | + def __init__(self, logger: loggers_type | None = None) -> None: |
| 1523 | + super().__init__(logger=logger) |
| 1524 | + self.expected_tools: List[BaseTool] = [ |
| 1525 | + MockGetROS2TopicsNamesAndTypesTool( |
| 1526 | + mock_topics_names_and_types=[ |
| 1527 | + "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", |
| 1528 | + "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n", |
| 1529 | + "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n", |
| 1530 | + "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", |
| 1531 | + "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n", |
| 1532 | + "topic: /color_camera_info\ntype: sensor_msgs/msg/CameraInfo\n", |
| 1533 | + "topic: /color_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n", |
| 1534 | + "topic: /depth_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n", |
| 1535 | + "topic: /depth_image5\ntype: sensor_msgs/msg/Image\n", |
| 1536 | + "topic: /to_human\ntype: rai_interfaces/msg/HRIMessage\n", |
| 1537 | + ] |
| 1538 | + ), |
| 1539 | + MockGetROS2MessageInterfaceTool( |
| 1540 | + mock_interfaces={ |
| 1541 | + "moveit_msgs/msg/AttachedCollisionObject": ( |
| 1542 | + "ros2 interface show moveit_msgs/msg/AttachedCollisionObject:\n" |
| 1543 | + "std_msgs/Header header\n" |
| 1544 | + "string link_name\n" |
| 1545 | + "moveit_msgs/msg/CollisionObject object\n" |
| 1546 | + "string[] touch_links\n" |
| 1547 | + ), |
| 1548 | + "sensor_msgs/msg/Image": ( |
| 1549 | + "ros2 interface show sensor_msgs/msg/Image:\n" |
| 1550 | + "std_msgs/Header header\n" |
| 1551 | + "uint32 height\n" |
| 1552 | + "uint32 width\n" |
| 1553 | + "string encoding\n" |
| 1554 | + "uint8 is_bigendian\n" |
| 1555 | + "uint32 step\n" |
| 1556 | + "uint8[] data\n" |
| 1557 | + ), |
| 1558 | + "rosgraph_msgs/msg/Clock": ( |
| 1559 | + "ros2 interface show rosgraph_msgs/msg/Clock:\n" |
| 1560 | + "builtin_interfaces/Time clock\n" |
| 1561 | + ), |
| 1562 | + "moveit_msgs/msg/CollisionObject": ( |
| 1563 | + "ros2 interface show moveit_msgs/msg/CollisionObject:\n" |
| 1564 | + "std_msgs/Header header\n" |
| 1565 | + "string id\n" |
| 1566 | + "shape_msgs/SolidPrimitive[] primitives\n" |
| 1567 | + "geometry_msgs/Pose[] primitive_poses\n" |
| 1568 | + "shape_msgs/Mesh[] meshes\n" |
| 1569 | + "geometry_msgs/Pose[] mesh_poses\n" |
| 1570 | + "shape_msgs/Plane[] planes\n" |
| 1571 | + "geometry_msgs/Pose[] plane_poses\n" |
| 1572 | + "uint8 operation\n" |
| 1573 | + ), |
| 1574 | + "sensor_msgs/msg/CameraInfo": ( |
| 1575 | + "ros2 interface show sensor_msgs/msg/CameraInfo:\n" |
| 1576 | + "std_msgs/Header header\n" |
| 1577 | + "uint32 height\n" |
| 1578 | + "uint32 width\n" |
| 1579 | + "string distortion_model\n" |
| 1580 | + "float64[] D\n" |
| 1581 | + "float64[9] K\n" |
| 1582 | + "float64[9] R\n" |
| 1583 | + "float64[12] P\n" |
| 1584 | + "uint32 binning_x\n" |
| 1585 | + "uint32 binning_y\n" |
| 1586 | + "sensor_msgs/RegionOfInterest roi\n" |
| 1587 | + ), |
| 1588 | + "rai_interfaces/msg/HRIMessage": ( |
| 1589 | + "ros2 interface show rai_interfaces/msg/HRIMessage:\n" |
| 1590 | + "std_msgs/Header header\n" |
| 1591 | + "string text\n" |
| 1592 | + "sensor_msgs/Image[] images\n" |
| 1593 | + "rai_interfaces/AudioMessage[] audios\n" |
| 1594 | + ), |
| 1595 | + "rai_interfaces/msg/AudioMessage": ( |
| 1596 | + "ros2 interface show rai_interfaces/msg/AudioMessage:\n" |
| 1597 | + "int16[] audio\n" |
| 1598 | + "uint16 sample_rate\n" |
| 1599 | + "uint16 channels\n" |
| 1600 | + ), |
| 1601 | + } |
| 1602 | + ), |
| 1603 | + MockPublishROS2MessageTool( |
| 1604 | + expected_topic="/to_human", |
| 1605 | + expected_message={ |
| 1606 | + "header": {"stamp": {"sec": 0, "nanosec": 0}, "frame_id": ""}, |
| 1607 | + "text": "Hello!", |
| 1608 | + "images": [], |
| 1609 | + "audios": [], |
| 1610 | + }, |
| 1611 | + expected_message_type="rai_interfaces/msg/HRIMessage", |
| 1612 | + ), |
| 1613 | + ] |
| 1614 | + |
| 1615 | + def get_system_prompt(self) -> str: |
| 1616 | + return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT |
| 1617 | + |
| 1618 | + def get_prompt(self) -> str: |
| 1619 | + return "Publish message 'Hello!' to the topic /to_human" |
| 1620 | + |
| 1621 | + def verify_tool_calls(self, response: dict[str, Any]): |
| 1622 | + """It is expected that the agent will request: |
| 1623 | + 1. The tool that retrieves the ROS2 topics names and types to recognize what type of message to_human topic has |
| 1624 | + 2. The tool that retrieves interfaces to check HRIMessage type |
| 1625 | + 3. The tool to publish message with proper topic, message type and content |
| 1626 | +
|
| 1627 | + Parameters |
| 1628 | + ---------- |
| 1629 | + response : dict[str, Any] |
| 1630 | + The response from the agent |
| 1631 | + """ |
| 1632 | + messages = response["messages"] |
| 1633 | + ai_messages: Sequence[AIMessage] = [ |
| 1634 | + message for message in messages if isinstance(message, AIMessage) |
| 1635 | + ] |
| 1636 | + |
| 1637 | + if len(ai_messages) == 4: |
| 1638 | + self.log_error( |
| 1639 | + msg=f"Expected exactly 4 AI messages, but got {len(ai_messages)}." |
| 1640 | + ) |
| 1641 | + if ai_messages: |
| 1642 | + if not self._is_ai_message_requesting_get_ros2_topics_and_types( |
| 1643 | + ai_messages[0] |
| 1644 | + ): |
| 1645 | + self.log_error( |
| 1646 | + msg="First AI message did not request ROS2 topics and types correctly." |
| 1647 | + ) |
| 1648 | + if self._check_tool_calls_num_in_ai_message(ai_messages[1], expected_num=1): |
| 1649 | + self._check_tool_call( |
| 1650 | + tool_call=ai_messages[1].tool_calls[0], |
| 1651 | + expected_name="get_ros2_message_interface", |
| 1652 | + expected_args={"msg_type": "rai_interfaces/msg/HRIMessage"}, |
| 1653 | + ) |
| 1654 | + if self._check_tool_calls_num_in_ai_message(ai_messages[2], expected_num=1): |
| 1655 | + self._check_tool_call( |
| 1656 | + tool_call=ai_messages[2].tool_calls[0], |
| 1657 | + expected_name="publish_ros2_message", |
| 1658 | + expected_args={ |
| 1659 | + "topic": "/to_human", |
| 1660 | + "message": { |
| 1661 | + "header": {"stamp": {"sec": 0, "nanosec": 0}, "frame_id": ""}, |
| 1662 | + "text": "Hello!", |
| 1663 | + "images": [], |
| 1664 | + "audios": [], |
| 1665 | + }, |
| 1666 | + "message_type": "rai_interfaces/msg/HRIMessage", |
| 1667 | + }, |
| 1668 | + ) |
| 1669 | + if not self.result.errors: |
| 1670 | + self.result.success = True |
0 commit comments