Skip to content

Adding concept to a part of the code #2842

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions include/xtensor/containers/xfixed.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,9 @@ namespace xt
explicit xfixed_container(const inner_shape_type& shape, layout_type l = L);
explicit xfixed_container(const inner_shape_type& shape, value_type v, layout_type l = L);

// remove this enable_if when removing the other value_type constructor
template <class IX = std::integral_constant<std::size_t, N>, class EN = std::enable_if_t<IX::value != 0, int>>
xfixed_container(nested_initializer_list_t<value_type, N> t);
template <class IX = std::integral_constant<std::size_t, N>>
xfixed_container(nested_initializer_list_t<value_type, N> t)
requires(IX::value != 0);

~xfixed_container() = default;

Expand Down Expand Up @@ -639,8 +639,9 @@ namespace xt
* Note: for clang < 3.8 this is an initializer_list and the size is not checked at compile-or runtime.
*/
template <class ET, class S, layout_type L, bool SH, class Tag>
template <class IX, class EN>
template <class IX>
inline xfixed_container<ET, S, L, SH, Tag>::xfixed_container(nested_initializer_list_t<value_type, N> t)
requires(IX::value != 0)
{
XTENSOR_ASSERT_MSG(
detail::check_initializer_list_shape<N>::run(t, this->shape()) == true,
Expand Down
3 changes: 3 additions & 0 deletions include/xtensor/containers/xscalar.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,9 @@ namespace xt
template <class E>
using is_xscalar = detail::is_xscalar_impl<E>;

template <class E>
concept xscalar_concept = is_xscalar<std::decay_t<E>>::value;

namespace detail
{
template <class... E>
Expand Down
32 changes: 14 additions & 18 deletions include/xtensor/containers/xstorage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,6 @@

namespace xt
{

namespace detail
{
template <class It>
using require_input_iter = typename std::enable_if<
std::is_convertible<typename std::iterator_traits<It>::iterator_category, std::input_iterator_tag>::value>::type;
}

template <class C>
struct is_contiguous_container : std::true_type
{
Expand Down Expand Up @@ -64,7 +56,7 @@ namespace xt
explicit uvector(size_type count, const allocator_type& alloc = allocator_type());
uvector(size_type count, const_reference value, const allocator_type& alloc = allocator_type());

template <class InputIt, class = detail::require_input_iter<InputIt>>
template <std::input_iterator InputIt>
uvector(InputIt first, InputIt last, const allocator_type& alloc = allocator_type());

uvector(std::initializer_list<T> init, const allocator_type& alloc = allocator_type());
Expand Down Expand Up @@ -277,7 +269,7 @@ namespace xt
}

template <class T, class A>
template <class InputIt, class>
template <std::input_iterator InputIt>
inline uvector<T, A>::uvector(InputIt first, InputIt last, const allocator_type& alloc)
: m_allocator(alloc)
, p_begin(nullptr)
Expand Down Expand Up @@ -675,19 +667,21 @@ namespace xt

svector(const std::vector<T>& vec);

template <class IT, class = detail::require_input_iter<IT>>
template <std::input_iterator IT>
svector(IT begin, IT end, const allocator_type& alloc = allocator_type());

template <std::size_t N2, bool I2, class = std::enable_if_t<N != N2, void>>
explicit svector(const svector<T, N2, A, I2>& rhs);
template <std::size_t N2, bool I2>
explicit svector(const svector<T, N2, A, I2>& rhs)
requires(N != N2);

svector& operator=(const svector& rhs);
svector& operator=(svector&& rhs) noexcept(std::is_nothrow_move_assignable<value_type>::value);
svector& operator=(const std::vector<T>& rhs);
svector& operator=(std::initializer_list<T> il);

template <std::size_t N2, bool I2, class = std::enable_if_t<N != N2, void>>
svector& operator=(const svector<T, N2, A, I2>& rhs);
template <std::size_t N2, bool I2>
svector& operator=(const svector<T, N2, A, I2>& rhs)
requires(N != N2);

svector(const svector& other);
svector(svector&& other) noexcept(std::is_nothrow_move_constructible<value_type>::value);
Expand Down Expand Up @@ -809,16 +803,17 @@ namespace xt
}

template <class T, std::size_t N, class A, bool Init>
template <class IT, class>
template <std::input_iterator IT>
inline svector<T, N, A, Init>::svector(IT begin, IT end, const allocator_type& alloc)
: m_allocator(alloc)
{
assign(begin, end);
}

template <class T, std::size_t N, class A, bool Init>
template <std::size_t N2, bool I2, class>
template <std::size_t N2, bool I2>
inline svector<T, N, A, Init>::svector(const svector<T, N2, A, I2>& rhs)
requires(N != N2)
: m_allocator(rhs.get_allocator())
{
assign(rhs.begin(), rhs.end());
Expand Down Expand Up @@ -876,8 +871,9 @@ namespace xt
}

template <class T, std::size_t N, class A, bool Init>
template <std::size_t N2, bool I2, class>
template <std::size_t N2, bool I2>
inline svector<T, N, A, Init>& svector<T, N, A, Init>::operator=(const svector<T, N2, A, I2>& rhs)
requires(N != N2)
{
m_allocator = std::allocator_traits<allocator_type>::select_on_container_copy_construction(
rhs.get_allocator()
Expand Down
3 changes: 3 additions & 0 deletions include/xtensor/core/xexpression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ namespace xt
template <class E>
using is_xexpression = is_crtp_base_of<xexpression, E>;

template <class E>
concept xexpression_concept = is_xexpression<E>::value;

template <class E, class R = void>
using enable_xexpression = typename std::enable_if<is_xexpression<E>::value, R>::type;

Expand Down
3 changes: 3 additions & 0 deletions include/xtensor/core/xshape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,9 @@ namespace xt
};
}

template <typename T>
concept fixed_shape_container_concept = detail::is_fixed<typename std::decay_t<T>::shape_type>::value;

template <class... S>
struct promote_shape
{
Expand Down
2 changes: 1 addition & 1 deletion include/xtensor/generators/xbuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ namespace xt
return detail::make_xgenerator(detail::concatenate_impl<CT...>(std::move(t), axis), shape);
}

template <std::size_t axis, class... CT, typename = std::enable_if_t<detail::all_fixed_shapes<CT...>::value>>
template <std::size_t axis, fixed_shape_container_concept... CT>
inline auto concatenate(std::tuple<CT...>&& t)
{
using shape_type = detail::concat_fixed_shape_t<axis, typename std::decay_t<CT>::shape_type...>;
Expand Down
18 changes: 11 additions & 7 deletions include/xtensor/generators/xgenerator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ namespace xt
template <class F, class R, class S>
class xgenerator;

template <typename T>
concept xgenerator_concept = is_specialization_of<xgenerator, std::decay_t<T>>::value;

template <class C, class R, class S>
struct xiterable_inner_types<xgenerator<C, R, S>>
{
Expand All @@ -80,10 +83,9 @@ namespace xt
* overlapping_memory_checker_traits *
*************************************/

template <class E>
struct overlapping_memory_checker_traits<
E,
std::enable_if_t<!has_memory_address<E>::value && is_specialization_of<xgenerator, E>::value>>
template <xgenerator_concept E>
requires(without_memory_address_concept<E>)
struct overlapping_memory_checker_traits<E>
{
static bool check_overlap(const E&, const memory_range&)
{
Expand Down Expand Up @@ -165,8 +167,9 @@ namespace xt
template <class O>
const_stepper stepper_end(const O& shape, layout_type) const noexcept;

template <class E, class FE = F, class = std::enable_if_t<has_assign_to<E, FE>::value>>
void assign_to(xexpression<E>& e) const noexcept;
template <class E, class FE = F>
void assign_to(xexpression<E>& e) const noexcept
requires(has_assign_to_v<E, FE>);

const functor_type& functor() const noexcept;

Expand Down Expand Up @@ -371,8 +374,9 @@ namespace xt
}

template <class F, class R, class S>
template <class E, class, class>
template <class E, class FE>
inline void xgenerator<F, R, S>::assign_to(xexpression<E>& e) const noexcept
requires(has_assign_to_v<E, FE>)
{
e.derived_cast().resize(m_shape);
m_f.assign_to(e);
Expand Down
19 changes: 9 additions & 10 deletions include/xtensor/generators/xrandom.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "../core/xtensor_config.hpp"
#include "../generators/xbuilder.hpp"
#include "../generators/xgenerator.hpp"
#include "../misc/xtl_concepts.hpp"
#include "../views/xindex_view.hpp"
#include "../views/xview.hpp"

Expand Down Expand Up @@ -175,13 +176,11 @@ namespace xt
template <class T, class E = random::default_engine_type>
void shuffle(xexpression<T>& e, E& engine = random::get_default_random_engine());

template <class T, class E = random::default_engine_type>
std::enable_if_t<xtl::is_integral<T>::value, xtensor<T, 1>>
permutation(T e, E& engine = random::get_default_random_engine());
template <xtl::integral_concept T, class E = random::default_engine_type>
xtensor<T, 1> permutation(T e, E& engine = random::get_default_random_engine());

template <class T, class E = random::default_engine_type>
std::enable_if_t<is_xexpression<std::decay_t<T>>::value, std::decay_t<T>>
permutation(T&& e, E& engine = random::get_default_random_engine());
template <xexpression_concept T, class E = random::default_engine_type>
std::decay_t<T> permutation(T&& e, E& engine = random::get_default_random_engine());

template <class T, class E = random::default_engine_type>
xtensor<typename T::value_type, 1> choice(
Expand Down Expand Up @@ -835,17 +834,17 @@ namespace xt
*
* @return randomly permuted copy of container or arange.
*/
template <class T, class E>
std::enable_if_t<xtl::is_integral<T>::value, xtensor<T, 1>> permutation(T e, E& engine)
template <xtl::integral_concept T, class E>
xtensor<T, 1> permutation(T e, E& engine)
{
xt::xtensor<T, 1> res = xt::arange<T>(e);
shuffle(res, engine);
return res;
}

/// @cond DOXYGEN_INCLUDE_SFINAE
template <class T, class E>
std::enable_if_t<is_xexpression<std::decay_t<T>>::value, std::decay_t<T>> permutation(T&& e, E& engine)
template <xexpression_concept T, class E>
std::decay_t<T> permutation(T&& e, E& engine)
{
using copy_type = std::decay_t<T>;
copy_type res = e;
Expand Down
100 changes: 43 additions & 57 deletions include/xtensor/misc/xfft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@
#include "../misc/xcomplex.hpp"
#include "../views/xaxis_slice_iterator.hpp"
#include "../views/xview.hpp"
#include "./xtl_concepts.hpp"

namespace xt
{
namespace fft
{
namespace detail
{
template <
class E,
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
template <xtl::complex_concept E>
inline auto radix2(E&& e)
{
using namespace xt::placeholders;
Expand Down Expand Up @@ -125,72 +124,59 @@ namespace xt
* @param axis the axis along which to perform the 1D FFT
* @return a transformed xarray of the specified precision
*/
template <
class E,
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
template <class E>
inline auto fft(E&& e, std::ptrdiff_t axis = -1)
{
using value_type = typename std::decay_t<E>::value_type;
using precision = typename value_type::value_type;
const auto saxis = xt::normalize_axis(e.dimension(), axis);
const size_t N = e.shape(saxis);
const bool powerOfTwo = !(N == 0) && !(N & (N - 1));
xt::xarray<std::complex<precision>> out = xt::eval(e);
auto begin = xt::axis_slice_begin(out, saxis);
auto end = xt::axis_slice_end(out, saxis);
for (auto iter = begin; iter != end; iter++)
using value_type = typename std::decay<E>::type::value_type;
if constexpr (xtl::is_complex<typename std::decay<E>::type::value_type>::value)
{
if (powerOfTwo)
{
xt::noalias(*iter) = detail::radix2(*iter);
}
else
using precision = typename value_type::value_type;
const auto saxis = xt::normalize_axis(e.dimension(), axis);
const size_t N = e.shape(saxis);
const bool powerOfTwo = !(N == 0) && !(N & (N - 1));
xt::xarray<std::complex<precision>> out = xt::eval(e);
auto begin = xt::axis_slice_begin(out, saxis);
auto end = xt::axis_slice_end(out, saxis);
for (auto iter = begin; iter != end; iter++)
{
xt::noalias(*iter) = detail::transform_bluestein(*iter);
if (powerOfTwo)
{
xt::noalias(*iter) = detail::radix2(*iter);
}
else
{
xt::noalias(*iter) = detail::transform_bluestein(*iter);
}
}
return out;
}
return out;
}

/**
* @brief 1D FFT of an Nd array along a specified axis
* @param e an Nd expression to be transformed to the fourier domain
* @param axis the axis along which to perform the 1D FFT
* @return a transformed xarray of the specified precision
*/
template <
class E,
typename std::enable_if<!xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
inline auto fft(E&& e, std::ptrdiff_t axis = -1)
{
using value_type = typename std::decay<E>::type::value_type;
return fft(xt::cast<std::complex<value_type>>(e), axis);
}

template <
class E,
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
auto ifft(E&& e, std::ptrdiff_t axis = -1)
{
// check the length of the data on that axis
const std::size_t n = e.shape(axis);
if (n == 0)
else
{
XTENSOR_THROW(std::runtime_error, "Cannot take the iFFT along an empty dimention");
return fft(xt::cast<std::complex<value_type>>(e), axis);
}
auto complex_args = xt::conj(e);
auto fft_res = xt::fft::fft(complex_args, axis);
fft_res = xt::conj(fft_res);
return fft_res;
}

template <
class E,
typename std::enable_if<!xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
template <class E>
inline auto ifft(E&& e, std::ptrdiff_t axis = -1)
{
using value_type = typename std::decay<E>::type::value_type;
return ifft(xt::cast<std::complex<value_type>>(e), axis);
if constexpr (xtl::is_complex<typename std::decay<E>::type::value_type>::value)
{
// check the length of the data on that axis
const std::size_t n = e.shape(axis);
if (n == 0)
{
XTENSOR_THROW(std::runtime_error, "Cannot take the iFFT along an empty dimention");
}
auto complex_args = xt::conj(e);
auto fft_res = xt::fft::fft(complex_args, axis);
fft_res = xt::conj(fft_res);
return fft_res;
}
else
{
using value_type = typename std::decay<E>::type::value_type;
return ifft(xt::cast<std::complex<value_type>>(e), axis);
}
}

/*
Expand Down
Loading