8
8
import sys
9
9
import types
10
10
import unittest
11
- from typing import (
12
- Any ,
13
- Callable ,
14
- Dict ,
15
- List ,
16
- Optional ,
17
- overload ,
18
- Sequence ,
19
- TypeVar ,
20
- Union ,
21
- )
11
+ from collections .abc import Sequence
12
+ from typing import Any , Callable , Optional , overload , TypeVar , Union
22
13
from typing_extensions import ParamSpec
23
14
from unittest .mock import patch
24
15
@@ -83,7 +74,7 @@ def extract_graph_backend(_gm, *args, **kwargs): # type: ignore[no-untyped-def]
83
74
84
75
def collect_results (
85
76
model : torch .nn .Module , prediction : Any , loss : Any , example_inputs : Any
86
- ) -> List [Any ]:
77
+ ) -> list [Any ]:
87
78
results = []
88
79
results .append (prediction )
89
80
results .append (loss )
@@ -140,7 +131,7 @@ def reduce_to_scalar_loss(out: torch.Tensor) -> torch.Tensor:
140
131
141
132
@overload
142
133
def reduce_to_scalar_loss (
143
- out : Union [List [Any ], tuple [Any , ...], Dict [Any , Any ]]
134
+ out : Union [list [Any ], tuple [Any , ...], dict [Any , Any ]]
144
135
) -> float :
145
136
...
146
137
@@ -186,7 +177,7 @@ def debug_insert_nops(
186
177
) -> Optional [GuardedCode ]:
187
178
"""used to debug jump updates"""
188
179
189
- def insert_nops (instructions : List [Any ], code_options : Any ) -> None :
180
+ def insert_nops (instructions : list [Any ], code_options : Any ) -> None :
190
181
instructions .insert (0 , create_instruction ("NOP" ))
191
182
instructions .insert (0 , create_instruction ("NOP" ))
192
183
@@ -222,7 +213,7 @@ def __init__(self) -> None:
222
213
self .op_count = 0
223
214
224
215
def __call__ (
225
- self , gm : torch .fx .GraphModule , example_inputs : List [torch .Tensor ]
216
+ self , gm : torch .fx .GraphModule , example_inputs : list [torch .Tensor ]
226
217
) -> Callable [..., Any ]:
227
218
self .frame_count += 1
228
219
for node in gm .graph .nodes :
@@ -240,10 +231,10 @@ def __init__(self, backend: str) -> None:
240
231
self .frame_count = 0
241
232
self .op_count = 0
242
233
self .backend = backend
243
- self .graphs : List [torch .fx .GraphModule ] = []
234
+ self .graphs : list [torch .fx .GraphModule ] = []
244
235
245
236
def __call__ (
246
- self , gm : torch .fx .GraphModule , example_inputs : List [torch .Tensor ]
237
+ self , gm : torch .fx .GraphModule , example_inputs : list [torch .Tensor ]
247
238
) -> Callable [..., Any ]:
248
239
from .backends .registry import lookup_backend
249
240
@@ -264,34 +255,34 @@ def clear(self) -> None:
264
255
# we can assert on
265
256
class EagerAndRecordGraphs :
266
257
def __init__ (self ) -> None :
267
- self .graphs : List [torch .fx .GraphModule ] = []
258
+ self .graphs : list [torch .fx .GraphModule ] = []
268
259
269
260
def __call__ (
270
- self , gm : torch .fx .GraphModule , example_inputs : List [torch .Tensor ]
261
+ self , gm : torch .fx .GraphModule , example_inputs : list [torch .Tensor ]
271
262
) -> Callable [..., Any ]:
272
263
self .graphs .append (gm )
273
264
return gm .forward
274
265
275
266
276
267
class AotEagerAndRecordGraphs :
277
268
def __init__ (self ) -> None :
278
- self .graphs : List [torch .fx .GraphModule ] = []
279
- self .fw_graphs : List [torch .fx .GraphModule ] = []
280
- self .bw_graphs : List [torch .fx .GraphModule ] = []
269
+ self .graphs : list [torch .fx .GraphModule ] = []
270
+ self .fw_graphs : list [torch .fx .GraphModule ] = []
271
+ self .bw_graphs : list [torch .fx .GraphModule ] = []
281
272
282
273
def __call__ (
283
- self , gm : torch .fx .GraphModule , example_inputs : List [torch .Tensor ]
274
+ self , gm : torch .fx .GraphModule , example_inputs : list [torch .Tensor ]
284
275
) -> Callable [..., Any ]:
285
276
self .graphs .append (gm )
286
277
287
278
def fw_compiler (
288
- gm : torch .fx .GraphModule , example_inputs : List [torch .Tensor ]
279
+ gm : torch .fx .GraphModule , example_inputs : list [torch .Tensor ]
289
280
) -> Callable [..., Any ]:
290
281
self .fw_graphs .append (gm )
291
282
return gm .forward
292
283
293
284
def bw_compiler (
294
- gm : torch .fx .GraphModule , example_inputs : List [torch .Tensor ]
285
+ gm : torch .fx .GraphModule , example_inputs : list [torch .Tensor ]
295
286
) -> Callable [..., Any ]:
296
287
self .bw_graphs .append (gm )
297
288
return gm .forward
@@ -360,7 +351,7 @@ def standard_test(
360
351
361
352
362
353
def dummy_fx_compile (
363
- gm : fx .GraphModule , example_inputs : List [torch .Tensor ]
354
+ gm : fx .GraphModule , example_inputs : list [torch .Tensor ]
364
355
) -> Callable [..., Any ]:
365
356
return gm .forward
366
357
0 commit comments