#!/usr/bin/env python
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
import logging
from dataclasses import Field, fields, make_dataclass
from pathlib import Path
from typing import Optional, Union
from lmflow.utils.versioning import get_python_version
[docs]
logger = logging.getLogger(__name__)
[docs]
def make_shell_args_from_dataclass(
dataclass_objects: list,
format: str = "subprocess",
skip_default: bool = True,
ignored_args_list: Optional[list[str]] = None,
) -> Union[str, list[str]]:
"""Return a string or a list of strings that can be used as shell arguments.
Parameters
----------
dataclass_objects : list
A list of dataclass objects.
format : str, optional
Return format, can be "shell" or "subprocess", by default "subprocess".
skip_default : bool, optional
Whether to skip attributes with default values, by default True.
Returns
-------
Union[str, list[str]]
"""
assert isinstance(dataclass_objects, list), "dataclass_objects should be a list of dataclass objects."
all_args = {}
for dataclass_object in dataclass_objects:
for k, v in dataclass_object.__dict__.items():
if ignored_args_list and k in ignored_args_list:
continue
if k not in dataclass_object.__dataclass_fields__:
# skip attributes that added dynamically
continue
if not v:
# skip attributes with None values
continue
if skip_default:
if dataclass_object.__dataclass_fields__[k].default == v:
continue
if k not in all_args:
if isinstance(v, Path):
all_args[k] = str(v)
elif isinstance(v, list):
all_args[k] = ",".join(v)
else:
all_args[k] = v
elif k in all_args:
if all_args[k] == v:
continue
else:
logger.warning(f"Found different values for the same key: {k}, using value: {v} instead.")
all_args[k] = v
if format == "shell":
final_res = " ".join([f"--{k} {v}" for k, v in all_args.items()])
elif format == "subprocess":
final_res = []
for k, v in all_args.items():
final_res.extend([f"--{k}", str(v)])
else:
raise ValueError(f"Unknown format: {format}")
return final_res
[docs]
def create_copied_dataclass(original_dataclass, field_prefix: str, class_prefix: str, new_default: dict = None):
"""Create a copied dataclass with new field names and default values.
Parameters
----------
original_dataclass : dataclass
field_prefix : str
The prefix to add to the **field** names of the copied dataclass.
class_prefix : str
The prefix to add to the **class** name of the copied dataclass.
new_default : dict, optional
The new default values for the copied dataclass. When None, the
default values of the original dataclass are used.
Returns
-------
dataclass
"""
original_fields = fields(original_dataclass)
new_default = new_default or {}
new_fields = []
for field in original_fields:
if get_python_version().minor >= 10:
new_field = (
f"{field_prefix}{field.name}",
field.type,
Field(
default=new_default.get(f"{field_prefix}{field.name}", field.default),
default_factory=field.default_factory,
init=field.init,
repr=field.repr,
hash=field.hash,
compare=field.compare,
metadata=field.metadata,
kw_only=False, # add in py3.10: https://docs.python.org/3/library/dataclasses.html
),
)
else:
new_field = (
f"{field_prefix}{field.name}",
field.type,
Field(
default=new_default.get(f"{field_prefix}{field.name}", field.default),
default_factory=field.default_factory,
init=field.init,
repr=field.repr,
hash=field.hash,
compare=field.compare,
metadata=field.metadata,
),
)
new_fields.append(new_field)
copied_dataclass = make_dataclass(f"{class_prefix}{original_dataclass.__name__}", new_fields)
return copied_dataclass
[docs]
def remove_dataclass_attr_prefix(data_instance, prefix: str) -> dict:
"""Remove the prefix from the attribute names of a dataclass instance.
Parameters
----------
data_instance : dataclass
prefix : str
The prefix to remove from the attribute names of the dataclass instance.
Returns
-------
dict
"""
new_attributes = {}
for field in fields(data_instance):
attr_name = field.name
attr_value = getattr(data_instance, attr_name)
new_attr_name = f"{attr_name[len(prefix) :]}"
new_attributes[new_attr_name] = attr_value
return new_attributes
[docs]
def add_dataclass_attr_prefix(data_instance, prefix: str) -> dict:
"""Add the prefix to the attribute names of a dataclass instance.
Parameters
----------
data_instance : dataclass
prefix : str
The prefix to add to the attribute names of the dataclass instance.
Returns
-------
dict
"""
new_attributes = {}
for field in fields(data_instance):
attr_name = field.name
attr_value = getattr(data_instance, attr_name)
new_attr_name = f"{prefix}{attr_name}"
new_attributes[new_attr_name] = attr_value
return new_attributes
[docs]
def print_banner(message: str):
length = len(message) + 8
border = "#" * length
logger.info(border)
logger.info(f"# {message} #")
logger.info(border)