Skip to content

Commit d9320cf

Browse files
authored
Enable subclassing of RoutingCost from python (#384)
1 parent 820436e commit d9320cf

File tree

2 files changed

+89
-8
lines changed

2 files changed

+89
-8
lines changed

lanelet2_examples/scripts/tutorial.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from lanelet2.core import (AllWayStop, AttributeMap, BasicPoint2d,
77
BoundingBox2d, Lanelet, LaneletMap,
88
LaneletWithStopLine, LineString3d, Point2d, Point3d,
9-
RightOfWay, TrafficLight, getId)
9+
RightOfWay, TrafficLight, getId, createMapFromLanelets)
1010
from lanelet2.projection import (UtmProjector, MercatorProjector,
1111
LocalCartesianProjector, GeocentricProjector)
1212

@@ -150,7 +150,7 @@ def part3lanelet_map():
150150
assert len(map.pointLayer.search(searchBox)) > 1
151151

152152
# you can also create a map from a list of primitives (replace Lanelets by the other types)
153-
mapBulk = lanelet2.core.createMapFromLanelets([get_a_lanelet()])
153+
mapBulk = createMapFromLanelets([get_a_lanelet()])
154154
assert len(mapBulk.laneletLayer) == 1
155155

156156

@@ -237,6 +237,17 @@ def part6routing():
237237
# for more complex queries, you can use the forEachSuccessor function and pass it a function object
238238
assert hasPathFromTo(graph, lanelet, toLanelet)
239239

240+
# it is also possible to create custom routing costs to influence the routing.
241+
# Note that this will be much slower than the costs implemented in C++, but it
242+
# is useful for prototyping and testing.
243+
class ConstantCost(lanelet2.routing.RoutingCost):
244+
def getCostSucceeding(self, rules, from_lanelet, to_lanelet):
245+
return 1
246+
247+
def getCostLaneChange(self, rules, from_lanelet, to_lanelet):
248+
return 1
249+
graph = lanelet2.routing.RoutingGraph(map, traffic_rules, [ConstantCost()])
250+
240251

241252
def hasPathFromTo(graph: lanelet2.routing.RoutingGraph, start: lanelet2.core.Lanelet, target: lanelet2.core.Lanelet):
242253
class TargetFound(BaseException):

lanelet2_python/python_api/routing.cpp

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,54 @@ Optional<T> objectToOptional(const object& o) {
5252
return o == object() ? Optional<T>{} : Optional<T>{extract<T>(o)()};
5353
}
5454

