-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
64 lines (53 loc) · 2.14 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from math import gamma
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.tri as tri
import torch
def plot_pair(img, prediction, yerr=None):
fig, axes = plt.subplot_mosaic(
"""
aaab
aaab
aaab
"""
)
axes["a"].imshow(img)
axes["a"].axis('off')
axes["b"].bar(range(3), prediction, yerr=yerr, color=["yellow", "green", "red"])
axes["b"].set_ylim(0, 1)
return fig
def one_hot_embedding(labels, num_classes=3):
# Convert to One Hot Encoding
y = torch.eye(num_classes)
return y[labels]
corners = np.array([[0, 0], [1, 0], [0.5, 0.75**0.5]])
triangle = tri.Triangulation(corners[:, 0], corners[:, 1])
# For each corner of the triangle, the pair of other corners
pairs = [corners[np.roll(range(3), -i)[1:]] for i in range(3)]
# The area of the triangle formed by point xy and another pair or points
tri_area = lambda xy, pair: 0.5 * np.linalg.norm(np.cross(*(pair - xy)))
def xy2bc(xy, tol=1.e-4):
'''Converts 2D Cartesian coordinates to barycentric.'''
coords = np.array([tri_area(xy, p) for p in pairs]) / AREA
return np.clip(coords, tol, 1.0 - tol)
class Dirichlet:
def __init__(self, alpha):
self._alpha = np.array(alpha)
self._coef = gamma(np.sum(self._alpha)) / \
np.multiply.reduce([gamma(a) for a in self._alpha])
def pdf(self, x):
'''Returns pdf value for `x`.'''
from operator import mul
return self._coef * np.multiply.reduce([xx ** (aa - 1)
for (xx, aa)in zip(x, self._alpha)])
def draw_pdf_contours(dist, nlevels=200, subdiv=8, **kwargs):
"""Draw distribution contours."""
refiner = tri.UniformTriRefiner(triangle)
trimesh = refiner.refine_triangulation(subdiv=subdiv)
pvals = [dist.pdf(xy2bc(xy)) for xy in zip(trimesh.x, trimesh.y)]
plt.tricontourf(trimesh, pvals, nlevels, cmap='jet', **kwargs)
plt.axis('equal')
plt.xlim(0, 1)
plt.ylim(0, 0.75**0.5)
plt.axis('off')