Skip to content

Commit 39ceaa7

Browse files
author
Alexandre Hoffmann
committed
feurst commit
1 parent d0f0425 commit 39ceaa7

16 files changed

+306
-338
lines changed

include/xtensor/containers/xfixed.hpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,7 @@ namespace xt
325325
explicit xfixed_container(const inner_shape_type& shape, layout_type l = L);
326326
explicit xfixed_container(const inner_shape_type& shape, value_type v, layout_type l = L);
327327

328-
// remove this enable_if when removing the other value_type constructor
329-
template <class IX = std::integral_constant<std::size_t, N>, class EN = std::enable_if_t<IX::value != 0, int>>
328+
template <class IX = std::integral_constant<std::size_t, N>> requires (IX::value != 0)
330329
xfixed_container(nested_initializer_list_t<value_type, N> t);
331330

332331
~xfixed_container() = default;
@@ -639,7 +638,7 @@ namespace xt
639638
* Note: for clang < 3.8 this is an initializer_list and the size is not checked at compile-or runtime.
640639
*/
641640
template <class ET, class S, layout_type L, bool SH, class Tag>
642-
template <class IX, class EN>
641+
template <class IX> requires (IX::value != 0)
643642
inline xfixed_container<ET, S, L, SH, Tag>::xfixed_container(nested_initializer_list_t<value_type, N> t)
644643
{
645644
XTENSOR_ASSERT_MSG(

include/xtensor/containers/xscalar.hpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,10 @@ namespace xt
316316

317317
template <class E>
318318
using is_xscalar = detail::is_xscalar_impl<E>;
319-
319+
320+
template<class E>
321+
concept is_xscalar_concept = is_xscalar<std::decay_t<E>>::value;
322+
320323
namespace detail
321324
{
322325
template <class... E>

include/xtensor/containers/xstorage.hpp

+8-16
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,6 @@
2525

2626
namespace xt
2727
{
28-
29-
namespace detail
30-
{
31-
template <class It>
32-
using require_input_iter = typename std::enable_if<
33-
std::is_convertible<typename std::iterator_traits<It>::iterator_category, std::input_iterator_tag>::value>::type;
34-
}
35-
3628
template <class C>
3729
struct is_contiguous_container : std::true_type
3830
{
@@ -64,7 +56,7 @@ namespace xt
6456
explicit uvector(size_type count, const allocator_type& alloc = allocator_type());
6557
uvector(size_type count, const_reference value, const allocator_type& alloc = allocator_type());
6658

67-
template <class InputIt, class = detail::require_input_iter<InputIt>>
59+
template <input_iterator_concept InputIt>
6860
uvector(InputIt first, InputIt last, const allocator_type& alloc = allocator_type());
6961

7062
uvector(std::initializer_list<T> init, const allocator_type& alloc = allocator_type());
@@ -277,7 +269,7 @@ namespace xt
277269
}
278270

279271
template <class T, class A>
280-
template <class InputIt, class>
272+
template <input_iterator_concept InputIt>
281273
inline uvector<T, A>::uvector(InputIt first, InputIt last, const allocator_type& alloc)
282274
: m_allocator(alloc)
283275
, p_begin(nullptr)
@@ -675,18 +667,18 @@ namespace xt
675667

676668
svector(const std::vector<T>& vec);
677669

678-
template <class IT, class = detail::require_input_iter<IT>>
670+
template <input_iterator_concept IT>
679671
svector(IT begin, IT end, const allocator_type& alloc = allocator_type());
680672

681-
template <std::size_t N2, bool I2, class = std::enable_if_t<N != N2, void>>
673+
template <std::size_t N2, bool I2> requires (N != N2)
682674
explicit svector(const svector<T, N2, A, I2>& rhs);
683675

684676
svector& operator=(const svector& rhs);
685677
svector& operator=(svector&& rhs) noexcept(std::is_nothrow_move_assignable<value_type>::value);
686678
svector& operator=(const std::vector<T>& rhs);
687679
svector& operator=(std::initializer_list<T> il);
688680

689-
template <std::size_t N2, bool I2, class = std::enable_if_t<N != N2, void>>
681+
template <std::size_t N2, bool I2> requires (N != N2)
690682
svector& operator=(const svector<T, N2, A, I2>& rhs);
691683

692684
svector(const svector& other);
@@ -809,15 +801,15 @@ namespace xt
809801
}
810802

811803
template <class T, std::size_t N, class A, bool Init>
812-
template <class IT, class>
804+
template <input_iterator_concept IT>
813805
inline svector<T, N, A, Init>::svector(IT begin, IT end, const allocator_type& alloc)
814806
: m_allocator(alloc)
815807
{
816808
assign(begin, end);
817809
}
818810

819811
template <class T, std::size_t N, class A, bool Init>
820-
template <std::size_t N2, bool I2, class>
812+
template <std::size_t N2, bool I2> requires (N != N2)
821813
inline svector<T, N, A, Init>::svector(const svector<T, N2, A, I2>& rhs)
822814
: m_allocator(rhs.get_allocator())
823815
{
@@ -876,7 +868,7 @@ namespace xt
876868
}
877869

878870
template <class T, std::size_t N, class A, bool Init>
879-
template <std::size_t N2, bool I2, class>
871+
template <std::size_t N2, bool I2> requires (N != N2)
880872
inline svector<T, N, A, Init>& svector<T, N, A, Init>::operator=(const svector<T, N2, A, I2>& rhs)
881873
{
882874
m_allocator = std::allocator_traits<allocator_type>::select_on_container_copy_construction(

include/xtensor/core/xexpression.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ namespace xt
178178

179179
template <class E>
180180
using is_xexpression = is_crtp_base_of<xexpression, E>;
181+
182+
template <class E>
183+
concept xexpression_concept = is_xexpression<E>::value;
181184

182185
template <class E, class R = void>
183186
using enable_xexpression = typename std::enable_if<is_xexpression<E>::value, R>::type;

include/xtensor/core/xshape.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,8 @@ namespace xt
508508
};
509509
}
510510

511+
template<typename T> concept fixed_shape_container_concept = detail::is_fixed<typename std::decay_t<T>::shape_type>::value;
512+
511513
template <class... S>
512514
struct promote_shape
513515
{

include/xtensor/generators/xbuilder.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,7 @@ namespace xt
833833
return detail::make_xgenerator(detail::concatenate_impl<CT...>(std::move(t), axis), shape);
834834
}
835835

836-
template <std::size_t axis, class... CT, typename = std::enable_if_t<detail::all_fixed_shapes<CT...>::value>>
836+
template <std::size_t axis, fixed_shape_container_concept... CT>
837837
inline auto concatenate(std::tuple<CT...>&& t)
838838
{
839839
using shape_type = detail::concat_fixed_shape_t<axis, typename std::decay_t<CT>::shape_type...>;

include/xtensor/generators/xgenerator.hpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ namespace xt
5959

6060
template <class F, class R, class S>
6161
class xgenerator;
62+
63+
template<typename T> concept xgenerator_concept = is_specialization_of<xgenerator, std::decay_t<T>>::value;
6264

6365
template <class C, class R, class S>
6466
struct xiterable_inner_types<xgenerator<C, R, S>>
@@ -80,10 +82,8 @@ namespace xt
8082
* overlapping_memory_checker_traits *
8183
*************************************/
8284

83-
template <class E>
84-
struct overlapping_memory_checker_traits<
85-
E,
86-
std::enable_if_t<!has_memory_address<E>::value && is_specialization_of<xgenerator, E>::value>>
85+
template <xgenerator_concept E> requires without_memory_address_concept<E>
86+
struct overlapping_memory_checker_traits<E>
8787
{
8888
static bool check_overlap(const E&, const memory_range&)
8989
{
@@ -165,7 +165,7 @@ namespace xt
165165
template <class O>
166166
const_stepper stepper_end(const O& shape, layout_type) const noexcept;
167167

168-
template <class E, class FE = F, class = std::enable_if_t<has_assign_to<E, FE>::value>>
168+
template <class E, class FE = F> requires has_assign_to_v<E, FE>
169169
void assign_to(xexpression<E>& e) const noexcept;
170170

171171
const functor_type& functor() const noexcept;
@@ -371,7 +371,7 @@ namespace xt
371371
}
372372

373373
template <class F, class R, class S>
374-
template <class E, class, class>
374+
template <class E, class FE> requires has_assign_to_v<E, FE>
375375
inline void xgenerator<F, R, S>::assign_to(xexpression<E>& e) const noexcept
376376
{
377377
e.derived_cast().resize(m_shape);

include/xtensor/generators/xrandom.hpp

+9-10
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "../generators/xgenerator.hpp"
3030
#include "../views/xindex_view.hpp"
3131
#include "../views/xview.hpp"
32+
#include "../misc/xtl_concepts.hpp"
3233

3334
namespace xt
3435
{
@@ -175,13 +176,11 @@ namespace xt
175176
template <class T, class E = random::default_engine_type>
176177
void shuffle(xexpression<T>& e, E& engine = random::get_default_random_engine());
177178

178-
template <class T, class E = random::default_engine_type>
179-
std::enable_if_t<xtl::is_integral<T>::value, xtensor<T, 1>>
180-
permutation(T e, E& engine = random::get_default_random_engine());
179+
template <xtl::integral_concept T, class E = random::default_engine_type>
180+
xtensor<T, 1> permutation(T e, E& engine = random::get_default_random_engine());
181181

182-
template <class T, class E = random::default_engine_type>
183-
std::enable_if_t<is_xexpression<std::decay_t<T>>::value, std::decay_t<T>>
184-
permutation(T&& e, E& engine = random::get_default_random_engine());
182+
template <xexpression_concept T, class E = random::default_engine_type>
183+
std::decay_t<T> permutation(T&& e, E& engine = random::get_default_random_engine());
185184

186185
template <class T, class E = random::default_engine_type>
187186
xtensor<typename T::value_type, 1> choice(
@@ -835,17 +834,17 @@ namespace xt
835834
*
836835
* @return randomly permuted copy of container or arange.
837836
*/
838-
template <class T, class E>
839-
std::enable_if_t<xtl::is_integral<T>::value, xtensor<T, 1>> permutation(T e, E& engine)
837+
template <xtl::integral_concept T, class E>
838+
xtensor<T, 1> permutation(T e, E& engine)
840839
{
841840
xt::xtensor<T, 1> res = xt::arange<T>(e);
842841
shuffle(res, engine);
843842
return res;
844843
}
845844

846845
/// @cond DOXYGEN_INCLUDE_SFINAE
847-
template <class T, class E>
848-
std::enable_if_t<is_xexpression<std::decay_t<T>>::value, std::decay_t<T>> permutation(T&& e, E& engine)
846+
template <xexpression_concept T, class E>
847+
std::decay_t<T> permutation(T&& e, E& engine)
849848
{
850849
using copy_type = std::decay_t<T>;
851850
copy_type res = e;

include/xtensor/misc/xfft.hpp

+24-38
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include <xtl/xcomplex.hpp>
77

8+
#include "./xtl_concepts.hpp"
89
#include "../containers/xarray.hpp"
910
#include "../core/xmath.hpp"
1011
#include "../core/xnoalias.hpp"
@@ -19,9 +20,7 @@ namespace xt
1920
{
2021
namespace detail
2122
{
22-
template <
23-
class E,
24-
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
23+
template <xtl::complex_concept E>
2524
inline auto radix2(E&& e)
2625
{
2726
using namespace xt::placeholders;
@@ -119,19 +118,19 @@ namespace xt
119118
}
120119
} // namespace detail
121120

122-
/**
121+
/**
123122
* @brief 1D FFT of an Nd array along a specified axis
124123
* @param e an Nd expression to be transformed to the fourier domain
125124
* @param axis the axis along which to perform the 1D FFT
126125
* @return a transformed xarray of the specified precision
127126
*/
128-
template <
129-
class E,
130-
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
127+
template<class E>
131128
inline auto fft(E&& e, std::ptrdiff_t axis = -1)
132129
{
133-
using value_type = typename std::decay_t<E>::value_type;
134-
using precision = typename value_type::value_type;
130+
using value_type = typename std::decay<E>::type::value_type;
131+
if constexpr (xtl::is_complex<typename std::decay<E>::type::value_type>::value)
132+
{
133+
using precision = typename value_type::value_type;
135134
const auto saxis = xt::normalize_axis(e.dimension(), axis);
136135
const size_t N = e.shape(saxis);
137136
const bool powerOfTwo = !(N == 0) && !(N & (N - 1));
@@ -150,29 +149,19 @@ namespace xt
150149
}
151150
}
152151
return out;
153-
}
154-
155-
/**
156-
* @brief 1D FFT of an Nd array along a specified axis
157-
* @param e an Nd expression to be transformed to the fourier domain
158-
* @param axis the axis along which to perform the 1D FFT
159-
* @return a transformed xarray of the specified precision
160-
*/
161-
template <
162-
class E,
163-
typename std::enable_if<!xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
164-
inline auto fft(E&& e, std::ptrdiff_t axis = -1)
165-
{
166-
using value_type = typename std::decay<E>::type::value_type;
167-
return fft(xt::cast<std::complex<value_type>>(e), axis);
168-
}
169-
170-
template <
171-
class E,
172-
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
173-
auto ifft(E&& e, std::ptrdiff_t axis = -1)
152+
}
153+
else
154+
{
155+
return fft(xt::cast<std::complex<value_type>>(e), axis);
156+
}
157+
}
158+
159+
template <class E>
160+
inline auto ifft(E&& e, std::ptrdiff_t axis = -1)
174161
{
175-
// check the length of the data on that axis
162+
if constexpr (xtl::is_complex<typename std::decay<E>::type::value_type>::value)
163+
{
164+
// check the length of the data on that axis
176165
const std::size_t n = e.shape(axis);
177166
if (n == 0)
178167
{
@@ -182,15 +171,12 @@ namespace xt
182171
auto fft_res = xt::fft::fft(complex_args, axis);
183172
fft_res = xt::conj(fft_res);
184173
return fft_res;
185-
}
186-
187-
template <
188-
class E,
189-
typename std::enable_if<!xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
190-
inline auto ifft(E&& e, std::ptrdiff_t axis = -1)
191-
{
174+
}
175+
else
176+
{
192177
using value_type = typename std::decay<E>::type::value_type;
193178
return ifft(xt::cast<std::complex<value_type>>(e), axis);
179+
}
194180
}
195181

196182
/*

include/xtensor/misc/xmanipulation.hpp

+12-10
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "../utils/xutils.hpp"
2323
#include "../views/xrepeat.hpp"
2424
#include "../views/xstrided_view.hpp"
25+
#include "xtl_concepts.hpp"
2526

2627
namespace xt
2728
{
@@ -64,7 +65,7 @@ namespace xt
6465
template <class E>
6566
auto squeeze(E&& e);
6667

67-
template <class E, class S, class Tag = check_policy::none, std::enable_if_t<!xtl::is_integral<S>::value, int> = 0>
68+
template <class E, xtl::non_integral_concept S, class Tag = check_policy::none>
6869
auto squeeze(E&& e, S&& axis, Tag check_policy = Tag());
6970

7071
template <class E>
@@ -210,20 +211,21 @@ namespace xt
210211
return transpose_impl(std::forward<E>(e), std::forward<S>(permutation), check_policy::none());
211212
}
212213

213-
template <class E, class S, class X, std::enable_if_t<has_data_interface<std::decay_t<E>>::value>* = nullptr>
214-
inline void compute_transposed_strides(E&& e, const S&, X& strides)
215-
{
216-
std::copy(e.strides().crbegin(), e.strides().crend(), strides.begin());
217-
}
218-
219-
template <class E, class S, class X, std::enable_if_t<!has_data_interface<std::decay_t<E>>::value>* = nullptr>
220-
inline void compute_transposed_strides(E&&, const S& shape, X& strides)
214+
template <class E, class S, class X>
215+
inline void compute_transposed_strides(E&& e, const S& shape, X& strides)
221216
{
217+
if constexpr (has_data_interface<std::decay_t<E>>::value)
218+
{
219+
std::copy(e.strides().crbegin(), e.strides().crend(), strides.begin());
220+
}
221+
else
222+
{
222223
// In the case where E does not have a data interface, the transposition
223224
// makes use of a flat storage adaptor that has layout XTENSOR_DEFAULT_TRAVERSAL
224225
// which should be the one inverted.
225226
layout_type l = transpose_layout(XTENSOR_DEFAULT_TRAVERSAL);
226227
compute_strides(shape, l, strides);
228+
}
227229
}
228230
}
229231

@@ -588,7 +590,7 @@ namespace xt
588590
* @param check_policy select check_policy. With check_policy::full(), selecting an axis
589591
* which is greater than one will throw a runtime_error.
590592
*/
591-
template <class E, class S, class Tag, std::enable_if_t<!xtl::is_integral<S>::value, int>>
593+
template <class E, xtl::non_integral_concept S, class Tag>
592594
inline auto squeeze(E&& e, S&& axis, Tag check_policy)
593595
{
594596
return detail::squeeze_impl(std::forward<E>(e), std::forward<S>(axis), check_policy);

0 commit comments

Comments
 (0)