7
7
#ifndef __CONSTRAINT_H__
8
8
#define __CONSTRAINT_H__
9
9
10
+ #include < shogun/io/SGIO.h>
11
+ #include < shogun/util/traits.h>
12
+
10
13
#include < string>
11
14
#include < tuple>
12
15
13
16
namespace shogun
14
17
{
18
+ class SGObject ;
19
+
15
20
namespace constraint_detail
16
21
{
17
22
template <typename T, typename ... Args, std::size_t ... Idx>
@@ -86,18 +91,80 @@ namespace shogun
86
91
template <typename T>
87
92
struct generic_checker
88
93
{
89
- public:
90
- generic_checker (T val) : m_val(val){};
91
- bool operator ()(T val) const
94
+ generic_checker () = default ;
95
+ bool operator ()(const T& val) const
92
96
{
93
97
return check (val);
94
98
};
95
99
96
100
virtual std::string error_msg () const = 0;
97
101
102
+ protected:
103
+ virtual bool check (const T& val) const = 0;
104
+ };
105
+
106
+ template <typename T>
107
+ struct custom_constraint : generic_checker<T>
108
+ {
109
+ template <typename Functor>
110
+ custom_constraint (Functor&& func): m_func(func)
111
+ {
112
+ }
113
+
114
+ std::string error_msg () const override
115
+ {
116
+ return msg;
117
+ }
118
+
119
+ protected:
120
+ bool check (const T& val) const override
121
+ {
122
+ try
123
+ {
124
+ m_func (val);
125
+ }
126
+ catch (const std::exception& e)
127
+ {
128
+ msg = std::string (e.what ());
129
+ return false ;
130
+ }
131
+
132
+ return true ;
133
+ }
134
+
135
+ private:
136
+ std::string msg;
137
+ std::function<void (const T&)> m_func;
138
+ };
139
+
140
+
141
+ template <typename DerivedType>
142
+ struct castable : generic_checker<std::shared_ptr<SGObject>>
143
+ {
144
+ castable (): generic_checker<std::shared_ptr<SGObject>>()
145
+ {
146
+ }
147
+
148
+ std::string error_msg () const override
149
+ {
150
+ return " of type " + demangled_type<DerivedType>();
151
+ }
152
+
153
+ protected:
154
+ bool check (const std::shared_ptr<SGObject>& ptr) const override
155
+ {
156
+ return static_cast <bool >(std::dynamic_pointer_cast<DerivedType>(ptr));
157
+ }
158
+ };
159
+
160
+ template <typename T>
161
+ struct comparisson_checker : generic_checker<T>
162
+ {
163
+ comparisson_checker (T val): generic_checker<T>() {
164
+ m_val = val;
165
+ }
98
166
protected:
99
167
T m_val;
100
- virtual bool check (T val) const = 0;
101
168
};
102
169
103
170
/* *
@@ -106,18 +173,17 @@ namespace shogun
106
173
* @tparam T the type of val
107
174
*/
108
175
template <typename T>
109
- struct less_than : generic_checker <T>
176
+ struct less_than : comparisson_checker <T>
110
177
{
111
- public:
112
- less_than (T val) : generic_checker<T>(val){};
178
+ less_than (T val) : comparisson_checker<T>(val){};
113
179
114
180
std::string error_msg () const override
115
181
{
116
182
return " less than " + std::to_string (this ->m_val );
117
183
}
118
184
119
185
protected:
120
- bool check (T val) const override
186
+ bool check (const T& val) const override
121
187
{
122
188
return val < this ->m_val ;
123
189
}
@@ -129,18 +195,17 @@ namespace shogun
129
195
* @tparam T the type of val
130
196
*/
131
197
template <typename T>
132
- struct less_than_or_equal : generic_checker <T>
198
+ struct less_than_or_equal : comparisson_checker <T>
133
199
{
134
- public:
135
- less_than_or_equal (T val) : generic_checker<T>(val){};
200
+ less_than_or_equal (T val) : comparisson_checker<T>(val){};
136
201
137
202
std::string error_msg () const override
138
203
{
139
204
return " less than " + std::to_string (this ->m_val );
140
205
}
141
206
142
207
protected:
143
- bool check (T val) const override
208
+ bool check (const T& val) const override
144
209
{
145
210
return val <= this ->m_val ;
146
211
}
@@ -152,17 +217,16 @@ namespace shogun
152
217
* @tparam T the type of val
153
218
*/
154
219
template <typename T>
155
- struct greater_than : generic_checker <T>
220
+ struct greater_than : comparisson_checker <T>
156
221
{
157
- public:
158
- greater_than (T val) : generic_checker<T>(val){};
222
+ greater_than (T val) : comparisson_checker<T>(val){};
159
223
std::string error_msg () const override
160
224
{
161
225
return " greater than " + std::to_string (this ->m_val );
162
226
}
163
227
164
228
protected:
165
- bool check (T val) const override
229
+ bool check (const T& val) const override
166
230
{
167
231
return val > this ->m_val ;
168
232
}
@@ -174,18 +238,17 @@ namespace shogun
174
238
* @tparam T the type of val
175
239
*/
176
240
template <typename T>
177
- struct greater_than_or_equal : generic_checker <T>
241
+ struct greater_than_or_equal : comparisson_checker <T>
178
242
{
179
- public:
180
- greater_than_or_equal (T val) : generic_checker<T>(val){};
243
+ greater_than_or_equal (T val) : comparisson_checker<T>(val){};
181
244
182
245
std::string error_msg () const override
183
246
{
184
247
return " less than " + std::to_string (this ->m_val );
185
248
}
186
249
187
250
protected:
188
- bool check (T val) const override
251
+ bool check (const T& val) const override
189
252
{
190
253
return val >= this ->m_val ;
191
254
}
@@ -237,14 +300,13 @@ namespace shogun
237
300
}
238
301
239
302
template <typename T>
240
- bool run (T val, std::string& buffer ) const
303
+ std::optional<std:: string> check ( const T& val ) const
241
304
{
242
305
if (!constraint_detail::apply (val, m_funcs))
243
306
{
244
- buffer = constraint_detail::get_error (m_funcs);
245
- return false ;
307
+ return constraint_detail::get_error (m_funcs);
246
308
}
247
- return true ;
309
+ return std::nullopt ;
248
310
}
249
311
250
312
private:
0 commit comments