Skip to content

Commit a6390be

Browse files
committed
Fix test_bench
1 parent b2b4158 commit a6390be

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

test_bench.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,10 @@ def pytest_generate_tests(metafunc):
4242

4343
if metafunc.cls and metafunc.cls.__name__ == "TestBenchNetwork":
4444
paths = _list_model_paths()
45-
model_names = [os.path.basename(path) for path in paths]
4645
metafunc.parametrize(
47-
"model_name",
48-
model_names,
49-
ids=model_names,
46+
"model_path",
47+
paths,
48+
ids=[os.path.basename(path) for path in paths],
5049
scope="class",
5150
)
5251

@@ -62,13 +61,14 @@ def pytest_generate_tests(metafunc):
6261
)
6362
class TestBenchNetwork:
6463

65-
def test_train(self, model_name, device, compiler, benchmark):
64+
def test_train(self, model_path, device, benchmark):
6665
try:
66+
model_name = os.path.basename(model_path)
6767
if skip_by_metadata(
6868
test="train",
6969
device=device,
7070
extra_args=[],
71-
metadata=get_metadata_from_yaml(model_name),
71+
metadata=get_metadata_from_yaml(model_path),
7272
):
7373
raise NotImplementedError("Test skipped by its metadata.")
7474
# TODO: skipping quantized tests for now due to BC-breaking changes for prepare
@@ -91,13 +91,14 @@ def test_train(self, model_name, device, compiler, benchmark):
9191
except NotImplementedError:
9292
print(f"Test train on {device} is not implemented, skipping...")
9393

94-
def test_eval(self, model_name, device, compiler, benchmark, pytestconfig):
94+
def test_eval(self, model_path, device, benchmark, pytestconfig):
9595
try:
96+
model_name = os.path.basename(model_path)
9697
if skip_by_metadata(
9798
test="eval",
9899
device=device,
99100
extra_args=[],
100-
metadata=get_metadata_from_yaml(model_name),
101+
metadata=get_metadata_from_yaml(model_path),
101102
):
102103
raise NotImplementedError("Test skipped by its metadata.")
103104
# TODO: skipping quantized tests for now due to BC-breaking changes for prepare
@@ -110,16 +111,15 @@ def test_eval(self, model_name, device, compiler, benchmark, pytestconfig):
110111

111112
task.make_model_instance(test="eval", device=device)
112113

113-
with task.no_grad(disable_nograd=pytestconfig.getoption("disable_nograd")):
114-
benchmark(task.invoke)
115-
benchmark.extra_info["machine_state"] = get_machine_state()
116-
benchmark.extra_info["batch_size"] = task.get_model_attribute(
117-
"batch_size"
118-
)
119-
benchmark.extra_info["precision"] = task.get_model_attribute(
120-
"dargs", "precision"
121-
)
122-
benchmark.extra_info["test"] = "eval"
114+
benchmark(task.invoke)
115+
benchmark.extra_info["machine_state"] = get_machine_state()
116+
benchmark.extra_info["batch_size"] = task.get_model_attribute(
117+
"batch_size"
118+
)
119+
benchmark.extra_info["precision"] = task.get_model_attribute(
120+
"dargs", "precision"
121+
)
122+
benchmark.extra_info["test"] = "eval"
123123

124124
except NotImplementedError:
125125
print(f"Test eval on {device} is not implemented, skipping...")

0 commit comments

Comments
 (0)