Help with lifted vmap #1399
-
I've recently created a module that simulates wave propagation. It is written as below: from typing import Any, Iterable, Tuple
from jax import numpy as jnp, random
from flax import linen as nn
class WERNN(nn.Module):
v0: jnp.array
dz: float
dx: float
dt: float
def makesimulation(self, source: Any, nt: int):
'''
Returns acoustic wave simulation results given a source and a number of time
steps to be calculated (nt).
'''
return jnp.empty((self.v0.shape[0], nt)) # just for simplification
@nn.compact
def __call__(self, source: Any, nt: int):
self.variable('variables', 'something', lambda: 0)
return self.makesimulation(source, nt) This module returns a 2D array. As I need to use the result of many simulations stored in a 3D array, I thought of doing a lifted vmap on the WERNN module. Here is my trial for this simplified model: class BatchWERNN(nn.Module):
v0: jnp.array
dz: float
dx: float
dt: float
@nn.compact
def __call__(self, sources: Tuple[Any], nt: int):
WERNNVmap = nn.vmap(WERNN, in_axes=0, out_axes=0,
variable_axes={'params': None, 'variables': None},
split_rngs={'params': False, 'variables': False})
return WERNNVmap(v0=self.v0, dz=self.dz, dx=self.dx, dt=self.dt)(sources, nt) In my example
I still cannot understand what is getting in my way though. Do you? I really appreciate your attention. Thank you for trying to help me! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
I think you need to broadcast |
Beta Was this translation helpful? Give feedback.
I think you need to broadcast
nt
because you reuse the same value in each call. Soin_axes=(0, None)