Skip to content

Commit edc67a3

Browse files
authored
Merge pull request #2782 from spectre-ns/fft
Pure xtensor FFT implementation
2 parents d9c3782 + db9913b commit edc67a3

File tree

8 files changed

+349
-0
lines changed

8 files changed

+349
-0
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,4 @@ __pycache__
6262

6363
# Generated files
6464
*.pc
65+
.vscode/settings.json

CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ set(XTENSOR_HEADERS
140140
${XTENSOR_INCLUDE_DIR}/xtensor/xfixed.hpp
141141
${XTENSOR_INCLUDE_DIR}/xtensor/xfunction.hpp
142142
${XTENSOR_INCLUDE_DIR}/xtensor/xfunctor_view.hpp
143+
${XTENSOR_INCLUDE_DIR}/xtensor/xfft.hpp
143144
${XTENSOR_INCLUDE_DIR}/xtensor/xgenerator.hpp
144145
${XTENSOR_INCLUDE_DIR}/xtensor/xhistogram.hpp
145146
${XTENSOR_INCLUDE_DIR}/xtensor/xindex_view.hpp

docs/source/api/container_index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,4 @@ xexpression API is actually implemented in ``xstrided_container`` and ``xcontain
3333
xindex_view
3434
xfunctor_view
3535
xrepeat
36+
xfft

docs/source/xfft.rst

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
.. Copyright (c) 2016, Johan Mabille, Sylvain Corlay and Wolf Vollprecht
2+
Distributed under the terms of the BSD 3-Clause License.
3+
The full license is in the file LICENSE, distributed with this software.
4+
xfft
5+
====
6+
7+
Defined in ``xtensor/xfft.hpp``
8+
9+
.. doxygenclass:: xt::fft::convolve
10+
:project: xtensor
11+
:members:
12+
13+
.. doxygentypedef:: xt::fft::fft
14+
:project: xtensor
15+
16+
.. doxygentypedef:: xt::fft::ifft
17+
:project: xtensor

include/xtensor/xfft.hpp

+241
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
#ifdef XTENSOR_USE_TBB
2+
#include <oneapi/tbb.h>
3+
#endif
4+
#include <stdexcept>
5+
6+
#include <xtl/xcomplex.hpp>
7+
8+
#include <xtensor/xarray.hpp>
9+
#include <xtensor/xaxis_slice_iterator.hpp>
10+
#include <xtensor/xbuilder.hpp>
11+
#include <xtensor/xcomplex.hpp>
12+
#include <xtensor/xmath.hpp>
13+
#include <xtensor/xnoalias.hpp>
14+
#include <xtensor/xview.hpp>
15+
16+
namespace xt
17+
{
18+
namespace fft
19+
{
20+
namespace detail
21+
{
22+
template <
23+
class E,
24+
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
25+
inline auto radix2(E&& e)
26+
{
27+
using namespace xt::placeholders;
28+
using namespace std::complex_literals;
29+
using value_type = typename std::decay_t<E>::value_type;
30+
using precision = typename value_type::value_type;
31+
auto N = e.size();
32+
const bool powerOfTwo = !(N == 0) && !(N & (N - 1));
33+
// check for power of 2
34+
if (!powerOfTwo || N == 0)
35+
{
36+
// TODO: Replace implementation with dft
37+
XTENSOR_THROW(std::runtime_error, "FFT Implementation requires power of 2");
38+
}
39+
auto pi = xt::numeric_constants<precision>::PI;
40+
xt::xtensor<value_type, 1> ev = e;
41+
if (N <= 1)
42+
{
43+
return ev;
44+
}
45+
else
46+
{
47+
#ifdef XTENSOR_USE_TBB
48+
xt::xtensor<value_type, 1> even;
49+
xt::xtensor<value_type, 1> odd;
50+
oneapi::tbb::parallel_invoke(
51+
[&]
52+
{
53+
even = radix2(xt::view(ev, xt::range(0, _, 2)));
54+
},
55+
[&]
56+
{
57+
odd = radix2(xt::view(ev, xt::range(1, _, 2)));
58+
}
59+
);
60+
#else
61+
auto even = radix2(xt::view(ev, xt::range(0, _, 2)));
62+
auto odd = radix2(xt::view(ev, xt::range(1, _, 2)));
63+
#endif
64+
65+
auto range = xt::arange<double>(N / 2);
66+
auto exp = xt::exp(static_cast<value_type>(-2i) * pi * range / N);
67+
auto t = exp * odd;
68+
auto first_half = even + t;
69+
auto second_half = even - t;
70+
// TODO: should be a call to stack if performance was improved
71+
auto spectrum = xt::xtensor<value_type, 1>::from_shape({N});
72+
xt::view(spectrum, xt::range(0, N / 2)) = first_half;
73+
xt::view(spectrum, xt::range(N / 2, N)) = second_half;
74+
return spectrum;
75+
}
76+
}
77+
78+
template <typename E>
79+
auto transform_bluestein(E&& data)
80+
{
81+
using value_type = typename std::decay_t<E>::value_type;
82+
using precision = typename value_type::value_type;
83+
84+
// Find a power-of-2 convolution length m such that m >= n * 2 + 1
85+
const std::size_t n = data.size();
86+
size_t m = std::ceil(std::log2(n * 2 + 1));
87+
m = std::pow(2, m);
88+
89+
// Trignometric table
90+
auto exp_table = xt::xtensor<std::complex<precision>, 1>::from_shape({n});
91+
xt::xtensor<std::size_t, 1> i = xt::pow(xt::linspace<std::size_t>(0, n - 1, n), 2);
92+
i %= (n * 2);
93+
94+
auto angles = xt::eval(precision{3.141592653589793238463} * i / n);
95+
auto j = std::complex<precision>(0, 1);
96+
exp_table = xt::exp(-angles * j);
97+
98+
// Temporary vectors and preprocessing
99+
auto av = xt::empty<std::complex<precision>>({m});
100+
xt::view(av, xt::range(0, n)) = data * exp_table;
101+
102+
103+
auto bv = xt::empty<std::complex<precision>>({m});
104+
xt::view(bv, xt::range(0, n)) = ::xt::conj(exp_table);
105+
xt::view(bv, xt::range(-n + 1, xt::placeholders::_)) = xt::view(
106+
::xt::conj(xt::flip(exp_table)),
107+
xt::range(xt::placeholders::_, -1)
108+
);
109+
110+
// Convolution
111+
auto xv = radix2(av);
112+
auto yv = radix2(bv);
113+
auto spectrum_k = xv * yv;
114+
auto complex_args = xt::conj(spectrum_k);
115+
auto fft_res = radix2(complex_args);
116+
auto cv = xt::conj(fft_res) / m;
117+
118+
return xt::eval(xt::view(cv, xt::range(0, n)) * exp_table);
119+
}
120+
} // namespace detail
121+
122+
/**
123+
* @brief 1D FFT of an Nd array along a specified axis
124+
* @param e an Nd expression to be transformed to the fourier domain
125+
* @param axis the axis along which to perform the 1D FFT
126+
* @return a transformed xarray of the specified precision
127+
*/
128+
template <
129+
class E,
130+
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
131+
inline auto fft(E&& e, std::ptrdiff_t axis = -1)
132+
{
133+
using value_type = typename std::decay_t<E>::value_type;
134+
using precision = typename value_type::value_type;
135+
const auto saxis = xt::normalize_axis(e.dimension(), axis);
136+
const size_t N = e.shape(saxis);
137+
const bool powerOfTwo = !(N == 0) && !(N & (N - 1));
138+
xt::xarray<std::complex<precision>> out = xt::eval(e);
139+
auto begin = xt::axis_slice_begin(out, saxis);
140+
auto end = xt::axis_slice_end(out, saxis);
141+
for (auto iter = begin; iter != end; iter++)
142+
{
143+
if (powerOfTwo)
144+
{
145+
xt::noalias(*iter) = detail::radix2(*iter);
146+
}
147+
else
148+
{
149+
xt::noalias(*iter) = detail::transform_bluestein(*iter);
150+
}
151+
}
152+
return out;
153+
}
154+
155+
/**
156+
* @brief 1D FFT of an Nd array along a specified axis
157+
* @param e an Nd expression to be transformed to the fourier domain
158+
* @param axis the axis along which to perform the 1D FFT
159+
* @return a transformed xarray of the specified precision
160+
*/
161+
template <
162+
class E,
163+
typename std::enable_if<!xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
164+
inline auto fft(E&& e, std::ptrdiff_t axis = -1)
165+
{
166+
using value_type = typename std::decay<E>::type::value_type;
167+
return fft(xt::cast<std::complex<value_type>>(e), axis);
168+
}
169+
170+
template <
171+
class E,
172+
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
173+
auto ifft(E&& e, std::ptrdiff_t axis = -1)
174+
{
175+
// check the length of the data on that axis
176+
const std::size_t n = e.shape(axis);
177+
if (n == 0)
178+
{
179+
XTENSOR_THROW(std::runtime_error, "Cannot take the iFFT along an empty dimention");
180+
}
181+
auto complex_args = xt::conj(e);
182+
auto fft_res = xt::fft::fft(complex_args, axis);
183+
fft_res = xt::conj(fft_res);
184+
return fft_res;
185+
}
186+
187+
template <
188+
class E,
189+
typename std::enable_if<!xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
190+
inline auto ifft(E&& e, std::ptrdiff_t axis = -1)
191+
{
192+
using value_type = typename std::decay<E>::type::value_type;
193+
return ifft(xt::cast<std::complex<value_type>>(e), axis);
194+
}
195+
196+
/*
197+
* @brief performs a circular fft convolution xvec and yvec must
198+
* be the same shape.
199+
* @param xvec first array of the convolution
200+
* @param yvec second array of the convolution
201+
* @param axis axis along which to perform the convolution
202+
*/
203+
template <typename E1, typename E2>
204+
auto convolve(E1&& xvec, E2&& yvec, std::ptrdiff_t axis = -1)
205+
{
206+
// we could broadcast but that could get complicated???
207+
if (xvec.dimension() != yvec.dimension())
208+
{
209+
XTENSOR_THROW(std::runtime_error, "Mismatched dimentions");
210+
}
211+
212+
auto saxis = xt::normalize_axis(xvec.dimension(), axis);
213+
if (xvec.shape(saxis) != yvec.shape(saxis))
214+
{
215+
XTENSOR_THROW(std::runtime_error, "Mismatched lengths along slice axis");
216+
}
217+
218+
const std::size_t n = xvec.shape(saxis);
219+
220+
auto xv = fft(xvec, axis);
221+
auto yv = fft(yvec, axis);
222+
223+
auto begin_x = xt::axis_slice_begin(xv, saxis);
224+
auto end_x = xt::axis_slice_end(xv, saxis);
225+
auto iter_y = xt::axis_slice_begin(yv, saxis);
226+
227+
for (auto iter = begin_x; iter != end_x; iter++)
228+
{
229+
(*iter) = (*iter_y++) * (*iter);
230+
}
231+
232+
auto outvec = ifft(xv, axis);
233+
234+
// Scaling (because this FFT implementation omits it)
235+
outvec = outvec / n;
236+
237+
return outvec;
238+
}
239+
240+
}
241+
} // namespace xt::fft

include/xtensor/xmath.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ namespace xt
338338
XTENSOR_UNARY_MATH_FUNCTOR(isfinite);
339339
XTENSOR_UNARY_MATH_FUNCTOR(isinf);
340340
XTENSOR_UNARY_MATH_FUNCTOR(isnan);
341+
XTENSOR_UNARY_MATH_FUNCTOR(conj);
341342
}
342343

