Skip to content

Commit f143eaf

Browse files
authored
Dynamic cast constraint (#5082)
* dynamic cast constraint * cleanup HistogramWordStringKernel
1 parent a075792 commit f143eaf

File tree

5 files changed

+118
-117
lines changed

5 files changed

+118
-117
lines changed

src/shogun/base/AnyParameter.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <list>
1515
#include <memory>
1616
#include <string_view>
17+
#include <optional>
1718

1819
namespace shogun
1920
{
@@ -168,7 +169,7 @@ namespace shogun
168169
}
169170
AnyParameter(
170171
Any&& value, const AnyParameterProperties& properties,
171-
std::function<std::string(Any)> constrain_function)
172+
std::function<std::optional<std::string>(Any)> constrain_function)
172173
: m_value(std::move(value)), m_properties(properties),
173174
m_constrain_function(std::move(constrain_function))
174175
{
@@ -218,7 +219,7 @@ namespace shogun
218219
return m_init_function;
219220
}
220221

221-
const std::function<std::string(Any)>& get_constrain_function() const
222+
const std::function<std::optional<std::string>(Any)>& get_constrain_function() const
222223
noexcept
223224
{
224225
return m_constrain_function;
@@ -272,7 +273,7 @@ namespace shogun
272273
Any m_value;
273274
AnyParameterProperties m_properties;
274275
std::shared_ptr<params::AutoInit> m_init_function;
275-
std::function<std::string(Any)> m_constrain_function;
276+
std::function<std::optional<std::string>(Any)> m_constrain_function;
276277
std::vector<std::function<void()>> m_callback_functions;
277278
};
278279
} // namespace shogun

src/shogun/base/SGObject.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -921,10 +921,8 @@ class SGObject: public std::enable_shared_from_this<SGObject>
921921
BaseTag(name), AnyParameter(
922922
make_any_ref(value), properties,
923923
[constrain_function](const auto& val) {
924-
std::string result;
925924
auto casted_val = any_cast<T1>(val);
926-
constrain_function.run(casted_val, result);
927-
return result;
925+
return constrain_function.check(casted_val);
928926
}));
929927
register_parameter_visitor<T1>();
930928
}
@@ -1131,12 +1129,12 @@ class SGObject: public std::enable_shared_from_this<SGObject>
11311129

11321130
if (pprop.has_property(ParameterProperties::CONSTRAIN))
11331131
{
1134-
auto msg = param.get_constrain_function()(make_any(value));
1135-
if (!msg.empty())
1132+
const auto& val = param.get_constrain_function()(make_any(value));
1133+
if (val)
11361134
{
11371135
require(!do_checks,
11381136
"{}::{} cannot be updated because it must be: {}!",
1139-
get_name(), _tag.name().c_str(), msg.c_str());
1137+
get_name(), _tag.name().c_str(), *val);
11401138
}
11411139
}
11421140
if constexpr (std::is_same_v<T, Any>)

src/shogun/base/constraint.h

Lines changed: 86 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,16 @@
77
#ifndef __CONSTRAINT_H__
88
#define __CONSTRAINT_H__
99

10+
#include <shogun/io/SGIO.h>
11+
#include <shogun/util/traits.h>
12+
1013
#include <string>
1114
#include <tuple>
1215

