Skip to content

Commit ea172d6

Browse files
committed
feat: publish message and get interfaces tools mock, test for publish
1 parent 0b2b2e5 commit ea172d6

File tree

3 files changed

+215
-1
lines changed

3 files changed

+215
-1
lines changed

src/rai_bench/rai_bench/examples/tool_calling_agent_bench_tasks.py

+2
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,12 @@
3030
MoveExistingObjectFrontTask,
3131
MoveExistingObjectLeftTask,
3232
MoveToPointTask,
33+
PublishROS2CustomMessageTask,
3334
SwapObjectsTask,
3435
)
3536

3637
tasks: Sequence[ToolCallingAgentTask] = [
38+
PublishROS2CustomMessageTask(),
3739
GetROS2RGBCameraTask(),
3840
GetROS2TopicsTask(),
3941
GetROS2DepthCameraTask(),

src/rai_bench/rai_bench/tool_calling_agent_bench/mocked_tools.py

+57-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import List, Tuple
15+
from typing import Any, Dict, List, Tuple
1616
from unittest.mock import MagicMock
1717

1818
import numpy as np
@@ -27,7 +27,9 @@
2727
)
2828
from rai.tools.ros2 import (
2929
GetROS2ImageTool,
30+
GetROS2MessageInterfaceTool,
3031
GetROS2TopicsNamesAndTypesTool,
32+
PublishROS2MessageTool,
3133
ReceiveROS2MessageTool,
3234
)
3335

@@ -181,3 +183,57 @@ def _run(self, object_name: str) -> str:
181183
return f"No {object_name}s detected."
182184
else:
183185
return f"Centroids of detected {object_name}s in manipulator frame: {expected_positions} Sizes of the detected objects are unknown."
186+
187+
188+
class MockPublishROS2MessageTool(PublishROS2MessageTool):
189+
connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector)
190+
expected_topic: str
191+
expected_message: Dict[str, Any]
192+
expected_message_type: str
193+
194+
def _run(self, topic: str, message: Dict[str, Any], message_type: str) -> str:
195+
"""
196+
Mocked method that simulates publihing to a topic and return a status string.
197+
198+
Parameters
199+
----------
200+
topic : str
201+
The name of the topic to which the message is published.
202+
message : Dict[str, Any]
203+
The content of the message as a dictionary.
204+
message_type : str
205+
The type of the message being published.
206+
207+
"""
208+
if (
209+
self.expected_topic == topic
210+
and self.expected_message == message
211+
and self.expected_message_type == message_type
212+
):
213+
return "Message published successfully"
214+
else:
215+
return "Failed to publish message"
216+
217+
218+
class MockGetROS2MessageInterfaceTool(GetROS2MessageInterfaceTool):
219+
connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector)
220+
mock_interfaces: Dict[str, str]
221+
222+
def _run(self, msg_type: str) -> str:
223+
"""
224+
Mocked method that returns the interface definition for a given ROS2 message type.
225+
226+
Parameters
227+
----------
228+
msg_type : str
229+
The ROS2 message type for which to retrieve the interface definition.
230+
231+
Returns
232+
-------
233+
str
234+
The mocked output of 'ros2 interface show' for the specified message type.
235+
"""
236+
if msg_type in self.mock_interfaces:
237+
return self.mock_interfaces[msg_type]
238+
else:
239+
return f"Interface for {msg_type} not found."

src/rai_bench/rai_bench/tool_calling_agent_bench/ros2_agent_tasks.py

+156
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@
2929
from rai_bench.tool_calling_agent_bench.mocked_tools import (
3030
MockGetObjectPositionsTool,
3131
MockGetROS2ImageTool,
32+
MockGetROS2MessageInterfaceTool,
3233
MockGetROS2TopicsNamesAndTypesTool,
3334
MockMoveToPointTool,
35+
MockPublishROS2MessageTool,
3436
MockReceiveROS2MessageTool,
3537
)
3638