343344
#undef XTENSOR_UNARY_MATH_FUNCTOR

test/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ set(XTENSOR_TESTS
186186
test_xdynamic_view.cpp
187187
test_xfunctor_adaptor.cpp
188188
test_xfixed.cpp
189+
test_xfft.cpp
189190
test_xhistogram.cpp
190191
test_xpad.cpp
191192
test_xindex_view.cpp

test/test_xfft.cpp

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#include "xtensor/xarray.hpp"
2+
#include "xtensor/xfft.hpp"
3+
4+
#include "test_common_macros.hpp"
5+
6+
namespace xt
7+
{
8+
TEST(xfft, fft_power_2)
9+
{
10+
size_t k = 2;
11+
size_t n = 8192;
12+
size_t A = 10;
13+
auto x = xt::linspace<float>(0, static_cast<float>(n - 1), n);
14+
xt::xarray<float> y = A * xt::sin(2 * xt::numeric_constants<float>::PI * x * k / n);
15+
auto res = xt::fft::fft(y) / (n / 2);
16+
REQUIRE(A == doctest::Approx(std::abs(res(k))).epsilon(.0001));
17+
}
18+
19+
TEST(xfft, ifft_power_2)
20+
{
21+
size_t k = 2;
22+
size_t n = 8;
23+
size_t A = 10;
24+
auto x = xt::linspace<float>(0, static_cast<float>(n - 1), n);
25+
xt::xarray<float> y = A * xt::sin(2 * xt::numeric_constants<float>::PI * x * k / n);
26+
auto res = xt::fft::ifft(y) / (n / 2);
27+
REQUIRE(A == doctest::Approx(std::abs(res(k))).epsilon(.0001));
28+
}
29+
30+
TEST(xfft, convolve_power_2)
31+
{
32+
xt::xarray<float> x = {1.0, 1.0, 1.0, 5.0};
33+
xt::xarray<float> y = {5.0, 1.0, 1.0, 1.0};
34+
xt::xarray<float> expected = {12, 12, 12, 28};
35+
36+
auto result = xt::fft::convolve(x, y);
37+
38+
for (size_t i = 0; i < x.size(); i++)
39+
{
40+
REQUIRE(expected(i) == doctest::Approx(std::abs(result(i))).epsilon(.0001));
41+
}
42+
}
43+
44+
TEST(xfft, fft_n_0_axis)
45+
{
46+
size_t k = 2;
47+
size_t n = 10;
48+
size_t A = 1;
49+
size_t dim = 10;
50+
auto x = xt::linspace<float>(0, n - 1, n) * xt::ones<float>({dim, n});
51+
xt::xarray<float> y = A * xt::sin(2 * xt::numeric_constants<float>::PI * x * k / n);
52+
y = xt::transpose(y);
53+
auto res = xt::fft::fft(y, 0) / (n / 2.0);
54+
REQUIRE(A == doctest::Approx(std::abs(res(k, 0))).epsilon(.0001));
55+
REQUIRE(A == doctest::Approx(std::abs(res(k, 1))).epsilon(.0001));
56+
}
57+
58+
TEST(xfft, fft_n_1_axis)
59+
{
60+
size_t k = 2;
61+
size_t n = 15;
62+
size_t A = 1;
63+
size_t dim = 2;
64+
auto x = xt::linspace<float>(0, n - 1, n) * xt::ones<float>({dim, n});
65+
xt::xarray<float> y = A * xt::sin(2 * xt::numeric_constants<float>::PI * x * k / n);
66+
auto res = xt::fft::fft(y) / (n / 2.0);
67+
REQUIRE(A == doctest::Approx(std::abs(res(0, k))).epsilon(.0001));
68+
REQUIRE(A == doctest::Approx(std::abs(res(1, k))).epsilon(.0001));
69+
}
70+
71+
TEST(xfft, convolve_n)
72+
{
73+
xt::xarray<float> x = {1.0, 1.0, 1.0, 5.0, 1.0};
74+
xt::xarray<float> y = {5.0, 1.0, 1.0, 1.0, 1.0};
75+
xt::xarray<size_t> expected = {13, 13, 13, 29, 13};
76+
77+
auto result = xt::fft::convolve(x, y);
78+
79+
xt::xarray<float> abs = xt::abs(result);
80+
81+
for (size_t i = 0; i < abs.size(); i++)
82+
{
83+
REQUIRE(expected(i) == doctest::Approx(abs(i)).epsilon(.0001));
84+
}
85+
}
86+
}

0 commit comments

Comments
 (0)