-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluation.py
71 lines (59 loc) · 2.29 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
This module contains a function to compute the bits per character (BPC) for a given
RNN-based model and a string of characters using cross-entropy loss.
Functions:
- compute_bpc(model: object, string: str) -> float:
Given an RNN-based model and a string of characters, computes the bits per
character (BPC) using cross-entropy loss.
Usage:
import math
import torch
from utils import char_tensor, CHUNK_LEN
from compute_bpc import compute_bpc
# Example usage with a trained model and input string
model = MyRNNModel()
input_string = "Hello, world!"
bpc_result = compute_bpc(model, input_string)
print(f"Bits per character for the given string: {bpc_result}")
"""
import math
import torch
from utils import char_tensor, CHUNK_LEN
def compute_bpc(model : object, string : str) -> float:
"""
Given a model and a string of characters, compute bits per character
(BPC) using that model.
Args:
model: RNN-based model (RNN, LSTM, GRU, etc.)
string: string of characters
Returns:
BPC for that set of string.
"""
criterion = torch.nn.CrossEntropyLoss()
avg_bpc = 0
bpc_losses = []
num_iters = 0
for i in range(0,len(string)-1,CHUNK_LEN):
hidden, cell = model.init_hidden()
chunk = string[i:i+CHUNK_LEN+1]
inp : torch.TensorType = char_tensor(chunk[:-1]).unsqueeze(0) #adds a dimension in the 0th index
# print(inp.size()) = 200
target : torch.TensorType = char_tensor(chunk[1:]).unsqueeze(0)
# print(target.size()) = 200
if len(target.squeeze(0)) != 200:
continue
loss = 0
for c in range(CHUNK_LEN):
with torch.no_grad():
output, (hidden,cell) = model(inp[:, c],hidden,cell)
loss += criterion(output, target[:, c].view(1))
loss = loss.item()/CHUNK_LEN #gets element in the tensor
# Bits per character = CrossEntropyLoss / log2
bpc = loss / math.log(2)
avg_bpc += bpc
bpc_losses.append(bpc)
num_iters += 1
if num_iters % 1500 == 0:
print(f"Number of iterations run for BPC calc : {num_iters}")
print(f"Total number of iterations : {num_iters} BPC : {bpc}")
return avg_bpc / num_iters