File tree Expand file tree Collapse file tree 2 files changed +10
-14
lines changed Expand file tree Collapse file tree 2 files changed +10
-14
lines changed Original file line number Diff line number Diff line change 16
16
17
17
from collections .abc import Iterable , Sequence
18
18
from typing import Any , Protocol
19
+ import functools
19
20
20
21
from flax .core import meta
21
22
from flax .linen import initializers
@@ -588,14 +589,11 @@ def maybe_broadcast(
588
589
kernel_size_dilated = [
589
590
(k - 1 ) * d + 1 for k , d in zip (kernel_size , kernel_dilation )
590
591
]
591
- zero_pad : list [tuple [int , int ]] = [(0 , 0 )]
592
- pads = (
593
- zero_pad
594
- + [((k - 1 ) // 2 , k // 2 ) for k in kernel_size_dilated ]
595
- + [(0 , 0 )]
596
- )
592
+ pads = [((k - 1 ) // 2 , k // 2 ) for k in kernel_size_dilated ] + [(0 , 0 )]
597
593
padding_mode = {'CIRCULAR' : 'wrap' , 'REFLECT' : 'reflect' }[padding_lax ]
598
- inputs = jnp .pad (inputs , pads , mode = padding_mode )
594
+ inputs = jax .vmap (
595
+ functools .partial (jnp .pad , pad_width = pads , mode = padding_mode )
596
+ )(inputs )
599
597
padding_lax = 'VALID'
600
598
elif padding_lax == 'CAUSAL' :
601
599
if len (kernel_size ) != 1 :
Original file line number Diff line number Diff line change 13
13
# limitations under the License.
14
14
from __future__ import annotations
15
15
16
+ import functools
16
17
import typing as tp
17
18
18
19
import jax
@@ -751,14 +752,11 @@ def maybe_broadcast(
751
752
kernel_size_dilated = [
752
753
(k - 1 ) * d + 1 for k , d in zip (kernel_size , kernel_dilation )
753
754
]
754
- zero_pad : tp .List [tuple [int , int ]] = [(0 , 0 )]
755
- pads = (
756
- zero_pad
757
- + [((k - 1 ) // 2 , k // 2 ) for k in kernel_size_dilated ]
758
- + [(0 , 0 )]
759
- )
755
+ pads = [((k - 1 ) // 2 , k // 2 ) for k in kernel_size_dilated ] + [(0 , 0 )]
760
756
padding_mode = {'CIRCULAR' : 'wrap' , 'REFLECT' : 'reflect' }[padding_lax ]
761
- inputs = jnp .pad (inputs , pads , mode = padding_mode )
757
+ inputs = jax .vmap (
758
+ functools .partial (jnp .pad , pad_width = pads , mode = padding_mode )
759
+ )(inputs )
762
760
padding_lax = 'VALID'
763
761
elif padding_lax == 'CAUSAL' :
764
762
if len (kernel_size ) != 1 :
You can’t perform that action at this time.
0 commit comments