1316
namespace shogun
1417
{
18+
class SGObject;
19+
1520
namespace constraint_detail
1621
{
1722
template <typename T, typename... Args, std::size_t... Idx>
@@ -86,18 +91,80 @@ namespace shogun
8691
template <typename T>
8792
struct generic_checker
8893
{
89-
public:
90-
generic_checker(T val) : m_val(val){};
91-
bool operator()(T val) const
94+
generic_checker() = default;
95+
bool operator()(const T& val) const
9296
{
9397
return check(val);
9498
};
9599

96100
virtual std::string error_msg() const = 0;
97101

102+
protected:
103+
virtual bool check(const T& val) const = 0;
104+
};
105+
106+
template <typename T>
107+
struct custom_constraint: generic_checker<T>
108+
{
109+
template <typename Functor>
110+
custom_constraint(Functor&& func): m_func(func)
111+
{
112+
}
113+
114+
std::string error_msg() const override
115+
{
116+
return msg;
117+
}
118+
119+
protected:
120+
bool check(const T& val) const override
121+
{
122+
try
123+
{
124+
m_func(val);
125+
}
126+
catch (const std::exception& e)
127+
{
128+
msg = std::string(e.what());
129+
return false;
130+
}
131+
132+
return true;
133+
}
134+
135+
private:
136+
std::string msg;
137+
std::function<void(const T&)> m_func;
138+
};
139+
140+
141+
template <typename DerivedType>
142+
struct castable: generic_checker<std::shared_ptr<SGObject>>
143+
{
144+
castable(): generic_checker<std::shared_ptr<SGObject>>()
145+
{
146+
}
147+
148+
std::string error_msg() const override
149+
{
150+
return "of type " + demangled_type<DerivedType>();
151+
}
152+
153+
protected:
154+
bool check(const std::shared_ptr<SGObject>& ptr) const override
155+
{
156+
return static_cast<bool>(std::dynamic_pointer_cast<DerivedType>(ptr));
157+
}
158+
};
159+
160+
template <typename T>
161+
struct comparisson_checker: generic_checker<T>
162+
{
163+
comparisson_checker(T val): generic_checker<T>() {
164+
m_val = val;
165+
}
98166
protected:
99167
T m_val;
100-
virtual bool check(T val) const = 0;
101168
};
102169

103170
/**
@@ -106,18 +173,17 @@ namespace shogun
106173
* @tparam T the type of val
107174
*/
108175
template <typename T>
109-
struct less_than : generic_checker<T>
176+
struct less_than : comparisson_checker<T>
110177
{
111-
public:
112-
less_than(T val) : generic_checker<T>(val){};
178+
less_than(T val) : comparisson_checker<T>(val){};
113179

114180
std::string error_msg() const override
115181
{
116182
return "less than " + std::to_string(this->m_val);
117183
}
118184

119185
protected:
120-
bool check(T val) const override
186+
bool check(const T& val) const override
121187
{
122188
return val < this->m_val;
123189
}
@@ -129,18 +195,17 @@ namespace shogun
129195
* @tparam T the type of val
130196
*/
131197
template <typename T>
132-
struct less_than_or_equal : generic_checker<T>
198+
struct less_than_or_equal : comparisson_checker<T>
133199
{
134-
public:
135-
less_than_or_equal(T val) : generic_checker<T>(val){};
200+
less_than_or_equal(T val) : comparisson_checker<T>(val){};
136201

137202
std::string error_msg() const override
138203
{
139204
return "less than " + std::to_string(this->m_val);
140205
}
141206

142207
protected:
143-
bool check(T val) const override
208+
bool check(const T& val) const override
144209
{
145210
return val <= this->m_val;
146211
}
@@ -152,17 +217,16 @@ namespace shogun
152217
* @tparam T the type of val
153218
*/
154219
template <typename T>
155-
struct greater_than : generic_checker<T>
220+
struct greater_than : comparisson_checker<T>
156221
{
157-
public:
158-
greater_than(T val) : generic_checker<T>(val){};
222+
greater_than(T val) : comparisson_checker<T>(val){};
159223
std::string error_msg() const override
160224
{
161225
return "greater than " + std::to_string(this->m_val);
162226
}
163227

164228
protected:
165-
bool check(T val) const override
229+
bool check(const T& val) const override
166230
{
167231
return val > this->m_val;
168232
}
@@ -174,18 +238,17 @@ namespace shogun
174238
* @tparam T the type of val
175239
*/
176240
template <typename T>
177-
struct greater_than_or_equal : generic_checker<T>
241+
struct greater_than_or_equal : comparisson_checker<T>
178242
{
179-
public:
180-
greater_than_or_equal(T val) : generic_checker<T>(val){};
243+
greater_than_or_equal(T val) : comparisson_checker<T>(val){};
181244

182245
std::string error_msg() const override
183246
{
184247
return "less than " + std::to_string(this->m_val);
185248
}
186249

187250
protected:
188-
bool check(T val) const override
251+
bool check(const T& val) const override
189252
{
190253
return val >= this->m_val;
191254
}
@@ -237,14 +300,13 @@ namespace shogun
237300
}
238301

239302
template <typename T>
240-
bool run(T val, std::string& buffer) const
303+
std::optional<std::string> check(const T& val) const
241304
{
242305
if (!constraint_detail::apply(val, m_funcs))
243306
{
244-
buffer = constraint_detail::get_error(m_funcs);
245-
return false;
307+
return constraint_detail::get_error(m_funcs);
246308
}
247-
return true;
309+
return std::nullopt;
248310
}
249311

250312
private:

0 commit comments

Comments
 (0)