Skip to content

Commit a7e9f85

Browse files
authored
enable test_layerwise_casting_memory cases on XPU (#11406)
* enable test_layerwise_casting_memory cases on XPU Signed-off-by: Yao Matrix <[email protected]> * fix style Signed-off-by: Yao Matrix <[email protected]> --------- Signed-off-by: Yao Matrix <[email protected]>
1 parent 9ce89e2 commit a7e9f85

File tree

2 files changed

+22
-9
lines changed

2 files changed

+22
-9
lines changed

src/diffusers/utils/testing_utils.py

+11
Original file line numberDiff line numberDiff line change
@@ -1186,6 +1186,13 @@ def _is_torch_fp64_available(device):
11861186
"mps": 0,
11871187
"default": 0,
11881188
}
1189+
BACKEND_SYNCHRONIZE = {
1190+
"cuda": torch.cuda.synchronize,
1191+
"xpu": getattr(torch.xpu, "synchronize", None),
1192+
"cpu": None,
1193+
"mps": None,
1194+
"default": None,
1195+
}
11891196

11901197

11911198
# This dispatches a defined function according to the accelerator from the function definitions.
@@ -1208,6 +1215,10 @@ def backend_manual_seed(device: str, seed: int):
12081215
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
12091216

12101217

1218+
def backend_synchronize(device: str):
1219+
return _device_agnostic_dispatch(device, BACKEND_SYNCHRONIZE)
1220+
1221+
12111222
def backend_empty_cache(device: str):
12121223
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
12131224

tests/models/test_modeling_common.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@
5959
from diffusers.utils.testing_utils import (
6060
CaptureLogger,
6161
backend_empty_cache,
62+
backend_max_memory_allocated,
63+
backend_reset_peak_memory_stats,
64+
backend_synchronize,
6265
floats_tensor,
6366
get_python_version,
6467
is_torch_compile,
@@ -68,7 +71,6 @@
6871
require_torch_2,
6972
require_torch_accelerator,
7073
require_torch_accelerator_with_training,
71-
require_torch_gpu,
7274
require_torch_multi_accelerator,
7375
run_test_in_subprocess,
7476
slow,
@@ -341,7 +343,7 @@ def test_weight_overwrite(self):
341343

342344
assert model.config.in_channels == 9
343345

344-
@require_torch_gpu
346+
@require_torch_accelerator
345347
def test_keep_modules_in_fp32(self):
346348
r"""
347349
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32 when we load the model in fp16/bf16
@@ -1480,16 +1482,16 @@ def test_layerwise_casting(storage_dtype, compute_dtype):
14801482
test_layerwise_casting(torch.float8_e5m2, torch.float32)
14811483
test_layerwise_casting(torch.float8_e4m3fn, torch.bfloat16)
14821484

1483-
@require_torch_gpu
1485+
@require_torch_accelerator
14841486
def test_layerwise_casting_memory(self):
14851487
MB_TOLERANCE = 0.2
14861488
LEAST_COMPUTE_CAPABILITY = 8.0
14871489

14881490
def reset_memory_stats():
14891491
gc.collect()
1490-
torch.cuda.synchronize()
1491-
torch.cuda.empty_cache()
1492-
torch.cuda.reset_peak_memory_stats()
1492+
backend_synchronize(torch_device)
1493+
backend_empty_cache(torch_device)
1494+
backend_reset_peak_memory_stats(torch_device)
14931495

14941496
def get_memory_usage(storage_dtype, compute_dtype):
14951497
torch.manual_seed(0)
@@ -1502,7 +1504,7 @@ def get_memory_usage(storage_dtype, compute_dtype):
15021504
reset_memory_stats()
15031505
model(**inputs_dict)
15041506
model_memory_footprint = model.get_memory_footprint()
1505-
peak_inference_memory_allocated_mb = torch.cuda.max_memory_allocated() / 1024**2
1507+
peak_inference_memory_allocated_mb = backend_max_memory_allocated(torch_device) / 1024**2
15061508

15071509
return model_memory_footprint, peak_inference_memory_allocated_mb
15081510

@@ -1512,7 +1514,7 @@ def get_memory_usage(storage_dtype, compute_dtype):
15121514
torch.float8_e4m3fn, torch.bfloat16
15131515
)
15141516

1515-
compute_capability = get_torch_cuda_device_capability()
1517+
compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None
15161518
self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint)
15171519
# NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
15181520
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
@@ -1527,7 +1529,7 @@ def get_memory_usage(storage_dtype, compute_dtype):
15271529
)
15281530

15291531
@parameterized.expand([False, True])
1530-
@require_torch_gpu
1532+
@require_torch_accelerator
15311533
def test_group_offloading(self, record_stream):
15321534
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
15331535
torch.manual_seed(0)

0 commit comments

Comments
 (0)