@@ -1512,3 +1514,157 @@ def _matches_sequence(
15121514
any(call["name"] == e["name"] and call["args"] == e["args"] for call in it)
15131515
for e in expected_tool_calls_seq
15141516
)
1517+
1518+
1519+
class PublishROS2CustomMessageTask(ROS2ToolCallingAgentTask):
1520+
complexity = "easy"
1521+
1522+
def __init__(self, logger: loggers_type | None = None) -> None:
1523+
super().__init__(logger=logger)
1524+
self.expected_tools: List[BaseTool] = [
1525+
MockGetROS2TopicsNamesAndTypesTool(
1526+
mock_topics_names_and_types=[
1527+
"topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n",
1528+
"topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n",
1529+
"topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n",
1530+
"topic: /clock\ntype: rosgraph_msgs/msg/Clock\n",
1531+
"topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n",
1532+
"topic: /color_camera_info\ntype: sensor_msgs/msg/CameraInfo\n",
1533+
"topic: /color_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n",
1534+
"topic: /depth_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n",
1535+
"topic: /depth_image5\ntype: sensor_msgs/msg/Image\n",
1536+
"topic: /to_human\ntype: rai_interfaces/msg/HRIMessage\n",
1537+
]
1538+
),
1539+
MockGetROS2MessageInterfaceTool(
1540+
mock_interfaces={
1541+
"moveit_msgs/msg/AttachedCollisionObject": (
1542+
"ros2 interface show moveit_msgs/msg/AttachedCollisionObject:\n"
1543+
"std_msgs/Header header\n"
1544+
"string link_name\n"
1545+
"moveit_msgs/msg/CollisionObject object\n"
1546+
"string[] touch_links\n"
1547+
),
1548+
"sensor_msgs/msg/Image": (
1549+
"ros2 interface show sensor_msgs/msg/Image:\n"
1550+
"std_msgs/Header header\n"
1551+
"uint32 height\n"
1552+
"uint32 width\n"
1553+
"string encoding\n"
1554+
"uint8 is_bigendian\n"
1555+
"uint32 step\n"
1556+
"uint8[] data\n"
1557+
),
1558+
"rosgraph_msgs/msg/Clock": (
1559+
"ros2 interface show rosgraph_msgs/msg/Clock:\n"
1560+
"builtin_interfaces/Time clock\n"
1561+
),
1562+
"moveit_msgs/msg/CollisionObject": (
1563+
"ros2 interface show moveit_msgs/msg/CollisionObject:\n"
1564+
"std_msgs/Header header\n"
1565+
"string id\n"
1566+
"shape_msgs/SolidPrimitive[] primitives\n"
1567+
"geometry_msgs/Pose[] primitive_poses\n"
1568+
"shape_msgs/Mesh[] meshes\n"
1569+
"geometry_msgs/Pose[] mesh_poses\n"
1570+
"shape_msgs/Plane[] planes\n"
1571+
"geometry_msgs/Pose[] plane_poses\n"
1572+
"uint8 operation\n"
1573+
),
1574+
"sensor_msgs/msg/CameraInfo": (
1575+
"ros2 interface show sensor_msgs/msg/CameraInfo:\n"
1576+
"std_msgs/Header header\n"
1577+
"uint32 height\n"
1578+
"uint32 width\n"
1579+
"string distortion_model\n"
1580+
"float64[] D\n"
1581+
"float64[9] K\n"
1582+
"float64[9] R\n"
1583+
"float64[12] P\n"
1584+
"uint32 binning_x\n"
1585+
"uint32 binning_y\n"
1586+
"sensor_msgs/RegionOfInterest roi\n"
1587+
),
1588+
"rai_interfaces/msg/HRIMessage": (
1589+
"ros2 interface show rai_interfaces/msg/HRIMessage:\n"
1590+
"std_msgs/Header header\n"
1591+
"string text\n"
1592+
"sensor_msgs/Image[] images\n"
1593+
"rai_interfaces/AudioMessage[] audios\n"
1594+
),
1595+
"rai_interfaces/msg/AudioMessage": (
1596+
"ros2 interface show rai_interfaces/msg/AudioMessage:\n"
1597+
"int16[] audio\n"
1598+
"uint16 sample_rate\n"
1599+
"uint16 channels\n"
1600+
),
1601+
}
1602+
),
1603+
MockPublishROS2MessageTool(
1604+
expected_topic="/to_human",
1605+
expected_message={
1606+
"header": {"stamp": {"sec": 0, "nanosec": 0}, "frame_id": ""},
1607+
"text": "Hello!",
1608+
"images": [],
1609+
"audios": [],
1610+
},
1611+
expected_message_type="rai_interfaces/msg/HRIMessage",
1612+
),
1613+
]
1614+
1615+
def get_system_prompt(self) -> str:
1616+
return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT
1617+
1618+
def get_prompt(self) -> str:
1619+
return "Publish message 'Hello!' to the topic /to_human"
1620+
1621+
def verify_tool_calls(self, response: dict[str, Any]):
1622+
"""It is expected that the agent will request:
1623+
1. The tool that retrieves the ROS2 topics names and types to recognize what type of message to_human topic has
1624+
2. The tool that retrieves interfaces to check HRIMessage type
1625+
3. The tool to publish message with proper topic, message type and content
1626+
1627+
Parameters
1628+
----------
1629+
response : dict[str, Any]
1630+
The response from the agent
1631+
"""
1632+
messages = response["messages"]
1633+
ai_messages: Sequence[AIMessage] = [
1634+
message for message in messages if isinstance(message, AIMessage)
1635+
]
1636+
1637+
if len(ai_messages) == 4:
1638+
self.log_error(
1639+
msg=f"Expected exactly 4 AI messages, but got {len(ai_messages)}."
1640+
)
1641+
if ai_messages:
1642+
if not self._is_ai_message_requesting_get_ros2_topics_and_types(
1643+
ai_messages[0]
1644+
):
1645+
self.log_error(
1646+
msg="First AI message did not request ROS2 topics and types correctly."
1647+
)
1648+
if self._check_tool_calls_num_in_ai_message(ai_messages[1], expected_num=1):
1649+
self._check_tool_call(
1650+
tool_call=ai_messages[1].tool_calls[0],
1651+
expected_name="get_ros2_message_interface",
1652+
expected_args={"msg_type": "rai_interfaces/msg/HRIMessage"},
1653+
)
1654+
if self._check_tool_calls_num_in_ai_message(ai_messages[2], expected_num=1):
1655+
self._check_tool_call(
1656+
tool_call=ai_messages[2].tool_calls[0],
1657+
expected_name="publish_ros2_message",
1658+
expected_args={
1659+
"topic": "/to_human",
1660+
"message": {
1661+
"header": {"stamp": {"sec": 0, "nanosec": 0}, "frame_id": ""},
1662+
"text": "Hello!",
1663+
"images": [],
1664+
"audios": [],
1665+
},
1666+
"message_type": "rai_interfaces/msg/HRIMessage",
1667+
},
1668+
)
1669+
if not self.result.errors:
1670+
self.result.success = True

0 commit comments

Comments
 (0)