-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathdebug.py
52 lines (38 loc) · 1.44 KB
/
debug.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
def any_nan(tensor: torch.Tensor) -> bool:
"""Returns true if the tensor contains a NaN
Args:
tensor (torch.Tensor): the input tensor
Returns:
bool: true if contains a NaN
"""
return bool(torch.isnan(tensor).any().item())
def print_min_max(name, tensor):
"""Print information about a tensor
Args:
name (str): tensor name
tensor (torch.Tensor): the tensor
"""
print(
f"{name} | min {tensor.min()} | max {tensor.max()} | hasnan {any_nan(tensor)} | shape {tensor.shape}"
)
def assert_allclose(tensor, value, tol=1e-5, message=""):
"""Check that all values in the tensor are close to value
Args:
tensor (torch.Tensor): the tensor
value: target value(s)
tol (float, optional): Defaults to 1e-5. tolerance
message (str, optional): Defaults to "". displayed error message
"""
assert ((tensor - value).abs() < tol).all(), message
def assert_proba_distribution(probabilities, tol=1e-5):
"""Check that the tensor is a probability distribution
Args:
probabilities (torch.Tensor): the distribution
tol (float, optional): Defaults to 1e-5. tolerance
"""
assert (probabilities.sum() - 1.0).abs() < tol and (
probabilities >= 0
).all(), "tensor was expected to be a proability distribution (sum={}, negatives={})".format(
probabilities.sum(), (probabilities < 0).any()
)