Skip to content

Commit c5675dd

Browse files
authored
[Relax][PyTorch] Support linspace op for ExportedProgram importer (#17889)
* Update base_fx_graph_translator.py * Update exported_program_translator.py * Update test_frontend_from_exported_program.py * Update test_frontend_from_exported_program.py * Update test_frontend_from_exported_program.py * fix lint * Update test_frontend_from_exported_program.py
1 parent 67297c4 commit c5675dd

File tree

3 files changed

+45
-0
lines changed

3 files changed

+45
-0
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

+23
Original file line numberDiff line numberDiff line change
@@ -1465,6 +1465,29 @@ def _inplace_masked_fill(self, node: fx.Node) -> relax.Var:
14651465
self.env[node.args[0]] = output
14661466
return output
14671467

1468+
def _linspace(self, node: fx.Node) -> relax.Var:
1469+
args = self.retrieve_args(node)
1470+
start = args[0]
1471+
stop = args[1]
1472+
step = args[2]
1473+
1474+
if step != 1:
1475+
step = (stop - start) / (step - 1)
1476+
stop = stop + (step / 2)
1477+
else:
1478+
stop = start + step
1479+
1480+
if len(args) <= 3 or args[3] is None:
1481+
import torch
1482+
1483+
dtype = self._convert_data_type(str(torch.get_default_dtype()))
1484+
else:
1485+
dtype = self._convert_data_type(args[3])
1486+
1487+
return self.block_builder.emit(
1488+
relax.op.arange(start=start, end=stop, step=step, dtype=dtype)
1489+
)
1490+
14681491
def _masked_fill(self, node: fx.Node) -> relax.Var:
14691492
x = self.env[node.args[0]]
14701493
mask = self.env[node.args[1]]

python/tvm/relax/frontend/torch/exported_program_translator.py

+1
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ def create_convert_map(
475475
"full_like.default": self._full_like,
476476
"index_select.default": self._index_select,
477477
"lift_fresh_copy.default": self._to_copy,
478+
"linspace.default": self._linspace,
478479
"masked_fill.Scalar": self._masked_fill,
479480
"new_ones.default": self._new_ones,
480481
"one_hot.default": self._one_hot,

tests/python/relax/test_frontend_from_exported_program.py

+21
Original file line numberDiff line numberDiff line change
@@ -4817,5 +4817,26 @@ def main(
48174817
verify_model(Eye2(), example_args2, {}, Expected2)
48184818

48194819

4820+
def test_linspace():
4821+
class Linspace(Module):
4822+
def forward(self, input):
4823+
return torch.linspace(0, 1, steps=9, dtype=torch.float32)
4824+
4825+
@tvm.script.ir_module
4826+
class Expected:
4827+
@R.function
4828+
def main(
4829+
input: R.Tensor((9, 9), dtype="float32")
4830+
) -> R.Tuple(R.Tensor((9,), dtype="float32")):
4831+
with R.dataflow():
4832+
lv: R.Tensor((9,), dtype="float32") = R.arange(0, 1.0625, 0.125, dtype="float32")
4833+
gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv,)
4834+
R.output(gv)
4835+
return gv
4836+
4837+
example_args = (torch.randn(9, 9, dtype=torch.float32),)
4838+
verify_model(Linspace(), example_args, {}, Expected)
4839+
4840+
48204841
if __name__ == "__main__":
48214842
tvm.testing.main()

0 commit comments

Comments
 (0)