-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 1f24a64
Showing
5 changed files
with
527 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
# VS Code | ||
.vscode | ||
|
||
# Trained model | ||
model.pt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import os | ||
import torch | ||
from sklearn.preprocessing import LabelEncoder | ||
from torch.utils.data import Dataset | ||
|
||
class TESNamesDataset(Dataset): | ||
''' The Elder Scrolls Names dataset class. | ||
The Elder Scrolls Names dataset is a dataset of first names of all 10 | ||
major races in Tamriel. The dataset contains male and female first names | ||
of different lengths and characters organized into folders of races and | ||
files of genders. | ||
''' | ||
def __init__(self, data_root, charset, max_length): | ||
''' Initializes the Elder Scrolls dataset. | ||
The initialization appends a terminating character, \0, and therfore | ||
the passed charset argument should not contain the terminating | ||
character. | ||
Parameters | ||
---------- | ||
data_root: str | ||
Absolute path to the root folder of the dataset. | ||
charset: str | ||
String of all characters expected to be present in the names. | ||
max_length: str | ||
The maximum number of characters in a name to be used for | ||
zero-padding or truncation of names when preprocessing. | ||
''' | ||
self.data_root = data_root | ||
self.charset = charset + '\0' | ||
self.max_length = max_length | ||
self.race_codec = LabelEncoder() | ||
self.gender_codec = LabelEncoder() | ||
self.char_codec = LabelEncoder() | ||
self.samples = [] | ||
self._init_dataset() | ||
|
||
def __len__(self): | ||
return len(self.samples) | ||
|
||
def __getitem__(self, idx): | ||
race, gender, name = self.samples[idx] | ||
return self.one_hot_sample(race, gender, name) | ||
|
||
def _init_dataset(self): | ||
''' Dataset initialization subroutine. | ||
Goes through all the folders in the root directory of the dataset | ||
and reads all the files in the subfolders and appends tuples of | ||
race, gender and name into the `self.samples` list. | ||
The label encoder for the races, genders and characters are also | ||
initialized here. | ||
''' | ||
races = set() | ||
genders = set() | ||
|
||
for race in os.listdir(self.data_root): | ||
race_folder = os.path.join(self.data_root, race) | ||
races.add(race) | ||
|
||
for gender in os.listdir(race_folder): | ||
gender_filepath = os.path.join(race_folder, gender) | ||
genders.add(gender) | ||
|
||
with open(gender_filepath, 'r') as gender_file: | ||
for name in gender_file.read().splitlines(): | ||
if len(name) > self.max_length: | ||
name = name[:self.max_length-1] + '\0' | ||
else: | ||
name = name + '\0' * (self.max_length - len(name)) | ||
self.samples.append((race, gender, name)) | ||
|
||
self.race_codec.fit(list(races)) | ||
self.gender_codec.fit(list(genders)) | ||
self.char_codec.fit(list(self.charset)) | ||
|
||
def to_one_hot(self, codec, values): | ||
''' Encodes a list of nominal values into a one-hot tensor. | ||
Parameters | ||
---------- | ||
codec: sklearn.preprocessing.LabelEncoder | ||
Scikit-learn label encoder for the list of values. | ||
values: list of str | ||
List of values to be converted into numbers. | ||
''' | ||
values_idx = codec.transform(values) | ||
return torch.eye(len(codec.classes_))[values_idx] | ||
|
||
def one_hot_sample(self, race, gender, name): | ||
''' Converts a single sample into its one-hot counterpart. | ||
Calls the `to_one_hot` function for each of the value in a sample: | ||
race, gender, and name. The race and gender gets converted into | ||
a 1xR tensor, and 1xG tensor, respectively, where R is the number of | ||
races in the dataset and G is the number of genders in the dataset. | ||
The name gets converted into a tensor of 1xMxC where M is the maximum | ||
length of the names (`self.max_length`) and C is the length of the | ||
character set (after adding the terminationg, \0, character). | ||
Parameters | ||
---------- | ||
race: str | ||
The race of the sample. | ||
gender: str | ||
The gender of the sample. | ||
name: str | ||
The name of the sample. | ||
''' | ||
t_race = self.to_one_hot(self.race_codec, [race]) | ||
t_gender = self.to_one_hot(self.gender_codec, [gender]) | ||
t_name = self.to_one_hot(self.char_codec, list(name)) | ||
return t_race, t_gender, t_name | ||
|
||
|
||
if __name__ == '__main__': | ||
import string | ||
from torch.utils.data import DataLoader | ||
|
||
data_root = '/home/syafiq/Data/tes-names/' | ||
charset = string.ascii_letters + '\'- ' | ||
max_length = 30 | ||
dataset = TESNamesDataset(data_root, charset, max_length) | ||
print(dataset[100]) | ||
|
||
dataloader = DataLoader(dataset, batch_size=10, shuffle=True) | ||
print(next(iter(dataloader))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
from collections import deque | ||
import string | ||
from torch.utils.data import DataLoader | ||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from datasets import TESNamesDataset | ||
from models import TESLSTM | ||
|
||
|
||
def generate(race, gender, char, dataset, model, device): | ||
''' Generates a "novel" name given the parameters. | ||
Given the desired race, gender and initial character, the trained model | ||
will produce a new name by predicting what letter should come next and | ||
feeding the predicted letter as an input to the model until it reaches | ||
the maximum length or the terminating character is predicted. | ||
Parameters | ||
---------- | ||
race: str | ||
Desired race for new name. | ||
gender: str | ||
Desired gender for new name. | ||
char: str | ||
Starting character of the new name. | ||
dataset: torch.utils.data.Dataset | ||
The dataset of Elder Scrolls names. | ||
model: models.TESLSTM | ||
The trained model used for prediction. | ||
device: torch.device | ||
The device on which to execute. | ||
''' | ||
name = char | ||
model.eval() | ||
|
||
t_race, t_gender, t_char = dataset.one_hot_sample(race, gender, char) | ||
t_hidden, t_cell = model.init_hidden(1) | ||
|
||
t_race = t_race.view(1, 1, -1).to(device) | ||
t_gender = t_gender.view(1, 1, -1).to(device) | ||
t_char = t_char.view(1, 1, -1).to(device) | ||
t_hidden = t_hidden.to(device) | ||
t_cell = t_cell.to(device) | ||
|
||
for _ in range(dataset.max_length): | ||
t_char, t_hidden, t_cell = \ | ||
model(t_race, t_gender, t_char, t_hidden, t_cell) | ||
|
||
char_idx = t_char.argmax(dim=1).item() | ||
new_char = dataset.char_codec.inverse_transform([char_idx])[0] | ||
|
||
if new_char == '\0': | ||
break | ||
else: | ||
name += new_char | ||
t_char = dataset.to_one_hot(dataset.char_codec, [new_char]) | ||
t_char = t_char.view(1, 1, -1).to(device) | ||
|
||
return name | ||
|
||
|
||
if __name__ == '__main__': | ||
data_root = '/home/syafiq/Data/tes-names/' | ||
charset = string.ascii_letters + '\'- ' | ||
max_length = 30 | ||
|
||
# Prepare GPU. | ||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | ||
|
||
# Prepare dataset. | ||
dataset = TESNamesDataset(data_root, charset, max_length) | ||
|
||
input_size = ( | ||
len(dataset.race_codec.classes_) + | ||
len(dataset.gender_codec.classes_) + | ||
len(dataset.char_codec.classes_) | ||
) | ||
hidden_size = 128 | ||
output_size = len(dataset.char_codec.classes_) | ||
|
||
# Prepare model. | ||
model = TESLSTM(input_size, hidden_size, output_size) | ||
model.load_state_dict(torch.load('model.pt')) | ||
model = model.to(device) | ||
|
||
new_names = [] | ||
|
||
# Predict a name for all combinations. | ||
for race in dataset.race_codec.classes_: | ||
for gender in dataset.gender_codec.classes_: | ||
for letter in string.ascii_uppercase: | ||
name = generate(race, gender, letter, dataset, model, device) | ||
print(race, gender, name) | ||
new_names.append(name) | ||
|
||
# See how many names are copied from the dataset, if any. | ||
sample_names = [name.replace('\0', '') for _, _, name in dataset.samples] | ||
intersection_set = set(new_names).intersection(set(sample_names)) | ||
print('%% of similar names: %.2f%%' % (len(intersection_set) / len(dataset))) |
Oops, something went wrong.