Replies: 1 comment 3 replies
-
Hey @PVarnai, we don't have a "sliding_window_scan", while it sounds cool you can mimic it by first creating the windows and then scanning over them. Here is a link to some JAX code for creating such windows: Copying the code here for convenience (credit to @erdmann): from functools import partial
import jax
import jax.numpy as jnp
from jax import jit, vmap
@partial(jit, static_argnums=(1,))
def moving_window(a, size: int):
starts = jnp.arange(len(a) - size + 1)
return vmap(lambda start: jax.lax.dynamic_slice(a, (start,), (size,)))(starts)
a = jnp.arange(10)
print(moving_window(a, 4))
You can generalize it for |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi!
I found some great examples of scanning over a sequence of data one by one, also using the lifted
linen.scan
version which is convenient with the flax training loop. But what if I want to scan a rolling window over a sequence of data? I couldn't really find anything, and I'm surprised this is not a usual operation (or it is I'm just missing something). For example, if I had a tensorx
of shape(T, dims...)
, instead of getting the slicesx[0, dims...]
,x[1, dims...]
, ... to operate on within each iteration of the scan, I would want to getx[0:window, dims...]
,x[1:window+1, dims...]
, and so on.Thanks for any help!
Beta Was this translation helpful? Give feedback.
All reactions