Skip to content

Commit 893e936

Browse files
committed
vmap np.pad over the batch
1 parent 1e48380 commit 893e936

File tree

2 files changed

+10
-14
lines changed

2 files changed

+10
-14
lines changed

flax/linen/linear.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from collections.abc import Iterable, Sequence
1818
from typing import Any, Protocol
19+
import functools
1920

2021
from flax.core import meta
2122
from flax.linen import initializers
@@ -588,14 +589,11 @@ def maybe_broadcast(
588589
kernel_size_dilated = [
589590
(k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation)
590591
]
591-
zero_pad: list[tuple[int, int]] = [(0, 0)]
592-
pads = (
593-
zero_pad
594-
+ [((k - 1) // 2, k // 2) for k in kernel_size_dilated]
595-
+ [(0, 0)]
596-
)
592+
pads = [((k - 1) // 2, k // 2) for k in kernel_size_dilated] + [(0, 0)]
597593
padding_mode = {'CIRCULAR': 'wrap', 'REFLECT': 'reflect'}[padding_lax]
598-
inputs = jnp.pad(inputs, pads, mode=padding_mode)
594+
inputs = jax.vmap(
595+
functools.partial(jnp.pad, pad_width=pads, mode=padding_mode)
596+
)(inputs)
599597
padding_lax = 'VALID'
600598
elif padding_lax == 'CAUSAL':
601599
if len(kernel_size) != 1:

flax/nnx/nn/linear.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16+
import functools
1617
import typing as tp
1718

1819
import jax
@@ -751,14 +752,11 @@ def maybe_broadcast(
751752
kernel_size_dilated = [
752753
(k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation)
753754
]
754-
zero_pad: tp.List[tuple[int, int]] = [(0, 0)]
755-
pads = (
756-
zero_pad
757-
+ [((k - 1) // 2, k // 2) for k in kernel_size_dilated]
758-
+ [(0, 0)]
759-
)
755+
pads = [((k - 1) // 2, k // 2) for k in kernel_size_dilated] + [(0, 0)]
760756
padding_mode = {'CIRCULAR': 'wrap', 'REFLECT': 'reflect'}[padding_lax]
761-
inputs = jnp.pad(inputs, pads, mode=padding_mode)
757+
inputs = jax.vmap(
758+
functools.partial(jnp.pad, pad_width=pads, mode=padding_mode)
759+
)(inputs)
762760
padding_lax = 'VALID'
763761
elif padding_lax == 'CAUSAL':
764762
if len(kernel_size) != 1:

0 commit comments

Comments
 (0)