Skip to content

Commit 5594719

Browse files
committed
fix: Fix HaikuIvyModule by handling a backward incompatible release of jax which removed the .device() method on jax Arrays in favour of a .device attribute
1 parent d3efff2 commit 5594719

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

ivy/stateful/module.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections import OrderedDict
55
import os
66
import copy
7+
from packaging import version
78
import dill
89
from typing import Optional, Tuple, Dict
910

@@ -687,7 +688,11 @@ def _build(self, params_hk, *args, **kwargs):
687688
param_iterator = self._hk_params.cont_to_iterator()
688689
_, param0 = next(param_iterator, ["_", 0])
689690
if hasattr(param0, "device"):
690-
self._device = ivy.as_ivy_dev(param0.device())
691+
import jax
692+
if version.parse(jax.__version__) >= version.parse("0.4.31"):
693+
self._device = ivy.as_ivy_dev(param0.device)
694+
else:
695+
self._device = ivy.as_ivy_dev(param0.device())
691696
else:
692697
self._device = ivy.as_ivy_dev("cpu")
693698
ivy.previous_backend()

0 commit comments

Comments
 (0)