Open
Description
n00b to this very cool project, looking to enforce a broadcast-ability pattern where a dimension in one tensor either matches or can be broadcast to (i.e. equals 1) a dimension in another tensor.
@typeguard.typechecked
def mwe(
x: torchtyping.TensorType[
...,
"foo",
"bar", # How do we make this "match bar from arg_b or equal 1"?
],
y: torchtyping.TensorType[
"bar",
]) -> torch typing.TensorType[...,"foo","bar"]:
return x * y
Am I missing an existing way to do this in torchtyping
out of the box? Would this need an extension?
Metadata
Metadata
Assignees
Labels
No labels