How to pass dtype to self.param(...)
in flax.linen.Module
#1424
Replies: 1 comment
-
Not sure if it is still the case but in the past the jax.nn.initializers couldn't produce half types. In Flax each layer allows you to store intermediates in a specified dtype but params are in float32 by default. This is because parameters in half precision tend to be really difficult to optimize. It is possible to take the params from an initialized network and cast them. This has some nice benefits like the ability to cast only in some specific cases. For example casting before evaluation is always a cheap win but you could also cast part of the network for a fine-tuning task. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
Problem you have encountered:
For a specific layernorm implementation, I'm tryning to pass
dtype
toself.param(....)
, but it seems that this is not possible.I would like the user be able to switch between
bfloat16
,float16
andfloat32
for the weight.How can one pass a
dtype
parameter toself.param(...)
?What you expected to happen:
I would like to implement the following layer:
but it doesn't seem to be possible
Beta Was this translation helpful? Give feedback.
All reactions