Spaces:
Running
on
Zero
Running
on
Zero
| import functools | |
| import gc | |
| import torch | |
| try: | |
| HAS_CUDA = torch.cuda.is_available() | |
| except Exception: | |
| HAS_CUDA = False | |
| try: | |
| HAS_MPS = torch.backends.mps.is_available() | |
| except Exception: | |
| HAS_MPS = False | |
| try: | |
| import intel_extension_for_pytorch as ipex # noqa | |
| HAS_XPU = torch.xpu.is_available() | |
| except Exception: | |
| HAS_XPU = False | |
| def clean_memory(): | |
| gc.collect() | |
| if HAS_CUDA: | |
| torch.cuda.empty_cache() | |
| if HAS_XPU: | |
| torch.xpu.empty_cache() | |
| if HAS_MPS: | |
| torch.mps.empty_cache() | |
| def clean_memory_on_device(device: torch.device): | |
| r""" | |
| Clean memory on the specified device, will be called from training scripts. | |
| """ | |
| gc.collect() | |
| # device may "cuda" or "cuda:0", so we need to check the type of device | |
| if device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| if device.type == "xpu": | |
| torch.xpu.empty_cache() | |
| if device.type == "mps": | |
| torch.mps.empty_cache() | |
| def get_preferred_device() -> torch.device: | |
| r""" | |
| Do not call this function from training scripts. Use accelerator.device instead. | |
| """ | |
| if HAS_CUDA: | |
| device = torch.device("cuda") | |
| elif HAS_XPU: | |
| device = torch.device("xpu") | |
| elif HAS_MPS: | |
| device = torch.device("mps") | |
| else: | |
| device = torch.device("cpu") | |
| print(f"get_preferred_device() -> {device}") | |
| return device | |
| def init_ipex(): | |
| """ | |
| Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`. | |
| This function should run right after importing torch and before doing anything else. | |
| If IPEX is not available, this function does nothing. | |
| """ | |
| try: | |
| if HAS_XPU: | |
| from library.ipex import ipex_init | |
| is_initialized, error_message = ipex_init() | |
| if not is_initialized: | |
| print("failed to initialize ipex:", error_message) | |
| else: | |
| return | |
| except Exception as e: | |
| print("failed to initialize ipex:", e) | |