Skip to content

Commit fa0f3e8

Browse files
author
Flax Authors
committed
Merge pull request #4634 from IvyZX:parentattr
PiperOrigin-RevId: 738944800
2 parents 1e48380 + 18faf54 commit fa0f3e8

File tree

2 files changed

+44
-26
lines changed

2 files changed

+44
-26
lines changed

flax/nnx/bridge/module.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -143,37 +143,34 @@ def _module_meta_call(cls: type[M], *args, **kwargs) -> M:
143143

144144
name = None
145145
if parent_ctx is not None:
146-
want_name = 'name' in kwargs and 'name' in inspect.get_annotations(cls)
147-
if not want_name and not parent_ctx.in_compact and 'name' in kwargs:
148-
raise ValueError(
149-
f"'name' can only be set in @compact functions. If in setup(), "
150-
"use parent's `self.<attr_name> to set the submodule name.")
146+
if 'parent' in kwargs:
147+
parent = kwargs.pop('parent')
148+
if parent_ctx.in_compact and parent is not None:
149+
raise ValueError(
150+
f"'parent' can only be set to None, got {type(parent).__name__}"
151+
)
152+
else:
153+
parent = parent_ctx.module
151154

152-
if parent_ctx.in_compact:
153-
if 'parent' in kwargs:
154-
parent = kwargs.pop('parent')
155-
if parent is not None:
156-
raise ValueError(
157-
f"'parent' can only be set to None, got {type(parent).__name__}"
158-
)
159-
else:
160-
if 'name' in kwargs:
161-
name = kwargs['name'] if want_name else kwargs.pop('name')
162-
if not isinstance(name, str):
163-
raise ValueError(f"'name' must be a 'str', got {type(name).__name__}")
164-
else:
165-
name = _auto_submodule_name(parent_ctx, cls)
166-
parent = parent_ctx.module
155+
if 'name' in kwargs:
156+
name = kwargs['name']
157+
if not 'name' in inspect.get_annotations(cls):
158+
kwargs.pop('name')
159+
if not isinstance(name, str):
160+
raise ValueError(f"'name' must be a 'str', got {type(name).__name__}")
161+
elif parent_ctx.in_compact:
162+
name = _auto_submodule_name(parent_ctx, cls)
167163

168164
module = nnx_module.ModuleMeta.__call__(cls, *args, **kwargs)
169165
module.scope = None
170166
module.attr_priorities = {}
171167

172-
# compact behavior
173168
if parent is not None:
174169
assert parent.scope is not None
175-
assert name is not None
176-
setattr(parent, name, module)
170+
# compact, or setup if `name` exists
171+
if name is not None:
172+
setattr(parent, name, module)
173+
parent.set_attr_priority(name, AttrPriority.INIT_PARENT)
177174

178175
return module # type: ignore
179176

@@ -182,9 +179,10 @@ def _module_meta_call(cls: type[M], *args, **kwargs) -> M:
182179
ModuleMeta.__call__ = _module_meta_call # type: ignore
183180

184181
class AttrPriority(enum.IntEnum):
185-
HIGH = 1
186-
DEFAULT = 2
187-
LOW = 3
182+
HIGH = 0
183+
INIT_PARENT = 20
184+
DEFAULT = 50
185+
LOW = 100
188186

189187

190188
class PriorityStr(str):

tests/nnx/bridge/module_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,26 @@ def __call__(self, x):
521521
params = model.init(jax.random.key(0), x)['params']
522522
self.assertSameElements(['zzz'], params.keys())
523523

524+
def test_linen_layer_naming(self):
525+
class Dense(bridge.Module):
526+
dout: int
527+
@bridge.compact
528+
def __call__(self, x):
529+
return x @ self.param('w', lambda _: jnp.ones((x.shape[-1], self.dout)))
530+
531+
class MLP(bridge.Module):
532+
nlayers: int
533+
def setup(self):
534+
self.layers = [Dense(4, name=f'layer_{i}') for i in range(self.nlayers)]
535+
def __call__(self, x):
536+
for layer in self.layers:
537+
x = layer(x)
538+
return x
539+
540+
model = MLP(nlayers=3)
541+
x = jnp.ones((2, 4))
542+
params = model.init(jax.random.key(0), x)['params']
543+
self.assertSameElements([f'layer_{i}' for i in range(3)], params.keys())
524544

525545

526546
if __name__ == '__main__':

0 commit comments

Comments
 (0)