@@ -42,11 +42,10 @@ def pytest_generate_tests(metafunc):
42
42
43
43
if metafunc .cls and metafunc .cls .__name__ == "TestBenchNetwork" :
44
44
paths = _list_model_paths ()
45
- model_names = [os .path .basename (path ) for path in paths ]
46
45
metafunc .parametrize (
47
- "model_name " ,
48
- model_names ,
49
- ids = model_names ,
46
+ "model_path " ,
47
+ paths ,
48
+ ids = [ os . path . basename ( path ) for path in paths ] ,
50
49
scope = "class" ,
51
50
)
52
51
@@ -62,13 +61,14 @@ def pytest_generate_tests(metafunc):
62
61
)
63
62
class TestBenchNetwork :
64
63
65
- def test_train (self , model_name , device , compiler , benchmark ):
64
+ def test_train (self , model_path , device , benchmark ):
66
65
try :
66
+ model_name = os .path .basename (model_path )
67
67
if skip_by_metadata (
68
68
test = "train" ,
69
69
device = device ,
70
70
extra_args = [],
71
- metadata = get_metadata_from_yaml (model_name ),
71
+ metadata = get_metadata_from_yaml (model_path ),
72
72
):
73
73
raise NotImplementedError ("Test skipped by its metadata." )
74
74
# TODO: skipping quantized tests for now due to BC-breaking changes for prepare
@@ -91,13 +91,14 @@ def test_train(self, model_name, device, compiler, benchmark):
91
91
except NotImplementedError :
92
92
print (f"Test train on { device } is not implemented, skipping..." )
93
93
94
- def test_eval (self , model_name , device , compiler , benchmark , pytestconfig ):
94
+ def test_eval (self , model_path , device , benchmark , pytestconfig ):
95
95
try :
96
+ model_name = os .path .basename (model_path )
96
97
if skip_by_metadata (
97
98
test = "eval" ,
98
99
device = device ,
99
100
extra_args = [],
100
- metadata = get_metadata_from_yaml (model_name ),
101
+ metadata = get_metadata_from_yaml (model_path ),
101
102
):
102
103
raise NotImplementedError ("Test skipped by its metadata." )
103
104
# TODO: skipping quantized tests for now due to BC-breaking changes for prepare
@@ -110,16 +111,15 @@ def test_eval(self, model_name, device, compiler, benchmark, pytestconfig):
110
111
111
112
task .make_model_instance (test = "eval" , device = device )
112
113
113
- with task .no_grad (disable_nograd = pytestconfig .getoption ("disable_nograd" )):
114
- benchmark (task .invoke )
115
- benchmark .extra_info ["machine_state" ] = get_machine_state ()
116
- benchmark .extra_info ["batch_size" ] = task .get_model_attribute (
117
- "batch_size"
118
- )
119
- benchmark .extra_info ["precision" ] = task .get_model_attribute (
120
- "dargs" , "precision"
121
- )
122
- benchmark .extra_info ["test" ] = "eval"
114
+ benchmark (task .invoke )
115
+ benchmark .extra_info ["machine_state" ] = get_machine_state ()
116
+ benchmark .extra_info ["batch_size" ] = task .get_model_attribute (
117
+ "batch_size"
118
+ )
119
+ benchmark .extra_info ["precision" ] = task .get_model_attribute (
120
+ "dargs" , "precision"
121
+ )
122
+ benchmark .extra_info ["test" ] = "eval"
123
123
124
124
except NotImplementedError :
125
125
print (f"Test eval on { device } is not implemented, skipping..." )
0 commit comments