Skip to content

Commit 715befb

Browse files
committed
remove singledispatch for better static typing
1 parent 60b1625 commit 715befb

File tree

1 file changed

+18
-31
lines changed

1 file changed

+18
-31
lines changed

bioimageio/core/digest_spec.py

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import importlib.util
4-
from functools import singledispatch
54
from itertools import chain
65
from typing import (
76
Any,
@@ -20,7 +19,7 @@
2019
from numpy.typing import NDArray
2120
from typing_extensions import Unpack, assert_never
2221

23-
from bioimageio.spec._internal.io import HashKwargs, download
22+
from bioimageio.spec._internal.io_utils import HashKwargs, download
2423
from bioimageio.spec.common import FileSource
2524
from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
2625
from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile
@@ -44,44 +43,32 @@
4443
from .tensor import Tensor
4544

4645

47-
@singledispatch
48-
def import_callable(node: type, /) -> Callable[..., Any]:
46+
def import_callable(
47+
node: Union[CallableFromDepencency, ArchitectureFromLibraryDescr],
48+
/,
49+
**kwargs: Unpack[HashKwargs],
50+
) -> Callable[..., Any]:
4951
"""import a callable (e.g. a torch.nn.Module) from a spec node describing it"""
50-
raise TypeError(type(node))
51-
52-
53-
@import_callable.register
54-
def _(node: CallableFromDepencency, **kwargs: Unpack[HashKwargs]) -> Callable[..., Any]:
55-
module = importlib.import_module(node.module_name)
56-
c = getattr(module, str(node.callable_name))
57-
if not callable(c):
58-
raise ValueError(f"{node} (imported: {c}) is not callable")
59-
60-
return c
52+
if isinstance(node, CallableFromDepencency):
53+
module = importlib.import_module(node.module_name)
54+
c = getattr(module, str(node.callable_name))
55+
elif isinstance(node, ArchitectureFromLibraryDescr):
56+
module = importlib.import_module(node.import_from)
57+
c = getattr(module, str(node.callable))
58+
elif isinstance(node, CallableFromFile):
59+
c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs)
60+
elif isinstance(node, ArchitectureFromFileDescr):
61+
c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256)
6162

63+
else:
64+
assert_never(node)
6265

63-
@import_callable.register
64-
def _(
65-
node: ArchitectureFromLibraryDescr, **kwargs: Unpack[HashKwargs]
66-
) -> Callable[..., Any]:
67-
module = importlib.import_module(node.import_from)
68-
c = getattr(module, str(node.callable))
6966
if not callable(c):
7067
raise ValueError(f"{node} (imported: {c}) is not callable")
7168

7269
return c
7370

7471

75-
@import_callable.register
76-
def _(node: CallableFromFile, **kwargs: Unpack[HashKwargs]):
77-
return _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs)
78-
79-
80-
@import_callable.register
81-
def _(node: ArchitectureFromFileDescr, **kwargs: Unpack[HashKwargs]):
82-
return _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256)
83-
84-
8572
def _import_from_file_impl(
8673
source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]
8774
):

0 commit comments

Comments
 (0)