Source code for lmflow.utils.envs

"""
ref: https://github.com/pytorch/torchtune/blob/main/torchtune/utils/_device.py
"""

import logging
import os
from typing import Any

import torch

logger = logging.getLogger(__name__)

__all__ = [
    "get_device_name",
    "get_torch_device",
    "is_accelerate_env",
    "require_cuda_for_gpu_mode",
    "set_cuda_device",
]


[docs] def is_accelerate_env(): """Return True if any environment variable *name* starts with ``ACCELERATE_``.""" return any(key.startswith("ACCELERATE_") for key in os.environ)
[docs] def require_cuda_for_gpu_mode() -> None: """Raise if GPU execution was requested but CUDA is not available.""" if not torch.cuda.is_available(): raise RuntimeError( "CUDA is not available on this machine, but GPU execution was requested. " "Install a CUDA-enabled PyTorch build and run on a GPU, or use CPU-compatible " "settings where the pipeline supports them." )
[docs] def set_cuda_device(local_rank: int) -> None: """Bind this process to ``local_rank`` on CUDA; raises if CUDA is unavailable.""" require_cuda_for_gpu_mode() torch.cuda.set_device(local_rank)
[docs] def get_device_name() -> str: """ Get the device name based on the current machine. """ if torch.cuda.is_available(): device = "cuda" else: device = "cpu" return device
[docs] def get_torch_device() -> Any: """Return ``torch.<device_name>`` for the current device name. If ``torch`` has no attribute with that name, logs a warning and returns ``torch.cuda`` as fallback. """ device_name = get_device_name() try: return getattr(torch, device_name) except AttributeError: logger.warning(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.") return torch.cuda