-
Notifications
You must be signed in to change notification settings - Fork 25
/
utils.py
48 lines (39 loc) · 1.29 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import numpy as np
from PIL import Image
def pdist_torch(emb1, emb2):
'''
compute the eucilidean distance matrix between embeddings1 and embeddings2
using gpu
'''
m, n = emb1.shape[0], emb2.shape[0]
emb1_pow = torch.pow(emb1, 2).sum(dim = 1, keepdim = True).expand(m, n)
emb2_pow = torch.pow(emb2, 2).sum(dim = 1, keepdim = True).expand(n, m).t()
dist_mtx = emb1_pow + emb2_pow
dist_mtx = dist_mtx.addmm_(1, -2, emb1, emb2.t())
dist_mtx = dist_mtx.clamp(min = 1e-12).sqrt()
return dist_mtx
def pdist_np(emb1, emb2):
'''
compute the eucilidean distance matrix between embeddings1 and embeddings2
using cpu
'''
m, n = emb1.shape[0], emb2.shape[0]
emb1_pow = np.square(emb1).sum(axis = 1)[..., np.newaxis]
emb2_pow = np.square(emb2).sum(axis = 1)[np.newaxis, ...]
dist_mtx = -2 * np.matmul(emb1, emb2.T) + emb1_pow + emb2_pow
dist_mtx = np.sqrt(dist_mtx.clip(min = 1e-12))
return dist_mtx
if __name__ == "__main__":
a = np.arange(4*128).reshape(4, 128)
b = np.arange(10, 10 + 5*128).reshape(5, 128)
r1 = pdist_np(a, b)
print(r1.shape)
print(r1)
a = torch.Tensor(a)
b = torch.Tensor(b)
r2 = pdist_torch(a, b)
print(r2.shape)
print(r2)