diff --git a/src/rai_bench/rai_bench/examples/tool_calling_agent_bench_navigation_tasks.py b/src/rai_bench/rai_bench/examples/tool_calling_agent_bench_navigation_tasks.py
new file mode 100644
index 000000000..e1baf4ad1
--- /dev/null
+++ b/src/rai_bench/rai_bench/examples/tool_calling_agent_bench_navigation_tasks.py
@@ -0,0 +1,28 @@
+# Copyright (C) 2025 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Sequence
+
+from rai_bench.tool_calling_agent_bench.agent_tasks_interfaces import (
+ ToolCallingAgentTask,
+)
+
+# from rai_bench.tool_calling_agent_bench.ros2_agent_tasks import (
+# NavigateToPointTask,
+# )
+
+# tasks: Sequence[ToolCallingAgentTask] = [
+# NavigateToPointTask(),
+# # SpinAroundTask()
+# ]
diff --git a/src/rai_bench/rai_bench/examples/tool_calling_agent_bench_tasks.py b/src/rai_bench/rai_bench/examples/tool_calling_agent_bench_tasks.py
index 07dd05dae..201358e53 100644
--- a/src/rai_bench/rai_bench/examples/tool_calling_agent_bench_tasks.py
+++ b/src/rai_bench/rai_bench/examples/tool_calling_agent_bench_tasks.py
@@ -18,85 +18,126 @@
ToolCallingAgentTask,
)
from rai_bench.tool_calling_agent_bench.ros2_agent_tasks import (
- GetAllROS2RGBCamerasTask,
- GetObjectPositionsTask,
- GetROS2DepthCameraTask,
- GetROS2MessageTask,
- GetROS2RGBCameraTask,
- GetROS2TopicsTask,
- GetROS2TopicsTask2,
- GrabExistingObjectTask,
- GrabNotExistingObjectTask,
- MoveExistingObjectFrontTask,
- MoveExistingObjectLeftTask,
- MoveToPointTask,
- SwapObjectsTask,
+ PublishROS2HRIMessageTask3ExtraCalls,
+ PublishROS2HRIMessageTask0ExtraCalls,
+ PublishROS2HRIMessageTask1ExtraCall,
+ PublishROS2AudioMessageTask0ExtraCalls,
+ PublishROS2AudioMessageTask3ExtraCalls,
+ PublishROS2AudioMessageTask1ExtraCall,
+ PublishROS2DetectionArrayTask3ExtraCalls,
+ PublishROS2DetectionArrayTask1ExtraCall,
+ PublishROS2DetectionArrayTask0ExtraCalls,
+ CallROS2ManipulatorMoveToServiceTask3ExtraCalls,
+ CallROS2ManipulatorMoveToServiceTask1ExtraCall,
+ CallROS2ManipulatorMoveToServiceTask0ExtraCalls,
+ CallGroundedSAMSegmentTask3ExtraCalls,
+ CallGroundedSAMSegmentTask1ExtraCall,
+ CallGroundedSAMSegmentTask0ExtraCalls,
+ CallGroundingDinoClassifyTask3ExtraCalls,
+ CallGroundingDinoClassifyTask1ExtraCall,
+ CallGroundingDinoClassifyTask0ExtraCalls,
+ CallGetLogDigestTask3ExtraCalls,
+ CallGetLogDigestTask1ExtraCall,
+ CallGetLogDigestTask0ExtraCalls,
+ CallVectorStoreRetrievalTask3ExtraCalls,
+ CallVectorStoreRetrievalTask1ExtraCall,
+ CallVectorStoreRetrievalTask0ExtraCalls,
+ CallWhatISeeTask3ExtraCalls,
+ CallWhatISeeTask1ExtraCall,
+ CallWhatISeeTask0ExtraCalls,
)
tasks: Sequence[ToolCallingAgentTask] = [
- GetROS2RGBCameraTask(),
- GetROS2TopicsTask(),
- GetROS2DepthCameraTask(),
- GetAllROS2RGBCamerasTask(),
- GetROS2TopicsTask2(),
- GetROS2MessageTask(),
- MoveToPointTask(args={"x": 1.0, "y": 2.0, "z": 3.0, "task": "grab"}),
- MoveToPointTask(args={"x": 1.2, "y": 2.3, "z": 3.4, "task": "drop"}),
- GetObjectPositionsTask(
- objects={
- "carrot": [{"x": 1.0, "y": 2.0, "z": 3.0}],
- "apple": [{"x": 4.0, "y": 5.0, "z": 6.0}],
- "banana": [
- {"x": 7.0, "y": 8.0, "z": 9.0},
- {"x": 10.0, "y": 11.0, "z": 12.0},
- ],
- },
- ),
- GrabExistingObjectTask(
- object_to_grab="banana",
- objects={
- "banana": [{"x": 7.0, "y": 8.0, "z": 9.0}],
- "apple": [
- {"x": 4.0, "y": 5.0, "z": 6.0},
- {"x": 10.0, "y": 11.0, "z": 12.0},
- ],
- },
- ),
- GrabNotExistingObjectTask(
- object_to_grab="apple",
- objects={
- "banana": [{"x": 7.0, "y": 8.0, "z": 9.0}],
- "cube": [
- {"x": 4.0, "y": 5.0, "z": 6.0},
- {"x": 10.0, "y": 11.0, "z": 12.0},
- ],
- },
- ),
- MoveExistingObjectLeftTask(
- object_to_grab="banana",
- objects={
- "banana": [{"x": 7.0, "y": 8.0, "z": 9.0}],
- "apple": [
- {"x": 4.0, "y": 5.0, "z": 6.0},
- {"x": 10.0, "y": 11.0, "z": 12.0},
- ],
- },
- ),
- MoveExistingObjectFrontTask(
- object_to_grab="banana",
- objects={
- "banana": [{"x": 7.0, "y": 8.0, "z": 9.0}],
- "apple": [
- {"x": 4.0, "y": 5.0, "z": 6.0},
- {"x": 10.0, "y": 11.0, "z": 12.0},
- ],
- },
- ),
- SwapObjectsTask(
- objects={
- "banana": [{"x": 1.0, "y": 2.0, "z": 3.0}],
- "apple": [{"x": 4.0, "y": 5.0, "z": 6.0}],
- },
- objects_to_swap=["banana", "apple"],
- ),
+ PublishROS2HRIMessageTask3ExtraCalls(),
+ PublishROS2HRIMessageTask1ExtraCall(),
+ PublishROS2HRIMessageTask0ExtraCalls(),
+ PublishROS2AudioMessageTask3ExtraCalls(),
+ PublishROS2AudioMessageTask1ExtraCall(),
+ PublishROS2AudioMessageTask0ExtraCalls(),
+ PublishROS2DetectionArrayTask3ExtraCalls(),
+ PublishROS2DetectionArrayTask1ExtraCall(),
+ PublishROS2DetectionArrayTask0ExtraCalls(),
+ CallROS2ManipulatorMoveToServiceTask3ExtraCalls(),
+ CallROS2ManipulatorMoveToServiceTask1ExtraCall(),
+ CallROS2ManipulatorMoveToServiceTask0ExtraCalls(),
+ CallGroundedSAMSegmentTask3ExtraCalls(),
+ CallGroundedSAMSegmentTask1ExtraCall(),
+ CallGroundedSAMSegmentTask0ExtraCalls(),
+ CallGroundingDinoClassifyTask3ExtraCalls(),
+ CallGroundingDinoClassifyTask1ExtraCall(),
+ CallGroundingDinoClassifyTask0ExtraCalls(),
+ CallGetLogDigestTask3ExtraCalls(),
+ CallGetLogDigestTask1ExtraCall(),
+ CallGetLogDigestTask0ExtraCalls(),
+ CallVectorStoreRetrievalTask3ExtraCalls(),
+ CallVectorStoreRetrievalTask1ExtraCall(),
+ CallVectorStoreRetrievalTask0ExtraCalls(),
+ CallWhatISeeTask3ExtraCalls(),
+ CallWhatISeeTask1ExtraCall(),
+ CallWhatISeeTask0ExtraCalls(),
+ # GetROS2RGBCameraTask(),
+ # GetROS2TopicsTask(),
+ # GetROS2DepthCameraTask(),
+ # GetAllROS2RGBCamerasTask(),
+ # GetROS2TopicsTask2(),
+ # GetROS2MessageTask(),
+ # MoveToPointTask(args={"x": 1.0, "y": 2.0, "z": 3.0, "task": "grab"}),
+ # MoveToPointTask(args={"x": 1.2, "y": 2.3, "z": 3.4, "task": "drop"}),
+ # GetObjectPositionsTask(
+ # objects={
+ # "carrot": [{"x": 1.0, "y": 2.0, "z": 3.0}],
+ # "apple": [{"x": 4.0, "y": 5.0, "z": 6.0}],
+ # "banana": [
+ # {"x": 7.0, "y": 8.0, "z": 9.0},
+ # {"x": 10.0, "y": 11.0, "z": 12.0},
+ # ],
+ # },
+ # ),
+ # GrabExistingObjectTask(
+ # object_to_grab="banana",
+ # objects={
+ # "banana": [{"x": 7.0, "y": 8.0, "z": 9.0}],
+ # "apple": [
+ # {"x": 4.0, "y": 5.0, "z": 6.0},
+ # {"x": 10.0, "y": 11.0, "z": 12.0},
+ # ],
+ # },
+ # ),
+ # GrabNotExistingObjectTask(
+ # object_to_grab="apple",
+ # objects={
+ # "banana": [{"x": 7.0, "y": 8.0, "z": 9.0}],
+ # "cube": [
+ # {"x": 4.0, "y": 5.0, "z": 6.0},
+ # {"x": 10.0, "y": 11.0, "z": 12.0},
+ # ],
+ # },
+ # ),
+ # MoveExistingObjectLeftTask(
+ # object_to_grab="banana",
+ # objects={
+ # "banana": [{"x": 7.0, "y": 8.0, "z": 9.0}],
+ # "apple": [
+ # {"x": 4.0, "y": 5.0, "z": 6.0},
+ # {"x": 10.0, "y": 11.0, "z": 12.0},
+ # ],
+ # },
+ # ),
+ # MoveExistingObjectFrontTask(
+ # object_to_grab="banana",
+ # objects={
+ # "banana": [{"x": 7.0, "y": 8.0, "z": 9.0}],
+ # "apple": [
+ # {"x": 4.0, "y": 5.0, "z": 6.0},
+ # {"x": 10.0, "y": 11.0, "z": 12.0},
+ # ],
+ # },
+ # ),
+ # SwapObjectsTask(
+ # objects={
+ # "banana": [{"x": 1.0, "y": 2.0, "z": 3.0}],
+ # "apple": [{"x": 4.0, "y": 5.0, "z": 6.0}],
+ # },
+ # objects_to_swap=["banana", "apple"],
+ # ),
]
diff --git a/src/rai_bench/rai_bench/examples/tool_calling_agent_test_bench.py b/src/rai_bench/rai_bench/examples/tool_calling_agent_test_bench.py
index f34587490..bce90d9ae 100644
--- a/src/rai_bench/rai_bench/examples/tool_calling_agent_test_bench.py
+++ b/src/rai_bench/rai_bench/examples/tool_calling_agent_test_bench.py
@@ -23,6 +23,8 @@
)
from rai_bench.examples.tool_calling_agent_bench_tasks import tasks
+
+# from rai_bench.examples.tool_calling_agent_bench_navigation_tasks import tasks
from rai_bench.tool_calling_agent_bench.agent_bench import ToolCallingAgentBenchmark
if __name__ == "__main__":
@@ -60,7 +62,7 @@
tasks=tasks, logger=bench_logger, results_filename=results_filename
)
- model_type = "simple_model"
+ model_type = "complex_model"
model_config = get_llm_model_config_and_vendor(model_type=model_type)[0]
model_name = getattr(model_config, model_type)
diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/__init__.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/__init__.py
new file mode 100644
index 000000000..0f0767afa
--- /dev/null
+++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/__init__.py
@@ -0,0 +1,23 @@
+# Copyright (C) 2025 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .action_base_model import ActionBaseModel
+from .navigate_to_pose import NavigateToPoseAction
+from .spin import SpinAction
+
+__all__ = [
+ "ActionBaseModel",
+ "NavigateToPoseAction",
+ "SpinAction",
+]
diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/action_base_model.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/action_base_model.py
new file mode 100644
index 000000000..fef8444db
--- /dev/null
+++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/action_base_model.py
@@ -0,0 +1,25 @@
+# Copyright (C) 2025 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+
+from pydantic import BaseModel
+
+
+class ActionBaseModel(BaseModel):
+ action_name: str
+ action_type: str
+ goal: Any
+ result: Any
+ feedback: Any
diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/navigate_to_pose.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/navigate_to_pose.py
new file mode 100644
index 000000000..20cfe8ae7
--- /dev/null
+++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/navigate_to_pose.py
@@ -0,0 +1,80 @@
+# Copyright (C) 2025 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+from pydantic import BaseModel
+
+from rai_bench.tool_calling_agent_bench.actions.action_base_model import ActionBaseModel
+
+
+class Time(BaseModel):
+ sec: Optional[int] = 0
+ nanosec: Optional[int] = 0
+
+
+class Header(BaseModel):
+ stamp: Optional[Time] = Time()
+ frame_id: str
+
+
+class Position(BaseModel):
+ x: float
+ y: float
+ z: float
+
+
+class Orientation(BaseModel):
+ x: Optional[float] = 0.0
+ y: Optional[float] = 0.0
+ z: Optional[float] = 0.0
+ w: Optional[float] = 1.0
+
+
+class Pose(BaseModel):
+ position: Position
+ orientation: Optional[Orientation] = Orientation()
+
+
+class PoseStamped(BaseModel):
+ header: Header
+ pose: Pose
+
+
+class Goal(BaseModel):
+ pose: PoseStamped
+ behavior_tree: Optional[str] = ""
+
+
+class Result(BaseModel):
+ result: dict
+
+
+class Feedback(BaseModel):
+ current_pose: PoseStamped
+ navigation_time: Time
+ estimated_time_remaining: Time
+ number_of_recoveries: int
+ distance_remaining: float
+
+
+class NavigateToPoseAction(ActionBaseModel):
+ action_name: str = "/navigate_to_pose"
+ action_type: str = "nav2_msgs/action/NavigateToPose"
+ goal: Goal
+ result: Result
+ feedback: Feedback
+
+
+# TODO (mkotynia): create init for actions
diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/spin.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/spin.py
new file mode 100644
index 000000000..030a4601e
--- /dev/null
+++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/actions/spin.py
@@ -0,0 +1,46 @@
+# Copyright (C) 2025 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+from pydantic import BaseModel
+
+from rai_bench.tool_calling_agent_bench.actions.action_base_model import ActionBaseModel
+
+
+class Time(BaseModel):
+ sec: Optional[int] = 0
+ nanosec: Optional[int] = 0
+
+
+class Goal(BaseModel):
+ target_yaw: Optional[float] = 0.0
+ time_allowance: Optional[Time] = Time()
+
+
+class Result(BaseModel):
+ result: dict
+
+
+class Feedback(BaseModel):
+ angle_turned: Optional[float] = 0.0
+ remaining_yaw: Optional[float] = 0.0
+
+
+class SpinAction(ActionBaseModel):
+ action_name: str = "/spin"
+ action_type: str = "nav2_msgs/action/Spin"
+ goal: Goal
+ result: Result
+ feedback: Feedback
diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/agent_bench.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/agent_bench.py
index 9c59e1bf1..b8b22d581 100644
--- a/src/rai_bench/rai_bench/tool_calling_agent_bench/agent_bench.py
+++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/agent_bench.py
@@ -36,8 +36,9 @@
class TaskResult(BaseModel):
- task_prompt: str = Field(..., description="The task prompt.")
- system_prompt: str = Field(..., description="The system prompt.")
+ # task_prompt: str = Field(..., description="The task prompt.")
+ # system_prompt: str = Field(..., description="The system prompt.")
+ task: str = Field(..., description="task name")
complexity: str = Field(..., description="Complexity of the task.")
model_name: str = Field(..., description="Name of the LLM.")
success: bool = Field(
@@ -184,8 +185,9 @@ def run_next(self, agent: CompiledStateGraph, model_name: str) -> None:
)
task_result = TaskResult(
- task_prompt=task.get_prompt(),
- system_prompt=task.get_system_prompt(),
+ # task_prompt=task.get_prompt(),
+ # system_prompt=task.get_system_prompt(),
+ task=task.__class__.__name__,
complexity=task.complexity,
model_name=model_name,
success=result.success,
diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/agent_tasks_interfaces.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/agent_tasks_interfaces.py
index 611711d6d..7c05cc20f 100644
--- a/src/rai_bench/rai_bench/tool_calling_agent_bench/agent_tasks_interfaces.py
+++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/agent_tasks_interfaces.py
@@ -14,15 +14,735 @@
import logging
from abc import ABC, abstractmethod
-from typing import Any, List, Literal
+from typing import Any, Dict, List, Literal, Sequence, Type
from langchain_core.messages import AIMessage, ToolCall
from langchain_core.runnables.config import DEFAULT_RECURSION_LIMIT
from langchain_core.tools import BaseTool
from pydantic import BaseModel
+from rai_bench.tool_calling_agent_bench.messages.base import Clock
+from rai_bench.tool_calling_agent_bench.messages.services import (
+ ManipulatorMoveToRequest,
+ RAIGroundedSamRequest,
+ RAIGroundingDinoRequest,
+ StringListRequest,
+ VectorStoreRetrievalRequest,
+ WhatISeeRequest,
+)
+from rai_bench.tool_calling_agent_bench.messages.topics import (
+ AudioMessage,
+ CameraInfo,
+ HRIMessage,
+ Image,
+ RAIDetectionArray,
+)
+from rai_bench.tool_calling_agent_bench.mocked_tools import (
+ MockCallROS2ServiceTool,
+ MockCancelROS2ActionTool,
+ MockGetROS2ActionIDsTool,
+ MockGetROS2ActionsNamesAndTypesTool,
+ MockGetROS2ImageTool,
+ MockGetROS2MessageInterfaceTool,
+ MockGetROS2ServicesNamesAndTypesTool,
+ MockGetROS2TopicsNamesAndTypesTool,
+ MockMoveToPointTool,
+ MockPublishROS2MessageTool,
+ MockStartROS2ActionTool,
+)
+
loggers_type = logging.Logger
+# dict of interfaces where keys are interfaces types and values are output
+# of GetROS2MessageInterfaceTool which are same as ros2 interface show outputs
+# the dict contains custom as well as couple other common interfaces
+MOCK_INTERFACES: Dict[str, str] = {
+ "sensor_msgs/msg/CameraInfo": """
+# This message defines meta information for a camera. It should be in a
+# camera namespace on topic "camera_info" and accompanied by up to five
+# image topics named:
+#
+# image_raw - raw data from the camera driver, possibly Bayer encoded
+# image - monochrome, distorted
+# image_color - color, distorted
+# image_rect - monochrome, rectified
+# image_rect_color - color, rectified
+#
+# The image_pipeline contains packages (image_proc, stereo_image_proc)
+# for producing the four processed image topics from image_raw and
+# camera_info. The meaning of the camera parameters are described in
+# detail at http://www.ros.org/wiki/image_pipeline/CameraInfo.
+#
+# The image_geometry package provides a user-friendly interface to
+# common operations using this meta information. If you want to, e.g.,
+# project a 3d point into image coordinates, we strongly recommend
+# using image_geometry.
+#
+# If the camera is uncalibrated, the matrices D, K, R, P should be left
+# zeroed out. In particular, clients may assume that K[0] == 0.0
+# indicates an uncalibrated camera.
+
+#######################################################################
+# Image acquisition info #
+#######################################################################
+
+# Time of image acquisition, camera coordinate frame ID
+std_msgs/Header header # Header timestamp should be acquisition time of image
+ builtin_interfaces/Time stamp
+ int32 sec
+ uint32 nanosec
+ string frame_id
+ # Header frame_id should be optical frame of camera
+ # origin of frame should be optical center of camera
+ # +x should point to the right in the image
+ # +y should point down in the image
+ # +z should point into the plane of the image
+
+
+#######################################################################
+# Calibration Parameters #
+#######################################################################
+# These are fixed during camera calibration. Their values will be the #
+# same in all messages until the camera is recalibrated. Note that #
+# self-calibrating systems may "recalibrate" frequently. #
+# #
+# The internal parameters can be used to warp a raw (distorted) image #
+# to: #
+# 1. An undistorted image (requires D and K) #
+# 2. A rectified image (requires D, K, R) #
+# The projection matrix P projects 3D points into the rectified image.#
+#######################################################################
+
+# The image dimensions with which the camera was calibrated.
+# Normally this will be the full camera resolution in pixels.
+uint32 height
+uint32 width
+
+# The distortion model used. Supported models are listed in
+# sensor_msgs/distortion_models.hpp. For most cameras, "plumb_bob" - a
+# simple model of radial and tangential distortion - is sufficent.
+string distortion_model
+
+# The distortion parameters, size depending on the distortion model.
+# For "plumb_bob", the 5 parameters are: (k1, k2, t1, t2, k3).
+float64[] d
+
+# Intrinsic camera matrix for the raw (distorted) images.
+# [fx 0 cx]
+# K = [ 0 fy cy]
+# [ 0 0 1]
+# Projects 3D points in the camera coordinate frame to 2D pixel
+# coordinates using the focal lengths (fx, fy) and principal point
+# (cx, cy).
+float64[9] k # 3x3 row-major matrix
+
+# Rectification matrix (stereo cameras only)
+# A rotation matrix aligning the camera coordinate system to the ideal
+# stereo image plane so that epipolar lines in both stereo images are
+# parallel.
+float64[9] r # 3x3 row-major matrix
+
+# Projection/camera matrix
+# [fx' 0 cx' Tx]
+# P = [ 0 fy' cy' Ty]
+# [ 0 0 1 0]
+# By convention, this matrix specifies the intrinsic (camera) matrix
+# of the processed (rectified) image. That is, the left 3x3 portion
+# is the normal camera intrinsic matrix for the rectified image.
+# It projects 3D points in the camera coordinate frame to 2D pixel
+# coordinates using the focal lengths (fx', fy') and principal point
+# (cx', cy') - these may differ from the values in K.
+# For monocular cameras, Tx = Ty = 0. Normally, monocular cameras will
+# also have R = the identity and P[1:3,1:3] = K.
+# For a stereo pair, the fourth column [Tx Ty 0]' is related to the
+# position of the optical center of the second camera in the first
+# camera's frame. We assume Tz = 0 so both cameras are in the same
+# stereo image plane. The first camera always has Tx = Ty = 0. For
+# the right (second) camera of a horizontal stereo pair, Ty = 0 and
+# Tx = -fx' * B, where B is the baseline between the cameras.
+# Given a 3D point [X Y Z]', the projection (x, y) of the point onto
+# the rectified image is given by:
+# [u v w]' = P * [X Y Z 1]'
+# x = u / w
+# y = v / w
+# This holds for both images of a stereo pair.
+float64[12] p # 3x4 row-major matrix
+
+
+#######################################################################
+# Operational Parameters #
+#######################################################################
+# These define the image region actually captured by the camera #
+# driver. Although they affect the geometry of the output image, they #
+# may be changed freely without recalibrating the camera. #
+#######################################################################
+
+# Binning refers here to any camera setting which combines rectangular
+# neighborhoods of pixels into larger "super-pixels." It reduces the
+# resolution of the output image to
+# (width / binning_x) x (height / binning_y).
+# The default values binning_x = binning_y = 0 is considered the same
+# as binning_x = binning_y = 1 (no subsampling).
+uint32 binning_x
+uint32 binning_y
+
+# Region of interest (subwindow of full camera resolution), given in
+# full resolution (unbinned) image coordinates. A particular ROI
+# always denotes the same window of pixels on the camera sensor,
+# regardless of binning settings.
+# The default setting of roi (all values 0) is considered the same as
+# full resolution (roi.width = width, roi.height = height).
+RegionOfInterest roi
+ #
+ uint32 x_offset #
+ # (0 if the ROI includes the left edge of the image)
+ uint32 y_offset #
+ # (0 if the ROI includes the top edge of the image)
+ uint32 height #
+ uint32 width #
+ bool do_rectify
+""",
+ "sensor_msgs/msg/Image": """
+# This message contains an uncompressed image
+# (0, 0) is at top-left corner of image
+
+std_msgs/Header header # Header timestamp should be acquisition time of image
+ builtin_interfaces/Time stamp
+ int32 sec
+ uint32 nanosec
+ string frame_id
+ # Header frame_id should be optical frame of camera
+ # origin of frame should be optical center of cameara
+ # +x should point to the right in the image
+ # +y should point down in the image
+ # +z should point into to plane of the image
+ # If the frame_id here and the frame_id of the CameraInfo
+ # message associated with the image conflict
+ # the behavior is undefined
+
+uint32 height # image height, that is, number of rows
+uint32 width # image width, that is, number of columns
+
+# The legal values for encoding are in file src/image_encodings.cpp
+# If you want to standardize a new string format, join
+# ros-users@lists.ros.org and send an email proposing a new encoding.
+
+string encoding # Encoding of pixels -- channel meaning, ordering, size
+ # taken from the list of strings in include/sensor_msgs/image_encodings.hpp
+
+uint8 is_bigendian # is this data bigendian?
+uint32 step # Full row length in bytes
+uint8[] data # actual matrix data, size is (step * rows)
+""",
+ "rosgraph_msgs/msg/Clock": """
+# This message communicates the current time.
+#
+# For more information, see https://design.ros2.org/articles/clock_and_time.html.
+builtin_interfaces/Time clock
+ int32 sec
+ uint32 nanosec
+""",
+ "rai_interfaces/msg/HRIMessage": """
+#
+# Copyright (C) 2024 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+std_msgs/Header header
+ builtin_interfaces/Time stamp
+ int32 sec
+ uint32 nanosec
+ string frame_id
+string text
+sensor_msgs/Image[] images
+ std_msgs/Header header #
+ builtin_interfaces/Time stamp
+ int32 sec
+ uint32 nanosec
+ string frame_id
+ # Header frame_id should be optical frame of camera
+ # origin of frame should be optical center of cameara
+ # +x should point to the right in the image
+ # +y should point down in the image
+ # +z should point into to plane of the image
+ # If the frame_id here and the frame_id of the CameraInfo
+ # message associated with the image conflict
+ # the behavior is undefined
+ uint32 height #
+ uint32 width #
+ string encoding #
+ # taken from the list of strings in include/sensor_msgs/image_encodings.hpp
+ uint8 is_bigendian #
+ uint32 step #
+ uint8[] data #
+rai_interfaces/AudioMessage[] audios
+ #
+ #
+ #
+ #
+ #
+ int16[] audio
+ uint16 sample_rate
+ uint16 channels
+string communication_id
+int64 seq_no
+bool seq_end
+""",
+ "rai_interfaces/msg/AudioMessage": """
+#
+# Copyright (C) 2024 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+int16[] audio
+uint16 sample_rate
+uint16 channels
+""",
+ "rai_interfaces/msg/RAIDetectionArray": """
+#
+# Copyright (C) 2024 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# A list of 2D detections, for a multi-object 2D detector.
+std_msgs/Header header
+ builtin_interfaces/Time stamp
+ int32 sec
+ uint32 nanosec
+ string frame_id
+
+# A list of the detected proposals. A multi-proposal detector might generate
+# this list with many candidate detections generated from a single input.
+vision_msgs/Detection2D[] detections
+ #
+ std_msgs/Header header
+ builtin_interfaces/Time stamp
+ int32 sec
+ uint32 nanosec
+ string frame_id
+ ObjectHypothesisWithPose[] results
+ ObjectHypothesis hypothesis
+ string class_id
+ float64 score
+ geometry_msgs/PoseWithCovariance pose
+ Pose pose
+ Point position
+ float64 x
+ float64 y
+ float64 z
+ Quaternion orientation
+ float64 x 0
+ float64 y 0
+ float64 z 0
+ float64 w 1
+ float64[36] covariance
+ BoundingBox2D bbox
+ vision_msgs/Pose2D center
+ vision_msgs/Point2D position
+ float64 x
+ float64 y
+ float64 theta
+ float64 size_x
+ float64 size_y
+ string id
+# a list of classes being detected
+string[] detection_classes
+""",
+ "rai_interfaces/srv/ManipulatorMoveTo": """
+#
+# Copyright (C) 2024 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# A simplified approach with binary states for the gripper
+bool initial_gripper_state
+bool final_gripper_state
+geometry_msgs/PoseStamped target_pose
+ std_msgs/Header header
+ builtin_interfaces/Time stamp
+ int32 sec
+ uint32 nanosec
+ string frame_id
+ Pose pose
+ Point position
+ float64 x
+ float64 y
+ float64 z
+ Quaternion orientation
+ float64 x 0
+ float64 y 0
+ float64 z 0
+ float64 w 1
+---
+bool success
+""",
+ "rai_interfaces/srv/RAIGroundedSam": """
+#
+# Copyright (C) 2024 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+RAIDetectionArray detections
+ #
+ #
+ #
+ #
+ #
+ std_msgs/Header header
+ builtin_interfaces/Time stamp
+ int32 sec
+ uint32 nanosec
+ string frame_id
+ vision_msgs/Detection2D[] detections
+ #
+ std_msgs/Header header
+ builtin_interfaces/Time stamp
+ int32 sec
+ uint32 nanosec
+ string frame_id
+ ObjectHypothesisWithPose[] results
+ ObjectHypothesis hypothesis
+ string class_id
+ float64 score
+ geometry_msgs/PoseWithCovariance pose
+ Pose pose
+ Point position
+ float64 x
+ float64 y
+ float64 z
+ Quaternion orientation
+ float64 x 0
+ float64 y 0
+ float64 z 0
+ float64 w 1
+ float64[36] covariance
+ BoundingBox2D bbox
+ vision_msgs/Pose2D center
+ vision_msgs/Point2D position
+ float64 x
+ float64 y
+ float64 theta
+ float64 size_x
+ float64 size_y
+ string id
+ string[] detection_classes
+sensor_msgs/Image source_img
+ std_msgs/Header header #
+ builtin_interfaces/Time stamp
+ int32 sec
+ uint32 nanosec
+ string frame_id
+ # Header frame_id should be optical frame of camera
+ # origin of frame should be optical center of cameara
+ # +x should point to the right in the image
+ # +y should point down in the image
+ # +z should point into to plane of the image
+ # If the frame_id here and the frame_id of the CameraInfo
+ # message associated with the image conflict
+ # the behavior is undefined
+ uint32 height #
+ uint32 width #
+ string encoding #
+ # taken from the list of strings in include/sensor_msgs/image_encodings.hpp
+ uint8 is_bigendian #
+ uint32 step #
+ uint8[] data #
+---
+sensor_msgs/Image[] masks
+ std_msgs/Header header #
+ builtin_interfaces/Time stamp
+ int32 sec
+ uint32 nanosec
+ string frame_id
+ # Header frame_id should be optical frame of camera
+ # origin of frame should be optical center of cameara
+ # +x should point to the right in the image
+ # +y should point down in the image
+ # +z should point into to plane of the image
+ # If the frame_id here and the frame_id of the CameraInfo
+ # message associated with the image conflict
+ # the behavior is undefined
+ uint32 height #
+ uint32 width #
+ string encoding #
+ # taken from the list of strings in include/sensor_msgs/image_encodings.hpp
+ uint8 is_bigendian #
+ uint32 step #
+ uint8[] data #
+""",
+ "rai_interfaces/srv/RAIGroundingDino": """
+#
+# Copyright (C) 2024 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+string classes
+float64 box_threshold
+float64 text_threshold
+sensor_msgs/Image source_img
+ std_msgs/Header header #
+ builtin_interfaces/Time stamp
+ int32 sec
+ uint32 nanosec
+ string frame_id
+ # Header frame_id should be optical frame of camera
+ # origin of frame should be optical center of cameara
+ # +x should point to the right in the image
+ # +y should point down in the image
+ # +z should point into to plane of the image
+ # If the frame_id here and the frame_id of the CameraInfo
+ # message associated with the image conflict
+ # the behavior is undefined
+ uint32 height #
+ uint32 width #
+ string encoding #
+ # taken from the list of strings in include/sensor_msgs/image_encodings.hpp
+ uint8 is_bigendian #
+ uint32 step #
+ uint8[] data #
+---
+RAIDetectionArray detections
+ #
+ #
+ #
+ #
+ #
+ std_msgs/Header header
+ builtin_interfaces/Time stamp
+ int32 sec
+ uint32 nanosec
+ string frame_id
+ vision_msgs/Detection2D[] detections
+ #
+ std_msgs/Header header
+ builtin_interfaces/Time stamp
+ int32 sec
+ uint32 nanosec
+ string frame_id
+ ObjectHypothesisWithPose[] results
+ ObjectHypothesis hypothesis
+ string class_id
+ float64 score
+ geometry_msgs/PoseWithCovariance pose
+ Pose pose
+ Point position
+ float64 x
+ float64 y
+ float64 z
+ Quaternion orientation
+ float64 x 0
+ float64 y 0
+ float64 z 0
+ float64 w 1
+ float64[36] covariance
+ BoundingBox2D bbox
+ vision_msgs/Pose2D center
+ vision_msgs/Point2D position
+ float64 x
+ float64 y
+ float64 theta
+ float64 size_x
+ float64 size_y
+ string id
+ string[] detection_classes
+""",
+ "rai_interfaces/srv/StringList": """
+# Copyright (C) 2024 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Request - empty
+---
+# Response
+bool success
+string[] string_list
+""",
+ "rai_interfaces/srv/VectorStoreRetrieval": """
+# Copyright (C) 2024 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Request
+string query
+
+---
+# Response
+bool success
+string message
+string[] documents
+float32[] scores
+""",
+ "rai_interfaces/srv/WhatISee": """z
+# Copyright (C) 2024 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Request (empty)
+
+---
+# Response, timed with image timestamp
+string[] observations
+string perception_source
+sensor_msgs/Image image
+ std_msgs/Header header #
+ builtin_interfaces/Time stamp
+ int32 sec
+ uint32 nanosec
+ string frame_id
+ # Header frame_id should be optical frame of camera
+ # origin of frame should be optical center of cameara
+ # +x should point to the right in the image
+ # +y should point down in the image
+ # +z should point into to plane of the image
+ # If the frame_id here and the frame_id of the CameraInfo
+ # message associated with the image conflict
+ # the behavior is undefined
+ uint32 height #
+ uint32 width #
+ string encoding #
+ # taken from the list of strings in include/sensor_msgs/image_encodings.hpp
+ uint8 is_bigendian #
+ uint32 step #
+ uint8[] data #
+geometry_msgs/Pose pose
+ Point position
+ float64 x
+ float64 y
+ float64 z
+ Quaternion orientation
+ float64 x 0
+ float64 y 0
+ float64 z 0
+ float64 w 1
+""",
+ "rai_interfaces/action/Task": """
+# Goal
+string task
+string description
+string priority
+
+---
+# Result
+bool success
+string report
+
+---
+# Feedback
+string current_status
+""",
+ "/load_map": """
+string filename
+---
+bool success
+""",
+ "/query_planner_interface": """
+---
+
+# The planning instances that could be used in the benchmark
+PlannerInterfaceDescription[] planner_interfaces
+ string name
+ string pipeline_id
+ string[] planner_ids
+
+""",
+}
+
class Result(BaseModel):
success: bool = False
@@ -89,6 +809,167 @@ def verify_tool_calls(self, response: dict[str, Any]):
"""
pass
+ def _check_topic_tool_call_field(
+ self,
+ tool_call: ToolCall,
+ expected_name: str,
+ expected_topic: str,
+ expected_message_type: str,
+ field_path: str,
+ expected_value: Any,
+ ) -> bool:
+ """
+ Verifies a tool call for a topic publishing operation.
+
+ Parameters
+ ----------
+ tool_call : ToolCall
+ The tool call dictionary containing keys such as "name" and "args".
+ expected_name : str
+ The expected tool call name (e.g., "publish_ros2_message").
+ expected_topic : str
+ The expected topic name in the tool call's arguments.
+ expected_message_type : str
+ The expected message type (e.g., "rai_interfaces/msg/HRIMessage").
+ field_path : str
+ Dot-separated path to the field inside the message (e.g., "header.frame_id").
+ expected_value : Any
+ The expected value at the given field path.
+
+ Returns
+ -------
+ bool
+ True if all conditions are met; False otherwise.
+ """
+ # Check tool call name.
+ if tool_call.get("name") != expected_name:
+ self.log_error(
+ f"Expected tool call name '{expected_name}', but got '{tool_call.get('name')}'."
+ )
+ return False
+
+ args = tool_call.get("args", {})
+
+ # Check topic.
+ if args.get("topic") != expected_topic:
+ self.log_error(
+ f"Expected topic '{expected_topic}', but got '{args.get('topic')}'."
+ )
+ return False
+
+ # Check message type.
+ if args.get("message_type") != expected_message_type:
+ self.log_error(
+ f"Expected message type '{expected_message_type}', but got '{args.get('message_type')}'."
+ )
+ return False
+
+ # Traverse the message field.
+ message = args.get("message")
+ if message is None:
+ self.log_error("Tool call does not contain a 'message' argument.")
+ return False
+
+ keys = field_path.split(".")
+ value: Any = message
+ for key in keys:
+ if isinstance(value, dict) and key in value:
+ value = value[key]
+ else:
+ self.log_error(f"Field path '{field_path}' not found in the message.")
+ return False
+
+ if value != expected_value:
+ self.log_error(
+ f"Expected value for field '{field_path}' is '{expected_value}', but got '{value}'."
+ )
+ return False
+
+ return True
+
+ def _check_service_tool_call_field(
+ self,
+ tool_call: ToolCall,
+ expected_name: str,
+ expected_service: str,
+ expected_service_type: str,
+ field_path: str,
+ expected_value: Any,
+ ) -> bool:
+ """
+ Verifies a tool call for a service call.
+
+ Parameters
+ ----------
+ tool_call : ToolCall
+ The tool call dictionary containing keys such as "name" and "args".
+ expected_name : str
+ The expected tool call name (e.g., "call_ros2_service").
+ expected_service : str
+ The expected service name in the tool call's arguments.
+ expected_message_type : str
+ The expected message type.
+ field_path : str
+ Dot-separated path to the field inside the message.
+ expected_value : Any
+ The expected value at the given field path.
+
+ Returns
+ -------
+ bool
+ True if all conditions are met; False otherwise.
+ """
+ if tool_call.get("name") != expected_name:
+ self.log_error(
+ f"Expected tool call name '{expected_name}', but got '{tool_call.get('name')}'."
+ )
+ return False
+
+ args = tool_call.get("args", {})
+
+ # Check service.
+ if args.get("service_name") != expected_service:
+ self.log_error(
+ f"Expected service '{expected_service}', but got '{args.get('service')}'."
+ )
+ return False
+
+ # Check message type.
+ if args.get("service_type") != expected_service_type:
+ self.log_error(
+ f"Expected message type '{expected_service_type}', but got '{args.get('service_type')}'."
+ )
+ return False
+
+ service_args = args.get("service_args")
+ if service_args is None:
+ self.log_error("Tool call does not contain a 'service_args' argument.")
+ return False
+
+ if field_path == "":
+ if service_args == {}:
+ return True
+ else:
+ self.log_error(f"Expected empty service_args, but got: {service_args}")
+ return False
+
+ keys = field_path.split(".")
+ value: Any = service_args
+ for key in keys:
+ if isinstance(value, dict) and key in value:
+ value = value[key]
+ else:
+ self.log_error(f"Field path '{field_path}' not found in the message.")
+ return False
+
+ if value != expected_value:
+ self.log_error(
+ f"Expected value for field '{field_path}' is '{expected_value}', but got '{value}'."
+ )
+ return False
+
+ return True
+
def _check_tool_call(
self,
tool_call: ToolCall,
@@ -274,3 +1155,438 @@ def _is_ai_message_requesting_get_ros2_topics_and_types(
):
return False
return True
+
+ def _is_ai_message_requesting_get_ros2_services_and_types(
+ self, ai_message: AIMessage
+ ) -> bool:
+ """Helper method to check if the given AIMessage is calling the exactly one tool that gets ROS2 service names and types correctly.
+
+ Parameters
+ ----------
+ ai_message : AIMessage
+ The AIMessage to check
+
+ Returns
+ -------
+ bool
+ True if the ai_message is requesting get_ros2_service_names_and_types correctly, False otherwise
+ """
+ if not self._check_tool_calls_num_in_ai_message(ai_message, expected_num=1):
+ return False
+
+ tool_call: ToolCall = ai_message.tool_calls[0]
+ if not self._check_tool_call(
+ tool_call=tool_call,
+ expected_name="get_ros2_services_names_and_types",
+ expected_args={},
+ ):
+ return False
+ return True
+
+ def _is_ai_message_requesting_get_ros2_actions_and_types(
+ self, ai_message: AIMessage
+ ) -> bool:
+ """Helper method to check if the given AIMessage is calling the exactly one tool that gets ROS2 actions names and types correctly.
+
+ Parameters
+ ----------
+ ai_message : AIMessage
+ The AIMessage to check
+
+ Returns
+ -------
+ bool
+ True if the ai_message is requesting get_ros2_actions_names_and_types correctly, False otherwise
+ """
+ if not self._check_tool_calls_num_in_ai_message(ai_message, expected_num=1):
+ return False
+
+ tool_call: ToolCall = ai_message.tool_calls[0]
+ if not self._check_tool_call(
+ tool_call=tool_call,
+ expected_name="get_ros2_actions_names_and_types",
+ expected_args={},
+ ):
+ return False
+ return True
+
+ def get_tool_calls(self, response: dict[str, Any]) -> list[ToolCall]:
+ """Extracts all tool calls from the response, flattened across all AI messages."""
+ tool_calls: List[ToolCall] = []
+ for msg in response["messages"]:
+ if isinstance(msg, AIMessage):
+ tool_calls.extend(msg.tool_calls)
+ return tool_calls
+
+
+SERVICES_AND_TYPES = {
+ # sample interfaces
+ # "/load_map": "moveit_msgs/srv/LoadMap",
+ # "/query_planner_interface": "moveit_msgs/srv/QueryPlannerInterfaces",
+ # custom interfaces
+ "/manipulator_move_to": "rai_interfaces/srv/ManipulatorMoveTo",
+ "/grounded_sam_segment": "rai_interfaces/srv/RAIGroundedSam",
+ "/grounding_dino_classify": "rai_interfaces/srv/RAIGroundingDino",
+ "/get_log_digest": "rai_interfaces/srv/StringList",
+ "/rai_whoami_documentation_service": "rai_interfaces/srv/VectorStoreRetrieval",
+ "/rai/whatisee/get": "rai_interfaces/srv/WhatISee",
+}
+
+SERVICE_MODELS: Dict[str, Type[BaseModel]] = {
+ "rai_interfaces/srv/ManipulatorMoveTo": ManipulatorMoveToRequest,
+ "rai_interfaces/srv/RAIGroundedSam": RAIGroundedSamRequest,
+ "rai_interfaces/srv/RAIGroundingDino": RAIGroundingDinoRequest,
+ "rai_interfaces/srv/StringList": StringListRequest,
+ "rai_interfaces/srv/VectorStoreRetrieval": VectorStoreRetrievalRequest,
+ "rai_interfaces/srv/WhatISee": WhatISeeRequest,
+}
+
+TOPICS_AND_TYPES: Dict[str, str] = {
+ # sample topics
+ "/camera_image_color": "sensor_msgs/msg/Image",
+ "/camera_image_depth": "sensor_msgs/msg/Image",
+ "/clock": "rosgraph_msgs/msg/Clock",
+ "/color_camera_info": "sensor_msgs/msg/CameraInfo",
+ "/color_camera_info5": "sensor_msgs/msg/CameraInfo",
+ "/depth_camera_info5": "sensor_msgs/msg/CameraInfo",
+ "/depth_image5": "sensor_msgs/msg/Image",
+ # custom topics
+ "/to_human": "rai_interfaces/msg/HRIMessage",
+ "/send_audio": "rai_interfaces/msg/AudioMessage",
+ "/send_detections": "rai_interfaces/msg/RAIDetectionArray",
+}
+TOPIC_STRINGS = [
+ f"topic: {topic}\ntype: {msg_type}\n"
+ for topic, msg_type in TOPICS_AND_TYPES.items()
+]
+TOPIC_MODELS: Dict[str, Type[BaseModel]] = {
+ "sensor_msgs/msg/CameraInfo": CameraInfo,
+ "sensor_msgs/msg/Image": Image,
+ "rosgraph_msgs/msg/Clock": Clock,
+ "rai_interfaces/msg/HRIMessage": HRIMessage,
+ "rai_interfaces/msg/AudioMessage": AudioMessage,
+ "rai_interfaces/msg/RAIDetectionArray": RAIDetectionArray,
+}
+
+IMAGE_TOPICS: Dict[str, str] = {
+ "/attached_collision_object": "moveit_msgs/msg/AttachedCollisionObject",
+ "/camera_image_color": "sensor_msgs/msg/Image",
+ "/camera_image_depth": "sensor_msgs/msg/Image",
+ "/clock": "rosgraph_msgs/msg/Clock",
+ "/collision_object": "moveit_msgs/msg/CollisionObject",
+ "/color_camera_info": "sensor_msgs/msg/CameraInfo",
+ "/color_camera_info5": "sensor_msgs/msg/CameraInfo",
+ "/depth_camera_info5": "sensor_msgs/msg/CameraInfo",
+}
+
+SERVICE_STRINGS = [
+ f"service: {service}\ntype: {msg_type}\n"
+ for service, msg_type in SERVICES_AND_TYPES.items()
+]
+
+
+class CustomInterfacesTopicTask(ROS2ToolCallingAgentTask, ABC):
+ def __init__(self, logger: loggers_type | None = None) -> None:
+ super().__init__(logger=logger)
+
+ self.expected_tools: List[BaseTool] = [
+ MockGetROS2TopicsNamesAndTypesTool(
+ mock_topics_names_and_types=TOPIC_STRINGS
+ ),
+ MockGetROS2MessageInterfaceTool(mock_interfaces=MOCK_INTERFACES),
+ MockPublishROS2MessageTool(
+ available_topics=list(TOPICS_AND_TYPES.keys()),
+ available_message_types=list(TOPICS_AND_TYPES.values()),
+ available_topic_models=TOPIC_MODELS,
+ ),
+ MockCancelROS2ActionTool(),
+ MockGetROS2ActionIDsTool(),
+ MockMoveToPointTool(manipulator_frame="base_link"),
+ MockGetROS2ImageTool(available_topics=list(IMAGE_TOPICS.keys())),
+ MockGetROS2ServicesNamesAndTypesTool(
+ mock_service_names_and_types=SERVICE_STRINGS
+ ),
+ MockGetROS2MessageInterfaceTool(mock_interfaces=MOCK_INTERFACES),
+ MockCallROS2ServiceTool(
+ available_services=list(SERVICES_AND_TYPES.keys()),
+ available_service_types=list(SERVICES_AND_TYPES.values()),
+ available_service_models=SERVICE_MODELS,
+ ),
+ ]
+
+ @property
+ @abstractmethod
+ def expected_topic(self) -> str:
+ pass
+
+ @property
+ def expected_message_type(self) -> str:
+ return TOPICS_AND_TYPES[self.expected_topic]
+
+ @property
+ def extra_calls(self) -> int:
+ return 0
+
+ def verify_list_and_get_interface_tool_calls(
+ self, tool_calls: List[ToolCall]
+ ) -> tuple[bool, list[ToolCall]]:
+ """
+ Verifies tool calls in this required order:
+ 1. get_ros2_topics_and_types
+ 2. get_ros2_message_interface (with correct msg_type)
+
+ Returns
+ -------
+ Tuple[bool, List[AIMessage]]
+ Success flag and remaining messages (to be used in `verify_message_tool_call`)
+ """
+
+ expected_core_calls = 3
+ max_allowed = expected_core_calls + self.extra_calls
+ if len(tool_calls) > max_allowed:
+ self.log_error(
+ f"Too many tool calls. Expected at most {max_allowed}, got {len(tool_calls)}."
+ )
+ return False, []
+
+ stage = 0 # 0: expect topics, 1: expect interface
+ for idx, call in enumerate(tool_calls):
+ if stage == 0 and call["name"] == "get_ros2_topics_names_and_types":
+ stage = 1
+ continue
+
+ if stage == 1 and call["name"] == "get_ros2_message_interface":
+ if call["args"].get("msg_type") == self.expected_message_type:
+ stage = 2
+ return True, tool_calls[idx + 1 :]
+
+ self.log_error("Required tool calls not found in order: topics → interface")
+ return False, []
+
+ @abstractmethod
+ def verify_message_tool_call(self, tool_calls: List[ToolCall]) -> bool:
+ """
+ Search the remaining AI messages for the expected publish/service tool call.
+ """
+ pass
+
+ def verify_tool_calls(self, response: dict[str, Any]):
+ """
+ Validates the full sequence of AI tool calls with support for extras and ordering.
+
+ Steps:
+ 1. Get topics
+ 2. Get message interface
+ 3. Call publish/service with expected content
+ """
+ messages = response["messages"]
+ ai_messages: Sequence[AIMessage] = [
+ msg for msg in messages if isinstance(msg, AIMessage)
+ ]
+ self.logger.debug(f"AI messages: {ai_messages}")
+ tool_calls = self.get_tool_calls(response)
+
+ # success, remaining_tool_calls = self.verify_list_and_get_interface_tool_calls(
+ # tool_calls
+ # )
+ # if success and self.verify_message_tool_call(remaining_tool_calls):
+ # self.result.success = True
+ if self.verify_message_tool_call(tool_calls):
+ self.result.success = True
+
+
+class CustomInterfacesServiceTask(ROS2ToolCallingAgentTask, ABC):
+ def __init__(self, logger: loggers_type | None = None) -> None:
+ super().__init__(logger=logger)
+ self.expected_tools: List[BaseTool] = [
+ MockGetROS2TopicsNamesAndTypesTool(
+ mock_topics_names_and_types=TOPIC_STRINGS
+ ),
+ MockGetROS2MessageInterfaceTool(mock_interfaces=MOCK_INTERFACES),
+ MockPublishROS2MessageTool(
+ available_topics=list(TOPICS_AND_TYPES.keys()),
+ available_message_types=list(TOPICS_AND_TYPES.values()),
+ available_topic_models=TOPIC_MODELS,
+ ),
+ MockCancelROS2ActionTool(),
+ MockGetROS2ActionIDsTool(),
+ MockMoveToPointTool(manipulator_frame="base_link"),
+ MockGetROS2ImageTool(available_topics=list(IMAGE_TOPICS.keys())),
+ MockGetROS2ServicesNamesAndTypesTool(
+ mock_service_names_and_types=SERVICE_STRINGS
+ ),
+ MockGetROS2MessageInterfaceTool(mock_interfaces=MOCK_INTERFACES),
+ MockCallROS2ServiceTool(
+ available_services=list(SERVICES_AND_TYPES.keys()),
+ available_service_types=list(SERVICES_AND_TYPES.values()),
+ available_service_models=SERVICE_MODELS,
+ ),
+ ]
+
+ @property
+ @abstractmethod
+ def expected_service(self) -> str:
+ pass
+
+ @property
+ def expected_service_type(self) -> str:
+ return SERVICES_AND_TYPES[self.expected_service]
+
+ @property
+ def extra_calls(self) -> int:
+ return 0
+
+ def verify_list_and_get_interface_tool_calls(
+ self, tool_calls: List[ToolCall]
+ ) -> tuple[bool, list[ToolCall]]:
+ """
+ Verifies tool calls in this required order:
+ 1. get_ros2_services_names_and_types
+ 2. get_ros2_message_interface (with correct msg_type)
+
+ Returns
+ -------
+ Tuple[bool, List[ToolCall]]
+ Success flag and remaining tool calls for message verification
+ """
+ expected_core_calls = 3
+ max_allowed = expected_core_calls + self.extra_calls
+ if len(tool_calls) > max_allowed:
+ self.log_error(
+ f"Too many tool calls. Expected at most {max_allowed}, got {len(tool_calls)}."
+ )
+ return False, []
+
+ stage = 0 # 0: expect service list, 1: expect interface
+ for idx, call in enumerate(tool_calls):
+ if stage == 0 and call["name"] == "get_ros2_services_names_and_types":
+ stage = 1
+ continue
+
+ if stage == 1 and call["name"] == "get_ros2_message_interface":
+ if call["args"].get("msg_type") == self.expected_service_type:
+ stage = 2
+ return True, tool_calls[idx + 1 :]
+
+ self.log_error("Required tool calls not found in order: services → interface")
+ return False, []
+
+ @abstractmethod
+ def verify_message_tool_call(self, tool_calls: List[ToolCall]) -> bool:
+ """Search the remaining tool calls for the expected service call."""
+ pass
+
+ def verify_tool_calls(self, response: dict[str, Any]):
+ """
+ Full tool call sequence verification:
+ 1. Get services
+ 2. Get message interface
+ 3. Call service with expected values
+ """
+ messages = response["messages"]
+ ai_messages: Sequence[AIMessage] = [
+ msg for msg in messages if isinstance(msg, AIMessage)
+ ]
+ self.logger.debug(f"AI messages: {ai_messages}")
+ tool_calls = self.get_tool_calls(response)
+
+ # success, remaining_tool_calls = self.verify_list_and_get_interface_tool_calls(
+ # tool_calls
+ # )
+ # if success and
+
+ if self.verify_message_tool_call(tool_calls):
+ self.result.success = True
+
+
+class CustomInterfacesActionTask(ROS2ToolCallingAgentTask, ABC):
+ ACTIONS_AND_TYPES = {
+ # custom actions
+ "/perform_task": "rai_interfaces/action/Task",
+ # some sample actions
+ # "/execute_trajectory": "moveit_msgs/action/ExecuteTrajectory",
+ # "/move_action": "moveit_msgs/action/MoveGroup",
+ # "/follow_joint_trajectory": "control_msgs/action/FollowJointTrajectory",
+ # "/gripper_cmd": "control_msgs/action/GripperCommand",
+ }
+
+ action_strings = [
+ f"action: {action}\ntype: {msg_type}\n"
+ for action, msg_type in ACTIONS_AND_TYPES.items()
+ ]
+
+ def __init__(self, logger: loggers_type | None = None) -> None:
+ super().__init__(logger=logger)
+ self.expected_tools: List[BaseTool] = [
+ MockGetROS2ActionsNamesAndTypesTool(
+ mock_actions_names_and_types=self.action_strings
+ ),
+ MockGetROS2MessageInterfaceTool(mock_interfaces=MOCK_INTERFACES),
+ MockStartROS2ActionTool(
+ available_actions=list(self.ACTIONS_AND_TYPES.keys()),
+ available_action_types=list(self.ACTIONS_AND_TYPES.values()),
+ ),
+ ]
+
+ @property
+ @abstractmethod
+ def expected_action(self) -> str:
+ pass
+
+ @property
+ @abstractmethod
+ def expected_message(self) -> BaseModel:
+ pass
+
+ @property
+ def expected_action_type(self) -> str:
+ return self.ACTIONS_AND_TYPES[self.expected_action]
+
+ def verify_tool_calls(self, response: dict[str, Any]):
+ """It is expected that the agent will request:
+ 1. The tool that retrieves the topics names and types to recognize what type of message to_human topic has
+ 2. The tool that retrieves interfaces to check HRIMessage type
+ 3. The tool to publish message with proper topic, message type and content
+
+ Parameters
+ ----------
+ response : dict[str, Any]
+ The response from the agent
+ """
+ messages = response["messages"]
+ ai_messages: Sequence[AIMessage] = [
+ message for message in messages if isinstance(message, AIMessage)
+ ]
+ self.logger.debug(ai_messages)
+ if len(ai_messages) != 4:
+ self.log_error(
+ msg=f"Expected exactly 4 AI messages, but got {len(ai_messages)}."
+ )
+ if ai_messages:
+ if not self._is_ai_message_requesting_get_ros2_actions_and_types(
+ ai_messages[0]
+ ):
+ self.log_error(
+ msg="First AI message did not request ROS2 topics and types correctly."
+ )
+ if len(ai_messages) > 1:
+ if self._check_tool_calls_num_in_ai_message(ai_messages[1], expected_num=1):
+ self._check_tool_call(
+ tool_call=ai_messages[1].tool_calls[0],
+ expected_name="get_ros2_message_interface",
+ expected_args={"msg_type": self.expected_action_type},
+ )
+
+ if len(ai_messages) > 2:
+ if self._check_tool_calls_num_in_ai_message(ai_messages[2], expected_num=1):
+ self._check_tool_call(
+ tool_call=ai_messages[2].tool_calls[0],
+ expected_name="start_ros2_action",
+ expected_args={
+ "action_name": self.expected_action,
+ "action_args": self.expected_message.model_dump(),
+ "action_type": self.expected_action_type,
+ },
+ )
+ if not self.result.errors:
+ self.result.success = True
diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/__init__.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/__init__.py
new file mode 100644
index 000000000..97ceef6f0
--- /dev/null
+++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (C) 2025 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/actions.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/actions.py
new file mode 100644
index 000000000..9fd286938
--- /dev/null
+++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/actions.py
@@ -0,0 +1,40 @@
+# Copyright (C) 2025 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+from pydantic import BaseModel
+
+
+class TaskGoal(BaseModel):
+ task: Optional[str] = ""
+ description: Optional[str] = ""
+ priority: Optional[str] = ""
+
+
+class TaskResult(BaseModel):
+ success: Optional[bool] = False
+ report: Optional[str] = ""
+
+
+class TaskFeedback(BaseModel):
+ current_status: Optional[str] = ""
+
+
+class LoadMapRequest(BaseModel):
+ filename: Optional[str] = ""
+
+
+class LoadMapResponse(BaseModel):
+ success: Optional[bool] = False
diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/base.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/base.py
new file mode 100644
index 000000000..ccfa21f21
--- /dev/null
+++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/base.py
@@ -0,0 +1,101 @@
+# Copyright (C) 2025 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional
+
+from pydantic import BaseModel
+
+
+# TODO (jm) redundant with action models, remove action models later
+class Time(BaseModel):
+ sec: Optional[int] = 0
+ nanosec: Optional[int] = 0
+
+
+class Header(BaseModel):
+ stamp: Optional[Time] = Time()
+ frame_id: Optional[str] = ""
+
+
+class RegionOfInterest(BaseModel):
+ x_offset: Optional[int] = 0
+ y_offset: Optional[int] = 0
+ height: Optional[int] = 0
+ width: Optional[int] = 0
+ do_rectify: Optional[bool] = False
+
+
+class Position(BaseModel):
+ x: Optional[float] = 0.0
+ y: Optional[float] = 0.0
+ z: Optional[float] = 0.0
+
+
+class Orientation(BaseModel):
+ x: Optional[float] = 0.0
+ y: Optional[float] = 0.0
+ z: Optional[float] = 0.0
+ w: Optional[float] = 1.0
+
+
+class Pose(BaseModel):
+ position: Optional[Position] = Position()
+ orientation: Optional[Orientation] = Orientation()
+
+
+class PoseStamped(BaseModel):
+ header: Optional[Header] = Header()
+ pose: Optional[Pose] = Pose()
+
+
+class Clock(BaseModel):
+ clock: Optional[Time] = Time()
+
+
+class ObjectHypothesis(BaseModel):
+ class_id: Optional[str] = ""
+ score: Optional[float] = 0.0
+
+
+class PoseWithCovariance(BaseModel):
+ pose: Optional[Pose] = Pose()
+ covariance: Optional[List[float]] = [0.0] * 36
+
+
+class ObjectHypothesisWithPose(BaseModel):
+ hypothesis: Optional[ObjectHypothesis] = ObjectHypothesis()
+ pose: Optional[PoseWithCovariance] = PoseWithCovariance()
+
+
+class Point2D(BaseModel):
+ x: Optional[float] = 0.0
+ y: Optional[float] = 0.0
+
+
+class Pose2D(BaseModel):
+ position: Optional[Point2D] = Point2D()
+ theta: Optional[float] = 0.0
+
+
+class BoundingBox2D(BaseModel):
+ center: Optional[Pose2D] = Pose2D()
+ size_x: Optional[float] = 0.0
+ size_y: Optional[float] = 0.0
+
+
+class Detection2D(BaseModel):
+ header: Optional[Header] = Header()
+ results: Optional[List[ObjectHypothesisWithPose]] = []
+ bbox: Optional[BoundingBox2D] = BoundingBox2D()
+ id: Optional[str] = ""
diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/services.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/services.py
new file mode 100644
index 000000000..c60cdbcf2
--- /dev/null
+++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/services.py
@@ -0,0 +1,91 @@
+# Copyright (C) 2025 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional
+
+from pydantic import BaseModel
+
+from rai_bench.tool_calling_agent_bench.messages.base import Pose, PoseStamped
+from rai_bench.tool_calling_agent_bench.messages.topics import Image, RAIDetectionArray
+
+
+class ManipulatorMoveToRequest(BaseModel):
+ initial_gripper_state: Optional[bool] = False
+ final_gripper_state: Optional[bool] = False
+ target_pose: Optional[PoseStamped] = PoseStamped()
+
+
+class ManipulatorMoveToResponse(BaseModel):
+ success: Optional[bool] = False
+
+
+class RAIGroundedSamRequest(BaseModel):
+ detections: Optional[RAIDetectionArray] = RAIDetectionArray()
+ source_img: Optional[Image] = Image()
+
+
+class RAIGroundedSamResponse(BaseModel):
+ masks: Optional[List[Image]] = []
+
+
+class RAIGroundingDinoRequest(BaseModel):
+ classes: Optional[str] = ""
+ box_threshold: Optional[float] = 0.0
+ text_threshold: Optional[float] = 0.0
+ source_img: Optional[Image] = Image()
+
+
+class RAIGroundingDinoResponse(BaseModel):
+ detections: Optional[RAIDetectionArray] = RAIDetectionArray()
+
+
+class StringListRequest(BaseModel):
+ pass
+
+
+class StringListResponse(BaseModel):
+ success: Optional[bool] = False
+ string_list: Optional[List[str]] = []
+
+
+class VectorStoreRetrievalRequest(BaseModel):
+ query: Optional[str] = ""
+
+
+class VectorStoreRetrievalResponse(BaseModel):
+ success: Optional[bool] = False
+ message: Optional[str] = ""
+ documents: Optional[List[str]] = []
+ scores: Optional[List[float]] = []
+
+
+class WhatISeeRequest(BaseModel):
+ pass
+
+
+class WhatISeeResponse(BaseModel):
+ observations: Optional[List[str]] = []
+ perception_source: Optional[str] = ""
+ image: Optional[Image] = Image()
+ pose: Optional[Pose] = Pose()
+
+
+class PlannerInterfaceDescription(BaseModel):
+ name: Optional[str] = ""
+ pipeline_id: Optional[str] = ""
+ planner_ids: Optional[List[str]] = []
+
+
+class QueryPlannerInterfaceResponse(BaseModel):
+ planner_interfaces: Optional[List[PlannerInterfaceDescription]] = []
diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/topics.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/topics.py
new file mode 100644
index 000000000..1a68d6d77
--- /dev/null
+++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/messages/topics.py
@@ -0,0 +1,69 @@
+# Copyright (C) 2025 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional
+
+from pydantic import BaseModel
+
+from rai_bench.tool_calling_agent_bench.messages.base import (
+ Detection2D,
+ Header,
+ RegionOfInterest,
+)
+
+
+class CameraInfo(BaseModel):
+ header: Optional[Header] = Header()
+ height: Optional[int] = 0
+ width: Optional[int] = 0
+ distortion_model: Optional[str] = ""
+ d: Optional[List[float]] = []
+ k: Optional[List[float]] = [0.0] * 9
+ r: Optional[List[float]] = [0.0] * 9
+ p: Optional[List[float]] = [0.0] * 12
+ binning_x: Optional[int] = 0
+ binning_y: Optional[int] = 0
+ roi: Optional[RegionOfInterest] = RegionOfInterest()
+
+
+class Image(BaseModel):
+ header: Optional[Header] = Header()
+ height: Optional[int] = 0
+ width: Optional[int] = 0
+ encoding: Optional[str] = ""
+ is_bigendian: Optional[int] = 0
+ step: Optional[int] = 0
+ data: Optional[List[int]] = []
+
+
+class AudioMessage(BaseModel):
+ audio: Optional[List[int]] = []
+ sample_rate: Optional[int] = 0
+ channels: Optional[int] = 0
+
+
+class HRIMessage(BaseModel):
+ header: Optional[Header] = Header()
+ text: Optional[str] = ""
+ images: Optional[List[Image]] = []
+ audios: Optional[List[AudioMessage]] = []
+ communication_id: Optional[str] = ""
+ seq_no: Optional[int] = 0
+ seq_end: Optional[bool] = False
+
+
+class RAIDetectionArray(BaseModel):
+ header: Optional[Header] = Header()
+ detections: Optional[List[Detection2D]] = []
+ detection_classes: Optional[List[str]] = []
diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/mocked_tools.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/mocked_tools.py
index cdaad998b..ca8baace0 100644
--- a/src/rai_bench/rai_bench/tool_calling_agent_bench/mocked_tools.py
+++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/mocked_tools.py
@@ -12,11 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Tuple
+import uuid
+from threading import Lock
+from typing import Any, Dict, List, Tuple, Type
from unittest.mock import MagicMock
import numpy as np
import numpy.typing as npt
+from pydantic import BaseModel, ValidationError
from rai.communication.ros2.connectors import ROS2ARIConnector
from rai.communication.ros2.messages import ROS2ARIMessage
from rai.messages import MultimodalArtifact, preprocess_image
@@ -26,11 +29,23 @@
MoveToPointTool,
)
from rai.tools.ros2 import (
+ CallROS2ServiceTool,
+ CancelROS2ActionTool,
+ GetROS2ActionFeedbackTool,
+ GetROS2ActionIDsTool,
+ GetROS2ActionResultTool,
+ GetROS2ActionsNamesAndTypesTool,
GetROS2ImageTool,
+ GetROS2MessageInterfaceTool,
+ GetROS2ServicesNamesAndTypesTool,
GetROS2TopicsNamesAndTypesTool,
+ PublishROS2MessageTool,
ReceiveROS2MessageTool,
+ StartROS2ActionTool,
)
+from rai_bench.tool_calling_agent_bench.actions.action_base_model import ActionBaseModel
+
class MockGetROS2TopicsNamesAndTypesTool(GetROS2TopicsNamesAndTypesTool):
connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector)
@@ -49,7 +64,7 @@ def _run(self) -> str:
class MockGetROS2ImageTool(GetROS2ImageTool):
connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector)
- expected_topics: List[str]
+ available_topics: List[str]
def _run(
self, topic: str, timeout_sec: float = 1.0
@@ -73,7 +88,7 @@ def _run(
ValueError
If the passed topic is not correct.
"""
- if topic not in self.expected_topics:
+ if topic not in self.available_topics:
raise ValueError(
f"Topic {topic} is not available within {timeout_sec} seconds. Check if the topic exists."
)
@@ -98,7 +113,7 @@ def generate_mock_image() -> npt.NDArray[np.uint8]:
class MockReceiveROS2MessageTool(ReceiveROS2MessageTool):
connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector)
- expected_topics: List[str]
+ available_topics: List[str]
def _run(self, topic: str) -> str:
"""Method that returns a mock message if the passed topic is correct.
@@ -118,7 +133,7 @@ def _run(self, topic: str) -> str:
ValueError
If the passed topic is not correct.
"""
- if topic not in self.expected_topics:
+ if topic not in self.available_topics:
raise ValueError(
f"Topic {topic} is not available within 1.0 seconds. Check if the topic exists."
)
@@ -181,3 +196,217 @@ def _run(self, object_name: str) -> str:
return f"No {object_name}s detected."
else:
return f"Centroids of detected {object_name}s in manipulator frame: {expected_positions} Sizes of the detected objects are unknown."
+
+
+class MockPublishROS2MessageTool(PublishROS2MessageTool):
+ connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector)
+ available_topics: List[str]
+ available_message_types: List[str]
+ available_topic_models: Dict[str, Type[BaseModel]]
+
+ def _run(self, topic: str, message: Dict[str, Any], message_type: str) -> str:
+ """
+ Mocked method that simulates publihing to a topic and return a status string.
+
+ Parameters
+ ----------
+ topic : str
+ The name of the topic to which the message is published.
+ message : Dict[str, Any]
+ The content of the message as a dictionary.
+ message_type : str
+ The type of the message being published.
+
+ """
+ if topic not in self.available_topics:
+ raise ValueError(
+ f"Topic {topic} is not available within 1.0 seconds. Check if the topic exists."
+ )
+ if message_type not in self.available_message_types:
+ raise TypeError(
+ "Expected message one of message types: {}, got {}".format(
+ self.available_message_types, message_type
+ )
+ )
+
+ model = self.available_topic_models[message_type]
+ try:
+ model.model_validate(message)
+ except ValidationError as e:
+ raise ValueError(f"Failed to populate fields: {e}")
+
+ return "Message published successfully"
+
+
+class MockGetROS2MessageInterfaceTool(GetROS2MessageInterfaceTool):
+ connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector)
+ mock_interfaces: Dict[str, str]
+
+ def _run(self, msg_type: str) -> str:
+ """
+ Mocked method that returns the interface definition for a given ROS2 message type.
+
+ Parameters
+ ----------
+ msg_type : str
+ The ROS2 message type for which to retrieve the interface definition.
+
+ Returns
+ -------
+ str
+ The mocked output of 'ros2 interface show' for the specified message type.
+ """
+ if msg_type in self.mock_interfaces:
+ return self.mock_interfaces[msg_type]
+ else:
+ raise ImportError(f"Module {msg_type} not found.")
+
+
+class MockCallROS2ServiceTool(CallROS2ServiceTool):
+ connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector)
+ available_services: List[str]
+ available_service_types: List[str]
+ available_service_models: Dict[str, Type[BaseModel]]
+
+ def _run(
+ self,
+ service_name: str,
+ service_type: str,
+ service_args: Dict[str, Any],
+ ) -> str:
+ if service_name not in self.available_services:
+ raise ValueError(
+ f"Service {service_name} is not available within 1.0 seconds. Check if the service exists."
+ )
+ if service_type not in self.available_service_types:
+ raise TypeError(
+ "Expected one of service types: {}, got {}".format(
+ self.available_service_types, service_type
+ )
+ )
+ if service_type in self.available_service_models:
+ model = self.available_service_models[service_type]
+ try:
+ model.model_validate(service_args)
+ except ValidationError as e:
+ raise ValueError(f"Failed to populate fields: {e}")
+ response = ROS2ARIMessage(payload={"response": "success"})
+ return str(
+ {
+ "payload": response.payload,
+ "metadata": response.metadata,
+ }
+ )
+ else:
+ raise KeyError(
+ f"Model for service type {service_type} not included in models"
+ )
+
+
+class MockGetROS2ServicesNamesAndTypesTool(GetROS2ServicesNamesAndTypesTool):
+ connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector)
+ mock_service_names_and_types: list[str]
+
+ def _run(self) -> str:
+ """Mocked method that returns the mock topics and types instead of fetching from ROS2.
+
+ Returns
+ -------
+ str
+ Mocked output of 'get_ros2_topics_names_and_types' tool.
+ """
+ return "\n".join(self.mock_service_names_and_types)
+
+
+class MockGetROS2ActionsNamesAndTypesTool(GetROS2ActionsNamesAndTypesTool):
+ connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector)
+ mock_actions_names_and_types: list[str]
+
+ def _run(self) -> str:
+ """Mocked method that returns the mock topics and types instead of fetching from ROS2.
+
+ Returns
+ -------
+ str
+ Mocked output of 'get_ros2_topics_names_and_types' tool.
+ """
+ return "\n".join(self.mock_actions_names_and_types)
+
+
+class MockStartROS2ActionTool(StartROS2ActionTool):
+ connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector)
+ available_actions: List[str] = []
+ available_action_types: List[str] = []
+ available_action_models: List[type[ActionBaseModel]]
+
+ def _run(
+ self, action_name: str, action_type: str, action_args: Dict[str, Any]
+ ) -> str:
+ if action_name not in self.available_actions:
+ raise ValueError(
+ f"Action {action_name} is not available within 1.0 seconds. Check if the action exists."
+ )
+ if action_type not in self.available_action_types:
+ raise TypeError(
+ f"Expected one of action types: {self.available_action_types}, got {action_type}"
+ )
+ for action_model in self.available_action_models:
+ if (
+ action_model.model_fields["action_name"].default == action_name
+ and action_model.model_fields["action_type"].default == action_type
+ ):
+ goal = action_model.__annotations__["goal"]
+ goal.model_validate(action_args)
+ action_id = str(uuid.uuid4())
+ response = action_id
+ self.internal_action_id_mapping[response] = action_id
+ return "Action started with ID: " + response
+
+
+class MockCancelROS2ActionTool(CancelROS2ActionTool):
+ connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector)
+ available_action_ids: List[str] = []
+
+ def _run(self, action_id: str) -> str:
+ if action_id not in self.available_action_ids:
+ raise ValueError(f"Action {action_id} is not available for cancellation.")
+ return f"Action {action_id} cancelled"
+
+
+class MockGetROS2ActionFeedbackTool(GetROS2ActionFeedbackTool):
+ connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector)
+ available_feedbacks: Dict[str, List[Any]] = {}
+ internal_action_id_mapping: Dict[str, str] = {}
+ action_feedbacks_store_lock: Lock = Lock()
+
+ def _run(self, action_id: str) -> str:
+ if action_id not in self.internal_action_id_mapping:
+ raise KeyError(f"Action ID {action_id} not found in internal mapping.")
+ external_id = self.internal_action_id_mapping[action_id]
+ with self.action_feedbacks_store_lock:
+ feedbacks = self.available_feedbacks.get(external_id, [])
+ self.available_feedbacks[external_id] = []
+ return str(feedbacks)
+
+
+class MockGetROS2ActionResultTool(GetROS2ActionResultTool):
+ available_results: Dict[str, Any] = {}
+ internal_action_id_mapping: Dict[str, str] = {}
+ action_results_store_lock: Lock = Lock()
+
+ def _run(self, action_id: str) -> str:
+ if action_id not in self.internal_action_id_mapping:
+ raise KeyError(f"Action ID {action_id} not found in internal mapping.")
+ external_id = self.internal_action_id_mapping[action_id]
+ with self.action_results_store_lock:
+ if external_id not in self.available_results:
+ raise ValueError(f"No result available for action {action_id}")
+ result = self.available_results[external_id]
+ return str(result)
+
+
+class MockGetROS2ActionIDsTool(GetROS2ActionIDsTool):
+ internal_action_id_mapping: Dict[str, str] = {}
+
+ def _run(self) -> str:
+ return str(list(self.internal_action_id_mapping.keys()))
diff --git a/src/rai_bench/rai_bench/tool_calling_agent_bench/ros2_agent_tasks.py b/src/rai_bench/rai_bench/tool_calling_agent_bench/ros2_agent_tasks.py
index 794a98fcb..1451ee2d4 100644
--- a/src/rai_bench/rai_bench/tool_calling_agent_bench/ros2_agent_tasks.py
+++ b/src/rai_bench/rai_bench/tool_calling_agent_bench/ros2_agent_tasks.py
@@ -24,8 +24,26 @@
from rai.tools.ros.manipulation import MoveToPointToolInput
from rai_bench.tool_calling_agent_bench.agent_tasks_interfaces import (
+ CustomInterfacesServiceTask,
+ CustomInterfacesTopicTask,
ROS2ToolCallingAgentTask,
)
+from rai_bench.tool_calling_agent_bench.messages.base import (
+ BoundingBox2D,
+ Detection2D,
+ Header,
+ Orientation,
+ Point2D,
+ Pose,
+ Pose2D,
+ PoseStamped,
+ Position,
+ Time,
+)
+from rai_bench.tool_calling_agent_bench.messages.topics import (
+ Image,
+ RAIDetectionArray,
+)
from rai_bench.tool_calling_agent_bench.mocked_tools import (
MockGetObjectPositionsTool,
MockGetROS2ImageTool,
@@ -38,8 +56,15 @@
PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT = """You are a ROS 2 expert helping a user with their ROS 2 questions. You have access to various tools that allow you to query the ROS 2 system.
- Be proactive and use the tools to answer questions.
- """
+Be proactive and use the tools to answer questions.
+
+Example of tool calls:
+- get_ros2_message_interface, args: {'msg_type': 'geometry_msgs/msg/Twist'}
+- publish_ros2_message, args: {'topic': '/cmd_vel', 'message_type': 'geometry_msgs/msg/Twist', 'message': {linear: {x: 0.5, y: 0.0, z: 0.0}, angular: {x: 0.0, y: 0.0, z: 1.0}}}
+
+- get_ros2_message_interface, args: {'msg_type': 'turtlesim/srv/TeleportAbsolute'}
+- publish_ros2_message, args: {'topic': '/turtle1/teleport_absolute', 'message_type': 'turtlesim/srv/TeleportAbsolute', 'message': {x: 5.0, y: 2.0, theta: 1.57}}
+"""
class TaskParametrizationError(Exception):
@@ -180,23 +205,31 @@ def verify_tool_calls(self, response: dict[str, Any]):
class GetROS2RGBCameraTask(ROS2ToolCallingAgentTask):
complexity = "easy"
+ topics_and_types: Dict[str, str] = {
+ "/attached_collision_object": "moveit_msgs/msg/AttachedCollisionObject",
+ "/camera_image_color": "sensor_msgs/msg/Image",
+ "/camera_image_depth": "sensor_msgs/msg/Image",
+ "/clock": "rosgraph_msgs/msg/Clock",
+ "/collision_object": "moveit_msgs/msg/CollisionObject",
+ "/color_camera_info": "sensor_msgs/msg/CameraInfo",
+ "/color_camera_info5": "sensor_msgs/msg/CameraInfo",
+ "/depth_camera_info5": "sensor_msgs/msg/CameraInfo",
+ "/depth_image5": "sensor_msgs/msg/Image",
+ }
+
def __init__(self, logger: loggers_type | None = None) -> None:
super().__init__(logger=logger)
+
+ topic_strings = [
+ f"topic: {topic}\ntype: {msg_type}\n"
+ for topic, msg_type in self.topics_and_types.items()
+ ]
+
self.expected_tools: List[BaseTool] = [
MockGetROS2TopicsNamesAndTypesTool(
- mock_topics_names_and_types=[
- "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n",
- "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n",
- "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n",
- "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n",
- "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n",
- "topic: /color_camera_info\ntype: sensor_msgs/msg/CameraInfo\n",
- "topic: /color_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n",
- "topic: /depth_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n",
- "topic: /depth_image5\ntype: sensor_msgs/msg/Image\n",
- ]
+ mock_topics_names_and_types=topic_strings
),
- MockGetROS2ImageTool(expected_topics=["/camera_image_color"]),
+ MockGetROS2ImageTool(available_topics=list(self.topics_and_types.keys())),
]
def get_system_prompt(self) -> str:
@@ -247,22 +280,28 @@ def verify_tool_calls(self, response: dict[str, Any]):
class GetROS2DepthCameraTask(ROS2ToolCallingAgentTask):
complexity = "easy"
+ topics_and_types: Dict[str, str] = {
+ "/attached_collision_object": "moveit_msgs/msg/AttachedCollisionObject",
+ "/camera_image_color": "sensor_msgs/msg/Image",
+ "/camera_image_depth": "sensor_msgs/msg/Image",
+ "/clock": "rosgraph_msgs/msg/Clock",
+ "/collision_object": "moveit_msgs/msg/CollisionObject",
+ "/color_camera_info": "sensor_msgs/msg/CameraInfo",
+ "/color_camera_info5": "sensor_msgs/msg/CameraInfo",
+ "/depth_camera_info5": "sensor_msgs/msg/CameraInfo",
+ }
+
def __init__(self, logger: loggers_type | None = None) -> None:
super().__init__(logger=logger)
+ topic_strings = [
+ f"topic: {topic}\ntype: {msg_type}\n"
+ for topic, msg_type in self.topics_and_types.items()
+ ]
self.expected_tools: List[BaseTool] = [
MockGetROS2TopicsNamesAndTypesTool(
- mock_topics_names_and_types=[
- "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n",
- "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n",
- "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n",
- "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n",
- "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n",
- "topic: /color_camera_info\ntype: sensor_msgs/msg/CameraInfo\n",
- "topic: /color_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n",
- "topic: /depth_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n",
- ]
+ mock_topics_names_and_types=topic_strings
),
- MockGetROS2ImageTool(expected_topics=["/camera_image_depth"]),
+ MockGetROS2ImageTool(available_topics=list(self.topics_and_types.keys())),
]
def get_system_prompt(self) -> str:
@@ -274,7 +313,7 @@ def get_prompt(self) -> str:
def verify_tool_calls(self, response: dict[str, Any]):
"""It is expected that the agent will request:
1. The tool that retrieves the ROS2 topic names and types to identify the depth image topic.
- 2. The tool that retrieves the RGB image from the /camera_image_depth topic
+ 2. The tool that retrieves the depth image from the /camera_image_depth topic
Parameters
----------
@@ -314,26 +353,30 @@ def verify_tool_calls(self, response: dict[str, Any]):
class GetAllROS2RGBCamerasTask(ROS2ToolCallingAgentTask):
complexity = "easy"
+ topics_and_types: Dict[str, str] = {
+ "/attached_collision_object": "moveit_msgs/msg/AttachedCollisionObject",
+ "/camera_image_color": "sensor_msgs/msg/Image",
+ "/camera_image_depth": "sensor_msgs/msg/Image",
+ "/clock": "rosgraph_msgs/msg/Clock",
+ "/collision_object": "moveit_msgs/msg/CollisionObject",
+ "/color_camera_info": "sensor_msgs/msg/CameraInfo",
+ "/color_camera_info5": "sensor_msgs/msg/CameraInfo",
+ "/color_image5": "sensor_msgs/msg/Image",
+ "/depth_camera_info5": "sensor_msgs/msg/CameraInfo",
+ "/depth_image5": "sensor_msgs/msg/Image",
+ }
+
def __init__(self, logger: loggers_type | None = None) -> None:
super().__init__(logger=logger)
+ topic_strings = [
+ f"topic: {topic}\ntype: {msg_type}\n"
+ for topic, msg_type in self.topics_and_types.items()
+ ]
self.expected_tools: List[BaseTool] = [
MockGetROS2TopicsNamesAndTypesTool(
- mock_topics_names_and_types=[
- "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n",
- "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n",
- "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n",
- "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n",
- "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n",
- "topic: /color_camera_info\ntype: sensor_msgs/msg/CameraInfo\n",
- "topic: /color_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n",
- "topic: /color_image5\ntype: sensor_msgs/msg/Image\n",
- "topic: /depth_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n",
- "topic: /depth_image5\ntype: sensor_msgs/msg/Image\n",
- ]
- ),
- MockGetROS2ImageTool(
- expected_topics=["/camera_image_color", "/color_image5"]
+ mock_topics_names_and_types=topic_strings
),
+ MockGetROS2ImageTool(available_topics=list(self.topics_and_types.keys())),
]
def get_prompt(self) -> str:
@@ -382,7 +425,6 @@ def verify_tool_calls(self, response: dict[str, Any]):
"optional_args": {"timeout_sec": None},
},
]
-
self._check_multiple_tool_calls(
message=ai_messages[1], expected_tool_calls=expected_tool_calls
)
@@ -393,25 +435,29 @@ def verify_tool_calls(self, response: dict[str, Any]):
class GetAllROS2DepthCamerasTask(ROS2ToolCallingAgentTask):
complexity = "easy"
+ topics_and_types: Dict[str, str] = {
+ "/camera_image_color": "sensor_msgs/msg/Image",
+ "/camera_image_depth": "sensor_msgs/msg/Image",
+ "/clock": "rosgraph_msgs/msg/Clock",
+ "/collision_object": "moveit_msgs/msg/CollisionObject",
+ "/color_camera_info": "sensor_msgs/msg/CameraInfo",
+ "/color_camera_info5": "sensor_msgs/msg/CameraInfo",
+ "/color_image5": "sensor_msgs/msg/Image",
+ "/depth_camera_info5": "sensor_msgs/msg/CameraInfo",
+ "/depth_image5": "sensor_msgs/msg/Image",
+ }
+
def __init__(self, logger: loggers_type | None = None) -> None:
super().__init__(logger=logger)
+ topic_strings = [
+ f"topic: {topic}\ntype: {msg_type}\n"
+ for topic, msg_type in self.topics_and_types.items()
+ ]
self.expected_tools: List[BaseTool] = [
MockGetROS2TopicsNamesAndTypesTool(
- mock_topics_names_and_types=[
- "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n",
- "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n",
- "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n",
- "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n",
- "topic: /color_camera_info\ntype: sensor_msgs/msg/CameraInfo\n",
- "topic: /color_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n",
- "topic: /color_image5\ntype: sensor_msgs/msg/Image\n",
- "topic: /depth_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n",
- "topic: /depth_image5\ntype: sensor_msgs/msg/Image\n",
- ]
- ),
- MockGetROS2ImageTool(
- expected_topics=["/camera_image_depth", "/depth_image5"]
+ mock_topics_names_and_types=topic_strings
),
+ MockGetROS2ImageTool(available_topics=list(self.topics_and_types.keys())),
]
def get_prompt(self) -> str:
@@ -460,7 +506,6 @@ def verify_tool_calls(self, response: dict[str, Any]):
"optional_args": {"timeout_sec": None},
},
]
-
self._check_multiple_tool_calls(
message=ai_messages[1], expected_tool_calls=expected_tool_calls
)
@@ -471,23 +516,33 @@ def verify_tool_calls(self, response: dict[str, Any]):
class GetROS2MessageTask(ROS2ToolCallingAgentTask):
complexity = "easy"
+ topics_and_types: Dict[str, str] = {
+ "/attached_collision_object": "moveit_msgs/msg/AttachedCollisionObject",
+ "/camera_image_color": "sensor_msgs/msg/Image",
+ "/camera_image_depth": "sensor_msgs/msg/Image",
+ "/clock": "rosgraph_msgs/msg/Clock",
+ "/collision_object": "moveit_msgs/msg/CollisionObject",
+ "/color_camera_info": "sensor_msgs/msg/CameraInfo",
+ "/color_camera_info5": "sensor_msgs/msg/CameraInfo",
+ "/depth_camera_info5": "sensor_msgs/msg/CameraInfo",
+ "/depth_image5": "sensor_msgs/msg/Image",
+ }
+
def __init__(self, logger: loggers_type | None = None) -> None:
super().__init__(logger=logger)
+
+ topic_strings = [
+ f"topic: {topic}\ntype: {msg_type}\n"
+ for topic, msg_type in self.topics_and_types.items()
+ ]
+
self.expected_tools: List[BaseTool] = [
MockGetROS2TopicsNamesAndTypesTool(
- mock_topics_names_and_types=[
- "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n",
- "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n",
- "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n",
- "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n",
- "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n",
- "topic: /color_camera_info\ntype: sensor_msgs/msg/CameraInfo\n",
- "topic: /color_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n",
- "topic: /depth_camera_info5\ntype: sensor_msgs/msg/CameraInfo\n",
- "topic: /depth_image5\ntype: sensor_msgs/msg/Image\n",
- ]
+ mock_topics_names_and_types=topic_strings
+ ),
+ MockReceiveROS2MessageTool(
+ available_topics=list(self.topics_and_types.keys())
),
- MockReceiveROS2MessageTool(expected_topics=["/camera_image_color"]),
]
def get_system_prompt(self) -> str:
@@ -498,7 +553,7 @@ def get_prompt(self) -> str:
def verify_tool_calls(self, response: dict[str, Any]):
"""It is expected that the agent will request:
- 1. The tool that retrieves the ROS2 topics names and types to recognize the RGB image topic
+ 1. The tool that retrieves the ROS2 topics names and types to recognize the RGB image topic.
2. The tool that retrieves the RGB image from the /camera_image_color topic
Parameters
@@ -515,6 +570,7 @@ def verify_tool_calls(self, response: dict[str, Any]):
self.log_error(
msg=f"Expected at least 3 AI messages, but got {len(ai_messages)}."
)
+
if ai_messages:
if not self._is_ai_message_requesting_get_ros2_topics_and_types(
ai_messages[0]
@@ -522,6 +578,7 @@ def verify_tool_calls(self, response: dict[str, Any]):
self.log_error(
msg="First AI message did not request ROS2 topics and types correctly."
)
+
if len(ai_messages) > 1:
if self._check_tool_calls_num_in_ai_message(ai_messages[1], expected_num=1):
self._check_tool_call(
@@ -529,6 +586,7 @@ def verify_tool_calls(self, response: dict[str, Any]):
expected_name="receive_ros2_message",
expected_args={"topic": "/camera_image_color"},
)
+
if not self.result.errors:
self.result.success = True
@@ -536,20 +594,30 @@ def verify_tool_calls(self, response: dict[str, Any]):
class GetRobotDescriptionTask(ROS2ToolCallingAgentTask):
complexity = "easy"
+ topics_and_types: Dict[str, str] = {
+ "/pointcloud": "sensor_msgs/msg/PointCloud2",
+ "/robot_description": "std_msgs/msg/String",
+ "/rosout": "rcl_interfaces/msg/Log",
+ "/tf": "tf2_msgs/msg/TFMessage",
+ "/tf_static": "tf2_msgs/msg/TFMessage",
+ "/trajectory_execution_event": "std_msgs/msg/String",
+ }
+
def __init__(self, logger: loggers_type | None = None) -> None:
super().__init__(logger=logger)
+
+ topic_strings = [
+ f"topic: {topic}\ntype: {msg_type}\n"
+ for topic, msg_type in self.topics_and_types.items()
+ ]
+
self.expected_tools: List[BaseTool] = [
MockGetROS2TopicsNamesAndTypesTool(
- mock_topics_names_and_types=[
- "topic: /pointcloud\ntype: sensor_msgs/msg/PointCloud2\n",
- "topic: /robot_description\ntype: std_msgs/msg/String\n",
- "topic: /rosout\ntype: rcl_interfaces/msg/Log\n",
- "topic: /tf\ntype: tf2_msgs/msg/TFMessage\n",
- "topic: /tf_static\ntype: tf2_msgs/msg/TFMessage\n",
- "topic: /trajectory_execution_event\ntype: std_msgs/msg/String\n",
- ]
+ mock_topics_names_and_types=topic_strings
+ ),
+ MockReceiveROS2MessageTool(
+ available_topics=list(self.topics_and_types.keys())
),
- MockReceiveROS2MessageTool(expected_topics=["/robot_description"]),
]
def get_system_prompt(self) -> str:
@@ -591,6 +659,7 @@ def verify_tool_calls(self, response: dict[str, Any]):
expected_name="receive_ros2_message",
expected_args={"topic": "/robot_description"},
)
+
if not self.result.errors:
self.result.success = True
@@ -598,20 +667,30 @@ def verify_tool_calls(self, response: dict[str, Any]):
class GetPointcloudTask(ROS2ToolCallingAgentTask):
complexity = "easy"
+ topics_and_types: Dict[str, str] = {
+ "/pointcloud": "sensor_msgs/msg/PointCloud2",
+ "/robot_description": "std_msgs/msg/String",
+ "/rosout": "rcl_interfaces/msg/Log",
+ "/tf": "tf2_msgs/msg/TFMessage",
+ "/tf_static": "tf2_msgs/msg/TFMessage",
+ "/trajectory_execution_event": "std_msgs/msg/String",
+ }
+
def __init__(self, logger: loggers_type | None = None) -> None:
super().__init__(logger=logger)
+
+ topic_strings = [
+ f"topic: {topic}\ntype: {msg_type}\n"
+ for topic, msg_type in self.topics_and_types.items()
+ ]
+
self.expected_tools: List[BaseTool] = [
MockGetROS2TopicsNamesAndTypesTool(
- mock_topics_names_and_types=[
- "topic: /pointcloud\ntype: sensor_msgs/msg/PointCloud2\n",
- "topic: /robot_description\ntype: std_msgs/msg/String\n",
- "topic: /rosout\ntype: rcl_interfaces/msg/Log\n",
- "topic: /tf\ntype: tf2_msgs/msg/TFMessage\n",
- "topic: /tf_static\ntype: tf2_msgs/msg/TFMessage\n",
- "topic: /trajectory_execution_event\ntype: std_msgs/msg/String\n",
- ]
+ mock_topics_names_and_types=topic_strings
+ ),
+ MockReceiveROS2MessageTool(
+ available_topics=list(self.topics_and_types.keys())
),
- MockReceiveROS2MessageTool(expected_topics=["/pointcloud"]),
]
def get_system_prompt(self) -> str:
@@ -639,6 +718,7 @@ def verify_tool_calls(self, response: dict[str, Any]):
self.log_error(
msg=f"Expected at least 3 AI messages, but got {len(ai_messages)}."
)
+
if ai_messages:
if not self._is_ai_message_requesting_get_ros2_topics_and_types(
ai_messages[0]
@@ -646,6 +726,7 @@ def verify_tool_calls(self, response: dict[str, Any]):
self.log_error(
msg="First AI message did not request ROS2 topics and types correctly."
)
+
if len(ai_messages) > 1:
if self._check_tool_calls_num_in_ai_message(ai_messages[1], expected_num=1):
self._check_tool_call(
@@ -653,6 +734,7 @@ def verify_tool_calls(self, response: dict[str, Any]):
expected_name="receive_ros2_message",
expected_args={"topic": "/pointcloud"},
)
+
if not self.result.errors:
self.result.success = True
@@ -660,21 +742,30 @@ def verify_tool_calls(self, response: dict[str, Any]):
class MoveToPointTask(ROS2ToolCallingAgentTask):
complexity = "easy"
+ topics_and_types: Dict[str, str] = {
+ "/pointcloud": "sensor_msgs/msg/PointCloud2",
+ "/robot_description": "std_msgs/msg/String",
+ "/rosout": "rcl_interfaces/msg/Log",
+ "/tf": "tf2_msgs/msg/TFMessage",
+ }
+
def __init__(
self, args: Dict[str, Any], logger: loggers_type | None = None
) -> None:
super().__init__(logger=logger)
+
+ topic_strings = [
+ f"topic: {topic}\ntype: {msg_type}\n"
+ for topic, msg_type in self.topics_and_types.items()
+ ]
+
self.expected_tools: List[BaseTool] = [
MockGetROS2TopicsNamesAndTypesTool(
- mock_topics_names_and_types=[
- "topic: /pointcloud\ntype: sensor_msgs/msg/PointCloud2\n",
- "topic: /robot_description\ntype: std_msgs/msg/String\n",
- "topic: /rosout\ntype: rcl_interfaces/msg/Log\n",
- "topic: /tf\ntype: tf2_msgs/msg/TFMessage\n",
- ]
+ mock_topics_names_and_types=topic_strings
),
MockMoveToPointTool(manipulator_frame="base_link"),
]
+
self.args = MoveToPointToolInput(**args)
def get_system_prompt(self) -> str:
@@ -725,6 +816,13 @@ def verify_tool_calls(self, response: dict[str, Any]):
class GetObjectPositionsTask(ROS2ToolCallingAgentTask):
complexity = "easy"
+ topics_and_types: Dict[str, str] = {
+ "/pointcloud": "sensor_msgs/msg/PointCloud2",
+ "/robot_description": "std_msgs/msg/String",
+ "/rosout": "rcl_interfaces/msg/Log",
+ "/tf": "tf2_msgs/msg/TFMessage",
+ }
+
def __init__(
self,
objects: Dict[str, List[dict[str, float]]],
@@ -747,17 +845,19 @@ def __init__(
}
"""
super().__init__(logger=logger)
+
+ topic_strings = [
+ f"topic: {topic}\ntype: {msg_type}\n"
+ for topic, msg_type in self.topics_and_types.items()
+ ]
+
self.expected_tools: List[BaseTool] = [
MockGetROS2TopicsNamesAndTypesTool(
- mock_topics_names_and_types=[
- "topic: /pointcloud\ntype: sensor_msgs/msg/PointCloud2\n",
- "topic: /robot_description\ntype: std_msgs/msg/String\n",
- "topic: /rosout\ntype: rcl_interfaces/msg/Log\n",
- "topic: /tf\ntype: tf2_msgs/msg/TFMessage\n",
- ]
+ mock_topics_names_and_types=topic_strings
),
MockGetObjectPositionsTool(mock_objects=objects),
]
+
self.objects = objects
def get_system_prompt(self) -> str:
@@ -809,6 +909,7 @@ def verify_tool_calls(self, response: dict[str, Any]):
for object_type in self.objects
],
)
+
if not self.result.errors:
self.result.success = True
@@ -835,6 +936,14 @@ class GrabExistingObjectTask(ROS2ToolCallingAgentTask):
}
object_to_grab = "cube"
"""
+ topics_and_types: Dict[str, str] = {
+ "/attached_collision_object": "moveit_msgs/msg/AttachedCollisionObject",
+ "/camera_image_color": "sensor_msgs/msg/Image",
+ "/camera_image_depth": "sensor_msgs/msg/Image",
+ "/clock": "rosgraph_msgs/msg/Clock",
+ "/collision_object": "moveit_msgs/msg/CollisionObject",
+ "/color_camera_info": "sensor_msgs/msg/CameraInfo",
+ }
def __init__(
self,
@@ -843,16 +952,15 @@ def __init__(
logger: loggers_type | None = None,
) -> None:
super().__init__(logger=logger)
+
+ topic_strings = [
+ f"topic: {topic}\ntype: {msg_type}\n"
+ for topic, msg_type in self.topics_and_types.items()
+ ]
+
self.expected_tools: List[BaseTool] = [
MockGetROS2TopicsNamesAndTypesTool(
- mock_topics_names_and_types=[
- "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n",
- "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n",
- "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n",
- "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n",
- "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n",
- "topic: /color_camera_info\ntype: sensor_msgs/msg/CameraInfo\n",
- ]
+ mock_topics_names_and_types=topic_strings
),
MockGetObjectPositionsTool(
target_frame="panda_link0",
@@ -864,6 +972,7 @@ def __init__(
),
MockMoveToPointTool(manipulator_frame="panda_link0"),
]
+
self.objects = objects
self.object_to_grab = object_to_grab
self._verify_args()
@@ -905,6 +1014,7 @@ def verify_tool_calls(self, response: Dict[str, Any]):
ai_messages: Sequence[AIMessage] = [
message for message in messages if isinstance(message, AIMessage)
]
+
expected_num_ai_messages = 3
if len(ai_messages) != expected_num_ai_messages:
self.log_error(
@@ -930,6 +1040,7 @@ def verify_tool_calls(self, response: Dict[str, Any]):
expected_name="move_to_point",
expected_args=obj_to_grab,
)
+
if not self.result.errors:
self.result.success = True
@@ -957,6 +1068,15 @@ class GrabNotExistingObjectTask(ROS2ToolCallingAgentTask):
complexity = "medium"
+ topics_and_types: Dict[str, str] = {
+ "/attached_collision_object": "moveit_msgs/msg/AttachedCollisionObject",
+ "/camera_image_color": "sensor_msgs/msg/Image",
+ "/camera_image_depth": "sensor_msgs/msg/Image",
+ "/clock": "rosgraph_msgs/msg/Clock",
+ "/collision_object": "moveit_msgs/msg/CollisionObject",
+ "/color_camera_info": "sensor_msgs/msg/CameraInfo",
+ }
+
def __init__(
self,
objects: Dict[str, List[dict[str, float]]],
@@ -964,16 +1084,15 @@ def __init__(
logger: loggers_type | None = None,
) -> None:
super().__init__(logger=logger)
+
+ topic_strings = [
+ f"topic: {topic}\ntype: {msg_type}\n"
+ for topic, msg_type in self.topics_and_types.items()
+ ]
+
self.expected_tools: List[BaseTool] = [
MockGetROS2TopicsNamesAndTypesTool(
- mock_topics_names_and_types=[
- "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n",
- "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n",
- "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n",
- "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n",
- "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n",
- "topic: /color_camera_info\ntype: sensor_msgs/msg/CameraInfo\n",
- ]
+ mock_topics_names_and_types=topic_strings
),
MockGetObjectPositionsTool(
target_frame="panda_link0",
@@ -985,6 +1104,7 @@ def __init__(
),
MockMoveToPointTool(manipulator_frame="panda_link0"),
]
+
self.objects = objects
self.object_to_grab = object_to_grab
self._verify_args()
@@ -1020,11 +1140,13 @@ def verify_tool_calls(self, response: Dict[str, Any]):
ai_messages: Sequence[AIMessage] = [
message for message in messages if isinstance(message, AIMessage)
]
+
expected_num_ai_messages = 2
if len(ai_messages) != expected_num_ai_messages:
self.log_error(
msg=f"Expected {expected_num_ai_messages} AI messages, but got {len(ai_messages)}."
)
+
if ai_messages:
if self._check_tool_calls_num_in_ai_message(ai_messages[0], expected_num=1):
self._check_tool_call(
@@ -1060,6 +1182,14 @@ class MoveExistingObjectLeftTask(ROS2ToolCallingAgentTask):
complexity = "medium"
+ topics_and_types: Dict[str, str] = {
+ "/attached_collision_object": "moveit_msgs/msg/AttachedCollisionObject",
+ "/camera_image_color": "sensor_msgs/msg/Image",
+ "/camera_image_depth": "sensor_msgs/msg/Image",
+ "/clock": "rosgraph_msgs/msg/Clock",
+ "/collision_object": "moveit_msgs/msg/CollisionObject",
+ }
+
def __init__(
self,
objects: Dict[str, List[dict[str, float]]],
@@ -1067,15 +1197,15 @@ def __init__(
logger: loggers_type | None = None,
) -> None:
super().__init__(logger=logger)
+
+ topic_strings = [
+ f"topic: {topic}\ntype: {msg_type}\n"
+ for topic, msg_type in self.topics_and_types.items()
+ ]
+
self.expected_tools: List[BaseTool] = [
MockGetROS2TopicsNamesAndTypesTool(
- mock_topics_names_and_types=[
- "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n",
- "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n",
- "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n",
- "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n",
- "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n",
- ]
+ mock_topics_names_and_types=topic_strings
),
MockGetObjectPositionsTool(
target_frame="panda_link0",
@@ -1087,6 +1217,7 @@ def __init__(
),
MockMoveToPointTool(manipulator_frame="panda_link0"),
]
+
self.objects = objects
self.object_to_grab = object_to_grab
self._verify_args()
@@ -1130,11 +1261,13 @@ def verify_tool_calls(self, response: Dict[str, Any]):
ai_messages: Sequence[AIMessage] = [
message for message in messages if isinstance(message, AIMessage)
]
+
expected_num_ai_messages = 4
if len(ai_messages) != expected_num_ai_messages:
self.log_error(
msg=f"Expected {expected_num_ai_messages} AI messages, but got {len(ai_messages)}."
)
+
if ai_messages:
if self._check_tool_calls_num_in_ai_message(ai_messages[0], expected_num=1):
self._check_tool_call(
@@ -1187,6 +1320,12 @@ class MoveExistingObjectFrontTask(ROS2ToolCallingAgentTask):
complexity = "medium"
+ topics_and_types: Dict[str, str] = {
+ "/attached_collision_object": "moveit_msgs/msg/AttachedCollisionObject",
+ "/camera_image_color": "sensor_msgs/msg/Image",
+ "/camera_image_depth": "sensor_msgs/msg/Image",
+ }
+
def __init__(
self,
objects: Dict[str, List[dict[str, float]]],
@@ -1194,13 +1333,15 @@ def __init__(
logger: loggers_type | None = None,
) -> None:
super().__init__(logger=logger)
+
+ topic_strings = [
+ f"topic: {topic}\ntype: {msg_type}\n"
+ for topic, msg_type in self.topics_and_types.items()
+ ]
+
self.expected_tools: List[BaseTool] = [
MockGetROS2TopicsNamesAndTypesTool(
- mock_topics_names_and_types=[
- "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n",
- "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n",
- "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n",
- ]
+ mock_topics_names_and_types=topic_strings
),
MockGetObjectPositionsTool(
target_frame="panda_link0",
@@ -1212,6 +1353,7 @@ def __init__(
),
MockMoveToPointTool(manipulator_frame="panda_link0"),
]
+
self.objects = objects
self.object_to_grab = object_to_grab
self._verify_args()
@@ -1255,6 +1397,7 @@ def verify_tool_calls(self, response: Dict[str, Any]):
ai_messages: Sequence[AIMessage] = [
message for message in messages if isinstance(message, AIMessage)
]
+
expected_num_ai_messages = 4
if len(ai_messages) != expected_num_ai_messages:
self.log_error(
@@ -1323,6 +1466,12 @@ class SwapObjectsTask(ROS2ToolCallingAgentTask):
complexity = "hard"
+ topics_and_types: Dict[str, str] = {
+ "/attached_collision_object": "moveit_msgs/msg/AttachedCollisionObject",
+ "/camera_image_color": "sensor_msgs/msg/Image",
+ "/camera_image_depth": "sensor_msgs/msg/Image",
+ }
+
def __init__(
self,
objects: Dict[str, List[Dict[str, float]]],
@@ -1330,13 +1479,15 @@ def __init__(
logger: loggers_type | None = None,
) -> None:
super().__init__(logger=logger)
+
+ topic_strings = [
+ f"topic: {topic}\ntype: {msg_type}\n"
+ for topic, msg_type in self.topics_and_types.items()
+ ]
+
self.expected_tools: List[BaseTool] = [
MockGetROS2TopicsNamesAndTypesTool(
- mock_topics_names_and_types=[
- "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n",
- "topic: /camera_image_color\ntype: sensor_msgs/msg/Image\n",
- "topic: /camera_image_depth\ntype: sensor_msgs/msg/Image\n",
- ]
+ mock_topics_names_and_types=topic_strings
),
MockGetObjectPositionsTool(
target_frame="panda_link0",
@@ -1348,6 +1499,7 @@ def __init__(
),
MockMoveToPointTool(manipulator_frame="panda_link0"),
]
+
self.objects = objects
self.objects_to_swap = objects_to_swap
self._verify_args()
@@ -1392,7 +1544,6 @@ def verify_tool_calls(self, response: Dict[str, Any]):
response : Dict[str, Any]
The response from the agent
"""
-
messages = response["messages"]
ai_messages: Sequence[AIMessage] = [
message for message in messages if isinstance(message, AIMessage)
@@ -1512,3 +1663,1125 @@ def _matches_sequence(
any(call["name"] == e["name"] and call["args"] == e["args"] for call in it)
for e in expected_tool_calls_seq
)
+
+
+class PublishROS2HRIMessageTask3ExtraCalls(CustomInterfacesTopicTask):
+ complexity = "easy"
+ expected_text = "Hello!"
+
+ def get_system_prompt(self) -> str:
+ return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT
+
+ @property
+ def expected_topic(self) -> str:
+ return "/to_human"
+
+ @property
+ def extra_calls(self) -> int:
+ return 3
+
+ def verify_message_tool_call(self, tool_calls: list[ToolCall]) -> bool:
+ for call in tool_calls:
+ if self._check_topic_tool_call_field(
+ tool_call=call,
+ expected_name="publish_ros2_message",
+ expected_topic=self.expected_topic,
+ expected_message_type=self.expected_message_type,
+ field_path="text",
+ expected_value=self.expected_text,
+ ):
+ return True
+
+ self.log_error(f"No valid call to {self.expected_topic} found.")
+ return False
+
+ def get_prompt(self) -> str:
+ return (
+ f"You need to publish a message to the topic '{self.expected_topic}' with the text value: '{self.expected_text}'.\n"
+ "Before publishing, follow these steps:\n"
+ "1. Use the tool to retrieve the available ROS2 topics and their message types.\n"
+ f"2. Find the message type for the topic '{self.expected_topic}'.\n"
+ "3. Retrieve the full message interface definition for that type.\n"
+ "4. Construct the message filling only the fields you are instructed to. Rest of the fields will have default values.\n"
+ f"5. Publish the message to '{self.expected_topic}' using the correct message type and interface.\n"
+ )
+
+
+class PublishROS2HRIMessageTask1ExtraCall(PublishROS2HRIMessageTask3ExtraCalls):
+ complexity = "medium"
+
+ @property
+ def extra_calls(self) -> int:
+ return 1
+
+
+class PublishROS2HRIMessageTask0ExtraCalls(PublishROS2HRIMessageTask3ExtraCalls):
+ complexity = "hard"
+
+ @property
+ def extra_calls(self) -> int:
+ return 0
+
+
+class PublishROS2AudioMessageTask3ExtraCalls(CustomInterfacesTopicTask):
+ complexity = "easy"
+ expected_audio: List[int] = [123, 456, 789]
+ expected_sample_rate: int = 44100
+ expected_channels: int = 2
+
+ def get_system_prompt(self) -> str:
+ return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT
+
+ @property
+ def expected_topic(self) -> str:
+ return "/send_audio"
+
+ @property
+ def extra_calls(self) -> int:
+ return 3
+
+ def verify_message_tool_call(self, tool_calls: list[ToolCall]) -> bool:
+ for call in tool_calls:
+ if (
+ self._check_topic_tool_call_field(
+ tool_call=call,
+ expected_name="publish_ros2_message",
+ expected_topic=self.expected_topic,
+ expected_message_type=self.expected_message_type,
+ field_path="audio",
+ expected_value=self.expected_audio,
+ )
+ and self._check_topic_tool_call_field(
+ tool_call=call,
+ expected_name="publish_ros2_message",
+ expected_topic=self.expected_topic,
+ expected_message_type=self.expected_message_type,
+ field_path="sample_rate",
+ expected_value=self.expected_sample_rate,
+ )
+ and self._check_topic_tool_call_field(
+ tool_call=call,
+ expected_name="publish_ros2_message",
+ expected_topic=self.expected_topic,
+ expected_message_type=self.expected_message_type,
+ field_path="channels",
+ expected_value=self.expected_channels,
+ )
+ ):
+ return True
+
+ self.log_error(f"No valid call to {self.expected_topic} found.")
+ return False
+
+ def get_prompt(self) -> str:
+ return (
+ f"You need to publish a message to the topic '{self.expected_topic}' with audio samples {self.expected_audio}, "
+ f"sample rate {self.expected_sample_rate}, and {self.expected_channels} channels.\n"
+ "Before publishing, follow these steps:\n"
+ "1. Use the tool to retrieve the available ROS2 topics and their message types.\n"
+ f"2. Find the message type for the topic '{self.expected_topic}'.\n"
+ "3. Retrieve the full message interface definition for that type.\n"
+ "4. Construct the message filling only the fields you are instructed to. Rest of the fields will have default values.\n"
+ f"5. Publish the message to '{self.expected_topic}' using the correct message type and interface.\n"
+ )
+
+
+class PublishROS2AudioMessageTask1ExtraCall(PublishROS2AudioMessageTask3ExtraCalls):
+ complexity = "medium"
+
+ @property
+ def extra_calls(self) -> int:
+ return 1
+
+
+class PublishROS2AudioMessageTask0ExtraCalls(PublishROS2AudioMessageTask3ExtraCalls):
+ complexity = "hard"
+
+ @property
+ def extra_calls(self) -> int:
+ return 0
+
+
+class PublishROS2DetectionArrayTask3ExtraCalls(CustomInterfacesTopicTask):
+ complexity = "easy"
+
+ expected_detection_classes: List[str] = ["person", "car"]
+ expected_detections: List[Detection2D] = [
+ Detection2D(
+ bbox=BoundingBox2D(
+ center=Pose2D(position=Point2D(x=320.0, y=240.0), theta=0.0),
+ size_x=50.0,
+ size_y=50.0,
+ )
+ )
+ ]
+
+ def get_system_prompt(self) -> str:
+ return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT
+
+ @property
+ def expected_topic(self) -> str:
+ return "/send_detections"
+
+ @property
+ def expected_message(self) -> RAIDetectionArray:
+ return RAIDetectionArray(
+ detections=self.expected_detections,
+ detection_classes=self.expected_detection_classes,
+ )
+
+ @property
+ def extra_calls(self) -> int:
+ return 3
+
+ def verify_message_tool_call(self, tool_calls: list[ToolCall]) -> bool:
+ for call in tool_calls:
+ if self._check_topic_tool_call_field(
+ tool_call=call,
+ expected_name="publish_ros2_message",
+ expected_topic=self.expected_topic,
+ expected_message_type=self.expected_message_type,
+ field_path="detection_classes",
+ expected_value=self.expected_detection_classes,
+ ):
+ return True
+
+ self.log_error(f"No valid call to {self.expected_topic} found.")
+ return False
+
+ def get_prompt(self) -> str:
+ return (
+ f"You need to publish a detection message to the topic '{self.expected_topic}' with one detection:\n"
+ f"{self.expected_detections[0].model_dump()} and detection classes {self.expected_detection_classes}.\n"
+ "Before publishing, follow these steps:\n"
+ "1. Use the tool to retrieve the available ROS2 topics and their message types.\n"
+ f"2. Find the message type for the topic '{self.expected_topic}'.\n"
+ "3. Retrieve the full message interface definition for that type.\n"
+ "4. Construct the message filling only the fields you are instructed to. Rest of the fields will have default values.\n"
+ f"5. Publish the message to '{self.expected_topic}' using the correct message type and interface.\n"
+ )
+
+
+class PublishROS2DetectionArrayTask1ExtraCall(PublishROS2DetectionArrayTask3ExtraCalls):
+ complexity = "medium"
+
+ @property
+ def extra_calls(self) -> int:
+ return 1
+
+
+class PublishROS2DetectionArrayTask0ExtraCalls(
+ PublishROS2DetectionArrayTask3ExtraCalls
+):
+ complexity = "hard"
+
+ @property
+ def extra_calls(self) -> int:
+ return 0
+
+
+class CallROS2ManipulatorMoveToServiceTask3ExtraCalls(CustomInterfacesServiceTask):
+ complexity = "easy"
+
+ expected_initial_gripper_state = True
+ expected_final_gripper_state = False
+ expected_target_pose: PoseStamped = PoseStamped(
+ pose=Pose(
+ position=Position(x=1.0, y=2.0, z=3.0),
+ orientation=Orientation(x=0.0, y=0.0, z=0.0, w=1.0),
+ )
+ )
+
+ def get_system_prompt(self) -> str:
+ return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT
+
+ @property
+ def expected_service(self) -> str:
+ return "/manipulator_move_to"
+
+ @property
+ def extra_calls(self) -> int:
+ return 3
+
+ def verify_message_tool_call(self, tool_calls: list[ToolCall]) -> bool:
+ for call in tool_calls:
+ if (
+ self._check_service_tool_call_field(
+ tool_call=call,
+ expected_name="call_ros2_service",
+ expected_service=self.expected_service,
+ expected_service_type=self.expected_service_type,
+ field_path="initial_gripper_state",
+ expected_value=self.expected_initial_gripper_state,
+ )
+ and self._check_service_tool_call_field(
+ tool_call=call,
+ expected_name="call_ros2_service",
+ expected_service=self.expected_service,
+ expected_service_type=self.expected_service_type,
+ field_path="final_gripper_state",
+ expected_value=self.expected_final_gripper_state,
+ )
+ and self._check_service_tool_call_field(
+ tool_call=call,
+ expected_name="call_ros2_service",
+ expected_service=self.expected_service,
+ expected_service_type=self.expected_service_type,
+ field_path="target_pose",
+ expected_value=self.expected_target_pose.model_dump(),
+ )
+ ):
+ return True
+
+ self.log_error(f"No valid call to {self.expected_service} found.")
+ return False
+
+ def get_prompt(self) -> str:
+ return (
+ f"You need to call the service '{self.expected_service}' with a target_pose: "
+ f"{self.expected_target_pose.model_dump()} and gripper states (initial: {self.expected_initial_gripper_state}, final: {self.expected_final_gripper_state}).\n"
+ "Before calling, follow these steps:\n"
+ "1. Use the tool to retrieve the available ROS2 services and their types.\n"
+ f"2. Find the service type for '{self.expected_service}'.\n"
+ "3. Retrieve the full message interface definition for that service.\n"
+ "4. Construct the request message filling only the fields you are instructed to. Rest of the fields will have default values.\n"
+ f"5. Call the service '{self.expected_service}' using the correct message type and interface.\n"
+ )
+
+
+class CallROS2ManipulatorMoveToServiceTask1ExtraCall(
+ CallROS2ManipulatorMoveToServiceTask3ExtraCalls
+):
+ complexity = "medium"
+
+ @property
+ def extra_calls(self) -> int:
+ return 1
+
+
+class CallROS2ManipulatorMoveToServiceTask0ExtraCalls(
+ CallROS2ManipulatorMoveToServiceTask3ExtraCalls
+):
+ complexity = "hard"
+
+ @property
+ def extra_calls(self) -> int:
+ return 0
+
+
+class CallGroundedSAMSegmentTask3ExtraCalls(CustomInterfacesServiceTask):
+ complexity = "easy"
+
+ expected_detections: RAIDetectionArray = RAIDetectionArray(
+ header=Header(stamp=Time(sec=0, nanosec=0), frame_id="camera_frame"),
+ detections=[],
+ )
+
+ def get_system_prompt(self) -> str:
+ return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT
+
+ @property
+ def expected_service(self) -> str:
+ return "/grounded_sam_segment"
+
+ @property
+ def extra_calls(self) -> int:
+ return 3
+
+ def verify_message_tool_call(self, tool_calls: list[ToolCall]) -> bool:
+ for call in tool_calls:
+ if self._check_service_tool_call_field(
+ tool_call=call,
+ expected_name="call_ros2_service",
+ expected_service=self.expected_service,
+ expected_service_type=self.expected_service_type,
+ field_path="detections",
+ expected_value=self.expected_detections.model_dump(),
+ ):
+ return True
+ self.log_error(f"No valid call to {self.expected_service} found.")
+ return False
+
+ def get_prompt(self) -> str:
+ return (
+ f"You need to call the service '{self.expected_service}' with detections: {self.expected_detections.model_dump()}\n"
+ "Before calling, follow these steps:\n"
+ "1. Use the tool to retrieve the available ROS2 services and their types.\n"
+ f"2. Find the service type for '{self.expected_service}'.\n"
+ "3. Retrieve the full message interface definition for that service.\n"
+ "4. Construct the request message filling only the fields you are instructed to. Rest of the fields will have default values.\n"
+ f"5. Call the service '{self.expected_service}' using the correct message type and interface.\n"
+ )
+
+
+class CallGroundedSAMSegmentTask1ExtraCall(CallGroundedSAMSegmentTask3ExtraCalls):
+ complexity = "medium"
+
+ @property
+ def extra_calls(self) -> int:
+ return 1
+
+
+class CallGroundedSAMSegmentTask0ExtraCalls(CallGroundedSAMSegmentTask3ExtraCalls):
+ complexity = "hard"
+
+ @property
+ def extra_calls(self) -> int:
+ return 0
+
+
+class CallGroundingDinoClassifyTask3ExtraCalls(CustomInterfacesServiceTask):
+ complexity = "easy"
+
+ expected_classes: str = "bottle, book, chair"
+ expected_box_threshold: float = 0.4
+ expected_text_threshold: float = 0.25
+
+ def get_system_prompt(self) -> str:
+ return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT
+
+ @property
+ def extra_calls(self) -> int:
+ return 3
+
+ @property
+ def expected_service(self) -> str:
+ return "/grounding_dino_classify"
+
+ def verify_message_tool_call(self, tool_calls: list[ToolCall]) -> bool:
+ for call in tool_calls:
+ if (
+ self._check_service_tool_call_field(
+ call,
+ "call_ros2_service",
+ self.expected_service,
+ self.expected_service_type,
+ "classes",
+ self.expected_classes,
+ )
+ and self._check_service_tool_call_field(
+ call,
+ "call_ros2_service",
+ self.expected_service,
+ self.expected_service_type,
+ "box_threshold",
+ self.expected_box_threshold,
+ )
+ and self._check_service_tool_call_field(
+ call,
+ "call_ros2_service",
+ self.expected_service,
+ self.expected_service_type,
+ "text_threshold",
+ self.expected_text_threshold,
+ )
+ ):
+ return True
+
+ self.log_error(f"No valid call to {self.expected_service} found.")
+ return False
+
+ def get_prompt(self) -> str:
+ return (
+ f"You need to call the service '{self.expected_service}' with classes: '{self.expected_classes}', "
+ f"box_threshold: {self.expected_box_threshold}, text_threshold: {self.expected_text_threshold}, "
+ "Before calling, follow these steps:\n"
+ "1. Use the tool to retrieve the available ROS2 services and their types.\n"
+ f"2. Find the service type for '{self.expected_service}'.\n"
+ "3. Retrieve the full message interface definition for that service.\n"
+ "4. Construct the request message filling only the fields you are instructed to. Rest of the fields will have default values.\n"
+ f"5. Call the service '{self.expected_service}' using the correct message type and interface.\n"
+ )
+
+
+class CallGroundingDinoClassifyTask1ExtraCall(CallGroundingDinoClassifyTask3ExtraCalls):
+ complexity = "medium"
+
+ @property
+ def extra_calls(self) -> int:
+ return 1
+
+
+class CallGroundingDinoClassifyTask0ExtraCalls(
+ CallGroundingDinoClassifyTask3ExtraCalls
+):
+ complexity = "hard"
+
+ @property
+ def extra_calls(self) -> int:
+ return 0
+
+
+class CallGetLogDigestTask3ExtraCalls(CustomInterfacesServiceTask):
+ complexity = "easy"
+
+ def get_system_prompt(self) -> str:
+ return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT
+
+ @property
+ def expected_service(self) -> str:
+ return "/get_log_digest"
+
+ @property
+ def extra_calls(self) -> int:
+ return 3
+
+ def verify_message_tool_call(self, tool_calls: list[ToolCall]) -> bool:
+ for call in tool_calls:
+ if self._check_service_tool_call_field(
+ call,
+ "call_ros2_service",
+ self.expected_service,
+ self.expected_service_type,
+ field_path="", # empty request
+ expected_value="",
+ ):
+ return True
+
+ self.log_error(f"No valid call to {self.expected_service} found.")
+ return False
+
+ def get_prompt(self) -> str:
+ return (
+ f"You need to call the service '{self.expected_service}' with an empty request.\n"
+ "Before calling, follow these steps:\n"
+ "1. Use the tool to retrieve the available ROS2 services and their types.\n"
+ f"2. Find the service type for '{self.expected_service}'.\n"
+ "3. Retrieve the full message interface definition for that service.\n"
+ "4. Construct the message filling only the fields you are instructed to. Rest of the fields will have default values.\n"
+ f"5. Call the service '{self.expected_service}' using the correct message type and interface.\n"
+ )
+
+
+class CallGetLogDigestTask1ExtraCall(CallGetLogDigestTask3ExtraCalls):
+ complexity = "medium"
+
+ @property
+ def extra_calls(self) -> int:
+ return 1
+
+
+class CallGetLogDigestTask0ExtraCalls(CallGetLogDigestTask3ExtraCalls):
+ complexity = "hard"
+
+ @property
+ def extra_calls(self) -> int:
+ return 0
+
+
+class CallVectorStoreRetrievalTask3ExtraCalls(CustomInterfacesServiceTask):
+ complexity = "easy"
+
+ expected_query: str = "What is the purpose of this robot?"
+
+ def get_system_prompt(self) -> str:
+ return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT
+
+ @property
+ def expected_service(self) -> str:
+ return "/rai_whoami_documentation_service"
+
+ @property
+ def extra_calls(self) -> int:
+ return 3
+
+ def verify_message_tool_call(self, tool_calls: list[ToolCall]) -> bool:
+ for call in tool_calls:
+ if self._check_service_tool_call_field(
+ call,
+ "call_ros2_service",
+ self.expected_service,
+ self.expected_service_type,
+ "query",
+ self.expected_query,
+ ):
+ return True
+
+ self.log_error(f"No valid call to {self.expected_service} found.")
+ return False
+
+ def get_prompt(self) -> str:
+ return (
+ f"You need to call the service '{self.expected_service}' with the query: '{self.expected_query}'.\n"
+ "Before calling, follow these steps:\n"
+ "1. Use the tool to retrieve the available ROS2 services and their types.\n"
+ f"2. Find the service type for '{self.expected_service}'.\n"
+ "3. Retrieve the full message interface definition for that service.\n"
+ "4. Construct the request message filling only the fields you are instructed to. Rest of the fields will have default values.\n"
+ f"5. Call the service '{self.expected_service}' using the correct message type and interface.\n"
+ )
+
+
+class CallVectorStoreRetrievalTask1ExtraCall(CallVectorStoreRetrievalTask3ExtraCalls):
+ complexity = "medium"
+
+ @property
+ def extra_calls(self) -> int:
+ return 1
+
+
+class CallVectorStoreRetrievalTask0ExtraCalls(CallVectorStoreRetrievalTask3ExtraCalls):
+ complexity = "hard"
+
+ @property
+ def extra_calls(self) -> int:
+ return 0
+
+
+class CallWhatISeeTask3ExtraCalls(CustomInterfacesServiceTask):
+ complexity = "easy"
+
+ expected_observations: List[str] = ["table", "cup", "notebook"]
+ expected_perception_source: str = "front_camera"
+
+ expected_image: Image = Image(
+ header=Header(frame_id="camera_frame"),
+ height=480,
+ width=640,
+ )
+
+ expected_pose: Pose = Pose(
+ position=Position(x=1.0, y=2.0, z=0.5),
+ orientation=Orientation(x=0.0, y=0.0, z=0.0, w=1.0),
+ )
+
+ def get_system_prompt(self) -> str:
+ return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT
+
+ @property
+ def expected_service(self) -> str:
+ return "/rai/whatisee/get"
+
+ @property
+ def extra_calls(self) -> int:
+ return 3
+
+ def verify_message_tool_call(self, tool_calls: list[ToolCall]) -> bool:
+ for call in tool_calls:
+ if self._check_service_tool_call_field(
+ call,
+ "call_ros2_service",
+ self.expected_service,
+ self.expected_service_type,
+ field_path="", # empty request
+ expected_value="",
+ ):
+ return True
+
+ self.log_error(f"No valid call to {self.expected_service} found.")
+ return False
+
+ def get_prompt(self) -> str:
+ return (
+ f"You need to call the service '{self.expected_service}' with an empty request.\n"
+ "Before calling, follow these steps:\n"
+ "1. Use the tool to retrieve the available ROS2 services and their types.\n"
+ f"2. Find the service type for '{self.expected_service}'.\n"
+ "3. Retrieve the full message interface definition for that service.\n"
+ "4. Construct the request message filling only the fields you are instructed to. Rest of the fields will have default values.\n"
+ f"5. Call the service '{self.expected_service}' using the correct message type and interface.\n"
+ )
+
+
+class CallWhatISeeTask1ExtraCall(CallWhatISeeTask3ExtraCalls):
+ complexity = "medium"
+
+ @property
+ def extra_calls(self) -> int:
+ return 1
+
+
+class CallWhatISeeTask0ExtraCalls(CallWhatISeeTask3ExtraCalls):
+ complexity = "hard"
+
+ @property
+ def extra_calls(self) -> int:
+ return 0
+
+
+# class CallROS2CustomActionTask(CustomInterfacesActionTask):
+# complexity = "easy"
+
+# expected_task = "Where are you?"
+# expected_description = ""
+# expected_priority = "10"
+
+# def get_system_prompt(self) -> str:
+# return PROACTIVE_ROS2_EXPERT_SYSTEM_PROMPT
+
+# @property
+# def expected_action(self) -> str:
+# return "/perform_task"
+
+# @property
+# def expected_message(self) -> Dict[str, Any]:
+# expected = DEFAULT_MESSAGES[self.expected_action_type].copy()
+# expected["goal"]["task"] = self.expected_task
+# expected["goal"]["description"] = self.expected_description
+# expected["goal"]["priority "] = self.expected_priority
+# return expected
+
+
+ROBOT_NAVIGATION_SYSTEM_PROMPT = """You are an autonomous robot connected to ros2 environment. Your main goal is to fulfill the user's requests.
+ Do not make assumptions about the environment you are currently in.
+ You can use ros2 topics, services and actions to operate.
+
+ As a first step check transforms by getting 1 message from /tf topic
+ use /cmd_vel topic very carefully. Obstacle detection works only with nav2 stack, so be careful when it is not used. >
+ be patient with running ros2 actions. usually the take some time to run.
+ Always check your transform before and after you perform ros2 actions, so that you can verify if it worked.
+
+ Navigation tips:
+ - it's good to start finding objects by rotating, then navigating to some diverse location with occasional rotations. Remember to frequency detect objects.
+ - for driving forward/backward or to some coordinates, ros2 actions are better.
+ - for driving for some specific time or in specific manner (like shaper or turns) it good to use /cmd_vel topic
+ - you are currently unable to read map or point-cloud, so please avoid subscribing to such topics.
+ - if you are asked to drive towards some object, it's good to:
+ 1. check the camera image and verify if objects can be seen
+ 2. if only driving forward is required, do it
+ 3. if obstacle avoidance might be required, use ros2 actions navigate_*, but first check your current position, then very accurately estimate the goal pose.
+ - it is good to verify using given information if the robot is not stuck
+ - navigation actions sometimes fail. Their output can be read from rosout. You can also tell if they partially worked by checking the robot position and rotation.
+ - before using any ros2 interfaces, always make sure to check you are using the right interface
+ - processing camera image takes 5-10s. Take it into account that if the robot is moving, the information can be outdated. Handle it by good planning of your movements.
+ - you are encouraged to use wait tool in between checking the status of actions
+ - to find some object navigate around and check the surrounding area
+ - when the goal is accomplished please make sure to cancel running actions
+ - when you reach the navigation goal - double check if you reached it by checking the current position
+ - if you detect collision, please stop operation
+
+ - you will be given your camera image description. Based on this information you can reason about positions of objects.
+ - be careful and aboid obstacles
+
+ Here are the corners of your environment:
+ (-2.76,9.04, 0.0),
+ (4.62, 9.07, 0.0),
+ (-2.79, -3.83, 0.0),
+ (4.59, -3.81, 0.0)
+
+ This is location of places:
+ Kitchen:
+ (2.06, -0.23, 0.0),
+ (2.07, -1.43, 0.0),
+ (-2.44, -0.38, 0.0),
+ (-2.56, -1.47, 0.0)
+
+ # Living room:
+ (-2.49, 1.87, 0.0),
+ (-2.50, 5.49, 0.0),
+ (0.79, 5.73, 0.0),
+ (0.92, 1.01, 0.0)
+
+ Before starting anything, make sure to load available topics, services and actions.
+ """
+NAVIGATION_SERVICES_AND_TYPES: Dict[str, str] = {
+ "/assisted_teleop/_action/cancel_goal": "action_msgs/srv/CancelGoal",
+ "/assisted_teleop/_action/get_result": "nav2_msgs/action/AssistedTeleop_GetResult",
+ "/assisted_teleop/_action/send_goal": "nav2_msgs/action/AssistedTeleop_SendGoal",
+ "/backup/_action/cancel_goal": "action_msgs/srv/CancelGoal",
+ "/backup/_action/get_result": "nav2_msgs/action/BackUp_GetResult",
+ "/backup/_action/send_goal": "nav2_msgs/action/BackUp_SendGoal",
+ "/behavior_server/change_state": "lifecycle_msgs/srv/ChangeState",
+ "/behavior_server/describe_parameters": "rcl_interfaces/srv/DescribeParameters",
+ "/behavior_server/get_available_states": "lifecycle_msgs/srv/GetAvailableStates",
+ "/behavior_server/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions",
+ "/behavior_server/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes",
+ "/behavior_server/get_parameters": "rcl_interfaces/srv/GetParameters",
+ "/behavior_server/get_state": "lifecycle_msgs/srv/GetState",
+ "/behavior_server/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions",
+ "/behavior_server/list_parameters": "rcl_interfaces/srv/ListParameters",
+ "/behavior_server/set_parameters": "rcl_interfaces/srv/SetParameters",
+ "/behavior_server/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically",
+ "/bt_navigator/change_state": "lifecycle_msgs/srv/ChangeState",
+ "/bt_navigator/describe_parameters": "rcl_interfaces/srv/DescribeParameters",
+ "/bt_navigator/get_available_states": "lifecycle_msgs/srv/GetAvailableStates",
+ "/bt_navigator/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions",
+ "/bt_navigator/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes",
+ "/bt_navigator/get_parameters": "rcl_interfaces/srv/GetParameters",
+ "/bt_navigator/get_state": "lifecycle_msgs/srv/GetState",
+ "/bt_navigator/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions",
+ "/bt_navigator/list_parameters": "rcl_interfaces/srv/ListParameters",
+ "/bt_navigator/set_parameters": "rcl_interfaces/srv/SetParameters",
+ "/bt_navigator/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically",
+ "/bt_navigator_navigate_through_poses_rclcpp_node/describe_parameters": "rcl_interfaces/srv/DescribeParameters",
+ "/bt_navigator_navigate_through_poses_rclcpp_node/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes",
+ "/bt_navigator_navigate_through_poses_rclcpp_node/get_parameters": "rcl_interfaces/srv/GetParameters",
+ "/bt_navigator_navigate_through_poses_rclcpp_node/list_parameters": "rcl_interfaces/srv/ListParameters",
+ "/bt_navigator_navigate_through_poses_rclcpp_node/set_parameters": "rcl_interfaces/srv/SetParameters",
+ "/bt_navigator_navigate_through_poses_rclcpp_node/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically",
+ "/bt_navigator_navigate_to_pose_rclcpp_node/describe_parameters": "rcl_interfaces/srv/DescribeParameters",
+ "/bt_navigator_navigate_to_pose_rclcpp_node/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes",
+ "/bt_navigator_navigate_to_pose_rclcpp_node/get_parameters": "rcl_interfaces/srv/GetParameters",
+ "/bt_navigator_navigate_to_pose_rclcpp_node/list_parameters": "rcl_interfaces/srv/ListParameters",
+ "/bt_navigator_navigate_to_pose_rclcpp_node/set_parameters": "rcl_interfaces/srv/SetParameters",
+ "/bt_navigator_navigate_to_pose_rclcpp_node/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically",
+ "/compute_path_through_poses/_action/cancel_goal": "action_msgs/srv/CancelGoal",
+ "/compute_path_through_poses/_action/get_result": "nav2_msgs/action/ComputePathThroughPoses_GetResult",
+ "/compute_path_through_poses/_action/send_goal": "nav2_msgs/action/ComputePathThroughPoses_SendGoal",
+ "/compute_path_to_pose/_action/cancel_goal": "action_msgs/srv/CancelGoal",
+ "/compute_path_to_pose/_action/get_result": "nav2_msgs/action/ComputePathToPose_GetResult",
+ "/compute_path_to_pose/_action/send_goal": "nav2_msgs/action/ComputePathToPose_SendGoal",
+ "/controller_server/change_state": "lifecycle_msgs/srv/ChangeState",
+ "/controller_server/describe_parameters": "rcl_interfaces/srv/DescribeParameters",
+ "/controller_server/get_available_states": "lifecycle_msgs/srv/GetAvailableStates",
+ "/controller_server/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions",
+ "/controller_server/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes",
+ "/controller_server/get_parameters": "rcl_interfaces/srv/GetParameters",
+ "/controller_server/get_state": "lifecycle_msgs/srv/GetState",
+ "/controller_server/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions",
+ "/controller_server/list_parameters": "rcl_interfaces/srv/ListParameters",
+ "/controller_server/set_parameters": "rcl_interfaces/srv/SetParameters",
+ "/controller_server/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically",
+ "/drive_on_heading/_action/cancel_goal": "action_msgs/srv/CancelGoal",
+ "/drive_on_heading/_action/get_result": "nav2_msgs/action/DriveOnHeading_GetResult",
+ "/drive_on_heading/_action/send_goal": "nav2_msgs/action/DriveOnHeading_SendGoal",
+ "/follow_path/_action/cancel_goal": "action_msgs/srv/CancelGoal",
+ "/follow_path/_action/get_result": "nav2_msgs/action/FollowPath_GetResult",
+ "/follow_path/_action/send_goal": "nav2_msgs/action/FollowPath_SendGoal",
+ "/follow_waypoints/_action/cancel_goal": "action_msgs/srv/CancelGoal",
+ "/follow_waypoints/_action/get_result": "nav2_msgs/action/FollowWaypoints_GetResult",
+ "/follow_waypoints/_action/send_goal": "nav2_msgs/action/FollowWaypoints_SendGoal",
+ "/global_costmap/clear_around_global_costmap": "nav2_msgs/srv/ClearCostmapAroundRobot",
+ "/global_costmap/clear_entirely_global_costmap": "nav2_msgs/srv/ClearEntireCostmap",
+ "/global_costmap/clear_except_global_costmap": "nav2_msgs/srv/ClearCostmapExceptRegion",
+ "/global_costmap/get_costmap": "nav2_msgs/srv/GetCostmap",
+ "/global_costmap/global_costmap/change_state": "lifecycle_msgs/srv/ChangeState",
+ "/global_costmap/global_costmap/describe_parameters": "rcl_interfaces/srv/DescribeParameters",
+ "/global_costmap/global_costmap/get_available_states": "lifecycle_msgs/srv/GetAvailableStates",
+ "/global_costmap/global_costmap/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions",
+ "/global_costmap/global_costmap/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes",
+ "/global_costmap/global_costmap/get_parameters": "rcl_interfaces/srv/GetParameters",
+ "/global_costmap/global_costmap/get_state": "lifecycle_msgs/srv/GetState",
+ "/global_costmap/global_costmap/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions",
+ "/global_costmap/global_costmap/list_parameters": "rcl_interfaces/srv/ListParameters",
+ "/global_costmap/global_costmap/set_parameters": "rcl_interfaces/srv/SetParameters",
+ "/global_costmap/global_costmap/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically",
+ "/grounded_sam/describe_parameters": "rcl_interfaces/srv/DescribeParameters",
+ "/grounded_sam/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes",
+ "/grounded_sam/get_parameters": "rcl_interfaces/srv/GetParameters",
+ "/grounded_sam/list_parameters": "rcl_interfaces/srv/ListParameters",
+ "/grounded_sam/set_parameters": "rcl_interfaces/srv/SetParameters",
+ "/grounded_sam/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically",
+ "/grounded_sam_segment": "rai_interfaces/srv/RAIGroundedSam",
+ "/grounding_dino/describe_parameters": "rcl_interfaces/srv/DescribeParameters",
+ "/grounding_dino/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes",
+ "/grounding_dino/get_parameters": "rcl_interfaces/srv/GetParameters",
+ "/grounding_dino/list_parameters": "rcl_interfaces/srv/ListParameters",
+ "/grounding_dino/set_parameters": "rcl_interfaces/srv/SetParameters",
+ "/grounding_dino/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically",
+ "/grounding_dino_classify": "rai_interfaces/srv/RAIGroundingDino",
+ "/is_path_valid": "nav2_msgs/srv/IsPathValid",
+ "/launch_ros_138640/describe_parameters": "rcl_interfaces/srv/DescribeParameters",
+ "/launch_ros_138640/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes",
+ "/launch_ros_138640/get_parameters": "rcl_interfaces/srv/GetParameters",
+ "/launch_ros_138640/list_parameters": "rcl_interfaces/srv/ListParameters",
+ "/launch_ros_138640/set_parameters": "rcl_interfaces/srv/SetParameters",
+ "/launch_ros_138640/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically",
+ "/lifecycle_manager_navigation/describe_parameters": "rcl_interfaces/srv/DescribeParameters",
+ "/lifecycle_manager_navigation/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes",
+ "/lifecycle_manager_navigation/get_parameters": "rcl_interfaces/srv/GetParameters",
+ "/lifecycle_manager_navigation/is_active": "std_srvs/srv/Trigger",
+ "/lifecycle_manager_navigation/list_parameters": "rcl_interfaces/srv/ListParameters",
+ "/lifecycle_manager_navigation/manage_nodes": "nav2_msgs/srv/ManageLifecycleNodes",
+ "/lifecycle_manager_navigation/set_parameters": "rcl_interfaces/srv/SetParameters",
+ "/lifecycle_manager_navigation/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically",
+ "/lifecycle_manager_slam/describe_parameters": "rcl_interfaces/srv/DescribeParameters",
+ "/lifecycle_manager_slam/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes",
+ "/lifecycle_manager_slam/get_parameters": "rcl_interfaces/srv/GetParameters",
+ "/lifecycle_manager_slam/is_active": "std_srvs/srv/Trigger",
+ "/lifecycle_manager_slam/list_parameters": "rcl_interfaces/srv/ListParameters",
+ "/lifecycle_manager_slam/manage_nodes": "nav2_msgs/srv/ManageLifecycleNodes",
+ "/lifecycle_manager_slam/set_parameters": "rcl_interfaces/srv/SetParameters",
+ "/lifecycle_manager_slam/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically",
+ "/local_costmap/clear_around_local_costmap": "nav2_msgs/srv/ClearCostmapAroundRobot",
+ "/local_costmap/clear_entirely_local_costmap": "nav2_msgs/srv/ClearEntireCostmap",
+ "/local_costmap/clear_except_local_costmap": "nav2_msgs/srv/ClearCostmapExceptRegion",
+ "/local_costmap/get_costmap": "nav2_msgs/srv/GetCostmap",
+ "/local_costmap/local_costmap/change_state": "lifecycle_msgs/srv/ChangeState",
+ "/local_costmap/local_costmap/describe_parameters": "rcl_interfaces/srv/DescribeParameters",
+ "/local_costmap/local_costmap/get_available_states": "lifecycle_msgs/srv/GetAvailableStates",
+ "/local_costmap/local_costmap/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions",
+ "/local_costmap/local_costmap/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes",
+ "/local_costmap/local_costmap/get_parameters": "rcl_interfaces/srv/GetParameters",
+ "/local_costmap/local_costmap/get_state": "lifecycle_msgs/srv/GetState",
+ "/local_costmap/local_costmap/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions",
+ "/local_costmap/local_costmap/list_parameters": "rcl_interfaces/srv/ListParameters",
+ "/local_costmap/local_costmap/set_parameters": "rcl_interfaces/srv/SetParameters",
+ "/local_costmap/local_costmap/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically",
+ "/map_saver/change_state": "lifecycle_msgs/srv/ChangeState",
+ "/map_saver/describe_parameters": "rcl_interfaces/srv/DescribeParameters",
+ "/map_saver/get_available_states": "lifecycle_msgs/srv/GetAvailableStates",
+ "/map_saver/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions",
+ "/map_saver/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes",
+ "/map_saver/get_parameters": "rcl_interfaces/srv/GetParameters",
+ "/map_saver/get_state": "lifecycle_msgs/srv/GetState",
+ "/map_saver/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions",
+ "/map_saver/list_parameters": "rcl_interfaces/srv/ListParameters",
+ "/map_saver/save_map": "nav2_msgs/srv/SaveMap",
+ "/map_saver/set_parameters": "rcl_interfaces/srv/SetParameters",
+ "/map_saver/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically",
+ "/nav2_container/_container/list_nodes": "composition_interfaces/srv/ListNodes",
+ "/nav2_container/_container/load_node": "composition_interfaces/srv/LoadNode",
+ "/nav2_container/_container/unload_node": "composition_interfaces/srv/UnloadNode",
+ "/navigate_through_poses/_action/cancel_goal": "action_msgs/srv/CancelGoal",
+ "/navigate_through_poses/_action/get_result": "nav2_msgs/action/NavigateThroughPoses_GetResult",
+ "/navigate_through_poses/_action/send_goal": "nav2_msgs/action/NavigateThroughPoses_SendGoal",
+ "/navigate_to_pose/_action/cancel_goal": "action_msgs/srv/CancelGoal",
+ "/navigate_to_pose/_action/get_result": "nav2_msgs/action/NavigateToPose_GetResult",
+ "/navigate_to_pose/_action/send_goal": "nav2_msgs/action/NavigateToPose_SendGoal",
+ "/o3de_ros2_node/describe_parameters": "rcl_interfaces/srv/DescribeParameters",
+ "/o3de_ros2_node/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes",
+ "/o3de_ros2_node/get_parameters": "rcl_interfaces/srv/GetParameters",
+ "/o3de_ros2_node/list_parameters": "rcl_interfaces/srv/ListParameters",
+ "/o3de_ros2_node/set_parameters": "rcl_interfaces/srv/SetParameters",
+ "/o3de_ros2_node/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically",
+ "/planner_server/change_state": "lifecycle_msgs/srv/ChangeState",
+ "/planner_server/describe_parameters": "rcl_interfaces/srv/DescribeParameters",
+ "/planner_server/get_available_states": "lifecycle_msgs/srv/GetAvailableStates",
+ "/planner_server/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions",
+ "/planner_server/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes",
+ "/planner_server/get_parameters": "rcl_interfaces/srv/GetParameters",
+ "/planner_server/get_state": "lifecycle_msgs/srv/GetState",
+ "/planner_server/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions",
+ "/planner_server/list_parameters": "rcl_interfaces/srv/ListParameters",
+ "/planner_server/set_parameters": "rcl_interfaces/srv/SetParameters",
+ "/planner_server/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically",
+ "/rai_ros2_ari_connector_b6ed00ab6356/describe_parameters": "rcl_interfaces/srv/DescribeParameters",
+ "/rai_ros2_ari_connector_b6ed00ab6356/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes",
+ "/rai_ros2_ari_connector_b6ed00ab6356/get_parameters": "rcl_interfaces/srv/GetParameters",
+ "/rai_ros2_ari_connector_b6ed00ab6356/list_parameters": "rcl_interfaces/srv/ListParameters",
+ "/rai_ros2_ari_connector_b6ed00ab6356/set_parameters": "rcl_interfaces/srv/SetParameters",
+ "/rai_ros2_ari_connector_b6ed00ab6356/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically",
+ "/slam_toolbox/clear_changes": "slam_toolbox/srv/Clear",
+ "/slam_toolbox/clear_queue": "slam_toolbox/srv/ClearQueue",
+ "/slam_toolbox/describe_parameters": "rcl_interfaces/srv/DescribeParameters",
+ "/slam_toolbox/deserialize_map": "slam_toolbox/srv/DeserializePoseGraph",
+ "/slam_toolbox/dynamic_map": "nav_msgs/srv/GetMap",
+ "/slam_toolbox/get_interactive_markers": "visualization_msgs/srv/GetInteractiveMarkers",
+ "/slam_toolbox/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes",
+ "/slam_toolbox/get_parameters": "rcl_interfaces/srv/GetParameters",
+ "/slam_toolbox/list_parameters": "rcl_interfaces/srv/ListParameters",
+ "/slam_toolbox/manual_loop_closure": "slam_toolbox/srv/LoopClosure",
+ "/slam_toolbox/pause_new_measurements": "slam_toolbox/srv/Pause",
+ "/slam_toolbox/save_map": "slam_toolbox/srv/SaveMap",
+ "/slam_toolbox/serialize_map": "slam_toolbox/srv/SerializePoseGraph",
+ "/slam_toolbox/set_parameters": "rcl_interfaces/srv/SetParameters",
+ "/slam_toolbox/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically",
+ "/slam_toolbox/toggle_interactive_mode": "slam_toolbox/srv/ToggleInteractive",
+ "/smooth_path/_action/cancel_goal": "action_msgs/srv/CancelGoal",
+ "/smooth_path/_action/get_result": "nav2_msgs/action/SmoothPath_GetResult",
+ "/smooth_path/_action/send_goal": "nav2_msgs/action/SmoothPath_SendGoal",
+ "/smoother_server/change_state": "lifecycle_msgs/srv/ChangeState",
+ "/smoother_server/describe_parameters": "rcl_interfaces/srv/DescribeParameters",
+ "/smoother_server/get_available_states": "lifecycle_msgs/srv/GetAvailableStates",
+ "/smoother_server/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions",
+ "/smoother_server/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes",
+ "/smoother_server/get_parameters": "rcl_interfaces/srv/GetParameters",
+ "/smoother_server/get_state": "lifecycle_msgs/srv/GetState",
+ "/smoother_server/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions",
+ "/smoother_server/list_parameters": "rcl_interfaces/srv/ListParameters",
+ "/smoother_server/set_parameters": "rcl_interfaces/srv/SetParameters",
+ "/smoother_server/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically",
+ "/spin/_action/cancel_goal": "action_msgs/srv/CancelGoal",
+ "/spin/_action/get_result": "nav2_msgs/action/Spin_GetResult",
+ "/spin/_action/send_goal": "nav2_msgs/action/Spin_SendGoal",
+ "/tf2_frames": "tf2_msgs/srv/FrameGraph",
+ "/velocity_smoother/change_state": "lifecycle_msgs/srv/ChangeState",
+ "/velocity_smoother/describe_parameters": "rcl_interfaces/srv/DescribeParameters",
+ "/velocity_smoother/get_available_states": "lifecycle_msgs/srv/GetAvailableStates",
+ "/velocity_smoother/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions",
+ "/velocity_smoother/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes",
+ "/velocity_smoother/get_parameters": "rcl_interfaces/srv/GetParameters",
+ "/velocity_smoother/get_state": "lifecycle_msgs/srv/GetState",
+ "/velocity_smoother/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions",
+ "/velocity_smoother/list_parameters": "rcl_interfaces/srv/ListParameters",
+ "/velocity_smoother/set_parameters": "rcl_interfaces/srv/SetParameters",
+ "/velocity_smoother/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically",
+ "/wait/_action/cancel_goal": "action_msgs/srv/CancelGoal",
+ "/wait/_action/get_result": "nav2_msgs/action/Wait_GetResult",
+ "/wait/_action/send_goal": "nav2_msgs/action/Wait_SendGoal",
+ "/waypoint_follower/change_state": "lifecycle_msgs/srv/ChangeState",
+ "/waypoint_follower/describe_parameters": "rcl_interfaces/srv/DescribeParameters",
+ "/waypoint_follower/get_available_states": "lifecycle_msgs/srv/GetAvailableStates",
+ "/waypoint_follower/get_available_transitions": "lifecycle_msgs/srv/GetAvailableTransitions",
+ "/waypoint_follower/get_parameter_types": "rcl_interfaces/srv/GetParameterTypes",
+ "/waypoint_follower/get_parameters": "rcl_interfaces/srv/GetParameters",
+ "/waypoint_follower/get_state": "lifecycle_msgs/srv/GetState",
+ "/waypoint_follower/get_transition_graph": "lifecycle_msgs/srv/GetAvailableTransitions",
+ "/waypoint_follower/list_parameters": "rcl_interfaces/srv/ListParameters",
+ "/waypoint_follower/set_parameters": "rcl_interfaces/srv/SetParameters",
+ "/waypoint_follower/set_parameters_atomically": "rcl_interfaces/srv/SetParametersAtomically",
+}
+
+
+# class NavigateToPointTask(ROS2ToolCallingAgentTask):
+# complexity = "medium"
+# actions_and_types: Dict[str, str] = {
+# "/assisted_teleop": "nav2_msgs/action/AssistedTeleop",
+# "/backup": "nav2_msgs/action/BackUp",
+# "/compute_path_through_poses": "nav2_msgs/action/ComputePathThroughPoses",
+# "/compute_path_to_pose": "nav2_msgs/action/ComputePathToPose",
+# "/drive_on_heading": "nav2_msgs/action/DriveOnHeading",
+# "/follow_path": "nav2_msgs/action/FollowPath",
+# "/follow_waypoints": "nav2_msgs/action/FollowWaypoints",
+# "/navigate_through_poses": "nav2_msgs/action/NavigateThroughPoses",
+# "/navigate_to_pose": "nav2_msgs/action/NavigateToPose",
+# "/smooth_path": "nav2_msgs/action/SmoothPath",
+# "/spin": "nav2_msgs/action/Spin",
+# "/wait": "nav2_msgs/action/Wait",
+# }
+# services_and_types: Dict[str, str] = NAVIGATION_SERVICES_AND_TYPES
+# interfaces: Dict[str, Dict[str, Any]] = {
+# "nav2_msgs/action/NavigateToPose": {
+# "goal": {
+# "pose": {
+# "header": {"stamp": {"sec": 0, "nanosec": 0}, "frame_id": ""},
+# "pose": {
+# "position": {"x": 0.0, "y": 0.0, "z": 0.0},
+# "orientation": {"x": 0.0, "y": 0.0, "z": 0.0, "w": 1.0},
+# },
+# },
+# "behavior_tree": "",
+# },
+# "result": {"result": {}},
+# "feedback": {
+# "current_pose": {
+# "header": {"stamp": {"sec": 0, "nanosec": 0}, "frame_id": ""},
+# "pose": {
+# "position": {"x": 0.0, "y": 0.0, "z": 0.0},
+# "orientation": {"x": 0.0, "y": 0.0, "z": 0.0, "w": 1.0},
+# },
+# },
+# "navigation_time": {"sec": 0, "nanosec": 0},
+# "estimated_time_remaining": {"sec": 0, "nanosec": 0},
+# "number_of_recoveries": 0,
+# "distance_remaining": 0.0,
+# },
+# },
+# }
+# action_models: List[type[ActionBaseModel]] = [NavigateToPoseAction]
+
+# def __init__(self, logger: loggers_type | None = None) -> None:
+# super().__init__(logger=logger)
+# action_strings = [
+# f"action: {action}\ntype: {act_type}\n"
+# for action, act_type in self.actions_and_types.items()
+# ]
+# service_strings = [
+# f"service: {service}\ntype: {srv_type}\n"
+# for service, srv_type in self.services_and_types.items()
+# ]
+# interface_strings = {
+# msg_type: json.dumps(interface)
+# for msg_type, interface in self.interfaces.items()
+# }
+
+# self.expected_tools: List[BaseTool] = [
+# MockGetROS2ActionsNamesAndTypesTool(
+# mock_actions_names_and_types=action_strings
+# ),
+# MockStartROS2ActionTool(
+# available_actions=list(self.actions_and_types.keys()),
+# available_action_types=list(self.actions_and_types.values()),
+# available_action_models=self.action_models,
+# ),
+# MockGetROS2ActionFeedbackTool(),
+# MockGetROS2ActionResultTool(),
+# MockGetROS2ServicesNamesAndTypesTool(
+# mock_service_names_and_types=service_strings
+# ),
+# MockGetROS2MessageInterfaceTool(mock_interfaces=interface_strings),
+# ]
+
+# def get_system_prompt(self) -> str:
+# return ROBOT_NAVIGATION_SYSTEM_PROMPT
+
+# def get_prompt(self) -> str:
+# return (
+# "Call action /perform_task with the provided goal values: "
+# "{priority: 10, description: '', task: 'Where are you?'}"
+# )
+
+
+# class SpinAroundTask(ROS2ToolCallingAgentTask):
+# complexity = "medium"
+# interfaces: Dict[str, Dict[str, Any]] = {
+# "nav2_msgs/action/Spin": {
+# "goal": {"target_yaw": 0.0, "time_allowance": {"sec": 0, "nanosec": 0}},
+# "result": {"total_elapsed_time": {"sec": 0, "nanosec": 0}},
+# "feedback": {"angular_distance_traveled": 0.0},
+# }
+# }
+# actions_and_types: Dict[str, str] = {
+# "/assisted_teleop": "nav2_msgs/action/AssistedTeleop",
+# "/backup": "nav2_msgs/action/BackUp",
+# "/compute_path_through_poses": "nav2_msgs/action/ComputePathThroughPoses",
+# "/compute_path_to_pose": "nav2_msgs/action/ComputePathToPose",
+# "/drive_on_heading": "nav2_msgs/action/DriveOnHeading",
+# "/follow_path": "nav2_msgs/action/FollowPath",
+# "/follow_waypoints": "nav2_msgs/action/FollowWaypoints",
+# "/navigate_through_poses": "nav2_msgs/action/NavigateThroughPoses",
+# "/navigate_to_pose": "nav2_msgs/action/NavigateToPose",
+# "/smooth_path": "nav2_msgs/action/SmoothPath",
+# "/spin": "nav2_msgs/action/Spin",
+# "/wait": "nav2_msgs/action/Wait",
+# }
+# action_models: List[type[ActionBaseModel]] = [SpinAction]
+
+# def __init__(self, logger: loggers_type | None = None) -> None:
+# super().__init__(logger=logger)
+# action_strings = [
+# f"action: {action}\ntype: {act_type}\n"
+# for action, act_type in self.actions_and_types.items()
+# ]
+# self.expected_tools: List[BaseTool] = [
+# MockGetROS2ActionsNamesAndTypesTool(
+# mock_actions_names_and_types=action_strings
+# ),
+# MockStartROS2ActionTool(
+# available_actions=list(self.actions_and_types.keys()),
+# available_action_types=list(self.actions_and_types.values()),
+# available_action_models=self.action_models,
+# ),
+# MockGetROS2ActionFeedbackTool(),
+# MockGetROS2ActionResultTool(),
+# ]
+
+# def get_system_prompt(self) -> str:
+# return ROBOT_NAVIGATION_SYSTEM_PROMPT
+
+# def get_prompt(self) -> str:
+# return "Spin around by 3 radians."
+
+# def verify_tool_calls(self, response: dict[str, Any]):
+#
+# messages = response["messages"]
+# ai_messages: Sequence[AIMessage] = [
+# message for message in messages if isinstance(message, AIMessage)
+# ]
+# tool_calls = [
+# tool_call for message in ai_messages for tool_call in message.tool_calls
+# ]
+# expected_tool_calls: list[dict[str, Any]] = [
+# {"name": "get_ros2_actions_names_and_types", "args": {}},
+# {
+# "name": "start_ros2_action",
+# "args": {
+# "action_name": "/spin",
+# "action_type": "nav2_msgs/action/Spin",
+# "action_args": {"target_yaw": 3},
+# },
+# "optional_args": {
+# "action_args": {
+# "time_allowance": {"sec": ANY_VALUE, "nanosec": ANY_VALUE}
+# }
+# },
+# },
+# {"name": "get_ros2_action_feedback", "args": {"action_id": ANY_VALUE}},
+# {"name": "get_ros2_action_result", "args": {"action_id": ANY_VALUE}},
+# ]
+# self._check_multiple_tool_calls_from_list(
+# tool_calls=tool_calls, expected_tool_calls=expected_tool_calls
+# )
+# if not self.result.errors:
+# self.result.success = True
diff --git a/src/rai_core/rai/communication/ros2/api.py b/src/rai_core/rai/communication/ros2/api.py
index 7af6706da..2a98ba266 100644
--- a/src/rai_core/rai/communication/ros2/api.py
+++ b/src/rai_core/rai/communication/ros2/api.py
@@ -164,7 +164,9 @@ class ROS2TopicAPI:
_publishers: Dictionary mapping topic names to their publisher instances
"""
- def __init__(self, node: rclpy.node.Node) -> None:
+ def __init__(
+ self, node: rclpy.node.Node, destroy_subscribers: bool = False
+ ) -> None:
"""Initialize the ROS2 topic API.
Args:
@@ -181,7 +183,7 @@ def __init__(self, node: rclpy.node.Node) -> None:
# preventing node crashes.
self._last_msg: Dict[str, Tuple[float, Any]] = {}
self._subscriptions: Dict[str, rclpy.node.Subscription] = {}
- self._destroy_subscriptions: bool = False
+ self._destroy_subscribers: bool = destroy_subscribers
def get_topic_names_and_types(
self, no_demangle: bool = False
@@ -298,7 +300,35 @@ def receive(
def _generic_callback(self, topic: str, msg: Any) -> None:
self._last_msg[topic] = (time.time(), msg)
- def _wait_for_message(
+ def _wait_for_message_once(
+ self,
+ msg_cls: Type[Any],
+ node: rclpy.node.Node,
+ topic: str,
+ qos_profile: QoSProfile,
+ timeout_sec: float,
+ ) -> Tuple[bool, Any]:
+ ts = time.time()
+ success = False
+ msg = None
+
+ def callback(received_msg: Any):
+ nonlocal success, msg
+ success = True
+ msg = received_msg
+
+ sub = node.create_subscription(
+ msg_cls,
+ topic,
+ callback,
+ qos_profile=qos_profile,
+ )
+ while not success and time.time() - ts < timeout_sec:
+ time.sleep(0.01)
+ node.destroy_subscription(sub)
+ return success, msg
+
+ def _wait_for_message_persistent(
self,
msg_cls: Type[Any],
node: rclpy.node.Node,
@@ -313,19 +343,30 @@ def _wait_for_message(
partial(self._generic_callback, topic),
qos_profile=qos_profile,
)
-
ts = time.time()
while time.time() - ts < timeout_sec:
if topic in self._last_msg:
if self._last_msg[topic][0] + timeout_sec > time.time():
- if self._destroy_subscriptions:
- node.destroy_subscription(self._subscriptions.pop(topic))
return True, self._last_msg[topic][1]
time.sleep(0.01)
- if self._destroy_subscriptions:
- node.destroy_subscription(self._subscriptions.pop(topic))
return False, None
+ def _wait_for_message(
+ self,
+ msg_cls: Type[Any],
+ node: rclpy.node.Node,
+ topic: str,
+ qos_profile: QoSProfile,
+ timeout_sec: float,
+ ) -> Tuple[bool, Any]:
+ if self._destroy_subscribers:
+ return self._wait_for_message_once(
+ msg_cls, node, topic, qos_profile, timeout_sec
+ )
+ return self._wait_for_message_persistent(
+ msg_cls, node, topic, qos_profile, timeout_sec
+ )
+
def _is_topic_available(self, topic: str, timeout_sec: float) -> bool:
ts = time.time()
topic = topic if topic.startswith("/") else f"/{topic}"
diff --git a/src/rai_core/rai/communication/ros2/connectors/ari_connector.py b/src/rai_core/rai/communication/ros2/connectors/ari_connector.py
index 95a6adcef..ac9a32181 100644
--- a/src/rai_core/rai/communication/ros2/connectors/ari_connector.py
+++ b/src/rai_core/rai/communication/ros2/connectors/ari_connector.py
@@ -39,16 +39,69 @@
class ROS2ARIConnector(ROS2ActionMixin, ROS2ServiceMixin, ARIConnector[ROS2ARIMessage]):
+ """ROS2-specific implementation of the ARIConnector.
+
+ This connector provides functionality for ROS2 communication through topics,
+ services, and actions, as well as TF (Transform) operations.
+
+ Parameters
+ ----------
+ node_name : str, optional
+ Name of the ROS2 node. If not provided, generates a unique name with UUID.
+ destroy_subscribers : bool, optional
+ Whether to destroy subscribers after receiving a message, by default False.
+
+ Methods
+ -------
+ get_topics_names_and_types()
+ Get list of available topics and their message types.
+ get_services_names_and_types()
+ Get list of available services and their types.
+ get_actions_names_and_types()
+ Get list of available actions and their types.
+ send_message(message, target, msg_type, auto_qos_matching=True, qos_profile=None, **kwargs)
+ Send a message to a specified topic.
+ receive_message(source, timeout_sec=1.0, msg_type=None, auto_topic_type=True, **kwargs)
+ Receive a message from a specified topic.
+ wait_for_transform(tf_buffer, target_frame, source_frame, timeout_sec=1.0)
+ Wait for a transform to become available.
+ get_transform(target_frame, source_frame, timeout_sec=5.0)
+ Get the transform between two frames.
+ create_service(service_name, on_request, on_done=None, service_type, **kwargs)
+ Create a ROS2 service.
+ create_action(action_name, generate_feedback_callback, action_type, **kwargs)
+ Create a ROS2 action server.
+ shutdown()
+ Clean up resources and shut down the connector.
+
+ Notes
+ -----
+ Threading Model:
+ The connector creates a MultiThreadedExecutor that runs in a dedicated thread.
+ This executor processes all ROS2 callbacks and operations asynchronously.
+
+ Subscriber Lifecycle:
+ The `destroy_subscribers` parameter controls subscriber cleanup behavior:
+ - True: Subscribers are destroyed after receiving a message
+ - Pros: Better resource utilization
+ - Cons: Known stability issues (see: https://github.com/ros2/rclpy/issues/1142)
+ - False (default): Subscribers remain active after message reception
+ - Pros: More stable operation, avoids potential crashes
+ - Cons: May lead to memory/performance overhead from inactive subscribers
+ """
+
def __init__(
- self, node_name: str = f"rai_ros2_ari_connector_{str(uuid.uuid4())[-12:]}"
+ self,
+ node_name: str = f"rai_ros2_ari_connector_{str(uuid.uuid4())[-12:]}",
+ destroy_subscribers: bool = False,
):
super().__init__()
self._node = Node(node_name)
- self._topic_api = ROS2TopicAPI(self._node)
+ self._topic_api = ROS2TopicAPI(self._node, destroy_subscribers)
self._service_api = ROS2ServiceAPI(self._node)
self._actions_api = ROS2ActionAPI(self._node)
self._tf_buffer = Buffer(node=self._node)
- self.tf_listener = TransformListener(self._tf_buffer, self._node)
+ self._tf_listener = TransformListener(self._tf_buffer, self._node)
self._executor = MultiThreadedExecutor()
self._executor.add_node(self._node)
@@ -173,7 +226,7 @@ def node(self) -> Node:
return self._node
def shutdown(self):
- self.tf_listener.unregister()
+ self._tf_listener.unregister()
self._node.destroy_node()
self._actions_api.shutdown()
self._topic_api.shutdown()
diff --git a/src/rai_core/rai/communication/ros2/connectors/hri_connector.py b/src/rai_core/rai/communication/ros2/connectors/hri_connector.py
index 8e76d9c8b..94b3c666e 100644
--- a/src/rai_core/rai/communication/ros2/connectors/hri_connector.py
+++ b/src/rai_core/rai/communication/ros2/connectors/hri_connector.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
import threading
import uuid
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
@@ -19,7 +20,6 @@
from rclpy.executors import MultiThreadedExecutor
from rclpy.node import Node
-import rai_interfaces.msg
from rai.communication import HRIConnector
from rai.communication.ros2.api import (
ConfigurableROS2TopicAPI,
@@ -31,6 +31,11 @@
from rai.communication.ros2.connectors.service_mixin import ROS2ServiceMixin
from rai.communication.ros2.messages import ROS2HRIMessage
+try:
+ import rai_interfaces.msg
+except ImportError:
+ logging.warning("rai_interfaces is not installed, ROS 2 HRIMessage will not work.")
+
class ROS2HRIConnector(ROS2ActionMixin, ROS2ServiceMixin, HRIConnector[ROS2HRIMessage]):
def __init__(
diff --git a/src/rai_core/rai/tools/ros2/__init__.py b/src/rai_core/rai/tools/ros2/__init__.py
index 23b3ffc8f..dbb6297bd 100644
--- a/src/rai_core/rai/tools/ros2/__init__.py
+++ b/src/rai_core/rai/tools/ros2/__init__.py
@@ -14,6 +14,9 @@
from .actions import (
CancelROS2ActionTool,
+ GetROS2ActionFeedbackTool,
+ GetROS2ActionIDsTool,
+ GetROS2ActionResultTool,
GetROS2ActionsNamesAndTypesTool,
ROS2ActionToolkit,
StartROS2ActionTool,
@@ -37,6 +40,9 @@
__all__ = [
"CallROS2ServiceTool",
"CancelROS2ActionTool",
+ "GetROS2ActionFeedbackTool",
+ "GetROS2ActionIDsTool",
+ "GetROS2ActionResultTool",
"GetROS2ActionsNamesAndTypesTool",
"GetROS2ImageTool",
"GetROS2MessageInterfaceTool",
diff --git a/src/rai_core/rai/tools/ros2/actions.py b/src/rai_core/rai/tools/ros2/actions.py
index b6a19be22..a920197a6 100644
--- a/src/rai_core/rai/tools/ros2/actions.py
+++ b/src/rai_core/rai/tools/ros2/actions.py
@@ -29,7 +29,7 @@
from langchain_core.utils import stringify_dict
from pydantic import BaseModel, Field
-from rai.communication.ros2 import ROS2ARIConnector, ROS2ARIMessage
+from rai.communication.ros2 import ROS2ARIMessage
from rai.tools.ros2.base import BaseROS2Tool, BaseROS2Toolkit
internal_action_id_mapping: Dict[str, str] = {}
@@ -161,7 +161,6 @@ class StartROS2ActionToolInput(BaseModel):
class StartROS2ActionTool(BaseROS2Tool):
- connector: ROS2ARIConnector
feedback_callback: Callable[[Any, str], None] = lambda _, __: None
on_done_callback: Callable[[Any, str], None] = lambda _, __: None
internal_action_id_mapping: Dict[str, str] = Field(
diff --git a/src/rai_core/rai/tools/ros2/topics.py b/src/rai_core/rai/tools/ros2/topics.py
index 2fb2839e3..846f7b159 100644
--- a/src/rai_core/rai/tools/ros2/topics.py
+++ b/src/rai_core/rai/tools/ros2/topics.py
@@ -34,7 +34,7 @@
from rai.messages.multimodal import MultimodalArtifact
from rai.messages.utils import preprocess_image
from rai.tools.ros2.base import BaseROS2Tool, BaseROS2Toolkit
-from rai.tools.ros2.utils import ros2_message_to_dict
+from rai.tools.ros2.utils import render_interface_string, ros2_message_to_dict
class ROS2TopicsToolkit(BaseROS2Toolkit):
@@ -225,9 +225,8 @@ def _run(self, msg_type: str) -> str:
"""Show ros2 message interface in json format."""
msg_cls: Type[object] = rosidl_runtime_py.utilities.get_interface(msg_type)
try:
- msg_dict = ros2_message_to_dict(msg_cls()) # type: ignore
- return json.dumps(msg_dict)
- except NotImplementedError:
+ return render_interface_string(msg_type)
+ except (ValueError, LookupError, NotImplementedError):
# For action classes that can't be instantiated
goal_dict = ros2_message_to_dict(msg_cls.Goal()) # type: ignore
diff --git a/src/rai_core/rai/tools/ros2/utils.py b/src/rai_core/rai/tools/ros2/utils.py
index 04679ed61..95dd961af 100644
--- a/src/rai_core/rai/tools/ros2/utils.py
+++ b/src/rai_core/rai/tools/ros2/utils.py
@@ -12,11 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
+import typing
from typing import Any, OrderedDict
+import rosidl_adapter
+import rosidl_adapter.parser
import rosidl_runtime_py.convert
import rosidl_runtime_py.set_message
import rosidl_runtime_py.utilities
+from rosidl_adapter.parser import (
+ ACTION_REQUEST_RESPONSE_SEPARATOR,
+ SERVICE_REQUEST_RESPONSE_SEPARATOR,
+ Constant,
+ MessageSpecification,
+ parse_message_string,
+)
def ros2_message_to_dict(message: Any) -> OrderedDict[str, Any]:
@@ -35,3 +46,140 @@ def ros2_message_to_dict(message: Any) -> OrderedDict[str, Any]:
message
) # type: ignore
return msg_dict
+
+
+class InterfaceTextLine:
+ """A convenience class for a single text line in an interface file."""
+
+ def __init__(
+ self,
+ pkg_name: str,
+ msg_name: str,
+ line_text: str,
+ ):
+ if line_text in (
+ SERVICE_REQUEST_RESPONSE_SEPARATOR,
+ ACTION_REQUEST_RESPONSE_SEPARATOR,
+ ):
+ msg_spec = None
+ else:
+ msg_spec = parse_message_string(
+ pkg_name=pkg_name,
+ msg_name=msg_name,
+ message_string=line_text,
+ )
+ if len(msg_spec.fields) > 1: # type: ignore
+ raise ValueError("'line_text' must be only one line")
+ self._msg_spec: MessageSpecification | None = msg_spec
+ self._raw_line_text = line_text
+
+ def __str__(self) -> str:
+ return self._raw_line_text
+
+ def is_comment(self) -> bool:
+ return bool(self._msg_spec) and self._msg_spec.annotations["comment"] # type: ignore
+
+ def is_trailing_comment(self) -> bool:
+ return self._is_field_trailing_comment() or self._is_constant_trailing_comment()
+
+ def _is_field_trailing_comment(self) -> bool:
+ return self._field and self._field.annotations["comment"] # type: ignore
+
+ def _is_constant_trailing_comment(self) -> bool:
+ return self._constant and self._constant.annotations["comment"] # type: ignore
+
+ @property
+ def nested_type(self) -> typing.Optional[str]:
+ if self._field and self._is_nested():
+ interface_type: str = str(self._field.type)
+ if self._field.type.is_array: # type: ignore
+ interface_type = interface_type[: interface_type.find("[")]
+ return interface_type.replace("/", "/msg/")
+
+ @property
+ def trailing_comment(self) -> typing.Optional[str]:
+ if self._is_field_trailing_comment():
+ return self._field.annotations["comment"][0] # type: ignore
+ elif self._is_constant_trailing_comment():
+ return self._constant.annotations["comment"][0] # type: ignore
+ else:
+ return None
+
+ @property
+ def _field(self) -> rosidl_adapter.parser.Field | None:
+ if self._msg_spec and self._msg_spec.fields: # type: ignore
+ return self._msg_spec.fields[0] # type: ignore
+
+ @property
+ def _constant(self) -> Constant | None:
+ if self._msg_spec and self._msg_spec.constants: # type: ignore
+ return self._msg_spec.constants[0] # type: ignore
+
+ def _is_nested(self) -> bool:
+ if self._msg_spec and self._msg_spec.fields and self._field: # type: ignore
+ return "/" in str(self._field.type)
+ else:
+ return False
+
+
+def _get_interface_lines(
+ interface_identifier: str,
+) -> typing.Iterable[InterfaceTextLine]:
+ parts: typing.List[str] = interface_identifier.split("/")
+ if len(parts) != 3:
+ raise ValueError(
+ f"Invalid name '{interface_identifier}'. Expected three parts separated by '/'"
+ )
+ pkg_name, _, msg_name = parts
+
+ file_path = rosidl_runtime_py.get_interface_path(interface_identifier)
+ with open(file_path) as file_handler:
+ for line in file_handler:
+ yield InterfaceTextLine(
+ pkg_name=pkg_name,
+ msg_name=msg_name,
+ line_text=line.rstrip(),
+ )
+
+
+def _render_interface_line(
+ line: InterfaceTextLine, is_show_comments: bool, indent_level: int
+) -> str:
+ text = str(line)
+ if not is_show_comments:
+ if not text or line.is_comment():
+ return ""
+ elif line.is_trailing_comment():
+ if line.trailing_comment:
+ comment_start_idx = text.find(line.trailing_comment)
+ text = text[: comment_start_idx - 1].strip()
+ if text:
+ indent_string = indent_level * "\t"
+ return f"{indent_string}{text}"
+ return ""
+
+
+def render_interface_string(
+ interface_identifier: str,
+ is_show_comments: bool = True,
+ is_show_nested_comments: bool = False,
+ indent_level: int = 0,
+) -> str:
+ lines: typing.List[str] = []
+ for line in _get_interface_lines(interface_identifier):
+ rendered = _render_interface_line(
+ line, is_show_comments=is_show_comments, indent_level=indent_level
+ )
+ if rendered.strip():
+ lines.append(rendered)
+ if line.nested_type:
+ nested_rendered = render_interface_string(
+ line.nested_type,
+ is_show_comments=is_show_nested_comments,
+ is_show_nested_comments=is_show_nested_comments,
+ indent_level=indent_level + 1,
+ )
+ if nested_rendered.strip():
+ lines.append(nested_rendered)
+
+ return "\n".join(lines)