59
59
from diffusers .utils .testing_utils import (
60
60
CaptureLogger ,
61
61
backend_empty_cache ,
62
+ backend_max_memory_allocated ,
63
+ backend_reset_peak_memory_stats ,
64
+ backend_synchronize ,
62
65
floats_tensor ,
63
66
get_python_version ,
64
67
is_torch_compile ,
68
71
require_torch_2 ,
69
72
require_torch_accelerator ,
70
73
require_torch_accelerator_with_training ,
71
- require_torch_gpu ,
72
74
require_torch_multi_accelerator ,
73
75
run_test_in_subprocess ,
74
76
slow ,
@@ -341,7 +343,7 @@ def test_weight_overwrite(self):
341
343
342
344
assert model .config .in_channels == 9
343
345
344
- @require_torch_gpu
346
+ @require_torch_accelerator
345
347
def test_keep_modules_in_fp32 (self ):
346
348
r"""
347
349
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32 when we load the model in fp16/bf16
@@ -1480,16 +1482,16 @@ def test_layerwise_casting(storage_dtype, compute_dtype):
1480
1482
test_layerwise_casting (torch .float8_e5m2 , torch .float32 )
1481
1483
test_layerwise_casting (torch .float8_e4m3fn , torch .bfloat16 )
1482
1484
1483
- @require_torch_gpu
1485
+ @require_torch_accelerator
1484
1486
def test_layerwise_casting_memory (self ):
1485
1487
MB_TOLERANCE = 0.2
1486
1488
LEAST_COMPUTE_CAPABILITY = 8.0
1487
1489
1488
1490
def reset_memory_stats ():
1489
1491
gc .collect ()
1490
- torch . cuda . synchronize ( )
1491
- torch . cuda . empty_cache ( )
1492
- torch . cuda . reset_peak_memory_stats ( )
1492
+ backend_synchronize ( torch_device )
1493
+ backend_empty_cache ( torch_device )
1494
+ backend_reset_peak_memory_stats ( torch_device )
1493
1495
1494
1496
def get_memory_usage (storage_dtype , compute_dtype ):
1495
1497
torch .manual_seed (0 )
@@ -1502,7 +1504,7 @@ def get_memory_usage(storage_dtype, compute_dtype):
1502
1504
reset_memory_stats ()
1503
1505
model (** inputs_dict )
1504
1506
model_memory_footprint = model .get_memory_footprint ()
1505
- peak_inference_memory_allocated_mb = torch . cuda . max_memory_allocated ( ) / 1024 ** 2
1507
+ peak_inference_memory_allocated_mb = backend_max_memory_allocated ( torch_device ) / 1024 ** 2
1506
1508
1507
1509
return model_memory_footprint , peak_inference_memory_allocated_mb
1508
1510
@@ -1512,7 +1514,7 @@ def get_memory_usage(storage_dtype, compute_dtype):
1512
1514
torch .float8_e4m3fn , torch .bfloat16
1513
1515
)
1514
1516
1515
- compute_capability = get_torch_cuda_device_capability ()
1517
+ compute_capability = get_torch_cuda_device_capability () if torch_device == "cuda" else None
1516
1518
self .assertTrue (fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint )
1517
1519
# NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
1518
1520
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
@@ -1527,7 +1529,7 @@ def get_memory_usage(storage_dtype, compute_dtype):
1527
1529
)
1528
1530
1529
1531
@parameterized .expand ([False , True ])
1530
- @require_torch_gpu
1532
+ @require_torch_accelerator
1531
1533
def test_group_offloading (self , record_stream ):
1532
1534
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1533
1535
torch .manual_seed (0 )
0 commit comments