55+
class RoutingCostBaseWrapper : public lanelet::routing::RoutingCost, public wrapper<lanelet::routing::RoutingCost> {
56+
public:
57+
double getCostSucceeding(const traffic_rules::TrafficRules& trafficRules, const ConstLaneletOrArea& from,
58+
const ConstLaneletOrArea& to) const noexcept override {
59+
return this->get_override("getCostSucceeding")(boost::ref(trafficRules), from, to);
60+
}
61+
62+
double getCostLaneChange(const traffic_rules::TrafficRules& trafficRules, const ConstLanelets& from,
63+
const ConstLanelets& to) const noexcept override {
64+
return this->get_override("getCostLaneChange")(boost::ref(trafficRules), from, to);
65+
}
66+
};
67+
68+
template <typename BaseT>
69+
class RoutingCostWrapper : public BaseT, public wrapper<BaseT> {
70+
public:
71+
RoutingCostWrapper(const BaseT& base) : BaseT(base) {}
72+
RoutingCostWrapper(double laneChangeCost, double minLaneChange) : BaseT(laneChangeCost, minLaneChange) {}
73+
74+
double getCostSucceeding(const traffic_rules::TrafficRules& trafficRules, const ConstLaneletOrArea& from,
75+
const ConstLaneletOrArea& to) const noexcept override {
76+
const auto o = this->get_override("getCostSucceeding");
77+
if (o) {
78+
return o(boost::ref(trafficRules), from, to);
79+
}
80+
return BaseT::getCostSucceeding(trafficRules, from, to);
81+
}
82+
83+
double defaultGetCostSucceeding(const traffic_rules::TrafficRules& trafficRules, const ConstLaneletOrArea& from,
84+
const ConstLaneletOrArea& to) const {
85+
return BaseT::getCostSucceeding(trafficRules, from, to);
86+
}
87+
88+
double getCostLaneChange(const traffic_rules::TrafficRules& trafficRules, const ConstLanelets& from,
89+
const ConstLanelets& to) const noexcept override {
90+
const auto o = this->get_override("getCostLaneChange");
91+
if (o) {
92+
return o(boost::ref(trafficRules), from, to);
93+
}
94+
return BaseT::getCostLaneChange(trafficRules, from, to);
95+
}
96+
97+
double defaultGetCostLaneChange(const traffic_rules::TrafficRules& trafficRules, const ConstLanelets& from,
98+
const ConstLanelets& to) const {
99+
return BaseT::getCostLaneChange(trafficRules, from, to);
100+
}
101+
};
102+
55103
BOOST_PYTHON_MODULE(PYTHON_API_MODULE_NAME) { // NOLINT
56104
auto trafficRules = import("lanelet2.traffic_rules");
57105
using namespace lanelet::routing;
@@ -77,16 +125,38 @@ BOOST_PYTHON_MODULE(PYTHON_API_MODULE_NAME) { // NOLINT
77125
implicitly_convertible<LaneletMapPtr, LaneletMapConstPtr>();
78126

79127
// Routing costs
80-
class_<RoutingCost, boost::noncopyable, std::shared_ptr<RoutingCost>>( // NOLINT
81-
"RoutingCost", "Object for calculating routing costs between lanelets", no_init);
128+
class_<RoutingCostBaseWrapper, boost::noncopyable>( // NOLINT
129+
"RoutingCost", "Object for calculating routing costs between lanelets")
130+
.def("getCostSucceeding", pure_virtual(&RoutingCost::getCostSucceeding),
131+
"Get the cost of the transition from one to another lanelet", (arg("trafficRules"), arg("from"), arg("to")))
132+
.def("getCostLaneChange", pure_virtual(&RoutingCost::getCostLaneChange),
133+
"Get the cost of the lane change between two adjacent lanelets",
134+
(arg("trafficRules"), arg("from"), arg("to")));
135+
register_ptr_to_python<std::shared_ptr<RoutingCost>>();
82136

83-
class_<RoutingCostDistance, bases<RoutingCost>, std::shared_ptr<RoutingCostDistance>>( // NOLINT
137+
class_<RoutingCostWrapper<RoutingCostDistance>, bases<RoutingCost>>( // NOLINT
84138
"RoutingCostDistance", "Distance based routing cost calculation object",
85-
init<double, double>((arg("laneChangeCost"), arg("minLaneChangeDistance") = 0)));
139+
init<double, double>((arg("laneChangeCost"), arg("minLaneChangeDistance") = 0)))
140+
.def("getCostSucceeding", &RoutingCostDistance::getCostSucceeding,
141+
&RoutingCostWrapper<RoutingCostDistance>::defaultGetCostSucceeding,
142+
"Get the cost of the transition from one to another lanelet", (arg("trafficRules"), arg("from"), arg("to")))
143+
.def("getCostLaneChange", &RoutingCostDistance::getCostLaneChange,
144+
&RoutingCostWrapper<RoutingCostDistance>::defaultGetCostLaneChange,
145+
"Get the cost of the lane change between two adjacent lanelets",
146+
(arg("trafficRules"), arg("from"), arg("to")));
147+
register_ptr_to_python<std::shared_ptr<RoutingCostDistance>>();
86148

87-
class_<RoutingCostTravelTime, bases<RoutingCost>, std::shared_ptr<RoutingCostTravelTime>>( // NOLINT
149+
class_<RoutingCostWrapper<RoutingCostTravelTime>, bases<RoutingCost>>( // NOLINT
88150
"RoutingCostTravelTime", "Travel time based routing cost calculation object",
89-
init<double, double>((arg("laneChangeCost"), arg("minLaneChangeTime") = 0)));
151+
init<double, double>((arg("laneChangeCost"), arg("minLaneChangeTime") = 0)))
152+
.def("getCostSucceeding", &RoutingCostTravelTime::getCostSucceeding,
153+
&RoutingCostWrapper<RoutingCostTravelTime>::defaultGetCostSucceeding,
154+
"Get the cost of the transition from one to another lanelet", (arg("trafficRules"), arg("from"), arg("to")))
155+
.def("getCostLaneChange", &RoutingCostTravelTime::getCostLaneChange,
156+
&RoutingCostWrapper<RoutingCostTravelTime>::defaultGetCostLaneChange,
157+
"Get the cost of the lane change between two adjacent lanelets",
158+
(arg("trafficRules"), arg("from"), arg("to")));
159+
register_ptr_to_python<std::shared_ptr<RoutingCostTravelTime>>();
90160

91161
auto possPCost = static_cast<LaneletPaths (RoutingGraph::*)(const ConstLanelet&, double, RoutingCostId, bool) const>(
92162
&RoutingGraph::possiblePaths);

0 commit comments

Comments
 (0)