Skip to content

Commit cf9ffbc

Browse files
Gilvigsterkr
Gil
authored andcommitted
AnyParameterProperties refactor (shogun-toolbox#4412)
* combined parameter flags in a single mask * added mask_attribute to keep track of parameter availabilities * updated constructors and getters * added some documentation * changed properties default * typesafe bitmasking * refactored code to use enum class and bitmask operators
1 parent e2f4f99 commit cf9ffbc

File tree

3 files changed

+190
-20
lines changed

3 files changed

+190
-20
lines changed

src/shogun/base/AnyParameter.h

Lines changed: 73 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1+
/*
2+
* This software is distributed under BSD 3-clause license (see LICENSE file).
3+
*
4+
* Authors: Heiko Strathmann, Gil Hoben
5+
*/
6+
17
#ifndef __ANYPARAMETER_H__
28
#define __ANYPARAMETER_H__
39

410
#include <shogun/lib/any.h>
11+
#include <shogun/lib/bitmask_operators.h>
512

613
#include <string>
714

@@ -22,48 +29,101 @@ namespace shogun
2229
GRADIENT_AVAILABLE = 1
2330
};
2431

32+
/** parameter properties */
33+
enum class ParameterProperties
34+
{
35+
HYPER = 1u << 0,
36+
GRADIENT = 1u << 1,
37+
MODEL = 1u << 2
38+
};
39+
40+
enableEnumClassBitmask(ParameterProperties);
41+
42+
/** @brief Class AnyParameterProperties keeps track of of parameter meta
43+
* information, such as properties and descriptions The parameter properties
44+
* can be either true or false. These properties describe if a parameter is
45+
* for example a hyperparameter or if it has a gradient.
46+
*/
2547
class AnyParameterProperties
2648
{
2749
public:
50+
/** Default constructor where all parameter properties are false
51+
*/
2852
AnyParameterProperties()
29-
: m_description(), m_model_selection(MS_NOT_AVAILABLE),
30-
m_gradient(GRADIENT_NOT_AVAILABLE)
53+
: m_description("No description given"),
54+
m_attribute_mask(ParameterProperties())
3155
{
3256
}
57+
/** Constructor
58+
* @param description parameter description
59+
* @param hyperparameter set to true for parameters that determine
60+
* how training is performed, e.g. regularisation parameters
61+
* @param gradient set to true for parameters required for gradient
62+
* updates
63+
* @param model set to true for parameters used in inference, e.g.
64+
* weights and bias
65+
* */
3366
AnyParameterProperties(
3467
std::string description,
35-
EModelSelectionAvailability model_selection = MS_NOT_AVAILABLE,
36-
EGradientAvailability gradient = GRADIENT_NOT_AVAILABLE)
37-
: m_description(description), m_model_selection(model_selection),
68+
EModelSelectionAvailability hyperparameter = MS_NOT_AVAILABLE,
69+
EGradientAvailability gradient = GRADIENT_NOT_AVAILABLE,
70+
bool model = false)
71+
: m_description(description), m_model_selection(hyperparameter),
3872
m_gradient(gradient)
3973
{
74+
m_attribute_mask = ParameterProperties();
75+
if (hyperparameter)
76+
m_attribute_mask |= ParameterProperties::HYPER;
77+
if (gradient)
78+
m_attribute_mask |= ParameterProperties::GRADIENT;
79+
if (model)
80+
m_attribute_mask |= ParameterProperties::MODEL;
4081
}
82+
/** Mask constructor
83+
* @param description parameter description
84+
* @param attribute_mask mask encoding parameter properties
85+
* */
86+
AnyParameterProperties(
87+
std::string description, ParameterProperties attribute_mask)
88+
: m_description(description)
89+
{
90+
m_attribute_mask = attribute_mask;
91+
}
92+
/** Copy contructor */
4193
AnyParameterProperties(const AnyParameterProperties& other)
4294
: m_description(other.m_description),
4395
m_model_selection(other.m_model_selection),
44-
m_gradient(other.m_gradient)
96+
m_gradient(other.m_gradient),
97+
m_attribute_mask(other.m_attribute_mask)
4598
{
4699
}
47-
48-
std::string get_description() const
100+
const std::string& get_description() const
49101
{
50102
return m_description;
51103
}
52-
53104
EModelSelectionAvailability get_model_selection() const
54105
{
55-
return m_model_selection;
106+
return static_cast<EModelSelectionAvailability>(
107+
static_cast<int32_t>(
108+
m_attribute_mask & ParameterProperties::HYPER) > 0);
56109
}
57-
58110
EGradientAvailability get_gradient() const
59111
{
60-
return m_gradient;
112+
return static_cast<EGradientAvailability>(
113+
static_cast<int32_t>(
114+
m_attribute_mask & ParameterProperties::GRADIENT) > 0);
115+
}
116+
bool get_model() const
117+
{
118+
return static_cast<bool>(
119+
m_attribute_mask & ParameterProperties::MODEL);
61120
}
62121

