-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
120 lines (90 loc) · 3.2 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
This module provides utility functions for text generation using RNN-based models.
Functions:
- load_dataset(path: str) -> str:
Load and return the content of a text file located at the specified path.
- random_chunk(file_path: str = TRAIN_PATH) -> str:
Generate a random chunk of text from the loaded dataset.
- char_tensor(strings: str) -> torch.autograd.Variable:
Convert a string of characters into a PyTorch tensor.
- random_training_set(file_path: str = TRAIN_PATH) -> Tuple[torch.autograd.Variable, torch.autograd.Variable]:
Generate a random training set consisting of input and target tensors.
- time_since(since: float) -> str:
Helper function to print the elapsed time since a given timestamp.
Constants:
- CHUNK_LEN: int
The length of chunks used for training and generation.
- TRAIN_PATH: str
The path to the training dataset file.
"""
import unidecode
import string
import random
# import re
import time
import math
import torch
from torch.autograd import Variable
CHUNK_LEN = 200
TRAIN_PATH = './data/dickens_train.txt'
def load_dataset(path):
"""
Load and return the content of a text file located at the specified path.
Args:
path (str): The path to the text file.
Returns:
str: The content of the text file.
"""
all_characters = string.printable
n_characters = len(all_characters)
file = unidecode.unidecode(open(path, 'r').read())
return file
def random_chunk(file_path=TRAIN_PATH):
"""
Generate a random chunk of text from the loaded dataset.
Args:
file_path (str): The path to the training dataset file. Default is TRAIN_PATH.
Returns:
str: A random chunk of text.
"""
file = load_dataset(file_path)
start_index = random.randint(0, len(file) - CHUNK_LEN - 1)
end_index = start_index + CHUNK_LEN + 1
return file[start_index:end_index]
def char_tensor(strings):
"""
Convert a string of characters into a PyTorch tensor.
Args:
strings (str): The input string.
Returns:
torch.autograd.Variable: A PyTorch tensor representing the input string.
"""
all_characters = string.printable
tensor = torch.zeros(len(strings)).long()
for c in range(len(strings)):
tensor[c] = all_characters.index(strings[c])
return Variable(tensor)
def random_training_set(file_path=TRAIN_PATH):
"""
Generate a random training set consisting of input and target tensors.
Args:
file_path (str): The path to the training dataset file. Default is TRAIN_PATH.
Returns:
Tuple[torch.autograd.Variable, torch.autograd.Variable]: Input and target tensors.
"""
chunk = random_chunk(file_path=file_path)
inp = char_tensor(chunk[:-1])
target = char_tensor(chunk[1:])
return inp.unsqueeze(0), target.unsqueeze(0)
def time_since(since):
"""
Helper function to print the elapsed time since a given timestamp.
Args:
since (float): The starting timestamp.
Returns:
str: A formatted string representing the elapsed time.
"""
s = time.time() - since
m = math.floor(s / 60)
s -= m * 60
return '%dm %ds' % (m, s)