Source code for tensor
from typing import List, Tuple
import torch
[docs]
def tensor_all_zero(tensor: torch.Tensor) -> bool:
return torch.equal(tensor, torch.zeros_like(tensor))
[docs]
def find_nonzero_intervals(tensor: torch.Tensor, ignore_distance: int = 1) -> List[Tuple[int, int]]:
assert tensor.shape == torch.Size([tensor.numel()]), "Input tensor must be 1D"
assert ignore_distance > 0, "`ignore_distance` must be greater than 0"
assert ignore_distance < tensor.numel(), "`ignore_distance` must be less than the number of elements in the tensor"
nonzero_indices = torch.nonzero(tensor, as_tuple=False).squeeze()
diff = nonzero_indices[1:] - nonzero_indices[:-1]
non_continuous_points = torch.where(diff > ignore_distance)[0]
intervals = []
start = nonzero_indices[0].item()
for idx in non_continuous_points:
end = nonzero_indices[idx].item()
intervals.append((start, end))
start = nonzero_indices[idx + 1].item()
# last
intervals.append((start, nonzero_indices[-1].item()))
return intervals