63122
private:
64123
std::string m_description;
65124
EModelSelectionAvailability m_model_selection;
66125
EGradientAvailability m_gradient;
126+
ParameterProperties m_attribute_mask;
67127
};
68128

69129
class AnyParameter
@@ -116,6 +176,6 @@ namespace shogun
116176
Any m_value;
117177
AnyParameterProperties m_properties;
118178
};
119-
}
179+
} // namespace shogun
120180

121181
#endif

src/shogun/base/SGObject.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -675,8 +675,7 @@ class CSGObject
675675
template <typename T>
676676
void watch_param(
677677
const std::string& name, T* value,
678-
AnyParameterProperties properties = AnyParameterProperties(
679-
"Unknown parameter", MS_NOT_AVAILABLE, GRADIENT_NOT_AVAILABLE))
678+
AnyParameterProperties properties = AnyParameterProperties())
680679
{
681680
BaseTag tag(name);
682681
create_parameter(tag, AnyParameter(make_any_ref(value), properties));
@@ -693,8 +692,7 @@ class CSGObject
693692
template <typename T, typename S>
694693
void watch_param(
695694
const std::string& name, T** value, S* len,
696-
AnyParameterProperties properties = AnyParameterProperties(
697-
"Unknown parameter", MS_NOT_AVAILABLE, GRADIENT_NOT_AVAILABLE))
695+
AnyParameterProperties properties = AnyParameterProperties())
698696
{
699697
BaseTag tag(name);
700698
create_parameter(
@@ -714,8 +712,7 @@ class CSGObject
714712
template <typename T, typename S>
715713
void watch_param(
716714
const std::string& name, T** value, S* rows, S* cols,
717-
AnyParameterProperties properties = AnyParameterProperties(
718-
"Unknown parameter", MS_NOT_AVAILABLE, GRADIENT_NOT_AVAILABLE))
715+
AnyParameterProperties properties = AnyParameterProperties())
719716
{
720717
BaseTag tag(name);
721718
create_parameter(
@@ -733,7 +730,10 @@ class CSGObject
733730
{
734731
BaseTag tag(name);
735732
AnyParameterProperties properties(
736-
"Dynamic parameter", MS_NOT_AVAILABLE, GRADIENT_NOT_AVAILABLE);
733+
"Dynamic parameter",
734+
ParameterProperties::HYPER |
735+
ParameterProperties::GRADIENT |
736+
ParameterProperties::MODEL);
737737
std::function<T()> bind_method =
738738
std::bind(method, dynamic_cast<const S*>(this));
739739
create_parameter(tag, AnyParameter(make_any(bind_method), properties));

src/shogun/lib/bitmask_operators.h

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
#ifndef JSS_BITMASK_HPP
2+
#define JSS_BITMASK_HPP
3+
4+
// (C) Copyright 2015 Just Software Solutions Ltd
5+
//
6+
// Distributed under the Boost Software License, Version 1.0.
7+
//
8+
// Boost Software License - Version 1.0 - August 17th, 2003
9+
//
10+
// Permission is hereby granted, free of charge, to any person or
11+
// organization obtaining a copy of the software and accompanying
12+
// documentation covered by this license (the "Software") to use,
13+
// reproduce, display, distribute, execute, and transmit the
14+
// Software, and to prepare derivative works of the Software, and
15+
// to permit third-parties to whom the Software is furnished to
16+
// do so, all subject to the following:
17+
//
18+
// The copyright notices in the Software and this entire
19+
// statement, including the above license grant, this restriction
20+
// and the following disclaimer, must be included in all copies
21+
// of the Software, in whole or in part, and all derivative works
22+
// of the Software, unless such copies or derivative works are
23+
// solely in the form of machine-executable object code generated
24+
// by a source language processor.
25+
//
26+
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY
27+
// KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
28+
// WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
29+
// PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE
30+
// COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE
31+
// LIABLE FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN
32+
// CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
33+
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
34+
// THE SOFTWARE.
35+
36+
#include<type_traits>
37+
38+
namespace shogun {
39+
40+
template<typename E>
41+
struct enable_bitmask_operators {
42+
static constexpr bool enable = false;
43+
};
44+
45+
#define enableEnumClassBitmask(T) template<> \
46+
struct enable_bitmask_operators<T> \
47+
{ \
48+
static constexpr bool enable = true; \
49+
}
50+
51+
template<typename E>
52+
typename std::enable_if<enable_bitmask_operators<E>::enable, E>::type
53+
operator|(E lhs, E rhs) {
54+
typedef typename std::underlying_type<E>::type underlying;
55+
return static_cast<E>(
56+
static_cast<underlying>(lhs) | static_cast<underlying>(rhs));
57+
}
58+
59+
template<typename E>
60+
typename std::enable_if<enable_bitmask_operators<E>::enable, E>::type
61+
operator&(E lhs, E rhs) {
62+
typedef typename std::underlying_type<E>::type underlying;
63+
return static_cast<E>(
64+
static_cast<underlying>(lhs) & static_cast<underlying>(rhs));
65+
}
66+
67+
template<typename E>
68+
typename std::enable_if<enable_bitmask_operators<E>::enable, E>::type
69+
operator^(E lhs, E rhs) {
70+
typedef typename std::underlying_type<E>::type underlying;
71+
return static_cast<E>(
72+
static_cast<underlying>(lhs) ^ static_cast<underlying>(rhs));
73+
}
74+
75+
template<typename E>
76+
typename std::enable_if<enable_bitmask_operators<E>::enable, E>::type
77+
operator~(E lhs) {
78+
typedef typename std::underlying_type<E>::type underlying;
79+
return static_cast<E>(
80+
~static_cast<underlying>(lhs));
81+
}
82+
83+
template<typename E>
84+
typename std::enable_if<enable_bitmask_operators<E>::enable, E &>::type
85+
operator|=(E &lhs, E rhs) {
86+
typedef typename std::underlying_type<E>::type underlying;
87+
lhs = static_cast<E>(
88+
static_cast<underlying>(lhs) | static_cast<underlying>(rhs));
89+
return lhs;
90+
}
91+
92+
template<typename E>
93+
typename std::enable_if<enable_bitmask_operators<E>::enable, E &>::type
94+
operator&=(E &lhs, E rhs) {
95+
typedef typename std::underlying_type<E>::type underlying;
96+
lhs = static_cast<E>(
97+
static_cast<underlying>(lhs) & static_cast<underlying>(rhs));
98+
return lhs;
99+
}
100+
101+
template<typename E>
102+
typename std::enable_if<enable_bitmask_operators<E>::enable, E &>::type
103+
operator^=(E &lhs, E rhs) {
104+
typedef typename std::underlying_type<E>::type underlying;
105+
lhs = static_cast<E>(
106+
static_cast<underlying>(lhs) ^ static_cast<underlying>(rhs));
107+
return lhs;
108+
}
109+
}
110+
#endif

0 commit comments

Comments
 (0)