diff --git a/src/rai_bench/rai_bench/examples/tool_calling_agent_bench_navigation_tasks.py b/src/rai_bench/rai_bench/examples/tool_calling_agent_bench_navigation_tasks.py new file mode 100644 index 000000000..e1baf4ad1 --- /dev/null +++ b/src/rai_bench/rai_bench/examples/tool_calling_agent_bench_navigation_tasks.py @@ -0,0 +1,28 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence + +from rai_bench.tool_calling_agent_bench.agent_tasks_interfaces import ( + ToolCallingAgentTask, +) + +# from rai_bench.tool_calling_agent_bench.ros2_agent_tasks import ( +# NavigateToPointTask, +# ) + +# tasks: Sequence[ToolCallingAgentTask] = [ +# NavigateToPointTask(), +# # SpinAroundTask() +# ] diff --git a/src/rai_bench/rai_bench/examples/tool_calling_agent_bench_tasks.py b/src/rai_bench/rai_bench/examples/tool_calling_agent_bench_tasks.py index 07dd05dae..201358e53 100644 --- a/src/rai_bench/rai_bench/examples/tool_calling_agent_bench_tasks.py +++ b/src/rai_bench/rai_bench/examples/tool_calling_agent_bench_tasks.py @@ -18,85 +18,126 @@ ToolCallingAgentTask, ) from rai_bench.tool_calling_agent_bench.ros2_agent_tasks import ( - GetAllROS2RGBCamerasTask, - GetObjectPositionsTask, - GetROS2DepthCameraTask, - GetROS2MessageTask, - GetROS2RGBCameraTask, - GetROS2TopicsTask, - GetROS2TopicsTask2, - GrabExistingObjectTask, - GrabNotExistingObjectTask, - MoveExistingObjectFrontTask, - MoveExistingObjectLeftTask, - MoveToPointTask, - SwapObjectsTask, + PublishROS2HRIMessageTask3ExtraCalls, + PublishROS2HRIMessageTask0ExtraCalls, + PublishROS2HRIMessageTask1ExtraCall, + PublishROS2AudioMessageTask0ExtraCalls, + PublishROS2AudioMessageTask3ExtraCalls, + PublishROS2AudioMessageTask1ExtraCall, + PublishROS2DetectionArrayTask3ExtraCalls, + PublishROS2DetectionArrayTask1ExtraCall, + PublishROS2DetectionArrayTask0ExtraCalls, + CallROS2ManipulatorMoveToServiceTask3ExtraCalls, + CallROS2ManipulatorMoveToServiceTask1ExtraCall, + CallROS2ManipulatorMoveToServiceTask0ExtraCalls, + CallGroundedSAMSegmentTask3ExtraCalls, + CallGroundedSAMSegmentTask1ExtraCall, + CallGroundedSAMSegmentTask0ExtraCalls, + CallGroundingDinoClassifyTask3ExtraCalls, + CallGroundingDinoClassifyTask1ExtraCall, + CallGroundingDinoClassifyTask0ExtraCalls, + CallGetLogDigestTask3ExtraCalls, + CallGetLogDigestTask1ExtraCall, + CallGetLogDigestTask0ExtraCalls, + CallVectorStoreRetrievalTask3ExtraCalls, + CallVectorStoreRetrievalTask1ExtraCall, + CallVectorStoreRetrievalTask0ExtraCalls, + CallWhatISeeTask3ExtraCalls, + CallWhatISeeTask1ExtraCall, + CallWhatISeeTask0ExtraCalls, ) tasks: Sequence[ToolCallingAgentTask] = [ - GetROS2RGBCameraTask(), - GetROS2TopicsTask(), - GetROS2DepthCameraTask(), - GetAllROS2RGBCamerasTask(), - GetROS2TopicsTask2(), - GetROS2MessageTask(), - MoveToPointTask(args={"x": 1.0, "y": 2.0, "z": 3.0, "task": "grab"}), - MoveToPointTask(args={"x": 1.2, "y": 2.3, "z": 3.4, "task": "drop"}), - GetObjectPositionsTask( - objects={ - "carrot": [{"x": 1.0, "y": 2.0, "z": 3.0}], - "apple": [{"x": 4.0, "y": 5.0, "z": 6.0}], - "banana": [ - {"x": 7.0, "y": 8.0, "z": 9.0}, - {"x": 10.0, "y": 11.0, "z": 12.0}, - ], - }, - ), - GrabExistingObjectTask( - object_to_grab="banana", - objects={ - "banana": [{"x": 7.0, "y": 8.0, "z": 9.0}], - "apple": [ - {"x": 4.0, "y": 5.0, "z": 6.0}, - {"x": 10.0, "y": 11.0, "z": 12.0}, - ], - }, - ), - GrabNotExistingObjectTask( - object_to_grab="apple", - objects={ - "banana": [{"x": 7.0, "y": 8.0, "z": 9.0}], - "cube": [ - {"x": 4.0, "y": 5.0, "z": 6.0}, - {"x": 10.0, "y": 11.0, "z": 12.0}, - ], - }, - ), - MoveExistingObjectLeftTask( - object_to_grab="banana", - objects={ - "banana": [{"x": 7.0, "y": 8.0, "z": 9.0}], - "apple": [ - {"x": 4.0, "y": 5.0, "z": 6.0}, - {"x": 10.0, "y": 11.0, "z": 12.0}, - ], - }, - ), - MoveExistingObjectFrontTask( - object_to_grab="banana", - objects={ - "banana": [{"x": 7.0, "y": 8.0, "z": 9.0}], - "apple": [ - {"x": 4.0, "y": 5.0, "z": 6.0}, - {"x": 10.0, "y": 11.0, "z": 12.0}, - ], - }, - ), - SwapObjectsTask( - objects={ - "banana": [{"x": 1.0, "y": 2.0, "z": 3.0}], - "apple": [{"x": 4.0, "y": 5.0, "z": 6.0}], - }, - objects_to_swap=["banana", "apple"], - ), + PublishROS2HRIMessageTask3ExtraCalls(), + PublishROS2HRIMessageTask1ExtraCall(), + PublishROS2HRIMessageTask0ExtraCalls(), + PublishROS2AudioMessageTask3ExtraCalls(), + PublishROS2AudioMessageTask1ExtraCall(), + PublishROS2AudioMessageTask0ExtraCalls(), + PublishROS2DetectionArrayTask3ExtraCalls(), + PublishROS2DetectionArrayTask1ExtraCall(), + PublishROS2DetectionArrayTask0ExtraCalls(), + CallROS2ManipulatorMoveToServiceTask3ExtraCalls(), + CallROS2ManipulatorMoveToServiceTask1ExtraCall(), + CallROS2ManipulatorMoveToServiceTask0ExtraCalls(), + CallGroundedSAMSegmentTask3ExtraCalls(), + CallGroundedSAMSegmentTask1ExtraCall(), + CallGroundedSAMSegmentTask0ExtraCalls(), + CallGroundingDinoClassifyTask3ExtraCalls(), + CallGroundingDinoClassifyTask1ExtraCall(), + CallGroundingDinoClassifyTask0ExtraCalls(), + CallGetLogDigestTask3ExtraCalls(), + CallGetLogDigestTask1ExtraCall(), + CallGetLogDigestTask0ExtraCalls(), + CallVectorStoreRetrievalTask3ExtraCalls(), + CallVectorStoreRetrievalTask1ExtraCall(), + CallVectorStoreRetrievalTask0ExtraCalls(), + CallWhatISeeTask3ExtraCalls(), + CallWhatISeeTask1ExtraCall(), + CallWhatISeeTask0ExtraCalls(), + # GetROS2RGBCameraTask(), + # GetROS2TopicsTask(), + # GetROS2DepthCameraTask(), + # GetAllROS2RGBCamerasTask(), + # GetROS2TopicsTask2(), + # GetROS2MessageTask(), + # MoveToPointTask(args={"x": 1.0, "y": 2.0, "z": 3.0, "task": "grab"}), + # MoveToPointTask(args={"x": 1.2, "y": 2.3, "z": 3.4, "task": "drop"}), + # GetObjectPositionsTask( + # objects={ + # "carrot": [{"x": 1.0, "y": 2.0, "z": 3.0}], + # "apple": [{"x": 4.0, "y": 5.0, "z": 6.0}], + # "banana": [ + # {"x": 7.0, "y": 8.0, "z": 9.0}, + # {"x": 10.0, "y": 11.0, "z": 12.0}, + # ], + # }, + # ), + # GrabExistingObjectTask( + # object_to_grab="banana", + # objects={ + # "banana": [{"x": 7.0, "y": 8.0, "z": 9.0}], + # "apple": [ + # {"x": 4.0, "y": 5.0, "z": 6.0}, + # {"x": 10.0, "y": 11.0, "z": 12.0}, + # ], + # }, + # ), + # GrabNotExistingObjectTask( + # object_to_grab="apple", + # objects={ + # "banana": [{"x": 7.0, "y": 8.0, "z": 9.0}], + # "cube": [ + # {"x": 4.0, "y": 5.0, "z": 6.0}, + # {"x": 10.0, "y": 11.0, "z": 12.0}, + # ], + # }, + # ), + # MoveExistingObjectLeftTask( + # object_to_grab="banana", + # objects={ + # "banana": [{"x": 7.0, "y": 8.0, "z": 9.0}], + # "apple": [ + # {"x": 4.0, "y": 5.0, "z": 6.0}, + # {"x": 10.0, "y": 11.0, "z": 12.0}, + # ], + # }, + # ), + # MoveExistingObjectFrontTask( + # object_to_grab="banana", + # objects={ + # "banana": [{"x": 7.0, "y": 8.0, "z": 9.0}], + # "apple": [ + # {"x": 4.0, "y": 5.0, "z": 6.0}, + # {"x": 10.0, "y": 11.0, "z": 12.0}, + # ], + # }, + # ), + # SwapObjectsTask( + # objects={ + # "banana": [{"x": 1.0, "y": 2.0, "z": 3.0}], + # "apple": [{"x": 4.0, "y": 5.0, "z": 6.0}], + # }, + # objects_to_swap=["banana", "apple"], + # ), ] diff --git a/src/rai_bench/rai_bench/examples/tool_calling_agent_test_bench.py b/src/rai_bench/rai_bench/examples/tool_calling_agent_test_bench.py index f34587490..bce90d9ae 100644 --- a/src/rai_bench/rai_bench/examples/tool_calling_agent_test_bench.py +++ b/src/rai_bench/rai_bench/examples/tool_calling_agent_test_bench.py @@ -23,6 +23,8 @@ ) from rai_bench.examples.tool_calling_agent_bench_tasks import tasks + +# from rai_bench.examples.tool_calling_agent_bench_navigation_tasks import tasks from rai_bench.tool_calling_agent_bench.agent_bench import ToolCallingAgentBenchmark if __name__ == "__main__": @@ -60,7 +62,7 @@ tasks=tasks, logger=bench_logger, results_filename=results_filename ) - model_type = "simple_model" + model_type = "complex_model" model_config = get_llm_model_config_and_vendor(model_type=model_type)[0] model_name = getattr(model_config, model_type) diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/__init__.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/__init__.py new file mode 100644 index 000000000..0f0767afa --- /dev/null +++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/__init__.py @@ -0,0 +1,23 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .action_base_model import ActionBaseModel +from .navigate_to_pose import NavigateToPoseAction +from .spin import SpinAction + +__all__ = [ + "ActionBaseModel", + "NavigateToPoseAction", + "SpinAction", +] diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/action_base_model.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/action_base_model.py new file mode 100644 index 000000000..fef8444db --- /dev/null +++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/action_base_model.py @@ -0,0 +1,25 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from pydantic import BaseModel + + +class ActionBaseModel(BaseModel): + action_name: str + action_type: str + goal: Any + result: Any + feedback: Any diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/navigate_to_pose.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/navigate_to_pose.py new file mode 100644 index 000000000..20cfe8ae7 --- /dev/null +++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/navigate_to_pose.py @@ -0,0 +1,80 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from pydantic import BaseModel + +from rai_bench.tool_calling_agent_bench.actions.action_base_model import ActionBaseModel + + +class Time(BaseModel): + sec: Optional[int] = 0 + nanosec: Optional[int] = 0 + + +class Header(BaseModel): + stamp: Optional[Time] = Time() + frame_id: str + + +class Position(BaseModel): + x: float + y: float + z: float + + +class Orientation(BaseModel): + x: Optional[float] = 0.0 + y: Optional[float] = 0.0 + z: Optional[float] = 0.0 + w: Optional[float] = 1.0 + + +class Pose(BaseModel): + position: Position + orientation: Optional[Orientation] = Orientation() + + +class PoseStamped(BaseModel): + header: Header + pose: Pose + + +class Goal(BaseModel): + pose: PoseStamped + behavior_tree: Optional[str] = "" + + +class Result(BaseModel): + result: dict + + +class Feedback(BaseModel): + current_pose: PoseStamped + navigation_time: Time + estimated_time_remaining: Time + number_of_recoveries: int + distance_remaining: float + + +class NavigateToPoseAction(ActionBaseModel): + action_name: str = "/navigate_to_pose" + action_type: str = "nav2_msgs/action/NavigateToPose" + goal: Goal + result: Result + feedback: Feedback + + +# TODO (mkotynia): create init for actions diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/spin.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/spin.py new file mode 100644 index 000000000..030a4601e --- /dev/null +++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/spin.py @@ -0,0 +1,46 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from pydantic import BaseModel + +from rai_bench.tool_calling_agent_bench.actions.action_base_model import ActionBaseModel + + +class Time(BaseModel): + sec: Optional[int] = 0 + nanosec: Optional[int] = 0 + + +class Goal(BaseModel): + target_yaw: Optional[float] = 0.0 + time_allowance: Optional[Time] = Time() + + +class Result(BaseModel): + result: dict + + +class Feedback(BaseModel): + angle_turned: Optional[float] = 0.0 + remaining_yaw: Optional[float] = 0.0 + + +class SpinAction(ActionBaseModel): + action_name: str = "/spin" + action_type: str = "nav2_msgs/action/Spin" + goal: Goal + result: Result + feedback: Feedback diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/agent_bench.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/agent_bench.py index 9c59e1bf1..b8b22d581 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent_bench/agent_bench.py +++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/agent_bench.py @@ -36,8 +36,9 @@ class TaskResult(BaseModel): - task_prompt: str = Field(..., description="The task prompt.") - system_prompt: str = Field(..., description="The system prompt.") + # task_prompt: str = Field(..., description="The task prompt.") + # system_prompt: str = Field(..., description="The system prompt.") + task: str = Field(..., description="task name") complexity: str = Field(..., description="Complexity of the task.") model_name: str = Field(..., description="Name of the LLM.") success: bool = Field( @@ -184,8 +185,9 @@ def run_next(self, agent: CompiledStateGraph, model_name: str) -> None: ) task_result = TaskResult( - task_prompt=task.get_prompt(), - system_prompt=task.get_system_prompt(), + # task_prompt=task.get_prompt(), + # system_prompt=task.get_system_prompt(), + task=task.__class__.__name__, complexity=task.complexity, model_name=model_name, success=result.success, diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/agent_tasks_interfaces.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/agent_tasks_interfaces.py index 611711d6d..7c05cc20f 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent_bench/agent_tasks_interfaces.py +++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/agent_tasks_interfaces.py @@ -14,15 +14,735 @@ import logging from abc import ABC, abstractmethod -from typing import Any, List, Literal +from typing import Any, Dict, List, Literal, Sequence, Type from langchain_core.messages import AIMessage, ToolCall from langchain_core.runnables.config import DEFAULT_RECURSION_LIMIT from langchain_core.tools import BaseTool from pydantic import BaseModel +from rai_bench.tool_calling_agent_bench.messages.base import Clock +from rai_bench.tool_calling_agent_bench.messages.services import ( + ManipulatorMoveToRequest, + RAIGroundedSamRequest, + RAIGroundingDinoRequest, + StringListRequest, + VectorStoreRetrievalRequest, + WhatISeeRequest, +) +from rai_bench.tool_calling_agent_bench.messages.topics import ( + AudioMessage, + CameraInfo, + HRIMessage, + Image, + RAIDetectionArray, +) +from rai_bench.tool_calling_agent_bench.mocked_tools import ( + MockCallROS2ServiceTool, + MockCancelROS2ActionTool, + MockGetROS2ActionIDsTool, + MockGetROS2ActionsNamesAndTypesTool, + MockGetROS2ImageTool, + MockGetROS2MessageInterfaceTool, + MockGetROS2ServicesNamesAndTypesTool, + MockGetROS2TopicsNamesAndTypesTool, + MockMoveToPointTool, + MockPublishROS2MessageTool, + MockStartROS2ActionTool, +) + loggers_type = logging.Logger +# dict of interfaces where keys are interfaces types and values are output +# of GetROS2MessageInterfaceTool which are same as ros2 interface show outputs +# the dict contains custom as well as couple other common interfaces +MOCK_INTERFACES: Dict[str, str] = { + "sensor_msgs/msg/CameraInfo": """ +# This message defines meta information for a camera. It should be in a +# camera namespace on topic "camera_info" and accompanied by up to five +# image topics named: +# +# image_raw - raw data from the camera driver, possibly Bayer encoded +# image - monochrome, distorted +# image_color - color, distorted +# image_rect - monochrome, rectified +# image_rect_color - color, rectified +# +# The image_pipeline contains packages (image_proc, stereo_image_proc) +# for producing the four processed image topics from image_raw and +# camera_info. The meaning of the camera parameters are described in +# detail at http://www.ros.org/wiki/image_pipeline/CameraInfo. +# +# The image_geometry package provides a user-friendly interface to +# common operations using this meta information. If you want to, e.g., +# project a 3d point into image coordinates, we strongly recommend +# using image_geometry. +# +# If the camera is uncalibrated, the matrices D, K, R, P should be left +# zeroed out. In particular, clients may assume that K[0] == 0.0 +# indicates an uncalibrated camera. + +####################################################################### +# Image acquisition info # +####################################################################### + +# Time of image acquisition, camera coordinate frame ID +std_msgs/Header header # Header timestamp should be acquisition time of image + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + # Header frame_id should be optical frame of camera + # origin of frame should be optical center of camera + # +x should point to the right in the image + # +y should point down in the image + # +z should point into the plane of the image + + +####################################################################### +# Calibration Parameters # +####################################################################### +# These are fixed during camera calibration. Their values will be the # +# same in all messages until the camera is recalibrated. Note that # +# self-calibrating systems may "recalibrate" frequently. # +# # +# The internal parameters can be used to warp a raw (distorted) image # +# to: # +# 1. An undistorted image (requires D and K) # +# 2. A rectified image (requires D, K, R) # +# The projection matrix P projects 3D points into the rectified image.# +####################################################################### + +# The image dimensions with which the camera was calibrated. +# Normally this will be the full camera resolution in pixels. +uint32 height +uint32 width + +# The distortion model used. Supported models are listed in +# sensor_msgs/distortion_models.hpp. For most cameras, "plumb_bob" - a +# simple model of radial and tangential distortion - is sufficent. +string distortion_model + +# The distortion parameters, size depending on the distortion model. +# For "plumb_bob", the 5 parameters are: (k1, k2, t1, t2, k3). +float64[] d + +# Intrinsic camera matrix for the raw (distorted) images. +# [fx 0 cx] +# K = [ 0 fy cy] +# [ 0 0 1] +# Projects 3D points in the camera coordinate frame to 2D pixel +# coordinates using the focal lengths (fx, fy) and principal point +# (cx, cy). +float64[9] k # 3x3 row-major matrix + +# Rectification matrix (stereo cameras only) +# A rotation matrix aligning the camera coordinate system to the ideal +# stereo image plane so that epipolar lines in both stereo images are +# parallel. +float64[9] r # 3x3 row-major matrix + +# Projection/camera matrix +# [fx' 0 cx' Tx] +# P = [ 0 fy' cy' Ty] +# [ 0 0 1 0] +# By convention, this matrix specifies the intrinsic (camera) matrix +# of the processed (rectified) image. That is, the left 3x3 portion +# is the normal camera intrinsic matrix for the rectified image. +# It projects 3D points in the camera coordinate frame to 2D pixel +# coordinates using the focal lengths (fx', fy') and principal point +# (cx', cy') - these may differ from the values in K. +# For monocular cameras, Tx = Ty = 0. Normally, monocular cameras will +# also have R = the identity and P[1:3,1:3] = K. +# For a stereo pair, the fourth column [Tx Ty 0]' is related to the +# position of the optical center of the second camera in the first +# camera's frame. We assume Tz = 0 so both cameras are in the same +# stereo image plane. The first camera always has Tx = Ty = 0. For +# the right (second) camera of a horizontal stereo pair, Ty = 0 and +# Tx = -fx' * B, where B is the baseline between the cameras. +# Given a 3D point [X Y Z]', the projection (x, y) of the point onto +# the rectified image is given by: +# [u v w]' = P * [X Y Z 1]' +# x = u / w +# y = v / w +# This holds for both images of a stereo pair. +float64[12] p # 3x4 row-major matrix + + +####################################################################### +# Operational Parameters # +####################################################################### +# These define the image region actually captured by the camera # +# driver. Although they affect the geometry of the output image, they # +# may be changed freely without recalibrating the camera. # +####################################################################### + +# Binning refers here to any camera setting which combines rectangular +# neighborhoods of pixels into larger "super-pixels." It reduces the +# resolution of the output image to +# (width / binning_x) x (height / binning_y). +# The default values binning_x = binning_y = 0 is considered the same +# as binning_x = binning_y = 1 (no subsampling). +uint32 binning_x +uint32 binning_y + +# Region of interest (subwindow of full camera resolution), given in +# full resolution (unbinned) image coordinates. A particular ROI +# always denotes the same window of pixels on the camera sensor, +# regardless of binning settings. +# The default setting of roi (all values 0) is considered the same as +# full resolution (roi.width = width, roi.height = height). +RegionOfInterest roi + # + uint32 x_offset # + # (0 if the ROI includes the left edge of the image) + uint32 y_offset # + # (0 if the ROI includes the top edge of the image) + uint32 height # + uint32 width # + bool do_rectify +""", + "sensor_msgs/msg/Image": """ +# This message contains an uncompressed image +# (0, 0) is at top-left corner of image + +std_msgs/Header header # Header timestamp should be acquisition time of image + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + # Header frame_id should be optical frame of camera + # origin of frame should be optical center of cameara + # +x should point to the right in the image + # +y should point down in the image + # +z should point into to plane of the image + # If the frame_id here and the frame_id of the CameraInfo + # message associated with the image conflict + # the behavior is undefined + +uint32 height # image height, that is, number of rows +uint32 width # image width, that is, number of columns + +# The legal values for encoding are in file src/image_encodings.cpp +# If you want to standardize a new string format, join +# ros-users@lists.ros.org and send an email proposing a new encoding. + +string encoding # Encoding of pixels -- channel meaning, ordering, size + # taken from the list of strings in include/sensor_msgs/image_encodings.hpp + +uint8 is_bigendian # is this data bigendian? +uint32 step # Full row length in bytes +uint8[] data # actual matrix data, size is (step * rows) +""", + "rosgraph_msgs/msg/Clock": """ +# This message communicates the current time. +# +# For more information, see https://design.ros2.org/articles/clock_and_time.html. +builtin_interfaces/Time clock + int32 sec + uint32 nanosec +""", + "rai_interfaces/msg/HRIMessage": """ +# +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id +string text +sensor_msgs/Image[] images + std_msgs/Header header # + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + # Header frame_id should be optical frame of camera + # origin of frame should be optical center of cameara + # +x should point to the right in the image + # +y should point down in the image + # +z should point into to plane of the image + # If the frame_id here and the frame_id of the CameraInfo + # message associated with the image conflict + # the behavior is undefined + uint32 height # + uint32 width # + string encoding # + # taken from the list of strings in include/sensor_msgs/image_encodings.hpp + uint8 is_bigendian # + uint32 step # + uint8[] data # +rai_interfaces/AudioMessage[] audios + # + # + # + # + # + int16[] audio + uint16 sample_rate + uint16 channels +string communication_id +int64 seq_no +bool seq_end +""", + "rai_interfaces/msg/AudioMessage": """ +# +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +int16[] audio +uint16 sample_rate +uint16 channels +""", + "rai_interfaces/msg/RAIDetectionArray": """ +# +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# A list of 2D detections, for a multi-object 2D detector. +std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + +# A list of the detected proposals. A multi-proposal detector might generate +# this list with many candidate detections generated from a single input. +vision_msgs/Detection2D[] detections + # + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + ObjectHypothesisWithPose[] results + ObjectHypothesis hypothesis + string class_id + float64 score + geometry_msgs/PoseWithCovariance pose + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + float64[36] covariance + BoundingBox2D bbox + vision_msgs/Pose2D center + vision_msgs/Point2D position + float64 x + float64 y + float64 theta + float64 size_x + float64 size_y + string id +# a list of classes being detected +string[] detection_classes +""", + "rai_interfaces/srv/ManipulatorMoveTo": """ +# +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# A simplified approach with binary states for the gripper +bool initial_gripper_state +bool final_gripper_state +geometry_msgs/PoseStamped target_pose + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 +--- +bool success +""", + "rai_interfaces/srv/RAIGroundedSam": """ +# +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +RAIDetectionArray detections + # + # + # + # + # + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + vision_msgs/Detection2D[] detections + # + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + ObjectHypothesisWithPose[] results + ObjectHypothesis hypothesis + string class_id + float64 score + geometry_msgs/PoseWithCovariance pose + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + float64[36] covariance + BoundingBox2D bbox + vision_msgs/Pose2D center + vision_msgs/Point2D position + float64 x + float64 y + float64 theta + float64 size_x + float64 size_y + string id + string[] detection_classes +sensor_msgs/Image source_img + std_msgs/Header header # + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + # Header frame_id should be optical frame of camera + # origin of frame should be optical center of cameara + # +x should point to the right in the image + # +y should point down in the image + # +z should point into to plane of the image + # If the frame_id here and the frame_id of the CameraInfo + # message associated with the image conflict + # the behavior is undefined + uint32 height # + uint32 width # + string encoding # + # taken from the list of strings in include/sensor_msgs/image_encodings.hpp + uint8 is_bigendian # + uint32 step # + uint8[] data # +--- +sensor_msgs/Image[] masks + std_msgs/Header header # + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + # Header frame_id should be optical frame of camera + # origin of frame should be optical center of cameara + # +x should point to the right in the image + # +y should point down in the image + # +z should point into to plane of the image + # If the frame_id here and the frame_id of the CameraInfo + # message associated with the image conflict + # the behavior is undefined + uint32 height # + uint32 width # + string encoding # + # taken from the list of strings in include/sensor_msgs/image_encodings.hpp + uint8 is_bigendian # + uint32 step # + uint8[] data # +""", + "rai_interfaces/srv/RAIGroundingDino": """ +# +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +string classes +float64 box_threshold +float64 text_threshold +sensor_msgs/Image source_img + std_msgs/Header header # + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + # Header frame_id should be optical frame of camera + # origin of frame should be optical center of cameara + # +x should point to the right in the image + # +y should point down in the image + # +z should point into to plane of the image + # If the frame_id here and the frame_id of the CameraInfo + # message associated with the image conflict + # the behavior is undefined + uint32 height # + uint32 width # + string encoding # + # taken from the list of strings in include/sensor_msgs/image_encodings.hpp + uint8 is_bigendian # + uint32 step # + uint8[] data # +--- +RAIDetectionArray detections + # + # + # + # + # + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + vision_msgs/Detection2D[] detections + # + std_msgs/Header header + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + ObjectHypothesisWithPose[] results + ObjectHypothesis hypothesis + string class_id + float64 score + geometry_msgs/PoseWithCovariance pose + Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 + float64[36] covariance + BoundingBox2D bbox + vision_msgs/Pose2D center + vision_msgs/Point2D position + float64 x + float64 y + float64 theta + float64 size_x + float64 size_y + string id + string[] detection_classes +""", + "rai_interfaces/srv/StringList": """ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Request - empty +--- +# Response +bool success +string[] string_list +""", + "rai_interfaces/srv/VectorStoreRetrieval": """ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Request +string query + +--- +# Response +bool success +string message +string[] documents +float32[] scores +""", + "rai_interfaces/srv/WhatISee": """z +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Request (empty) + +--- +# Response, timed with image timestamp +string[] observations +string perception_source +sensor_msgs/Image image + std_msgs/Header header # + builtin_interfaces/Time stamp + int32 sec + uint32 nanosec + string frame_id + # Header frame_id should be optical frame of camera + # origin of frame should be optical center of cameara + # +x should point to the right in the image + # +y should point down in the image + # +z should point into to plane of the image + # If the frame_id here and the frame_id of the CameraInfo + # message associated with the image conflict + # the behavior is undefined + uint32 height # + uint32 width # + string encoding # + # taken from the list of strings in include/sensor_msgs/image_encodings.hpp + uint8 is_bigendian # + uint32 step # + uint8[] data # +geometry_msgs/Pose pose + Point position + float64 x + float64 y + float64 z + Quaternion orientation + float64 x 0 + float64 y 0 + float64 z 0 + float64 w 1 +""", + "rai_interfaces/action/Task": """ +# Goal +string task +string description +string priority + +--- +# Result +bool success +string report + +--- +# Feedback +string current_status +""", + "/load_map": """ +string filename +--- +bool success +""", + "/query_planner_interface": """ +--- + +# The planning instances that could be used in the benchmark +PlannerInterfaceDescription[] planner_interfaces + string name + string pipeline_id + string[] planner_ids + +""", +} + class Result(BaseModel): success: bool = False @@ -89,6 +809,167 @@ def verify_tool_calls(self, response: dict[str, Any]): """ pass + def _check_topic_tool_call_field( + self, + tool_call: ToolCall, + expected_name: str, + expected_topic: str, + expected_message_type: str, + field_path: str, + expected_value: Any, + ) -> bool: + """ + Verifies a tool call for a topic publishing operation. + + Parameters + ---------- + tool_call : ToolCall + The tool call dictionary containing keys such as "name" and "args". + expected_name : str + The expected tool call name (e.g., "publish_ros2_message"). + expected_topic : str + The expected topic name in the tool call's arguments. + expected_message_type : str + The expected message type (e.g., "rai_interfaces/msg/HRIMessage"). + field_path : str + Dot-separated path to the field inside the message (e.g., "header.frame_id"). + expected_value : Any + The expected value at the given field path. + + Returns + ------- + bool + True if all conditions are met; False otherwise. + """ + # Check tool call name. + if tool_call.get("name") != expected_name: + self.log_error( + f"Expected tool call name '{expected_name}', but got '{tool_call.get('name')}'." + ) + return False + + args = tool_call.get("args", {}) + + # Check topic. + if args.get("topic") != expected_topic: + self.log_error( + f"Expected topic '{expected_topic}', but got '{args.get('topic')}'." + ) + return False + + # Check message type. + if args.get("message_type") != expected_message_type: + self.log_error( + f"Expected message type '{expected_message_type}', but got '{args.get('message_type')}'." + ) + return False + + # Traverse the message field. + message = args.get("message") + if message is None: + self.log_error("Tool call does not contain a 'message' argument.") + return False + + keys = field_path.split(".") + value: Any = message + for key in keys: + if isinstance(value, dict) and key in value: + value = value[key] + else: + self.log_error(f"Field path '{field_path}' not found in the message.") + return False + + if value != expected_value: + self.log_error( + f"Expected value for field '{field_path}' is '{expected_value}', but got '{value}'." + ) + return False + + return True + + def _check_service_tool_call_field( + self, + tool_call: ToolCall, + expected_name: str, + expected_service: str, + expected_service_type: str, + field_path: str, + expected_value: Any, + ) -> bool: + """ + Verifies a tool call for a service call. + + Parameters + ---------- + tool_call : ToolCall + The tool call dictionary containing keys such as "name" and "args". + expected_name : str + The expected tool call name (e.g., "call_ros2_service"). + expected_service : str + The expected service name in the tool call's arguments. + expected_message_type : str + The expected message type. + field_path : str + Dot-separated path to the field inside the message. + expected_value : Any + The expected value at the given field path. + + Returns + ------- + bool + True if all conditions are met; False otherwise. + """ + if tool_call.get("name") != expected_name: + self.log_error( + f"Expected tool call name '{expected_name}', but got '{tool_call.get('name')}'." + ) + return False + + args = tool_call.get("args", {}) + + # Check service. + if args.get("service_name") != expected_service: + self.log_error( + f"Expected service '{expected_service}', but got '{args.get('service')}'." + ) + return False + + # Check message type. + if args.get("service_type") != expected_service_type: + self.log_error( + f"Expected message type '{expected_service_type}', but got '{args.get('service_type')}'." + ) + return False + + service_args = args.get("service_args") + if service_args is None: + self.log_error("Tool call does not contain a 'service_args' argument.") + return False + + if field_path == "": + if service_args == {}: + return True + else: + self.log_error(f"Expected empty service_args, but got: {service_args}") + return False + + keys = field_path.split(".") + value: Any = service_args + for key in keys: + if isinstance(value, dict) and key in value: + value = value[key] + else: + self.log_error(f"Field path '{field_path}' not found in the message.") + return False + + if value != expected_value: + self.log_error( + f"Expected value for field '{field_path}' is '{expected_value}', but got '{value}'." + ) + return False + + return True + def _check_tool_call( self, tool_call: ToolCall, @@ -274,3 +1155,438 @@ def _is_ai_message_requesting_get_ros2_topics_and_types( ): return False return True + + def _is_ai_message_requesting_get_ros2_services_and_types( + self, ai_message: AIMessage + ) -> bool: + """Helper method to check if the given AIMessage is calling the exactly one tool that gets ROS2 service names and types correctly. + + Parameters + ---------- + ai_message : AIMessage + The AIMessage to check + + Returns + ------- + bool + True if the ai_message is requesting get_ros2_service_names_and_types correctly, False otherwise + """ + if not self._check_tool_calls_num_in_ai_message(ai_message, expected_num=1): + return False + + tool_call: ToolCall = ai_message.tool_calls[0] + if not self._check_tool_call( + tool_call=tool_call, + expected_name="get_ros2_services_names_and_types", + expected_args={}, + ): + return False + return True + + def _is_ai_message_requesting_get_ros2_actions_and_types( + self, ai_message: AIMessage + ) -> bool: + """Helper method to check if the given AIMessage is calling the exactly one tool that gets ROS2 actions names and types correctly. + + Parameters + ---------- + ai_message : AIMessage + The AIMessage to check + + Returns + ------- + bool + True if the ai_message is requesting get_ros2_actions_names_and_types correctly, False otherwise + """ + if not self._check_tool_calls_num_in_ai_message(ai_message, expected_num=1): + return False + + tool_call: ToolCall = ai_message.tool_calls[0] + if not self._check_tool_call( + tool_call=tool_call, + expected_name="get_ros2_actions_names_and_types", + expected_args={}, + ): + return False + return True + + def get_tool_calls(self, response: dict[str, Any]) -> list[ToolCall]: + """Extracts all tool calls from the response, flattened across all AI messages.""" + tool_calls: List[ToolCall] = [] + for msg in response["messages"]: + if isinstance(msg, AIMessage): + tool_calls.extend(msg.tool_calls) + return tool_calls + + +SERVICES_AND_TYPES = { + # sample interfaces + # "/load_map": "moveit_msgs/srv/LoadMap", + # "/query_planner_interface": "moveit_msgs/srv/QueryPlannerInterfaces", + # custom interfaces + "/manipulator_move_to": "rai_interfaces/srv/ManipulatorMoveTo", + "/grounded_sam_segment": "rai_interfaces/srv/RAIGroundedSam", + "/grounding_dino_classify": "rai_interfaces/srv/RAIGroundingDino", + "/get_log_digest": "rai_interfaces/srv/StringList", + "/rai_whoami_documentation_service": "rai_interfaces/srv/VectorStoreRetrieval", + "/rai/whatisee/get": "rai_interfaces/srv/WhatISee", +} + +SERVICE_MODELS: Dict[str, Type[BaseModel]] = { + "rai_interfaces/srv/ManipulatorMoveTo": ManipulatorMoveToRequest, + "rai_interfaces/srv/RAIGroundedSam": RAIGroundedSamRequest, + "rai_interfaces/srv/RAIGroundingDino": RAIGroundingDinoRequest, + "rai_interfaces/srv/StringList": StringListRequest, + "rai_interfaces/srv/VectorStoreRetrieval": VectorStoreRetrievalRequest, + "rai_interfaces/srv/WhatISee": WhatISeeRequest, +} + +TOPICS_AND_TYPES: Dict[str, str] = { + # sample topics + "/camera_image_color": "sensor_msgs/msg/Image", + "/camera_image_depth": "sensor_msgs/msg/Image", + "/clock": "rosgraph_msgs/msg/Clock", + "/color_camera_info": "sensor_msgs/msg/CameraInfo", + "/color_camera_info5": "sensor_msgs/msg/CameraInfo", + "/depth_camera_info5": "sensor_msgs/msg/CameraInfo", + "/depth_image5": "sensor_msgs/msg/Image", + # custom topics + "/to_human": "rai_interfaces/msg/HRIMessage", + "/send_audio": "rai_interfaces/msg/AudioMessage", + "/send_detections": "rai_interfaces/msg/RAIDetectionArray", +} +TOPIC_STRINGS = [ + f"topic: {topic}\ntype: {msg_type}\n" + for topic, msg_type in TOPICS_AND_TYPES.items() +] +TOPIC_MODELS: Dict[str, Type[BaseModel]] = { + "sensor_msgs/msg/CameraInfo": CameraInfo, + "sensor_msgs/msg/Image": Image, + "rosgraph_msgs/msg/Clock": Clock, + "rai_interfaces/msg/HRIMessage": HRIMessage, + "rai_interfaces/msg/AudioMessage": AudioMessage, + "rai_interfaces/msg/RAIDetectionArray": RAIDetectionArray, +} + +IMAGE_TOPICS: Dict[str, str] = { + "/attached_collision_object": "moveit_msgs/msg/AttachedCollisionObject", + "/camera_image_color": "sensor_msgs/msg/Image", + "/camera_image_depth": "sensor_msgs/msg/Image", + "/clock": "rosgraph_msgs/msg/Clock", + "/collision_object": "moveit_msgs/msg/CollisionObject", + "/color_camera_info": "sensor_msgs/msg/CameraInfo", + "/color_camera_info5": "sensor_msgs/msg/CameraInfo", + "/depth_camera_info5": "sensor_msgs/msg/CameraInfo", +} + +SERVICE_STRINGS = [ + f"service: {service}\ntype: {msg_type}\n" + for service, msg_type in SERVICES_AND_TYPES.items() +] + + +class CustomInterfacesTopicTask(ROS2ToolCallingAgentTask, ABC): + def __init__(self, logger: loggers_type | None = None) -> None: + super().__init__(logger=logger) + + self.expected_tools: List[BaseTool] = [ + MockGetROS2TopicsNamesAndTypesTool( + mock_topics_names_and_types=TOPIC_STRINGS + ), + MockGetROS2MessageInterfaceTool(mock_interfaces=MOCK_INTERFACES), + MockPublishROS2MessageTool( + available_topics=list(TOPICS_AND_TYPES.keys()), + available_message_types=list(TOPICS_AND_TYPES.values()), + available_topic_models=TOPIC_MODELS, + ), + MockCancelROS2ActionTool(), + MockGetROS2ActionIDsTool(), + MockMoveToPointTool(manipulator_frame="base_link"), + MockGetROS2ImageTool(available_topics=list(IMAGE_TOPICS.keys())), + MockGetROS2ServicesNamesAndTypesTool( + mock_service_names_and_types=SERVICE_STRINGS + ), + MockGetROS2MessageInterfaceTool(mock_interfaces=MOCK_INTERFACES), + MockCallROS2ServiceTool( + available_services=list(SERVICES_AND_TYPES.keys()), + available_service_types=list(SERVICES_AND_TYPES.values()), + available_service_models=SERVICE_MODELS, + ), + ] + + @property + @abstractmethod + def expected_topic(self) -> str: + pass + + @property + def expected_message_type(self) -> str: + return TOPICS_AND_TYPES[self.expected_topic] + + @property + def extra_calls(self) -> int: + return 0 + + def verify_list_and_get_interface_tool_calls( + self, tool_calls: List[ToolCall] + ) -> tuple[bool, list[ToolCall]]: + """ + Verifies tool calls in this required order: + 1. get_ros2_topics_and_types + 2. get_ros2_message_interface (with correct msg_type) + + Returns + ------- + Tuple[bool, List[AIMessage]] + Success flag and remaining messages (to be used in `verify_message_tool_call`) + """ + + expected_core_calls = 3 + max_allowed = expected_core_calls + self.extra_calls + if len(tool_calls) > max_allowed: + self.log_error( + f"Too many tool calls. Expected at most {max_allowed}, got {len(tool_calls)}." + ) + return False, [] + + stage = 0 # 0: expect topics, 1: expect interface + for idx, call in enumerate(tool_calls): + if stage == 0 and call["name"] == "get_ros2_topics_names_and_types": + stage = 1 + continue + + if stage == 1 and call["name"] == "get_ros2_message_interface": + if call["args"].get("msg_type") == self.expected_message_type: + stage = 2 + return True, tool_calls[idx + 1 :] + + self.log_error("Required tool calls not found in order: topics → interface") + return False, [] + + @abstractmethod + def verify_message_tool_call(self, tool_calls: List[ToolCall]) -> bool: + """ + Search the remaining AI messages for the expected publish/service tool call. + """ + pass + + def verify_tool_calls(self, response: dict[str, Any]): + """ + Validates the full sequence of AI tool calls with support for extras and ordering. + + Steps: + 1. Get topics + 2. Get message interface + 3. Call publish/service with expected content + """ + messages = response["messages"] + ai_messages: Sequence[AIMessage] = [ + msg for msg in messages if isinstance(msg, AIMessage) + ] + self.logger.debug(f"AI messages: {ai_messages}") + tool_calls = self.get_tool_calls(response) + + # success, remaining_tool_calls = self.verify_list_and_get_interface_tool_calls( + # tool_calls + # ) + # if success and self.verify_message_tool_call(remaining_tool_calls): + # self.result.success = True + if self.verify_message_tool_call(tool_calls): + self.result.success = True + + +class CustomInterfacesServiceTask(ROS2ToolCallingAgentTask, ABC): + def __init__(self, logger: loggers_type | None = None) -> None: + super().__init__(logger=logger) + self.expected_tools: List[BaseTool] = [ + MockGetROS2TopicsNamesAndTypesTool( + mock_topics_names_and_types=TOPIC_STRINGS + ), + MockGetROS2MessageInterfaceTool(mock_interfaces=MOCK_INTERFACES), + MockPublishROS2MessageTool( + available_topics=list(TOPICS_AND_TYPES.keys()), + available_message_types=list(TOPICS_AND_TYPES.values()), + available_topic_models=TOPIC_MODELS, + ), + MockCancelROS2ActionTool(), + MockGetROS2ActionIDsTool(), + MockMoveToPointTool(manipulator_frame="base_link"), + MockGetROS2ImageTool(available_topics=list(IMAGE_TOPICS.keys())), + MockGetROS2ServicesNamesAndTypesTool( + mock_service_names_and_types=SERVICE_STRINGS + ), + MockGetROS2MessageInterfaceTool(mock_interfaces=MOCK_INTERFACES), + MockCallROS2ServiceTool( + available_services=list(SERVICES_AND_TYPES.keys()), + available_service_types=list(SERVICES_AND_TYPES.values()), + available_service_models=SERVICE_MODELS, + ), + ] + + @property + @abstractmethod + def expected_service(self) -> str: + pass + + @property + def expected_service_type(self) -> str: + return SERVICES_AND_TYPES[self.expected_service] + + @property + def extra_calls(self) -> int: + return 0 + + def verify_list_and_get_interface_tool_calls( + self, tool_calls: List[ToolCall] + ) -> tuple[bool, list[ToolCall]]: + """ + Verifies tool calls in this required order: + 1. get_ros2_services_names_and_types + 2. get_ros2_message_interface (with correct msg_type) + + Returns + ------- + Tuple[bool, List[ToolCall]] + Success flag and remaining tool calls for message verification + """ + expected_core_calls = 3 + max_allowed = expected_core_calls + self.extra_calls + if len(tool_calls) > max_allowed: + self.log_error( + f"Too many tool calls. Expected at most {max_allowed}, got {len(tool_calls)}." + ) + return False, [] + + stage = 0 # 0: expect service list, 1: expect interface + for idx, call in enumerate(tool_calls): + if stage == 0 and call["name"] == "get_ros2_services_names_and_types": + stage = 1 + continue + + if stage == 1 and call["name"] == "get_ros2_message_interface": + if call["args"].get("msg_type") == self.expected_service_type: + stage = 2 + return True, tool_calls[idx + 1 :] + + self.log_error("Required tool calls not found in order: services → interface") + return False, [] + + @abstractmethod + def verify_message_tool_call(self, tool_calls: List[ToolCall]) -> bool: + """Search the remaining tool calls for the expected service call.""" + pass + + def verify_tool_calls(self, response: dict[str, Any]): + """ + Full tool call sequence verification: + 1. Get services + 2. Get message interface + 3. Call service with expected values + """ + messages = response["messages"] + ai_messages: Sequence[AIMessage] = [ + msg for msg in messages if isinstance(msg, AIMessage) + ] + self.logger.debug(f"AI messages: {ai_messages}") + tool_calls = self.get_tool_calls(response) + + # success, remaining_tool_calls = self.verify_list_and_get_interface_tool_calls( + # tool_calls + # ) + # if success and + + if self.verify_message_tool_call(tool_calls): + self.result.success = True + + +class CustomInterfacesActionTask(ROS2ToolCallingAgentTask, ABC): + ACTIONS_AND_TYPES = { + # custom actions + "/perform_task": "rai_interfaces/action/Task", + # some sample actions + # "/execute_trajectory": "moveit_msgs/action/ExecuteTrajectory", + # "/move_action": "moveit_msgs/action/MoveGroup", + # "/follow_joint_trajectory": "control_msgs/action/FollowJointTrajectory", + # "/gripper_cmd": "control_msgs/action/GripperCommand", + } + + action_strings = [ + f"action: {action}\ntype: {msg_type}\n" + for action, msg_type in ACTIONS_AND_TYPES.items() + ] + + def __init__(self, logger: loggers_type | None = None) -> None: + super().__init__(logger=logger) + self.expected_tools: List[BaseTool] = [ + MockGetROS2ActionsNamesAndTypesTool( + mock_actions_names_and_types=self.action_strings + ), + MockGetROS2MessageInterfaceTool(mock_interfaces=MOCK_INTERFACES), + MockStartROS2ActionTool( + available_actions=list(self.ACTIONS_AND_TYPES.keys()), + available_action_types=list(self.ACTIONS_AND_TYPES.values()), + ), + ] + + @property + @abstractmethod + def expected_action(self) -> str: + pass + + @property + @abstractmethod + def expected_message(self) -> BaseModel: + pass + + @property + def expected_action_type(self) -> str: + return self.ACTIONS_AND_TYPES[self.expected_action] + + def verify_tool_calls(self, response: dict[str, Any]): + """It is expected that the agent will request: + 1. The tool that retrieves the topics names and types to recognize what type of message to_human topic has + 2. The tool that retrieves interfaces to check HRIMessage type + 3. The tool to publish message with proper topic, message type and content + + Parameters + ---------- + response : dict[str, Any] + The response from the agent + """ + messages = response["messages"] + ai_messages: Sequence[AIMessage] = [ + message for message in messages if isinstance(message, AIMessage) + ] + self.logger.debug(ai_messages) + if len(ai_messages) != 4: + self.log_error( + msg=f"Expected exactly 4 AI messages, but got {len(ai_messages)}." + ) + if ai_messages: + if not self._is_ai_message_requesting_get_ros2_actions_and_types( + ai_messages[0] + ): + self.log_error( + msg="First AI message did not request ROS2 topics and types correctly." + ) + if len(ai_messages) > 1: + if self._check_tool_calls_num_in_ai_message(ai_messages[1], expected_num=1): + self._check_tool_call( + tool_call=ai_messages[1].tool_calls[0], + expected_name="get_ros2_message_interface", + expected_args={"msg_type": self.expected_action_type}, + ) + + if len(ai_messages) > 2: + if self._check_tool_calls_num_in_ai_message(ai_messages[2], expected_num=1): + self._check_tool_call( + tool_call=ai_messages[2].tool_calls[0], + expected_name="start_ros2_action", + expected_args={ + "action_name": self.expected_action, + "action_args": self.expected_message.model_dump(), + "action_type": self.expected_action_type, + }, + ) + if not self.result.errors: + self.result.success = True diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/__init__.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/__init__.py new file mode 100644 index 000000000..97ceef6f0 --- /dev/null +++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/__init__.py @@ -0,0 +1,13 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/actions.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/actions.py new file mode 100644 index 000000000..9fd286938 --- /dev/null +++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/actions.py @@ -0,0 +1,40 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from pydantic import BaseModel + + +class TaskGoal(BaseModel): + task: Optional[str] = "" + description: Optional[str] = "" + priority: Optional[str] = "" + + +class TaskResult(BaseModel): + success: Optional[bool] = False + report: Optional[str] = "" + + +class TaskFeedback(BaseModel): + current_status: Optional[str] = "" + + +class LoadMapRequest(BaseModel): + filename: Optional[str] = "" + + +class LoadMapResponse(BaseModel): + success: Optional[bool] = False diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/base.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/base.py new file mode 100644 index 000000000..ccfa21f21 --- /dev/null +++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/base.py @@ -0,0 +1,101 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +from pydantic import BaseModel + + +# TODO (jm) redundant with action models, remove action models later +class Time(BaseModel): + sec: Optional[int] = 0 + nanosec: Optional[int] = 0 + + +class Header(BaseModel): + stamp: Optional[Time] = Time() + frame_id: Optional[str] = "" + + +class RegionOfInterest(BaseModel): + x_offset: Optional[int] = 0 + y_offset: Optional[int] = 0 + height: Optional[int] = 0 + width: Optional[int] = 0 + do_rectify: Optional[bool] = False + + +class Position(BaseModel): + x: Optional[float] = 0.0 + y: Optional[float] = 0.0 + z: Optional[float] = 0.0 + + +class Orientation(BaseModel): + x: Optional[float] = 0.0 + y: Optional[float] = 0.0 + z: Optional[float] = 0.0 + w: Optional[float] = 1.0 + + +class Pose(BaseModel): + position: Optional[Position] = Position() + orientation: Optional[Orientation] = Orientation() + + +class PoseStamped(BaseModel): + header: Optional[Header] = Header() + pose: Optional[Pose] = Pose() + + +class Clock(BaseModel): + clock: Optional[Time] = Time() + + +class ObjectHypothesis(BaseModel): + class_id: Optional[str] = "" + score: Optional[float] = 0.0 + + +class PoseWithCovariance(BaseModel): + pose: Optional[Pose] = Pose() + covariance: Optional[List[float]] = [0.0] * 36 + + +class ObjectHypothesisWithPose(BaseModel): + hypothesis: Optional[ObjectHypothesis] = ObjectHypothesis() + pose: Optional[PoseWithCovariance] = PoseWithCovariance() + + +class Point2D(BaseModel): + x: Optional[float] = 0.0 + y: Optional[float] = 0.0 + + +class Pose2D(BaseModel): + position: Optional[Point2D] = Point2D() + theta: Optional[float] = 0.0 + + +class BoundingBox2D(BaseModel): + center: Optional[Pose2D] = Pose2D() + size_x: Optional[float] = 0.0 + size_y: Optional[float] = 0.0 + + +class Detection2D(BaseModel): + header: Optional[Header] = Header() + results: Optional[List[ObjectHypothesisWithPose]] = [] + bbox: Optional[BoundingBox2D] = BoundingBox2D() + id: Optional[str] = "" diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/services.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/services.py new file mode 100644 index 000000000..c60cdbcf2 --- /dev/null +++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/services.py @@ -0,0 +1,91 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +from pydantic import BaseModel + +from rai_bench.tool_calling_agent_bench.messages.base import Pose, PoseStamped +from rai_bench.tool_calling_agent_bench.messages.topics import Image, RAIDetectionArray + + +class ManipulatorMoveToRequest(BaseModel): + initial_gripper_state: Optional[bool] = False + final_gripper_state: Optional[bool] = False + target_pose: Optional[PoseStamped] = PoseStamped() + + +class ManipulatorMoveToResponse(BaseModel): + success: Optional[bool] = False + + +class RAIGroundedSamRequest(BaseModel): + detections: Optional[RAIDetectionArray] = RAIDetectionArray() + source_img: Optional[Image] = Image() + + +class RAIGroundedSamResponse(BaseModel): + masks: Optional[List[Image]] = [] + + +class RAIGroundingDinoRequest(BaseModel): + classes: Optional[str] = "" + box_threshold: Optional[float] = 0.0 + text_threshold: Optional[float] = 0.0 + source_img: Optional[Image] = Image() + + +class RAIGroundingDinoResponse(BaseModel): + detections: Optional[RAIDetectionArray] = RAIDetectionArray() + + +class StringListRequest(BaseModel): + pass + + +class StringListResponse(BaseModel): + success: Optional[bool] = False + string_list: Optional[List[str]] = [] + + +class VectorStoreRetrievalRequest(BaseModel): + query: Optional[str] = "" + + +class VectorStoreRetrievalResponse(BaseModel): + success: Optional[bool] = False + message: Optional[str] = "" + documents: Optional[List[str]] = [] + scores: Optional[List[float]] = [] + + +class WhatISeeRequest(BaseModel): + pass + + +class WhatISeeResponse(BaseModel): + observations: Optional[List[str]] = [] + perception_source: Optional[str] = "" + image: Optional[Image] = Image() + pose: Optional[Pose] = Pose() + + +class PlannerInterfaceDescription(BaseModel): + name: Optional[str] = "" + pipeline_id: Optional[str] = "" + planner_ids: Optional[List[str]] = [] + + +class QueryPlannerInterfaceResponse(BaseModel): + planner_interfaces: Optional[List[PlannerInterfaceDescription]] = [] diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/topics.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/topics.py new file mode 100644 index 000000000..1a68d6d77 --- /dev/null +++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/topics.py @@ -0,0 +1,69 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +from pydantic import BaseModel + +from rai_bench.tool_calling_agent_bench.messages.base import ( + Detection2D, + Header, + RegionOfInterest, +) + + +class CameraInfo(BaseModel): + header: Optional[Header] = Header() + height: Optional[int] = 0 + width: Optional[int] = 0 + distortion_model: Optional[str] = "" + d: Optional[List[float]] = [] + k: Optional[List[float]] = [0.0] * 9 + r: Optional[List[float]] = [0.0] * 9 + p: Optional[List[float]] = [0.0] * 12 + binning_x: Optional[int] = 0 + binning_y: Optional[int] = 0 + roi: Optional[RegionOfInterest] = RegionOfInterest() + + +class Image(BaseModel): + header: Optional[Header] = Header() + height: Optional[int] = 0 + width: Optional[int] = 0 + encoding: Optional[str] = "" + is_bigendian: Optional[int] = 0 + step: Optional[int] = 0 + data: Optional[List[int]] = [] + + +class AudioMessage(BaseModel): + audio: Optional[List[int]] = [] + sample_rate: Optional[int] = 0 + channels: Optional[int] = 0 + + +class HRIMessage(BaseModel): + header: Optional[Header] = Header() + text: Optional[str] = "" + images: Optional[List[Image]] = [] + audios: Optional[List[AudioMessage]] = [] + communication_id: Optional[str] = "" + seq_no: Optional[int] = 0 + seq_end: Optional[bool] = False + + +class RAIDetectionArray(BaseModel): + header: Optional[Header] = Header() + detections: Optional[List[Detection2D]] = [] + detection_classes: Optional[List[str]] = [] diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/mocked_tools.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/mocked_tools.py index cdaad998b..ca8baace0 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent_bench/mocked_tools.py +++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/mocked_tools.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple +import uuid +from threading import Lock +from typing import Any, Dict, List, Tuple, Type from unittest.mock import MagicMock import numpy as np import numpy.typing as npt +from pydantic import BaseModel, ValidationError from rai.communication.ros2.connectors import ROS2ARIConnector from rai.communication.ros2.messages import ROS2ARIMessage from rai.messages import MultimodalArtifact, preprocess_image @@ -26,11 +29,23 @@ MoveToPointTool, ) from rai.tools.ros2 import ( + CallROS2ServiceTool, + CancelROS2ActionTool, + GetROS2ActionFeedbackTool, + GetROS2ActionIDsTool, + GetROS2ActionResultTool, + GetROS2ActionsNamesAndTypesTool, GetROS2ImageTool, + GetROS2MessageInterfaceTool, + GetROS2ServicesNamesAndTypesTool, GetROS2TopicsNamesAndTypesTool, + PublishROS2MessageTool, ReceiveROS2MessageTool, + StartROS2ActionTool, ) +from rai_bench.tool_calling_agent_bench.actions.action_base_model import ActionBaseModel + class MockGetROS2TopicsNamesAndTypesTool(GetROS2TopicsNamesAndTypesTool): connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector) @@ -49,7 +64,7 @@ def _run(self) -> str: class MockGetROS2ImageTool(GetROS2ImageTool): connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector) - expected_topics: List[str] + available_topics: List[str] def _run( self, topic: str, timeout_sec: float = 1.0 @@ -73,7 +88,7 @@ def _run( ValueError If the passed topic is not correct. """ - if topic not in self.expected_topics: + if topic not in self.available_topics: raise ValueError( f"Topic {topic} is not available within {timeout_sec} seconds. Check if the topic exists." ) @@ -98,7 +113,7 @@ def generate_mock_image() -> npt.NDArray[np.uint8]: class MockReceiveROS2MessageTool(ReceiveROS2MessageTool): connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector) - expected_topics: List[str] + available_topics: List[str] def _run(self, topic: str) -> str: """Method that returns a mock message if the passed topic is correct. @@ -118,7 +133,7 @@ def _run(self, topic: str) -> str: ValueError If the passed topic is not correct. """ - if topic not in self.expected_topics: + if topic not in self.available_topics: raise ValueError( f"Topic {topic} is not available within 1.0 seconds. Check if the topic exists." ) @@ -181,3 +196,217 @@ def _run(self, object_name: str) -> str: return f"No {object_name}s detected." else: return f"Centroids of detected {object_name}s in manipulator frame: {expected_positions} Sizes of the detected objects are unknown." + + +class MockPublishROS2MessageTool(PublishROS2MessageTool): + connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector) + available_topics: List[str] + available_message_types: List[str] + available_topic_models: Dict[str, Type[BaseModel]] + + def _run(self, topic: str, message: Dict[str, Any], message_type: str) -> str: + """ + Mocked method that simulates publihing to a topic and return a status string. + + Parameters + ---------- + topic : str + The name of the topic to which the message is published. + message : Dict[str, Any] + The content of the message as a dictionary. + message_type : str + The type of the message being published. + + """ + if topic not in self.available_topics: + raise ValueError( + f"Topic {topic} is not available within 1.0 seconds. Check if the topic exists." + ) + if message_type not in self.available_message_types: + raise TypeError( + "Expected message one of message types: {}, got {}".format( + self.available_message_types, message_type + ) + ) + + model = self.available_topic_models[message_type] + try: + model.model_validate(message) + except ValidationError as e: + raise ValueError(f"Failed to populate fields: {e}") + + return "Message published successfully" + + +class MockGetROS2MessageInterfaceTool(GetROS2MessageInterfaceTool): + connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector) + mock_interfaces: Dict[str, str] + + def _run(self, msg_type: str) -> str: + """ + Mocked method that returns the interface definition for a given ROS2 message type. + + Parameters + ---------- + msg_type : str + The ROS2 message type for which to retrieve the interface definition. + + Returns + ------- + str + The mocked output of 'ros2 interface show' for the specified message type. + """ + if msg_type in self.mock_interfaces: + return self.mock_interfaces[msg_type] + else: + raise ImportError(f"Module {msg_type} not found.") + + +class MockCallROS2ServiceTool(CallROS2ServiceTool): + connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector) + available_services: List[str] + available_service_types: List[str] + available_service_models: Dict[str, Type[BaseModel]] + + def _run( + self, + service_name: str, + service_type: str, + service_args: Dict[str, Any], + ) -> str: + if service_name not in self.available_services: + raise ValueError( + f"Service {service_name} is not available within 1.0 seconds. Check if the service exists." + ) + if service_type not in self.available_service_types: + raise TypeError( + "Expected one of service types: {}, got {}".format( + self.available_service_types, service_type + ) + ) + if service_type in self.available_service_models: + model = self.available_service_models[service_type] + try: + model.model_validate(service_args) + except ValidationError as e: + raise ValueError(f"Failed to populate fields: {e}") + response = ROS2ARIMessage(payload={"response": "success"}) + return str( + { + "payload": response.payload, + "metadata": response.metadata, + } + ) + else: + raise KeyError( + f"Model for service type {service_type} not included in models" + ) + + +class MockGetROS2ServicesNamesAndTypesTool(GetROS2ServicesNamesAndTypesTool): + connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector) + mock_service_names_and_types: list[str] + + def _run(self) -> str: + """Mocked method that returns the mock topics and types instead of fetching from ROS2. + + Returns + ------- + str + Mocked output of 'get_ros2_topics_names_and_types' tool. + """ + return "\n".join(self.mock_service_names_and_types) + + +class MockGetROS2ActionsNamesAndTypesTool(GetROS2ActionsNamesAndTypesTool): + connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector) + mock_actions_names_and_types: list[str] + + def _run(self) -> str: + """Mocked method that returns the mock topics and types instead of fetching from ROS2. + + Returns + ------- + str + Mocked output of 'get_ros2_topics_names_and_types' tool. + """ + return "\n".join(self.mock_actions_names_and_types) + + +class MockStartROS2ActionTool(StartROS2ActionTool): + connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector) + available_actions: List[str] = [] + available_action_types: List[str] = [] + available_action_models: List[type[ActionBaseModel]] + + def _run( + self, action_name: str, action_type: str, action_args: Dict[str, Any] + ) -> str: + if action_name not in self.available_actions: + raise ValueError( + f"Action {action_name} is not available within 1.0 seconds. Check if the action exists." + ) + if action_type not in self.available_action_types: + raise TypeError( + f"Expected one of action types: {self.available_action_types}, got {action_type}" + ) + for action_model in self.available_action_models: + if ( + action_model.model_fields["action_name"].default == action_name + and action_model.model_fields["action_type"].default == action_type + ): + goal = action_model.__annotations__["goal"] + goal.model_validate(action_args) + action_id = str(uuid.uuid4()) + response = action_id + self.internal_action_id_mapping[response] = action_id + return "Action started with ID: " + response + + +class MockCancelROS2ActionTool(CancelROS2ActionTool): + connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector) + available_action_ids: List[str] = [] + + def _run(self, action_id: str) -> str: + if action_id not in self.available_action_ids: + raise ValueError(f"Action {action_id} is not available for cancellation.") + return f"Action {action_id} cancelled" + + +class MockGetROS2ActionFeedbackTool(GetROS2ActionFeedbackTool): + connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector) + available_feedbacks: Dict[str, List[Any]] = {} + internal_action_id_mapping: Dict[str, str] = {} + action_feedbacks_store_lock: Lock = Lock() + + def _run(self, action_id: str) -> str: + if action_id not in self.internal_action_id_mapping: + raise KeyError(f"Action ID {action_id} not found in internal mapping.") + external_id = self.internal_action_id_mapping[action_id] + with self.action_feedbacks_store_lock: + feedbacks = self.available_feedbacks.get(external_id, []) + self.available_feedbacks[external_id] = [] + return str(feedbacks) + + +class MockGetROS2ActionResultTool(GetROS2ActionResultTool): + available_results: Dict[str, Any] = {} + internal_action_id_mapping: Dict[str, str] = {} + action_results_store_lock: Lock = Lock() + + def _run(self, action_id: str) -> str: + if action_id not in self.internal_action_id_mapping: + raise KeyError(f"Action ID {action_id} not found in internal mapping.") + external_id = self.internal_action_id_mapping[action_id] + with self.action_results_store_lock: + if external_id not in self.available_results: + raise ValueError(f"No result available for action {action_id}") + result = self.available_results[external_id] + return str(result) + + +class MockGetROS2ActionIDsTool(GetROS2ActionIDsTool): + internal_action_id_mapping: Dict[str, str] = {} + + def _run(self) -> str: + return str(list(self.internal_action_id_mapping.keys())) diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/ros2_agent_tasks.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/ros2_agent_tasks.py index 794a98fcb..1451ee2d4 100644 --- a/src/rai_bench/rai_bench/tool_calling_agent_bench/ros2_agent_tasks.py +++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/ros2_agent_tasks.py @@ -24,8 +24,26 @@ from rai.tools.ros.manipulation import MoveToPointToolInput from rai_bench.tool_calling_agent_bench.agent_tasks_interfaces import ( + CustomInterfacesServiceTask, + CustomInterfacesTopicTask, ROS2ToolCallingAgentTask, ) +from rai_bench.tool_calling_agent_bench.messages.base import ( + BoundingBox2D, + Detection2D, + Header, + Orientation, + Point2D, + Pose, + Pose2D, + PoseStamped, + Position, + Time, +) +from rai_bench.tool_calling_agent_bench.messages.topics import ( + Image, + RAIDetectionArray, +) from rai_bench.tool_calling_agent_bench.mocked_tools import ( MockGetObjectPositionsTool, MockGetROS2ImageTool, @@ -38,8 +56,15 @@ PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT = """You are a ROS 2 expert helping a user with their ROS 2 questions. You have access to various tools that allow you to query the ROS 2 system. - Be proactive and use the tools to answer questions. - """ +Be proactive and use the tools to answer questions. + +Example of tool calls: +- get_ros2_message_interface, args: {'msg_type': 'geometry_msgs/msg/Twist'} +- publish_ros2_message, args: {'topic': '/cmd_vel', 'message_type': 'geometry_msgs/msg/Twist', 'message': {linear: {x: 0.5, y: 0.0, z: 0.0}, angular: {x: 0.0, y: 0.0, z: 1.0}}} + +- get_ros2_message_interface, args: {'msg_type': 'turtlesim/srv/TeleportAbsolute'} +- publish_ros2_message, args: {'topic': '/turtle1/teleport_absolute', 'message_type': 'turtlesim/srv/TeleportAbsolute', 'message': {x: 5.0, y: 2.0, theta: 1.57}} +""" class TaskParametrizationError(Exception): @@ -180,23 +205,31 @@ def verify_tool_calls(self, response: dict[str, Any]): class GetROS2RGBCameraTask(ROS2ToolCallingAgentTask): complexity = "easy" + topics_and_types: Dict[str, str] = { + "/attached_collision_object": "moveit_msgs/msg/AttachedCollisionObject", + "/camera_image_color": "sensor_msgs/msg/Image", + "/camera_image_depth": "sensor_msgs/msg/Image", + "/clock": "rosgraph_msgs/msg/Clock", + "/collision_object": "moveit_msgs/msg/CollisionObject", + "/color_camera_info": "sensor_msgs/msg/CameraInfo", + "/color_camera_info5": "sensor_msgs/msg/CameraInfo", + "/depth_camera_info5": "sensor_msgs/msg/CameraInfo", + "/depth_image5": "sensor_msgs/msg/Image", + } + def __init__(self, logger: loggers_type | None = None) -> None: super().__init__(logger=logger) + + topic_strings = [ + f"topic: {topic}\ntype: {msg_type}\n" + for topic, msg_type in self.topics_and_types.items() + ] + self.expected_tools: List[BaseTool] = [ MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", - "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n", - "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n", - "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", - "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n", - "topic: /color_camera_info\ntype: sensor_msgs/msg/CameraInfo\n", - "topic: /color_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n", - "topic: /depth_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n", - "topic: /depth_image5\ntype: sensor_msgs/msg/Image\n", - ] + mock_topics_names_and_types=topic_strings ), - MockGetROS2ImageTool(expected_topics=["/camera_image_color"]), + MockGetROS2ImageTool(available_topics=list(self.topics_and_types.keys())), ] def get_system_prompt(self) -> str: @@ -247,22 +280,28 @@ def verify_tool_calls(self, response: dict[str, Any]): class GetROS2DepthCameraTask(ROS2ToolCallingAgentTask): complexity = "easy" + topics_and_types: Dict[str, str] = { + "/attached_collision_object": "moveit_msgs/msg/AttachedCollisionObject", + "/camera_image_color": "sensor_msgs/msg/Image", + "/camera_image_depth": "sensor_msgs/msg/Image", + "/clock": "rosgraph_msgs/msg/Clock", + "/collision_object": "moveit_msgs/msg/CollisionObject", + "/color_camera_info": "sensor_msgs/msg/CameraInfo", + "/color_camera_info5": "sensor_msgs/msg/CameraInfo", + "/depth_camera_info5": "sensor_msgs/msg/CameraInfo", + } + def __init__(self, logger: loggers_type | None = None) -> None: super().__init__(logger=logger) + topic_strings = [ + f"topic: {topic}\ntype: {msg_type}\n" + for topic, msg_type in self.topics_and_types.items() + ] self.expected_tools: List[BaseTool] = [ MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", - "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n", - "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n", - "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", - "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n", - "topic: /color_camera_info\ntype: sensor_msgs/msg/CameraInfo\n", - "topic: /color_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n", - "topic: /depth_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n", - ] + mock_topics_names_and_types=topic_strings ), - MockGetROS2ImageTool(expected_topics=["/camera_image_depth"]), + MockGetROS2ImageTool(available_topics=list(self.topics_and_types.keys())), ] def get_system_prompt(self) -> str: @@ -274,7 +313,7 @@ def get_prompt(self) -> str: def verify_tool_calls(self, response: dict[str, Any]): """It is expected that the agent will request: 1. The tool that retrieves the ROS2 topic names and types to identify the depth image topic. - 2. The tool that retrieves the RGB image from the /camera_image_depth topic + 2. The tool that retrieves the depth image from the /camera_image_depth topic Parameters ---------- @@ -314,26 +353,30 @@ def verify_tool_calls(self, response: dict[str, Any]): class GetAllROS2RGBCamerasTask(ROS2ToolCallingAgentTask): complexity = "easy" + topics_and_types: Dict[str, str] = { + "/attached_collision_object": "moveit_msgs/msg/AttachedCollisionObject", + "/camera_image_color": "sensor_msgs/msg/Image", + "/camera_image_depth": "sensor_msgs/msg/Image", + "/clock": "rosgraph_msgs/msg/Clock", + "/collision_object": "moveit_msgs/msg/CollisionObject", + "/color_camera_info": "sensor_msgs/msg/CameraInfo", + "/color_camera_info5": "sensor_msgs/msg/CameraInfo", + "/color_image5": "sensor_msgs/msg/Image", + "/depth_camera_info5": "sensor_msgs/msg/CameraInfo", + "/depth_image5": "sensor_msgs/msg/Image", + } + def __init__(self, logger: loggers_type | None = None) -> None: super().__init__(logger=logger) + topic_strings = [ + f"topic: {topic}\ntype: {msg_type}\n" + for topic, msg_type in self.topics_and_types.items() + ] self.expected_tools: List[BaseTool] = [ MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", - "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n", - "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n", - "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", - "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n", - "topic: /color_camera_info\ntype: sensor_msgs/msg/CameraInfo\n", - "topic: /color_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n", - "topic: /color_image5\ntype: sensor_msgs/msg/Image\n", - "topic: /depth_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n", - "topic: /depth_image5\ntype: sensor_msgs/msg/Image\n", - ] - ), - MockGetROS2ImageTool( - expected_topics=["/camera_image_color", "/color_image5"] + mock_topics_names_and_types=topic_strings ), + MockGetROS2ImageTool(available_topics=list(self.topics_and_types.keys())), ] def get_prompt(self) -> str: @@ -382,7 +425,6 @@ def verify_tool_calls(self, response: dict[str, Any]): "optional_args": {"timeout_sec": None}, }, ] - self._check_multiple_tool_calls( message=ai_messages[1], expected_tool_calls=expected_tool_calls ) @@ -393,25 +435,29 @@ def verify_tool_calls(self, response: dict[str, Any]): class GetAllROS2DepthCamerasTask(ROS2ToolCallingAgentTask): complexity = "easy" + topics_and_types: Dict[str, str] = { + "/camera_image_color": "sensor_msgs/msg/Image", + "/camera_image_depth": "sensor_msgs/msg/Image", + "/clock": "rosgraph_msgs/msg/Clock", + "/collision_object": "moveit_msgs/msg/CollisionObject", + "/color_camera_info": "sensor_msgs/msg/CameraInfo", + "/color_camera_info5": "sensor_msgs/msg/CameraInfo", + "/color_image5": "sensor_msgs/msg/Image", + "/depth_camera_info5": "sensor_msgs/msg/CameraInfo", + "/depth_image5": "sensor_msgs/msg/Image", + } + def __init__(self, logger: loggers_type | None = None) -> None: super().__init__(logger=logger) + topic_strings = [ + f"topic: {topic}\ntype: {msg_type}\n" + for topic, msg_type in self.topics_and_types.items() + ] self.expected_tools: List[BaseTool] = [ MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n", - "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n", - "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", - "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n", - "topic: /color_camera_info\ntype: sensor_msgs/msg/CameraInfo\n", - "topic: /color_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n", - "topic: /color_image5\ntype: sensor_msgs/msg/Image\n", - "topic: /depth_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n", - "topic: /depth_image5\ntype: sensor_msgs/msg/Image\n", - ] - ), - MockGetROS2ImageTool( - expected_topics=["/camera_image_depth", "/depth_image5"] + mock_topics_names_and_types=topic_strings ), + MockGetROS2ImageTool(available_topics=list(self.topics_and_types.keys())), ] def get_prompt(self) -> str: @@ -460,7 +506,6 @@ def verify_tool_calls(self, response: dict[str, Any]): "optional_args": {"timeout_sec": None}, }, ] - self._check_multiple_tool_calls( message=ai_messages[1], expected_tool_calls=expected_tool_calls ) @@ -471,23 +516,33 @@ def verify_tool_calls(self, response: dict[str, Any]): class GetROS2MessageTask(ROS2ToolCallingAgentTask): complexity = "easy" + topics_and_types: Dict[str, str] = { + "/attached_collision_object": "moveit_msgs/msg/AttachedCollisionObject", + "/camera_image_color": "sensor_msgs/msg/Image", + "/camera_image_depth": "sensor_msgs/msg/Image", + "/clock": "rosgraph_msgs/msg/Clock", + "/collision_object": "moveit_msgs/msg/CollisionObject", + "/color_camera_info": "sensor_msgs/msg/CameraInfo", + "/color_camera_info5": "sensor_msgs/msg/CameraInfo", + "/depth_camera_info5": "sensor_msgs/msg/CameraInfo", + "/depth_image5": "sensor_msgs/msg/Image", + } + def __init__(self, logger: loggers_type | None = None) -> None: super().__init__(logger=logger) + + topic_strings = [ + f"topic: {topic}\ntype: {msg_type}\n" + for topic, msg_type in self.topics_and_types.items() + ] + self.expected_tools: List[BaseTool] = [ MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", - "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n", - "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n", - "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", - "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n", - "topic: /color_camera_info\ntype: sensor_msgs/msg/CameraInfo\n", - "topic: /color_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n", - "topic: /depth_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n", - "topic: /depth_image5\ntype: sensor_msgs/msg/Image\n", - ] + mock_topics_names_and_types=topic_strings + ), + MockReceiveROS2MessageTool( + available_topics=list(self.topics_and_types.keys()) ), - MockReceiveROS2MessageTool(expected_topics=["/camera_image_color"]), ] def get_system_prompt(self) -> str: @@ -498,7 +553,7 @@ def get_prompt(self) -> str: def verify_tool_calls(self, response: dict[str, Any]): """It is expected that the agent will request: - 1. The tool that retrieves the ROS2 topics names and types to recognize the RGB image topic + 1. The tool that retrieves the ROS2 topics names and types to recognize the RGB image topic. 2. The tool that retrieves the RGB image from the /camera_image_color topic Parameters @@ -515,6 +570,7 @@ def verify_tool_calls(self, response: dict[str, Any]): self.log_error( msg=f"Expected at least 3 AI messages, but got {len(ai_messages)}." ) + if ai_messages: if not self._is_ai_message_requesting_get_ros2_topics_and_types( ai_messages[0] @@ -522,6 +578,7 @@ def verify_tool_calls(self, response: dict[str, Any]): self.log_error( msg="First AI message did not request ROS2 topics and types correctly." ) + if len(ai_messages) > 1: if self._check_tool_calls_num_in_ai_message(ai_messages[1], expected_num=1): self._check_tool_call( @@ -529,6 +586,7 @@ def verify_tool_calls(self, response: dict[str, Any]): expected_name="receive_ros2_message", expected_args={"topic": "/camera_image_color"}, ) + if not self.result.errors: self.result.success = True @@ -536,20 +594,30 @@ def verify_tool_calls(self, response: dict[str, Any]): class GetRobotDescriptionTask(ROS2ToolCallingAgentTask): complexity = "easy" + topics_and_types: Dict[str, str] = { + "/pointcloud": "sensor_msgs/msg/PointCloud2", + "/robot_description": "std_msgs/msg/String", + "/rosout": "rcl_interfaces/msg/Log", + "/tf": "tf2_msgs/msg/TFMessage", + "/tf_static": "tf2_msgs/msg/TFMessage", + "/trajectory_execution_event": "std_msgs/msg/String", + } + def __init__(self, logger: loggers_type | None = None) -> None: super().__init__(logger=logger) + + topic_strings = [ + f"topic: {topic}\ntype: {msg_type}\n" + for topic, msg_type in self.topics_and_types.items() + ] + self.expected_tools: List[BaseTool] = [ MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /pointcloud\ntype: sensor_msgs/msg/PointCloud2\n", - "topic: /robot_description\ntype: std_msgs/msg/String\n", - "topic: /rosout\ntype: rcl_interfaces/msg/Log\n", - "topic: /tf\ntype: tf2_msgs/msg/TFMessage\n", - "topic: /tf_static\ntype: tf2_msgs/msg/TFMessage\n", - "topic: /trajectory_execution_event\ntype: std_msgs/msg/String\n", - ] + mock_topics_names_and_types=topic_strings + ), + MockReceiveROS2MessageTool( + available_topics=list(self.topics_and_types.keys()) ), - MockReceiveROS2MessageTool(expected_topics=["/robot_description"]), ] def get_system_prompt(self) -> str: @@ -591,6 +659,7 @@ def verify_tool_calls(self, response: dict[str, Any]): expected_name="receive_ros2_message", expected_args={"topic": "/robot_description"}, ) + if not self.result.errors: self.result.success = True @@ -598,20 +667,30 @@ def verify_tool_calls(self, response: dict[str, Any]): class GetPointcloudTask(ROS2ToolCallingAgentTask): complexity = "easy" + topics_and_types: Dict[str, str] = { + "/pointcloud": "sensor_msgs/msg/PointCloud2", + "/robot_description": "std_msgs/msg/String", + "/rosout": "rcl_interfaces/msg/Log", + "/tf": "tf2_msgs/msg/TFMessage", + "/tf_static": "tf2_msgs/msg/TFMessage", + "/trajectory_execution_event": "std_msgs/msg/String", + } + def __init__(self, logger: loggers_type | None = None) -> None: super().__init__(logger=logger) + + topic_strings = [ + f"topic: {topic}\ntype: {msg_type}\n" + for topic, msg_type in self.topics_and_types.items() + ] + self.expected_tools: List[BaseTool] = [ MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /pointcloud\ntype: sensor_msgs/msg/PointCloud2\n", - "topic: /robot_description\ntype: std_msgs/msg/String\n", - "topic: /rosout\ntype: rcl_interfaces/msg/Log\n", - "topic: /tf\ntype: tf2_msgs/msg/TFMessage\n", - "topic: /tf_static\ntype: tf2_msgs/msg/TFMessage\n", - "topic: /trajectory_execution_event\ntype: std_msgs/msg/String\n", - ] + mock_topics_names_and_types=topic_strings + ), + MockReceiveROS2MessageTool( + available_topics=list(self.topics_and_types.keys()) ), - MockReceiveROS2MessageTool(expected_topics=["/pointcloud"]), ] def get_system_prompt(self) -> str: @@ -639,6 +718,7 @@ def verify_tool_calls(self, response: dict[str, Any]): self.log_error( msg=f"Expected at least 3 AI messages, but got {len(ai_messages)}." ) + if ai_messages: if not self._is_ai_message_requesting_get_ros2_topics_and_types( ai_messages[0] @@ -646,6 +726,7 @@ def verify_tool_calls(self, response: dict[str, Any]): self.log_error( msg="First AI message did not request ROS2 topics and types correctly." ) + if len(ai_messages) > 1: if self._check_tool_calls_num_in_ai_message(ai_messages[1], expected_num=1): self._check_tool_call( @@ -653,6 +734,7 @@ def verify_tool_calls(self, response: dict[str, Any]): expected_name="receive_ros2_message", expected_args={"topic": "/pointcloud"}, ) + if not self.result.errors: self.result.success = True @@ -660,21 +742,30 @@ def verify_tool_calls(self, response: dict[str, Any]): class MoveToPointTask(ROS2ToolCallingAgentTask): complexity = "easy" + topics_and_types: Dict[str, str] = { + "/pointcloud": "sensor_msgs/msg/PointCloud2", + "/robot_description": "std_msgs/msg/String", + "/rosout": "rcl_interfaces/msg/Log", + "/tf": "tf2_msgs/msg/TFMessage", + } + def __init__( self, args: Dict[str, Any], logger: loggers_type | None = None ) -> None: super().__init__(logger=logger) + + topic_strings = [ + f"topic: {topic}\ntype: {msg_type}\n" + for topic, msg_type in self.topics_and_types.items() + ] + self.expected_tools: List[BaseTool] = [ MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /pointcloud\ntype: sensor_msgs/msg/PointCloud2\n", - "topic: /robot_description\ntype: std_msgs/msg/String\n", - "topic: /rosout\ntype: rcl_interfaces/msg/Log\n", - "topic: /tf\ntype: tf2_msgs/msg/TFMessage\n", - ] + mock_topics_names_and_types=topic_strings ), MockMoveToPointTool(manipulator_frame="base_link"), ] + self.args = MoveToPointToolInput(**args) def get_system_prompt(self) -> str: @@ -725,6 +816,13 @@ def verify_tool_calls(self, response: dict[str, Any]): class GetObjectPositionsTask(ROS2ToolCallingAgentTask): complexity = "easy" + topics_and_types: Dict[str, str] = { + "/pointcloud": "sensor_msgs/msg/PointCloud2", + "/robot_description": "std_msgs/msg/String", + "/rosout": "rcl_interfaces/msg/Log", + "/tf": "tf2_msgs/msg/TFMessage", + } + def __init__( self, objects: Dict[str, List[dict[str, float]]], @@ -747,17 +845,19 @@ def __init__( } """ super().__init__(logger=logger) + + topic_strings = [ + f"topic: {topic}\ntype: {msg_type}\n" + for topic, msg_type in self.topics_and_types.items() + ] + self.expected_tools: List[BaseTool] = [ MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /pointcloud\ntype: sensor_msgs/msg/PointCloud2\n", - "topic: /robot_description\ntype: std_msgs/msg/String\n", - "topic: /rosout\ntype: rcl_interfaces/msg/Log\n", - "topic: /tf\ntype: tf2_msgs/msg/TFMessage\n", - ] + mock_topics_names_and_types=topic_strings ), MockGetObjectPositionsTool(mock_objects=objects), ] + self.objects = objects def get_system_prompt(self) -> str: @@ -809,6 +909,7 @@ def verify_tool_calls(self, response: dict[str, Any]): for object_type in self.objects ], ) + if not self.result.errors: self.result.success = True @@ -835,6 +936,14 @@ class GrabExistingObjectTask(ROS2ToolCallingAgentTask): } object_to_grab = "cube" """ + topics_and_types: Dict[str, str] = { + "/attached_collision_object": "moveit_msgs/msg/AttachedCollisionObject", + "/camera_image_color": "sensor_msgs/msg/Image", + "/camera_image_depth": "sensor_msgs/msg/Image", + "/clock": "rosgraph_msgs/msg/Clock", + "/collision_object": "moveit_msgs/msg/CollisionObject", + "/color_camera_info": "sensor_msgs/msg/CameraInfo", + } def __init__( self, @@ -843,16 +952,15 @@ def __init__( logger: loggers_type | None = None, ) -> None: super().__init__(logger=logger) + + topic_strings = [ + f"topic: {topic}\ntype: {msg_type}\n" + for topic, msg_type in self.topics_and_types.items() + ] + self.expected_tools: List[BaseTool] = [ MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", - "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n", - "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n", - "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", - "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n", - "topic: /color_camera_info\ntype: sensor_msgs/msg/CameraInfo\n", - ] + mock_topics_names_and_types=topic_strings ), MockGetObjectPositionsTool( target_frame="panda_link0", @@ -864,6 +972,7 @@ def __init__( ), MockMoveToPointTool(manipulator_frame="panda_link0"), ] + self.objects = objects self.object_to_grab = object_to_grab self._verify_args() @@ -905,6 +1014,7 @@ def verify_tool_calls(self, response: Dict[str, Any]): ai_messages: Sequence[AIMessage] = [ message for message in messages if isinstance(message, AIMessage) ] + expected_num_ai_messages = 3 if len(ai_messages) != expected_num_ai_messages: self.log_error( @@ -930,6 +1040,7 @@ def verify_tool_calls(self, response: Dict[str, Any]): expected_name="move_to_point", expected_args=obj_to_grab, ) + if not self.result.errors: self.result.success = True @@ -957,6 +1068,15 @@ class GrabNotExistingObjectTask(ROS2ToolCallingAgentTask): complexity = "medium" + topics_and_types: Dict[str, str] = { + "/attached_collision_object": "moveit_msgs/msg/AttachedCollisionObject", + "/camera_image_color": "sensor_msgs/msg/Image", + "/camera_image_depth": "sensor_msgs/msg/Image", + "/clock": "rosgraph_msgs/msg/Clock", + "/collision_object": "moveit_msgs/msg/CollisionObject", + "/color_camera_info": "sensor_msgs/msg/CameraInfo", + } + def __init__( self, objects: Dict[str, List[dict[str, float]]], @@ -964,16 +1084,15 @@ def __init__( logger: loggers_type | None = None, ) -> None: super().__init__(logger=logger) + + topic_strings = [ + f"topic: {topic}\ntype: {msg_type}\n" + for topic, msg_type in self.topics_and_types.items() + ] + self.expected_tools: List[BaseTool] = [ MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", - "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n", - "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n", - "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", - "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n", - "topic: /color_camera_info\ntype: sensor_msgs/msg/CameraInfo\n", - ] + mock_topics_names_and_types=topic_strings ), MockGetObjectPositionsTool( target_frame="panda_link0", @@ -985,6 +1104,7 @@ def __init__( ), MockMoveToPointTool(manipulator_frame="panda_link0"), ] + self.objects = objects self.object_to_grab = object_to_grab self._verify_args() @@ -1020,11 +1140,13 @@ def verify_tool_calls(self, response: Dict[str, Any]): ai_messages: Sequence[AIMessage] = [ message for message in messages if isinstance(message, AIMessage) ] + expected_num_ai_messages = 2 if len(ai_messages) != expected_num_ai_messages: self.log_error( msg=f"Expected {expected_num_ai_messages} AI messages, but got {len(ai_messages)}." ) + if ai_messages: if self._check_tool_calls_num_in_ai_message(ai_messages[0], expected_num=1): self._check_tool_call( @@ -1060,6 +1182,14 @@ class MoveExistingObjectLeftTask(ROS2ToolCallingAgentTask): complexity = "medium" + topics_and_types: Dict[str, str] = { + "/attached_collision_object": "moveit_msgs/msg/AttachedCollisionObject", + "/camera_image_color": "sensor_msgs/msg/Image", + "/camera_image_depth": "sensor_msgs/msg/Image", + "/clock": "rosgraph_msgs/msg/Clock", + "/collision_object": "moveit_msgs/msg/CollisionObject", + } + def __init__( self, objects: Dict[str, List[dict[str, float]]], @@ -1067,15 +1197,15 @@ def __init__( logger: loggers_type | None = None, ) -> None: super().__init__(logger=logger) + + topic_strings = [ + f"topic: {topic}\ntype: {msg_type}\n" + for topic, msg_type in self.topics_and_types.items() + ] + self.expected_tools: List[BaseTool] = [ MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", - "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n", - "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n", - "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", - "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n", - ] + mock_topics_names_and_types=topic_strings ), MockGetObjectPositionsTool( target_frame="panda_link0", @@ -1087,6 +1217,7 @@ def __init__( ), MockMoveToPointTool(manipulator_frame="panda_link0"), ] + self.objects = objects self.object_to_grab = object_to_grab self._verify_args() @@ -1130,11 +1261,13 @@ def verify_tool_calls(self, response: Dict[str, Any]): ai_messages: Sequence[AIMessage] = [ message for message in messages if isinstance(message, AIMessage) ] + expected_num_ai_messages = 4 if len(ai_messages) != expected_num_ai_messages: self.log_error( msg=f"Expected {expected_num_ai_messages} AI messages, but got {len(ai_messages)}." ) + if ai_messages: if self._check_tool_calls_num_in_ai_message(ai_messages[0], expected_num=1): self._check_tool_call( @@ -1187,6 +1320,12 @@ class MoveExistingObjectFrontTask(ROS2ToolCallingAgentTask): complexity = "medium" + topics_and_types: Dict[str, str] = { + "/attached_collision_object": "moveit_msgs/msg/AttachedCollisionObject", + "/camera_image_color": "sensor_msgs/msg/Image", + "/camera_image_depth": "sensor_msgs/msg/Image", + } + def __init__( self, objects: Dict[str, List[dict[str, float]]], @@ -1194,13 +1333,15 @@ def __init__( logger: loggers_type | None = None, ) -> None: super().__init__(logger=logger) + + topic_strings = [ + f"topic: {topic}\ntype: {msg_type}\n" + for topic, msg_type in self.topics_and_types.items() + ] + self.expected_tools: List[BaseTool] = [ MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", - "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n", - "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n", - ] + mock_topics_names_and_types=topic_strings ), MockGetObjectPositionsTool( target_frame="panda_link0", @@ -1212,6 +1353,7 @@ def __init__( ), MockMoveToPointTool(manipulator_frame="panda_link0"), ] + self.objects = objects self.object_to_grab = object_to_grab self._verify_args() @@ -1255,6 +1397,7 @@ def verify_tool_calls(self, response: Dict[str, Any]): ai_messages: Sequence[AIMessage] = [ message for message in messages if isinstance(message, AIMessage) ] + expected_num_ai_messages = 4 if len(ai_messages) != expected_num_ai_messages: self.log_error( @@ -1323,6 +1466,12 @@ class SwapObjectsTask(ROS2ToolCallingAgentTask): complexity = "hard" + topics_and_types: Dict[str, str] = { + "/attached_collision_object": "moveit_msgs/msg/AttachedCollisionObject", + "/camera_image_color": "sensor_msgs/msg/Image", + "/camera_image_depth": "sensor_msgs/msg/Image", + } + def __init__( self, objects: Dict[str, List[Dict[str, float]]], @@ -1330,13 +1479,15 @@ def __init__( logger: loggers_type | None = None, ) -> None: super().__init__(logger=logger) + + topic_strings = [ + f"topic: {topic}\ntype: {msg_type}\n" + for topic, msg_type in self.topics_and_types.items() + ] + self.expected_tools: List[BaseTool] = [ MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", - "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n", - "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n", - ] + mock_topics_names_and_types=topic_strings ), MockGetObjectPositionsTool( target_frame="panda_link0", @@ -1348,6 +1499,7 @@ def __init__( ), MockMoveToPointTool(manipulator_frame="panda_link0"), ] + self.objects = objects self.objects_to_swap = objects_to_swap self._verify_args() @@ -1392,7 +1544,6 @@ def verify_tool_calls(self, response: Dict[str, Any]): response : Dict[str, Any] The response from the agent """ - messages = response["messages"] ai_messages: Sequence[AIMessage] = [ message for message in messages if isinstance(message, AIMessage) @@ -1512,3 +1663,1125 @@ def _matches_sequence( any(call["name"] == e["name"] and call["args"] == e["args"] for call in it) for e in expected_tool_calls_seq ) + + +class PublishROS2HRIMessageTask3ExtraCalls(CustomInterfacesTopicTask): + complexity = "easy" + expected_text = "Hello!" + + def get_system_prompt(self) -> str: + return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT + + @property + def expected_topic(self) -> str: + return "/to_human" + + @property + def extra_calls(self) -> int: + return 3 + + def verify_message_tool_call(self, tool_calls: list[ToolCall]) -> bool: + for call in tool_calls: + if self._check_topic_tool_call_field( + tool_call=call, + expected_name="publish_ros2_message", + expected_topic=self.expected_topic, + expected_message_type=self.expected_message_type, + field_path="text", + expected_value=self.expected_text, + ): + return True + + self.log_error(f"No valid call to {self.expected_topic} found.") + return False + + def get_prompt(self) -> str: + return ( + f"You need to publish a message to the topic '{self.expected_topic}' with the text value: '{self.expected_text}'.\n" + "Before publishing, follow these steps:\n" + "1. Use the tool to retrieve the available ROS2 topics and their message types.\n" + f"2. Find the message type for the topic '{self.expected_topic}'.\n" + "3. Retrieve the full message interface definition for that type.\n" + "4. Construct the message filling only the fields you are instructed to. Rest of the fields will have default values.\n" + f"5. Publish the message to '{self.expected_topic}' using the correct message type and interface.\n" + ) + + +class PublishROS2HRIMessageTask1ExtraCall(PublishROS2HRIMessageTask3ExtraCalls): + complexity = "medium" + + @property + def extra_calls(self) -> int: + return 1 + + +class PublishROS2HRIMessageTask0ExtraCalls(PublishROS2HRIMessageTask3ExtraCalls): + complexity = "hard" + + @property + def extra_calls(self) -> int: + return 0 + + +class PublishROS2AudioMessageTask3ExtraCalls(CustomInterfacesTopicTask): + complexity = "easy" + expected_audio: List[int] = [123, 456, 789] + expected_sample_rate: int = 44100 + expected_channels: int = 2 + + def get_system_prompt(self) -> str: + return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT + + @property + def expected_topic(self) -> str: + return "/send_audio" + + @property + def extra_calls(self) -> int: + return 3 + + def verify_message_tool_call(self, tool_calls: list[ToolCall]) -> bool: + for call in tool_calls: + if ( + self._check_topic_tool_call_field( + tool_call=call, + expected_name="publish_ros2_message", + expected_topic=self.expected_topic, + expected_message_type=self.expected_message_type, + field_path="audio", + expected_value=self.expected_audio, + ) + and self._check_topic_tool_call_field( + tool_call=call, + expected_name="publish_ros2_message", + expected_topic=self.expected_topic, + expected_message_type=self.expected_message_type, + field_path="sample_rate", + expected_value=self.expected_sample_rate, + ) + and self._check_topic_tool_call_field( + tool_call=call, + expected_name="publish_ros2_message", + expected_topic=self.expected_topic, + expected_message_type=self.expected_message_type, + field_path="channels", + expected_value=self.expected_channels, + ) + ): + return True + + self.log_error(f"No valid call to {self.expected_topic} found.") + return False + + def get_prompt(self) -> str: + return ( + f"You need to publish a message to the topic '{self.expected_topic}' with audio samples {self.expected_audio}, " + f"sample rate {self.expected_sample_rate}, and {self.expected_channels} channels.\n" + "Before publishing, follow these steps:\n" + "1. Use the tool to retrieve the available ROS2 topics and their message types.\n" + f"2. Find the message type for the topic '{self.expected_topic}'.\n" + "3. Retrieve the full message interface definition for that type.\n" + "4. Construct the message filling only the fields you are instructed to. Rest of the fields will have default values.\n" + f"5. Publish the message to '{self.expected_topic}' using the correct message type and interface.\n" + ) + + +class PublishROS2AudioMessageTask1ExtraCall(PublishROS2AudioMessageTask3ExtraCalls): + complexity = "medium" + + @property + def extra_calls(self) -> int: + return 1 + + +class PublishROS2AudioMessageTask0ExtraCalls(PublishROS2AudioMessageTask3ExtraCalls): + complexity = "hard" + + @property + def extra_calls(self) -> int: + return 0 + + +class PublishROS2DetectionArrayTask3ExtraCalls(CustomInterfacesTopicTask): + complexity = "easy" + + expected_detection_classes: List[str] = ["person", "car"] + expected_detections: List[Detection2D] = [ + Detection2D( + bbox=BoundingBox2D( + center=Pose2D(position=Point2D(x=320.0, y=240.0), theta=0.0), + size_x=50.0, + size_y=50.0, + ) + ) + ] + + def get_system_prompt(self) -> str: + return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT + + @property + def expected_topic(self) -> str: + return "/send_detections" + + @property + def expected_message(self) -> RAIDetectionArray: + return RAIDetectionArray( + detections=self.expected_detections, + detection_classes=self.expected_detection_classes, + ) + + @property + def extra_calls(self) -> int: + return 3 + + def verify_message_tool_call(self, tool_calls: list[ToolCall]) -> bool: + for call in tool_calls: + if self._check_topic_tool_call_field( + tool_call=call, + expected_name="publish_ros2_message", + expected_topic=self.expected_topic, + expected_message_type=self.expected_message_type, + field_path="detection_classes", + expected_value=self.expected_detection_classes, + ): + return True + + self.log_error(f"No valid call to {self.expected_topic} found.") + return False + + def get_prompt(self) -> str: + return ( + f"You need to publish a detection message to the topic '{self.expected_topic}' with one detection:\n" + f"{self.expected_detections[0].model_dump()} and detection classes {self.expected_detection_classes}.\n" + "Before publishing, follow these steps:\n" + "1. Use the tool to retrieve the available ROS2 topics and their message types.\n" + f"2. Find the message type for the topic '{self.expected_topic}'.\n" + "3. Retrieve the full message interface definition for that type.\n" + "4. Construct the message filling only the fields you are instructed to. Rest of the fields will have default values.\n" + f"5. Publish the message to '{self.expected_topic}' using the correct message type and interface.\n" + ) + + +class PublishROS2DetectionArrayTask1ExtraCall(PublishROS2DetectionArrayTask3ExtraCalls): + complexity = "medium" + + @property + def extra_calls(self) -> int: + return 1 + + +class PublishROS2DetectionArrayTask0ExtraCalls( + PublishROS2DetectionArrayTask3ExtraCalls +): + complexity = "hard" + + @property + def extra_calls(self) -> int: + return 0 + + +class CallROS2ManipulatorMoveToServiceTask3ExtraCalls(CustomInterfacesServiceTask): + complexity = "easy" + + expected_initial_gripper_state = True + expected_final_gripper_state = False + expected_target_pose: PoseStamped = PoseStamped( + pose=Pose( + position=Position(x=1.0, y=2.0, z=3.0), + orientation=Orientation(x=0.0, y=0.0, z=0.0, w=1.0), + ) + ) + + def get_system_prompt(self) -> str: + return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT + + @property + def expected_service(self) -> str: + return "/manipulator_move_to" + + @property + def extra_calls(self) -> int: + return 3 + + def verify_message_tool_call(self, tool_calls: list[ToolCall]) -> bool: + for call in tool_calls: + if ( + self._check_service_tool_call_field( + tool_call=call, + expected_name="call_ros2_service", + expected_service=self.expected_service, + expected_service_type=self.expected_service_type, + field_path="initial_gripper_state", + expected_value=self.expected_initial_gripper_state, + ) + and self._check_service_tool_call_field( + tool_call=call, + expected_name="call_ros2_service", + expected_service=self.expected_service, + expected_service_type=self.expected_service_type, + field_path="final_gripper_state", + expected_value=self.expected_final_gripper_state, + ) + and self._check_service_tool_call_field( + tool_call=call, + expected_name="call_ros2_service", + expected_service=self.expected_service, + expected_service_type=self.expected_service_type, + field_path="target_pose", + expected_value=self.expected_target_pose.model_dump(), + ) + ): + return True + + self.log_error(f"No valid call to {self.expected_service} found.") + return False + + def get_prompt(self) -> str: + return ( + f"You need to call the service '{self.expected_service}' with a target_pose: " + f"{self.expected_target_pose.model_dump()} and gripper states (initial: {self.expected_initial_gripper_state}, final: {self.expected_final_gripper_state}).\n" + "Before calling, follow these steps:\n" + "1. Use the tool to retrieve the available ROS2 services and their types.\n" + f"2. Find the service type for '{self.expected_service}'.\n" + "3. Retrieve the full message interface definition for that service.\n" + "4. Construct the request message filling only the fields you are instructed to. Rest of the fields will have default values.\n" + f"5. Call the service '{self.expected_service}' using the correct message type and interface.\n" + ) + + +class CallROS2ManipulatorMoveToServiceTask1ExtraCall( + CallROS2ManipulatorMoveToServiceTask3ExtraCalls +): + complexity = "medium" + + @property + def extra_calls(self) -> int: + return 1 + + +class CallROS2ManipulatorMoveToServiceTask0ExtraCalls( + CallROS2ManipulatorMoveToServiceTask3ExtraCalls +): + complexity = "hard" + + @property + def extra_calls(self) -> int: + return 0 + + +class CallGroundedSAMSegmentTask3ExtraCalls(CustomInterfacesServiceTask): + complexity = "easy" + + expected_detections: RAIDetectionArray = RAIDetectionArray( + header=Header(stamp=Time(sec=0, nanosec=0), frame_id="camera_frame"), + detections=[], + ) + + def get_system_prompt(self) -> str: + return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT + + @property + def expected_service(self) -> str: + return "/grounded_sam_segment" + + @property + def extra_calls(self) -> int: + return 3 + + def verify_message_tool_call(self, tool_calls: list[ToolCall]) -> bool: + for call in tool_calls: + if self._check_service_tool_call_field( + tool_call=call, + expected_name="call_ros2_service", + expected_service=self.expected_service, + expected_service_type=self.expected_service_type, + field_path="detections", + expected_value=self.expected_detections.model_dump(), + ): + return True + self.log_error(f"No valid call to {self.expected_service} found.") + return False + + def get_prompt(self) -> str: + return ( + f"You need to call the service '{self.expected_service}' with detections: {self.expected_detections.model_dump()}\n" + "Before calling, follow these steps:\n" + "1. Use the tool to retrieve the available ROS2 services and their types.\n" + f"2. Find the service type for '{self.expected_service}'.\n" + "3. Retrieve the full message interface definition for that service.\n" + "4. Construct the request message filling only the fields you are instructed to. Rest of the fields will have default values.\n" + f"5. Call the service '{self.expected_service}' using the correct message type and interface.\n" + ) + + +class CallGroundedSAMSegmentTask1ExtraCall(CallGroundedSAMSegmentTask3ExtraCalls): + complexity = "medium" + + @property + def extra_calls(self) -> int: + return 1 + + +class CallGroundedSAMSegmentTask0ExtraCalls(CallGroundedSAMSegmentTask3ExtraCalls): + complexity = "hard" + + @property + def extra_calls(self) -> int: + return 0 + + +class CallGroundingDinoClassifyTask3ExtraCalls(CustomInterfacesServiceTask): + complexity = "easy" + + expected_classes: str = "bottle, book, chair" + expected_box_threshold: float = 0.4 + expected_text_threshold: float = 0.25 + + def get_system_prompt(self) -> str: + return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT + + @property + def extra_calls(self) -> int: + return 3 + + @property + def expected_service(self) -> str: + return "/grounding_dino_classify" + + def verify_message_tool_call(self, tool_calls: list[ToolCall]) -> bool: + for call in tool_calls: + if ( + self._check_service_tool_call_field( + call, + "call_ros2_service", + self.expected_service, + self.expected_service_type, + "classes", + self.expected_classes, + ) + and self._check_service_tool_call_field( + call, + "call_ros2_service", + self.expected_service, + self.expected_service_type, + "box_threshold", + self.expected_box_threshold, + ) + and self._check_service_tool_call_field( + call, + "call_ros2_service", + self.expected_service, + self.expected_service_type, + "text_threshold", + self.expected_text_threshold, + ) + ): + return True + + self.log_error(f"No valid call to {self.expected_service} found.") + return False + + def get_prompt(self) -> str: + return ( + f"You need to call the service '{self.expected_service}' with classes: '{self.expected_classes}', " + f"box_threshold: {self.expected_box_threshold}, text_threshold: {self.expected_text_threshold}, " + "Before calling, follow these steps:\n" + "1. Use the tool to retrieve the available ROS2 services and their types.\n" + f"2. Find the service type for '{self.expected_service}'.\n" + "3. Retrieve the full message interface definition for that service.\n" + "4. Construct the request message filling only the fields you are instructed to. Rest of the fields will have default values.\n" + f"5. Call the service '{self.expected_service}' using the correct message type and interface.\n" + ) + + +class CallGroundingDinoClassifyTask1ExtraCall(CallGroundingDinoClassifyTask3ExtraCalls): + complexity = "medium" + + @property + def extra_calls(self) -> int: + return 1 + + +class CallGroundingDinoClassifyTask0ExtraCalls( + CallGroundingDinoClassifyTask3ExtraCalls +): + complexity = "hard" + + @property + def extra_calls(self) -> int: + return 0 + + +class CallGetLogDigestTask3ExtraCalls(CustomInterfacesServiceTask): + complexity = "easy" + + def get_system_prompt(self) -> str: + return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT + + @property + def expected_service(self) -> str: + return "/get_log_digest" + + @property + def extra_calls(self) -> int: + return 3 + + def verify_message_tool_call(self, tool_calls: list[ToolCall]) -> bool: + for call in tool_calls: + if self._check_service_tool_call_field( + call, + "call_ros2_service", + self.expected_service, + self.expected_service_type, + field_path="", # empty request + expected_value="", + ): + return True + + self.log_error(f"No valid call to {self.expected_service} found.") + return False + + def get_prompt(self) -> str: + return ( + f"You need to call the service '{self.expected_service}' with an empty request.\n" + "Before calling, follow these steps:\n" + "1. Use the tool to retrieve the available ROS2 services and their types.\n" + f"2. Find the service type for '{self.expected_service}'.\n" + "3. Retrieve the full message interface definition for that service.\n" + "4. Construct the message filling only the fields you are instructed to. Rest of the fields will have default values.\n" + f"5. Call the service '{self.expected_service}' using the correct message type and interface.\n" + ) + + +class CallGetLogDigestTask1ExtraCall(CallGetLogDigestTask3ExtraCalls): + complexity = "medium" + + @property + def extra_calls(self) -> int: + return 1 + + +class CallGetLogDigestTask0ExtraCalls(CallGetLogDigestTask3ExtraCalls): + complexity = "hard" + + @property + def extra_calls(self) -> int: + return 0 + + +class CallVectorStoreRetrievalTask3ExtraCalls(CustomInterfacesServiceTask): + complexity = "easy" + + expected_query: str = "What is the purpose of this robot?" + + def get_system_prompt(self) -> str: + return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT + + @property + def expected_service(self) -> str: + return "/rai_whoami_documentation_service" + + @property + def extra_calls(self) -> int: + return 3 + + def verify_message_tool_call(self, tool_calls: list[ToolCall]) -> bool: + for call in tool_calls: + if self._check_service_tool_call_field( + call, + "call_ros2_service", + self.expected_service, + self.expected_service_type, + "query", + self.expected_query, + ): + return True + + self.log_error(f"No valid call to {self.expected_service} found.") + return False + + def get_prompt(self) -> str: + return ( + f"You need to call the service '{self.expected_service}' with the query: '{self.expected_query}'.\n" + "Before calling, follow these steps:\n" + "1. Use the tool to retrieve the available ROS2 services and their types.\n" + f"2. Find the service type for '{self.expected_service}'.\n" + "3. Retrieve the full message interface definition for that service.\n" + "4. Construct the request message filling only the fields you are instructed to. Rest of the fields will have default values.\n" + f"5. Call the service '{self.expected_service}' using the correct message type and interface.\n" + ) + + +class CallVectorStoreRetrievalTask1ExtraCall(CallVectorStoreRetrievalTask3ExtraCalls): + complexity = "medium" + + @property + def extra_calls(self) -> int: + return 1 + + +class CallVectorStoreRetrievalTask0ExtraCalls(CallVectorStoreRetrievalTask3ExtraCalls): + complexity = "hard" + + @property + def extra_calls(self) -> int: + return 0 + + +class CallWhatISeeTask3ExtraCalls(CustomInterfacesServiceTask): + complexity = "easy" + + expected_observations: List[str] = ["table", "cup", "notebook"] + expected_perception_source: str = "front_camera" + + expected_image: Image = Image( + header=Header(frame_id="camera_frame"), + height=480, + width=640, + ) + + expected_pose: Pose = Pose( + position=Position(x=1.0, y=2.0, z=0.5), + orientation=Orientation(x=0.0, y=0.0, z=0.0, w=1.0), + ) + + def get_system_prompt(self) -> str: + return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT + + @property + def expected_service(self) -> str: + return "/rai/whatisee/get" + + @property + def extra_calls(self) -> int: + return 3 + + def verify_message_tool_call(self, tool_calls: list[ToolCall]) -> bool: + for call in tool_calls: + if self._check_service_tool_call_field( + call, + "call_ros2_service", + self.expected_service, + self.expected_service_type, + field_path="", # empty request + expected_value="", + ): + return True + + self.log_error(f"No valid call to {self.expected_service} found.") + return False + + def get_prompt(self) -> str: + return ( + f"You need to call the service '{self.expected_service}' with an empty request.\n" + "Before calling, follow these steps:\n" + "1. Use the tool to retrieve the available ROS2 services and their types.\n" + f"2. Find the service type for '{self.expected_service}'.\n" + "3. Retrieve the full message interface definition for that service.\n" + "4. Construct the request message filling only the fields you are instructed to. Rest of the fields will have default values.\n" + f"5. Call the service '{self.expected_service}' using the correct message type and interface.\n" + ) + + +class CallWhatISeeTask1ExtraCall(CallWhatISeeTask3ExtraCalls): + complexity = "medium" + + @property + def extra_calls(self) -> int: + return 1 + + +class CallWhatISeeTask0ExtraCalls(CallWhatISeeTask3ExtraCalls): + complexity = "hard" + + @property + def extra_calls(self) -> int: + return 0 + + +# class CallROS2CustomActionTask(CustomInterfacesActionTask): +# complexity = "easy" + +# expected_task = "Where are you?" +# expected_description = "" +# expected_priority = "10" + +# def get_system_prompt(self) -> str: +# return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT + +# @property +# def expected_action(self) -> str: +# return "/perform_task" + +# @property +# def expected_message(self) -> Dict[str, Any]: +# expected = DEFAULT_MESSAGES[self.expected_action_type].copy() +# expected["goal"]["task"] = self.expected_task +# expected["goal"]["description"] = self.expected_description +# expected["goal"]["priority "] = self.expected_priority +# return expected + + +ROBOT_NAVIGATION_SYSTEM_PROMPT = """You are an autonomous robot connected to ros2 environment. Your main goal is to fulfill the user's requests. + Do not make assumptions about the environment you are currently in. + You can use ros2 topics, services and actions to operate. + + As a first step check transforms by getting 1 message from /tf topic + use /cmd_vel topic very carefully. Obstacle detection works only with nav2 stack, so be careful when it is not used. > + be patient with running ros2 actions. usually the take some time to run. + Always check your transform before and after you perform ros2 actions, so that you can verify if it worked. + + Navigation tips: + - it's good to start finding objects by rotating, then navigating to some diverse location with occasional rotations. Remember to frequency detect objects. + - for driving forward/backward or to some coordinates, ros2 actions are better. + - for driving for some specific time or in specific manner (like shaper or turns) it good to use /cmd_vel topic + - you are currently unable to read map or point-cloud, so please avoid subscribing to such topics. + - if you are asked to drive towards some object, it's good to: + 1. check the camera image and verify if objects can be seen + 2. if only driving forward is required, do it + 3. if obstacle avoidance might be required, use ros2 actions navigate_*, but first check your current position, then very accurately estimate the goal pose. + - it is good to verify using given information if the robot is not stuck + - navigation actions sometimes fail. Their output can be read from rosout. You can also tell if they partially worked by checking the robot position and rotation. + - before using any ros2 interfaces, always make sure to check you are using the right interface + - processing camera image takes 5-10s. Take it into account that if the robot is moving, the information can be outdated. Handle it by good planning of your movements. + - you are encouraged to use wait tool in between checking the status of actions + - to find some object navigate around and check the surrounding area + - when the goal is accomplished please make sure to cancel running actions + - when you reach the navigation goal - double check if you reached it by checking the current position + - if you detect collision, please stop operation + + - you will be given your camera image description. Based on this information you can reason about positions of objects. + - be careful and aboid obstacles + + Here are the corners of your environment: + (-2.76,9.04, 0.0), + (4.62, 9.07, 0.0), + (-2.79, -3.83, 0.0), + (4.59, -3.81, 0.0) + + This is location of places: + Kitchen: + (2.06, -0.23, 0.0), + (2.07, -1.43, 0.0), + (-2.44, -0.38, 0.0), + (-2.56, -1.47, 0.0) + + # Living room: + (-2.49, 1.87, 0.0), + (-2.50, 5.49, 0.0), + (0.79, 5.73, 0.0), + (0.92, 1.01, 0.0) + + Before starting anything, make sure to load available topics, services and actions. + """ +NAVIGATION_SERVICES_AND_TYPES: Dict[str, str] = { + "/assisted_teleop/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/assisted_teleop/_action/get_result": "nav2_msgs/action/AssistedTeleop_GetResult", + "/assisted_teleop/_action/send_goal": "nav2_msgs/action/AssistedTeleop_SendGoal", + "/backup/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/backup/_action/get_result": "nav2_msgs/action/BackUp_GetResult", + "/backup/_action/send_goal": "nav2_msgs/action/BackUp_SendGoal", + "/behavior_server/change_state": "lifecycle_msgs/srv/ChangeState", + "/behavior_server/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/behavior_server/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", + "/behavior_server/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", + "/behavior_server/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/behavior_server/get_parameters": "rcl_interfaces/srv/GetParameters", + "/behavior_server/get_state": "lifecycle_msgs/srv/GetState", + "/behavior_server/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", + "/behavior_server/list_parameters": "rcl_interfaces/srv/ListParameters", + "/behavior_server/set_parameters": "rcl_interfaces/srv/SetParameters", + "/behavior_server/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/bt_navigator/change_state": "lifecycle_msgs/srv/ChangeState", + "/bt_navigator/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/bt_navigator/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", + "/bt_navigator/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", + "/bt_navigator/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/bt_navigator/get_parameters": "rcl_interfaces/srv/GetParameters", + "/bt_navigator/get_state": "lifecycle_msgs/srv/GetState", + "/bt_navigator/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", + "/bt_navigator/list_parameters": "rcl_interfaces/srv/ListParameters", + "/bt_navigator/set_parameters": "rcl_interfaces/srv/SetParameters", + "/bt_navigator/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/bt_navigator_navigate_through_poses_rclcpp_node/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/bt_navigator_navigate_through_poses_rclcpp_node/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/bt_navigator_navigate_through_poses_rclcpp_node/get_parameters": "rcl_interfaces/srv/GetParameters", + "/bt_navigator_navigate_through_poses_rclcpp_node/list_parameters": "rcl_interfaces/srv/ListParameters", + "/bt_navigator_navigate_through_poses_rclcpp_node/set_parameters": "rcl_interfaces/srv/SetParameters", + "/bt_navigator_navigate_through_poses_rclcpp_node/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/bt_navigator_navigate_to_pose_rclcpp_node/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/bt_navigator_navigate_to_pose_rclcpp_node/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/bt_navigator_navigate_to_pose_rclcpp_node/get_parameters": "rcl_interfaces/srv/GetParameters", + "/bt_navigator_navigate_to_pose_rclcpp_node/list_parameters": "rcl_interfaces/srv/ListParameters", + "/bt_navigator_navigate_to_pose_rclcpp_node/set_parameters": "rcl_interfaces/srv/SetParameters", + "/bt_navigator_navigate_to_pose_rclcpp_node/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/compute_path_through_poses/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/compute_path_through_poses/_action/get_result": "nav2_msgs/action/ComputePathThroughPoses_GetResult", + "/compute_path_through_poses/_action/send_goal": "nav2_msgs/action/ComputePathThroughPoses_SendGoal", + "/compute_path_to_pose/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/compute_path_to_pose/_action/get_result": "nav2_msgs/action/ComputePathToPose_GetResult", + "/compute_path_to_pose/_action/send_goal": "nav2_msgs/action/ComputePathToPose_SendGoal", + "/controller_server/change_state": "lifecycle_msgs/srv/ChangeState", + "/controller_server/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/controller_server/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", + "/controller_server/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", + "/controller_server/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/controller_server/get_parameters": "rcl_interfaces/srv/GetParameters", + "/controller_server/get_state": "lifecycle_msgs/srv/GetState", + "/controller_server/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", + "/controller_server/list_parameters": "rcl_interfaces/srv/ListParameters", + "/controller_server/set_parameters": "rcl_interfaces/srv/SetParameters", + "/controller_server/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/drive_on_heading/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/drive_on_heading/_action/get_result": "nav2_msgs/action/DriveOnHeading_GetResult", + "/drive_on_heading/_action/send_goal": "nav2_msgs/action/DriveOnHeading_SendGoal", + "/follow_path/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/follow_path/_action/get_result": "nav2_msgs/action/FollowPath_GetResult", + "/follow_path/_action/send_goal": "nav2_msgs/action/FollowPath_SendGoal", + "/follow_waypoints/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/follow_waypoints/_action/get_result": "nav2_msgs/action/FollowWaypoints_GetResult", + "/follow_waypoints/_action/send_goal": "nav2_msgs/action/FollowWaypoints_SendGoal", + "/global_costmap/clear_around_global_costmap": "nav2_msgs/srv/ClearCostmapAroundRobot", + "/global_costmap/clear_entirely_global_costmap": "nav2_msgs/srv/ClearEntireCostmap", + "/global_costmap/clear_except_global_costmap": "nav2_msgs/srv/ClearCostmapExceptRegion", + "/global_costmap/get_costmap": "nav2_msgs/srv/GetCostmap", + "/global_costmap/global_costmap/change_state": "lifecycle_msgs/srv/ChangeState", + "/global_costmap/global_costmap/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/global_costmap/global_costmap/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", + "/global_costmap/global_costmap/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", + "/global_costmap/global_costmap/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/global_costmap/global_costmap/get_parameters": "rcl_interfaces/srv/GetParameters", + "/global_costmap/global_costmap/get_state": "lifecycle_msgs/srv/GetState", + "/global_costmap/global_costmap/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", + "/global_costmap/global_costmap/list_parameters": "rcl_interfaces/srv/ListParameters", + "/global_costmap/global_costmap/set_parameters": "rcl_interfaces/srv/SetParameters", + "/global_costmap/global_costmap/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/grounded_sam/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/grounded_sam/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/grounded_sam/get_parameters": "rcl_interfaces/srv/GetParameters", + "/grounded_sam/list_parameters": "rcl_interfaces/srv/ListParameters", + "/grounded_sam/set_parameters": "rcl_interfaces/srv/SetParameters", + "/grounded_sam/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/grounded_sam_segment": "rai_interfaces/srv/RAIGroundedSam", + "/grounding_dino/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/grounding_dino/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/grounding_dino/get_parameters": "rcl_interfaces/srv/GetParameters", + "/grounding_dino/list_parameters": "rcl_interfaces/srv/ListParameters", + "/grounding_dino/set_parameters": "rcl_interfaces/srv/SetParameters", + "/grounding_dino/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/grounding_dino_classify": "rai_interfaces/srv/RAIGroundingDino", + "/is_path_valid": "nav2_msgs/srv/IsPathValid", + "/launch_ros_138640/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/launch_ros_138640/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/launch_ros_138640/get_parameters": "rcl_interfaces/srv/GetParameters", + "/launch_ros_138640/list_parameters": "rcl_interfaces/srv/ListParameters", + "/launch_ros_138640/set_parameters": "rcl_interfaces/srv/SetParameters", + "/launch_ros_138640/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/lifecycle_manager_navigation/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/lifecycle_manager_navigation/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/lifecycle_manager_navigation/get_parameters": "rcl_interfaces/srv/GetParameters", + "/lifecycle_manager_navigation/is_active": "std_srvs/srv/Trigger", + "/lifecycle_manager_navigation/list_parameters": "rcl_interfaces/srv/ListParameters", + "/lifecycle_manager_navigation/manage_nodes": "nav2_msgs/srv/ManageLifecycleNodes", + "/lifecycle_manager_navigation/set_parameters": "rcl_interfaces/srv/SetParameters", + "/lifecycle_manager_navigation/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/lifecycle_manager_slam/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/lifecycle_manager_slam/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/lifecycle_manager_slam/get_parameters": "rcl_interfaces/srv/GetParameters", + "/lifecycle_manager_slam/is_active": "std_srvs/srv/Trigger", + "/lifecycle_manager_slam/list_parameters": "rcl_interfaces/srv/ListParameters", + "/lifecycle_manager_slam/manage_nodes": "nav2_msgs/srv/ManageLifecycleNodes", + "/lifecycle_manager_slam/set_parameters": "rcl_interfaces/srv/SetParameters", + "/lifecycle_manager_slam/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/local_costmap/clear_around_local_costmap": "nav2_msgs/srv/ClearCostmapAroundRobot", + "/local_costmap/clear_entirely_local_costmap": "nav2_msgs/srv/ClearEntireCostmap", + "/local_costmap/clear_except_local_costmap": "nav2_msgs/srv/ClearCostmapExceptRegion", + "/local_costmap/get_costmap": "nav2_msgs/srv/GetCostmap", + "/local_costmap/local_costmap/change_state": "lifecycle_msgs/srv/ChangeState", + "/local_costmap/local_costmap/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/local_costmap/local_costmap/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", + "/local_costmap/local_costmap/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", + "/local_costmap/local_costmap/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/local_costmap/local_costmap/get_parameters": "rcl_interfaces/srv/GetParameters", + "/local_costmap/local_costmap/get_state": "lifecycle_msgs/srv/GetState", + "/local_costmap/local_costmap/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", + "/local_costmap/local_costmap/list_parameters": "rcl_interfaces/srv/ListParameters", + "/local_costmap/local_costmap/set_parameters": "rcl_interfaces/srv/SetParameters", + "/local_costmap/local_costmap/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/map_saver/change_state": "lifecycle_msgs/srv/ChangeState", + "/map_saver/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/map_saver/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", + "/map_saver/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", + "/map_saver/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/map_saver/get_parameters": "rcl_interfaces/srv/GetParameters", + "/map_saver/get_state": "lifecycle_msgs/srv/GetState", + "/map_saver/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", + "/map_saver/list_parameters": "rcl_interfaces/srv/ListParameters", + "/map_saver/save_map": "nav2_msgs/srv/SaveMap", + "/map_saver/set_parameters": "rcl_interfaces/srv/SetParameters", + "/map_saver/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/nav2_container/_container/list_nodes": "composition_interfaces/srv/ListNodes", + "/nav2_container/_container/load_node": "composition_interfaces/srv/LoadNode", + "/nav2_container/_container/unload_node": "composition_interfaces/srv/UnloadNode", + "/navigate_through_poses/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/navigate_through_poses/_action/get_result": "nav2_msgs/action/NavigateThroughPoses_GetResult", + "/navigate_through_poses/_action/send_goal": "nav2_msgs/action/NavigateThroughPoses_SendGoal", + "/navigate_to_pose/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/navigate_to_pose/_action/get_result": "nav2_msgs/action/NavigateToPose_GetResult", + "/navigate_to_pose/_action/send_goal": "nav2_msgs/action/NavigateToPose_SendGoal", + "/o3de_ros2_node/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/o3de_ros2_node/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/o3de_ros2_node/get_parameters": "rcl_interfaces/srv/GetParameters", + "/o3de_ros2_node/list_parameters": "rcl_interfaces/srv/ListParameters", + "/o3de_ros2_node/set_parameters": "rcl_interfaces/srv/SetParameters", + "/o3de_ros2_node/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/planner_server/change_state": "lifecycle_msgs/srv/ChangeState", + "/planner_server/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/planner_server/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", + "/planner_server/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", + "/planner_server/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/planner_server/get_parameters": "rcl_interfaces/srv/GetParameters", + "/planner_server/get_state": "lifecycle_msgs/srv/GetState", + "/planner_server/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", + "/planner_server/list_parameters": "rcl_interfaces/srv/ListParameters", + "/planner_server/set_parameters": "rcl_interfaces/srv/SetParameters", + "/planner_server/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/rai_ros2_ari_connector_b6ed00ab6356/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/rai_ros2_ari_connector_b6ed00ab6356/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/rai_ros2_ari_connector_b6ed00ab6356/get_parameters": "rcl_interfaces/srv/GetParameters", + "/rai_ros2_ari_connector_b6ed00ab6356/list_parameters": "rcl_interfaces/srv/ListParameters", + "/rai_ros2_ari_connector_b6ed00ab6356/set_parameters": "rcl_interfaces/srv/SetParameters", + "/rai_ros2_ari_connector_b6ed00ab6356/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/slam_toolbox/clear_changes": "slam_toolbox/srv/Clear", + "/slam_toolbox/clear_queue": "slam_toolbox/srv/ClearQueue", + "/slam_toolbox/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/slam_toolbox/deserialize_map": "slam_toolbox/srv/DeserializePoseGraph", + "/slam_toolbox/dynamic_map": "nav_msgs/srv/GetMap", + "/slam_toolbox/get_interactive_markers": "visualization_msgs/srv/GetInteractiveMarkers", + "/slam_toolbox/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/slam_toolbox/get_parameters": "rcl_interfaces/srv/GetParameters", + "/slam_toolbox/list_parameters": "rcl_interfaces/srv/ListParameters", + "/slam_toolbox/manual_loop_closure": "slam_toolbox/srv/LoopClosure", + "/slam_toolbox/pause_new_measurements": "slam_toolbox/srv/Pause", + "/slam_toolbox/save_map": "slam_toolbox/srv/SaveMap", + "/slam_toolbox/serialize_map": "slam_toolbox/srv/SerializePoseGraph", + "/slam_toolbox/set_parameters": "rcl_interfaces/srv/SetParameters", + "/slam_toolbox/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/slam_toolbox/toggle_interactive_mode": "slam_toolbox/srv/ToggleInteractive", + "/smooth_path/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/smooth_path/_action/get_result": "nav2_msgs/action/SmoothPath_GetResult", + "/smooth_path/_action/send_goal": "nav2_msgs/action/SmoothPath_SendGoal", + "/smoother_server/change_state": "lifecycle_msgs/srv/ChangeState", + "/smoother_server/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/smoother_server/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", + "/smoother_server/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", + "/smoother_server/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/smoother_server/get_parameters": "rcl_interfaces/srv/GetParameters", + "/smoother_server/get_state": "lifecycle_msgs/srv/GetState", + "/smoother_server/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", + "/smoother_server/list_parameters": "rcl_interfaces/srv/ListParameters", + "/smoother_server/set_parameters": "rcl_interfaces/srv/SetParameters", + "/smoother_server/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/spin/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/spin/_action/get_result": "nav2_msgs/action/Spin_GetResult", + "/spin/_action/send_goal": "nav2_msgs/action/Spin_SendGoal", + "/tf2_frames": "tf2_msgs/srv/FrameGraph", + "/velocity_smoother/change_state": "lifecycle_msgs/srv/ChangeState", + "/velocity_smoother/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/velocity_smoother/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", + "/velocity_smoother/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", + "/velocity_smoother/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/velocity_smoother/get_parameters": "rcl_interfaces/srv/GetParameters", + "/velocity_smoother/get_state": "lifecycle_msgs/srv/GetState", + "/velocity_smoother/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", + "/velocity_smoother/list_parameters": "rcl_interfaces/srv/ListParameters", + "/velocity_smoother/set_parameters": "rcl_interfaces/srv/SetParameters", + "/velocity_smoother/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", + "/wait/_action/cancel_goal": "action_msgs/srv/CancelGoal", + "/wait/_action/get_result": "nav2_msgs/action/Wait_GetResult", + "/wait/_action/send_goal": "nav2_msgs/action/Wait_SendGoal", + "/waypoint_follower/change_state": "lifecycle_msgs/srv/ChangeState", + "/waypoint_follower/describe_parameters": "rcl_interfaces/srv/DescribeParameters", + "/waypoint_follower/get_available_states": "lifecycle_msgs/srv/GetAvailableStates", + "/waypoint_follower/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions", + "/waypoint_follower/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes", + "/waypoint_follower/get_parameters": "rcl_interfaces/srv/GetParameters", + "/waypoint_follower/get_state": "lifecycle_msgs/srv/GetState", + "/waypoint_follower/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions", + "/waypoint_follower/list_parameters": "rcl_interfaces/srv/ListParameters", + "/waypoint_follower/set_parameters": "rcl_interfaces/srv/SetParameters", + "/waypoint_follower/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically", +} + + +# class NavigateToPointTask(ROS2ToolCallingAgentTask): +# complexity = "medium" +# actions_and_types: Dict[str, str] = { +# "/assisted_teleop": "nav2_msgs/action/AssistedTeleop", +# "/backup": "nav2_msgs/action/BackUp", +# "/compute_path_through_poses": "nav2_msgs/action/ComputePathThroughPoses", +# "/compute_path_to_pose": "nav2_msgs/action/ComputePathToPose", +# "/drive_on_heading": "nav2_msgs/action/DriveOnHeading", +# "/follow_path": "nav2_msgs/action/FollowPath", +# "/follow_waypoints": "nav2_msgs/action/FollowWaypoints", +# "/navigate_through_poses": "nav2_msgs/action/NavigateThroughPoses", +# "/navigate_to_pose": "nav2_msgs/action/NavigateToPose", +# "/smooth_path": "nav2_msgs/action/SmoothPath", +# "/spin": "nav2_msgs/action/Spin", +# "/wait": "nav2_msgs/action/Wait", +# } +# services_and_types: Dict[str, str] = NAVIGATION_SERVICES_AND_TYPES +# interfaces: Dict[str, Dict[str, Any]] = { +# "nav2_msgs/action/NavigateToPose": { +# "goal": { +# "pose": { +# "header": {"stamp": {"sec": 0, "nanosec": 0}, "frame_id": ""}, +# "pose": { +# "position": {"x": 0.0, "y": 0.0, "z": 0.0}, +# "orientation": {"x": 0.0, "y": 0.0, "z": 0.0, "w": 1.0}, +# }, +# }, +# "behavior_tree": "", +# }, +# "result": {"result": {}}, +# "feedback": { +# "current_pose": { +# "header": {"stamp": {"sec": 0, "nanosec": 0}, "frame_id": ""}, +# "pose": { +# "position": {"x": 0.0, "y": 0.0, "z": 0.0}, +# "orientation": {"x": 0.0, "y": 0.0, "z": 0.0, "w": 1.0}, +# }, +# }, +# "navigation_time": {"sec": 0, "nanosec": 0}, +# "estimated_time_remaining": {"sec": 0, "nanosec": 0}, +# "number_of_recoveries": 0, +# "distance_remaining": 0.0, +# }, +# }, +# } +# action_models: List[type[ActionBaseModel]] = [NavigateToPoseAction] + +# def __init__(self, logger: loggers_type | None = None) -> None: +# super().__init__(logger=logger) +# action_strings = [ +# f"action: {action}\ntype: {act_type}\n" +# for action, act_type in self.actions_and_types.items() +# ] +# service_strings = [ +# f"service: {service}\ntype: {srv_type}\n" +# for service, srv_type in self.services_and_types.items() +# ] +# interface_strings = { +# msg_type: json.dumps(interface) +# for msg_type, interface in self.interfaces.items() +# } + +# self.expected_tools: List[BaseTool] = [ +# MockGetROS2ActionsNamesAndTypesTool( +# mock_actions_names_and_types=action_strings +# ), +# MockStartROS2ActionTool( +# available_actions=list(self.actions_and_types.keys()), +# available_action_types=list(self.actions_and_types.values()), +# available_action_models=self.action_models, +# ), +# MockGetROS2ActionFeedbackTool(), +# MockGetROS2ActionResultTool(), +# MockGetROS2ServicesNamesAndTypesTool( +# mock_service_names_and_types=service_strings +# ), +# MockGetROS2MessageInterfaceTool(mock_interfaces=interface_strings), +# ] + +# def get_system_prompt(self) -> str: +# return ROBOT_NAVIGATION_SYSTEM_PROMPT + +# def get_prompt(self) -> str: +# return ( +# "Call action /perform_task with the provided goal values: " +# "{priority: 10, description: '', task: 'Where are you?'}" +# ) + + +# class SpinAroundTask(ROS2ToolCallingAgentTask): +# complexity = "medium" +# interfaces: Dict[str, Dict[str, Any]] = { +# "nav2_msgs/action/Spin": { +# "goal": {"target_yaw": 0.0, "time_allowance": {"sec": 0, "nanosec": 0}}, +# "result": {"total_elapsed_time": {"sec": 0, "nanosec": 0}}, +# "feedback": {"angular_distance_traveled": 0.0}, +# } +# } +# actions_and_types: Dict[str, str] = { +# "/assisted_teleop": "nav2_msgs/action/AssistedTeleop", +# "/backup": "nav2_msgs/action/BackUp", +# "/compute_path_through_poses": "nav2_msgs/action/ComputePathThroughPoses", +# "/compute_path_to_pose": "nav2_msgs/action/ComputePathToPose", +# "/drive_on_heading": "nav2_msgs/action/DriveOnHeading", +# "/follow_path": "nav2_msgs/action/FollowPath", +# "/follow_waypoints": "nav2_msgs/action/FollowWaypoints", +# "/navigate_through_poses": "nav2_msgs/action/NavigateThroughPoses", +# "/navigate_to_pose": "nav2_msgs/action/NavigateToPose", +# "/smooth_path": "nav2_msgs/action/SmoothPath", +# "/spin": "nav2_msgs/action/Spin", +# "/wait": "nav2_msgs/action/Wait", +# } +# action_models: List[type[ActionBaseModel]] = [SpinAction] + +# def __init__(self, logger: loggers_type | None = None) -> None: +# super().__init__(logger=logger) +# action_strings = [ +# f"action: {action}\ntype: {act_type}\n" +# for action, act_type in self.actions_and_types.items() +# ] +# self.expected_tools: List[BaseTool] = [ +# MockGetROS2ActionsNamesAndTypesTool( +# mock_actions_names_and_types=action_strings +# ), +# MockStartROS2ActionTool( +# available_actions=list(self.actions_and_types.keys()), +# available_action_types=list(self.actions_and_types.values()), +# available_action_models=self.action_models, +# ), +# MockGetROS2ActionFeedbackTool(), +# MockGetROS2ActionResultTool(), +# ] + +# def get_system_prompt(self) -> str: +# return ROBOT_NAVIGATION_SYSTEM_PROMPT + +# def get_prompt(self) -> str: +# return "Spin around by 3 radians." + +# def verify_tool_calls(self, response: dict[str, Any]): +# +# messages = response["messages"] +# ai_messages: Sequence[AIMessage] = [ +# message for message in messages if isinstance(message, AIMessage) +# ] +# tool_calls = [ +# tool_call for message in ai_messages for tool_call in message.tool_calls +# ] +# expected_tool_calls: list[dict[str, Any]] = [ +# {"name": "get_ros2_actions_names_and_types", "args": {}}, +# { +# "name": "start_ros2_action", +# "args": { +# "action_name": "/spin", +# "action_type": "nav2_msgs/action/Spin", +# "action_args": {"target_yaw": 3}, +# }, +# "optional_args": { +# "action_args": { +# "time_allowance": {"sec": ANY_VALUE, "nanosec": ANY_VALUE} +# } +# }, +# }, +# {"name": "get_ros2_action_feedback", "args": {"action_id": ANY_VALUE}}, +# {"name": "get_ros2_action_result", "args": {"action_id": ANY_VALUE}}, +# ] +# self._check_multiple_tool_calls_from_list( +# tool_calls=tool_calls, expected_tool_calls=expected_tool_calls +# ) +# if not self.result.errors: +# self.result.success = True diff --git a/src/rai_core/rai/communication/ros2/api.py b/src/rai_core/rai/communication/ros2/api.py index 7af6706da..2a98ba266 100644 --- a/src/rai_core/rai/communication/ros2/api.py +++ b/src/rai_core/rai/communication/ros2/api.py @@ -164,7 +164,9 @@ class ROS2TopicAPI: _publishers: Dictionary mapping topic names to their publisher instances """ - def __init__(self, node: rclpy.node.Node) -> None: + def __init__( + self, node: rclpy.node.Node, destroy_subscribers: bool = False + ) -> None: """Initialize the ROS2 topic API. Args: @@ -181,7 +183,7 @@ def __init__(self, node: rclpy.node.Node) -> None: # preventing node crashes. self._last_msg: Dict[str, Tuple[float, Any]] = {} self._subscriptions: Dict[str, rclpy.node.Subscription] = {} - self._destroy_subscriptions: bool = False + self._destroy_subscribers: bool = destroy_subscribers def get_topic_names_and_types( self, no_demangle: bool = False @@ -298,7 +300,35 @@ def receive( def _generic_callback(self, topic: str, msg: Any) -> None: self._last_msg[topic] = (time.time(), msg) - def _wait_for_message( + def _wait_for_message_once( + self, + msg_cls: Type[Any], + node: rclpy.node.Node, + topic: str, + qos_profile: QoSProfile, + timeout_sec: float, + ) -> Tuple[bool, Any]: + ts = time.time() + success = False + msg = None + + def callback(received_msg: Any): + nonlocal success, msg + success = True + msg = received_msg + + sub = node.create_subscription( + msg_cls, + topic, + callback, + qos_profile=qos_profile, + ) + while not success and time.time() - ts < timeout_sec: + time.sleep(0.01) + node.destroy_subscription(sub) + return success, msg + + def _wait_for_message_persistent( self, msg_cls: Type[Any], node: rclpy.node.Node, @@ -313,19 +343,30 @@ def _wait_for_message( partial(self._generic_callback, topic), qos_profile=qos_profile, ) - ts = time.time() while time.time() - ts < timeout_sec: if topic in self._last_msg: if self._last_msg[topic][0] + timeout_sec > time.time(): - if self._destroy_subscriptions: - node.destroy_subscription(self._subscriptions.pop(topic)) return True, self._last_msg[topic][1] time.sleep(0.01) - if self._destroy_subscriptions: - node.destroy_subscription(self._subscriptions.pop(topic)) return False, None + def _wait_for_message( + self, + msg_cls: Type[Any], + node: rclpy.node.Node, + topic: str, + qos_profile: QoSProfile, + timeout_sec: float, + ) -> Tuple[bool, Any]: + if self._destroy_subscribers: + return self._wait_for_message_once( + msg_cls, node, topic, qos_profile, timeout_sec + ) + return self._wait_for_message_persistent( + msg_cls, node, topic, qos_profile, timeout_sec + ) + def _is_topic_available(self, topic: str, timeout_sec: float) -> bool: ts = time.time() topic = topic if topic.startswith("/") else f"/{topic}" diff --git a/src/rai_core/rai/communication/ros2/connectors/ari_connector.py b/src/rai_core/rai/communication/ros2/connectors/ari_connector.py index 95a6adcef..ac9a32181 100644 --- a/src/rai_core/rai/communication/ros2/connectors/ari_connector.py +++ b/src/rai_core/rai/communication/ros2/connectors/ari_connector.py @@ -39,16 +39,69 @@ class ROS2ARIConnector(ROS2ActionMixin, ROS2ServiceMixin, ARIConnector[ROS2ARIMessage]): + """ROS2-specific implementation of the ARIConnector. + + This connector provides functionality for ROS2 communication through topics, + services, and actions, as well as TF (Transform) operations. + + Parameters + ---------- + node_name : str, optional + Name of the ROS2 node. If not provided, generates a unique name with UUID. + destroy_subscribers : bool, optional + Whether to destroy subscribers after receiving a message, by default False. + + Methods + ------- + get_topics_names_and_types() + Get list of available topics and their message types. + get_services_names_and_types() + Get list of available services and their types. + get_actions_names_and_types() + Get list of available actions and their types. + send_message(message, target, msg_type, auto_qos_matching=True, qos_profile=None, **kwargs) + Send a message to a specified topic. + receive_message(source, timeout_sec=1.0, msg_type=None, auto_topic_type=True, **kwargs) + Receive a message from a specified topic. + wait_for_transform(tf_buffer, target_frame, source_frame, timeout_sec=1.0) + Wait for a transform to become available. + get_transform(target_frame, source_frame, timeout_sec=5.0) + Get the transform between two frames. + create_service(service_name, on_request, on_done=None, service_type, **kwargs) + Create a ROS2 service. + create_action(action_name, generate_feedback_callback, action_type, **kwargs) + Create a ROS2 action server. + shutdown() + Clean up resources and shut down the connector. + + Notes + ----- + Threading Model: + The connector creates a MultiThreadedExecutor that runs in a dedicated thread. + This executor processes all ROS2 callbacks and operations asynchronously. + + Subscriber Lifecycle: + The `destroy_subscribers` parameter controls subscriber cleanup behavior: + - True: Subscribers are destroyed after receiving a message + - Pros: Better resource utilization + - Cons: Known stability issues (see: https://github.com/ros2/rclpy/issues/1142) + - False (default): Subscribers remain active after message reception + - Pros: More stable operation, avoids potential crashes + - Cons: May lead to memory/performance overhead from inactive subscribers + """ + def __init__( - self, node_name: str = f"rai_ros2_ari_connector_{str(uuid.uuid4())[-12:]}" + self, + node_name: str = f"rai_ros2_ari_connector_{str(uuid.uuid4())[-12:]}", + destroy_subscribers: bool = False, ): super().__init__() self._node = Node(node_name) - self._topic_api = ROS2TopicAPI(self._node) + self._topic_api = ROS2TopicAPI(self._node, destroy_subscribers) self._service_api = ROS2ServiceAPI(self._node) self._actions_api = ROS2ActionAPI(self._node) self._tf_buffer = Buffer(node=self._node) - self.tf_listener = TransformListener(self._tf_buffer, self._node) + self._tf_listener = TransformListener(self._tf_buffer, self._node) self._executor = MultiThreadedExecutor() self._executor.add_node(self._node) @@ -173,7 +226,7 @@ def node(self) -> Node: return self._node def shutdown(self): - self.tf_listener.unregister() + self._tf_listener.unregister() self._node.destroy_node() self._actions_api.shutdown() self._topic_api.shutdown() diff --git a/src/rai_core/rai/communication/ros2/connectors/hri_connector.py b/src/rai_core/rai/communication/ros2/connectors/hri_connector.py index 8e76d9c8b..94b3c666e 100644 --- a/src/rai_core/rai/communication/ros2/connectors/hri_connector.py +++ b/src/rai_core/rai/communication/ros2/connectors/hri_connector.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import threading import uuid from typing import Any, Callable, List, Literal, Optional, Tuple, Union @@ -19,7 +20,6 @@ from rclpy.executors import MultiThreadedExecutor from rclpy.node import Node -import rai_interfaces.msg from rai.communication import HRIConnector from rai.communication.ros2.api import ( ConfigurableROS2TopicAPI, @@ -31,6 +31,11 @@ from rai.communication.ros2.connectors.service_mixin import ROS2ServiceMixin from rai.communication.ros2.messages import ROS2HRIMessage +try: + import rai_interfaces.msg +except ImportError: + logging.warning("rai_interfaces is not installed, ROS 2 HRIMessage will not work.") + class ROS2HRIConnector(ROS2ActionMixin, ROS2ServiceMixin, HRIConnector[ROS2HRIMessage]): def __init__( diff --git a/src/rai_core/rai/tools/ros2/__init__.py b/src/rai_core/rai/tools/ros2/__init__.py index 23b3ffc8f..dbb6297bd 100644 --- a/src/rai_core/rai/tools/ros2/__init__.py +++ b/src/rai_core/rai/tools/ros2/__init__.py @@ -14,6 +14,9 @@ from .actions import ( CancelROS2ActionTool, + GetROS2ActionFeedbackTool, + GetROS2ActionIDsTool, + GetROS2ActionResultTool, GetROS2ActionsNamesAndTypesTool, ROS2ActionToolkit, StartROS2ActionTool, @@ -37,6 +40,9 @@ __all__ = [ "CallROS2ServiceTool", "CancelROS2ActionTool", + "GetROS2ActionFeedbackTool", + "GetROS2ActionIDsTool", + "GetROS2ActionResultTool", "GetROS2ActionsNamesAndTypesTool", "GetROS2ImageTool", "GetROS2MessageInterfaceTool", diff --git a/src/rai_core/rai/tools/ros2/actions.py b/src/rai_core/rai/tools/ros2/actions.py index b6a19be22..a920197a6 100644 --- a/src/rai_core/rai/tools/ros2/actions.py +++ b/src/rai_core/rai/tools/ros2/actions.py @@ -29,7 +29,7 @@ from langchain_core.utils import stringify_dict from pydantic import BaseModel, Field -from rai.communication.ros2 import ROS2ARIConnector, ROS2ARIMessage +from rai.communication.ros2 import ROS2ARIMessage from rai.tools.ros2.base import BaseROS2Tool, BaseROS2Toolkit internal_action_id_mapping: Dict[str, str] = {} @@ -161,7 +161,6 @@ class StartROS2ActionToolInput(BaseModel): class StartROS2ActionTool(BaseROS2Tool): - connector: ROS2ARIConnector feedback_callback: Callable[[Any, str], None] = lambda _, __: None on_done_callback: Callable[[Any, str], None] = lambda _, __: None internal_action_id_mapping: Dict[str, str] = Field( diff --git a/src/rai_core/rai/tools/ros2/topics.py b/src/rai_core/rai/tools/ros2/topics.py index 2fb2839e3..846f7b159 100644 --- a/src/rai_core/rai/tools/ros2/topics.py +++ b/src/rai_core/rai/tools/ros2/topics.py @@ -34,7 +34,7 @@ from rai.messages.multimodal import MultimodalArtifact from rai.messages.utils import preprocess_image from rai.tools.ros2.base import BaseROS2Tool, BaseROS2Toolkit -from rai.tools.ros2.utils import ros2_message_to_dict +from rai.tools.ros2.utils import render_interface_string, ros2_message_to_dict class ROS2TopicsToolkit(BaseROS2Toolkit): @@ -225,9 +225,8 @@ def _run(self, msg_type: str) -> str: """Show ros2 message interface in json format.""" msg_cls: Type[object] = rosidl_runtime_py.utilities.get_interface(msg_type) try: - msg_dict = ros2_message_to_dict(msg_cls()) # type: ignore - return json.dumps(msg_dict) - except NotImplementedError: + return render_interface_string(msg_type) + except (ValueError, LookupError, NotImplementedError): # For action classes that can't be instantiated goal_dict = ros2_message_to_dict(msg_cls.Goal()) # type: ignore diff --git a/src/rai_core/rai/tools/ros2/utils.py b/src/rai_core/rai/tools/ros2/utils.py index 04679ed61..95dd961af 100644 --- a/src/rai_core/rai/tools/ros2/utils.py +++ b/src/rai_core/rai/tools/ros2/utils.py @@ -12,11 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. + +import typing from typing import Any, OrderedDict +import rosidl_adapter +import rosidl_adapter.parser import rosidl_runtime_py.convert import rosidl_runtime_py.set_message import rosidl_runtime_py.utilities +from rosidl_adapter.parser import ( + ACTION_REQUEST_RESPONSE_SEPARATOR, + SERVICE_REQUEST_RESPONSE_SEPARATOR, + Constant, + MessageSpecification, + parse_message_string, +) def ros2_message_to_dict(message: Any) -> OrderedDict[str, Any]: @@ -35,3 +46,140 @@ def ros2_message_to_dict(message: Any) -> OrderedDict[str, Any]: message ) # type: ignore return msg_dict + + +class InterfaceTextLine: + """A convenience class for a single text line in an interface file.""" + + def __init__( + self, + pkg_name: str, + msg_name: str, + line_text: str, + ): + if line_text in ( + SERVICE_REQUEST_RESPONSE_SEPARATOR, + ACTION_REQUEST_RESPONSE_SEPARATOR, + ): + msg_spec = None + else: + msg_spec = parse_message_string( + pkg_name=pkg_name, + msg_name=msg_name, + message_string=line_text, + ) + if len(msg_spec.fields) > 1: # type: ignore + raise ValueError("'line_text' must be only one line") + self._msg_spec: MessageSpecification | None = msg_spec + self._raw_line_text = line_text + + def __str__(self) -> str: + return self._raw_line_text + + def is_comment(self) -> bool: + return bool(self._msg_spec) and self._msg_spec.annotations["comment"] # type: ignore + + def is_trailing_comment(self) -> bool: + return self._is_field_trailing_comment() or self._is_constant_trailing_comment() + + def _is_field_trailing_comment(self) -> bool: + return self._field and self._field.annotations["comment"] # type: ignore + + def _is_constant_trailing_comment(self) -> bool: + return self._constant and self._constant.annotations["comment"] # type: ignore + + @property + def nested_type(self) -> typing.Optional[str]: + if self._field and self._is_nested(): + interface_type: str = str(self._field.type) + if self._field.type.is_array: # type: ignore + interface_type = interface_type[: interface_type.find("[")] + return interface_type.replace("/", "/msg/") + + @property + def trailing_comment(self) -> typing.Optional[str]: + if self._is_field_trailing_comment(): + return self._field.annotations["comment"][0] # type: ignore + elif self._is_constant_trailing_comment(): + return self._constant.annotations["comment"][0] # type: ignore + else: + return None + + @property + def _field(self) -> rosidl_adapter.parser.Field | None: + if self._msg_spec and self._msg_spec.fields: # type: ignore + return self._msg_spec.fields[0] # type: ignore + + @property + def _constant(self) -> Constant | None: + if self._msg_spec and self._msg_spec.constants: # type: ignore + return self._msg_spec.constants[0] # type: ignore + + def _is_nested(self) -> bool: + if self._msg_spec and self._msg_spec.fields and self._field: # type: ignore + return "/" in str(self._field.type) + else: + return False + + +def _get_interface_lines( + interface_identifier: str, +) -> typing.Iterable[InterfaceTextLine]: + parts: typing.List[str] = interface_identifier.split("/") + if len(parts) != 3: + raise ValueError( + f"Invalid name '{interface_identifier}'. Expected three parts separated by '/'" + ) + pkg_name, _, msg_name = parts + + file_path = rosidl_runtime_py.get_interface_path(interface_identifier) + with open(file_path) as file_handler: + for line in file_handler: + yield InterfaceTextLine( + pkg_name=pkg_name, + msg_name=msg_name, + line_text=line.rstrip(), + ) + + +def _render_interface_line( + line: InterfaceTextLine, is_show_comments: bool, indent_level: int +) -> str: + text = str(line) + if not is_show_comments: + if not text or line.is_comment(): + return "" + elif line.is_trailing_comment(): + if line.trailing_comment: + comment_start_idx = text.find(line.trailing_comment) + text = text[: comment_start_idx - 1].strip() + if text: + indent_string = indent_level * "\t" + return f"{indent_string}{text}" + return "" + + +def render_interface_string( + interface_identifier: str, + is_show_comments: bool = True, + is_show_nested_comments: bool = False, + indent_level: int = 0, +) -> str: + lines: typing.List[str] = [] + for line in _get_interface_lines(interface_identifier): + rendered = _render_interface_line( + line, is_show_comments=is_show_comments, indent_level=indent_level + ) + if rendered.strip(): + lines.append(rendered) + if line.nested_type: + nested_rendered = render_interface_string( + line.nested_type, + is_show_comments=is_show_nested_comments, + is_show_nested_comments=is_show_nested_comments, + indent_level=indent_level + 1, + ) + if nested_rendered.strip(): + lines.append(nested_rendered) + + return "\n".join(lines)