Source code for lmflow.utils.envs

"""
ref: https://github.com/pytorch/torchtune/blob/main/torchtune/utils/_device.py
"""

import os
import logging

import torch


[docs] logger = logging.getLogger(__name__)
[docs] is_cuda_available = torch.cuda.is_available()
[docs] def is_accelerate_env(): for key, _ in os.environ.items(): if key.startswith("ACCELERATE_"): return True return False
[docs] def get_device_name() -> str: """ Get the device name based on the current machine. """ if is_cuda_available: device = "cuda" else: device = "cpu" return device
[docs] def get_torch_device() -> any: """Return the corresponding torch attribute based on the device type string. Returns: module: The corresponding torch device namespace, or torch.cuda if not found. """ device_name = get_device_name() try: return getattr(torch, device_name) except AttributeError: logger.warning(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.") return torch.cuda