You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to use flax.core.FrozenDict as lookup table in JAX, but it seems it doesn't work due to the inability to hash JAX arrays. Am I doing something wrong here or is just a limitation with in Dict-like objects in JAX?
My goal is to use a dictionary to make a lookup-table in JAX, which would scale with O(1) complexity with batch size, N. My current (working) method is to use jnp.searchsorted(values, queries), but I was wondering if using FrozenDict (or something similar) like,
I've written a minimal reproducible example below, highlighting my current issue and where this dictionary based approach fails. Is there a way to get around this issue?
import jax
jax.config.update('jax_enable_x64',True)
from jax import numpy as jnp
from jax import random as jr
from flax.core import FrozenDict
from functools import partial
import operator
key = jr.key(42)
N = 1_000 # batch size
values = jr.randint(key, shape=(N,), minval=0, maxval=int(2**32))
values = jnp.sort(values)
indices = jnp.arange(values.shape[0])
# given new idx return old idx
val2idx_mapping = FrozenDict({int(val):idx for idx, val in zip(indices, values)})
new_values = jr.permutation(key, values) # get 'new' values with different order
searchsorted_idxs = jnp.searchsorted(values, new_values, side='left') # works (but can be quite slow for very large N >> 1e8)
print('Searchsorted idxs: ',searchsorted_idxs)
@partial(jax.vmap, in_axes=(0))
def lookup(x: jnp.ndarray) -> int:
return operator.getitem(val2idx_mapping, x.astype(int))
idxs = lookup(new_values)
print('Lookup idxs: ',idxs) # unhashable type: 'BatchTracer' error
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hi All,
I'm trying to use
flax.core.FrozenDict
as lookup table in JAX, but it seems it doesn't work due to the inability to hash JAX arrays. Am I doing something wrong here or is just a limitation with inDict
-like objects in JAX?My goal is to use a dictionary to make a lookup-table in JAX, which would scale with O(1) complexity with batch size, N. My current (working) method is to use
jnp.searchsorted(values, queries)
, but I was wondering if usingFrozenDict
(or something similar) like,is possible in JAX?
I've written a minimal reproducible example below, highlighting my current issue and where this dictionary based approach fails. Is there a way to get around this issue?
Beta Was this translation helpful? Give feedback.
All reactions