Source code for lmflow.utils.test_utils
from collections import OrderedDict
from typing import Dict, List, Tuple, Union, Optional, TYPE_CHECKING
import torch
import torch.nn as nn
import numpy as np
[docs]
def compare_model(
model_ref: "nn.Module",
model_trained: "nn.Module",
module_trained: Optional[List[str]] = None
) -> None:
state_dict_ref = model_ref.state_dict()
state_dict_trained = model_trained.state_dict()
assert set(state_dict_ref.keys()) == set(state_dict_trained.keys())
for name in state_dict_ref.keys():
if module_trained is not None:
if any([module in name for module in module_trained]):
assert torch.allclose(state_dict_ref[name], state_dict_trained[name], rtol=1e-4, atol=1e-5) is False
else:
assert torch.allclose(state_dict_ref[name], state_dict_trained[name], rtol=1e-4, atol=1e-5) is True