Skip to content

Commit

Permalink
Initializes repository
Browse files Browse the repository at this point in the history
  • Loading branch information
syaffers committed Jan 17, 2019
0 parents commit 1f24a64
Show file tree
Hide file tree
Showing 5 changed files with 527 additions and 0 deletions.
121 changes: 121 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# VS Code
.vscode

# Trained model
model.pt
131 changes: 131 additions & 0 deletions datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import os
import torch
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset

class TESNamesDataset(Dataset):
''' The Elder Scrolls Names dataset class.
The Elder Scrolls Names dataset is a dataset of first names of all 10
major races in Tamriel. The dataset contains male and female first names
of different lengths and characters organized into folders of races and
files of genders.
'''
def __init__(self, data_root, charset, max_length):
''' Initializes the Elder Scrolls dataset.
The initialization appends a terminating character, \0, and therfore
the passed charset argument should not contain the terminating
character.
Parameters
----------
data_root: str
Absolute path to the root folder of the dataset.
charset: str
String of all characters expected to be present in the names.
max_length: str
The maximum number of characters in a name to be used for
zero-padding or truncation of names when preprocessing.
'''
self.data_root = data_root
self.charset = charset + '\0'
self.max_length = max_length
self.race_codec = LabelEncoder()
self.gender_codec = LabelEncoder()
self.char_codec = LabelEncoder()
self.samples = []
self._init_dataset()

def __len__(self):
return len(self.samples)

def __getitem__(self, idx):
race, gender, name = self.samples[idx]
return self.one_hot_sample(race, gender, name)

def _init_dataset(self):
''' Dataset initialization subroutine.
Goes through all the folders in the root directory of the dataset
and reads all the files in the subfolders and appends tuples of
race, gender and name into the `self.samples` list.
The label encoder for the races, genders and characters are also
initialized here.
'''
races = set()
genders = set()

for race in os.listdir(self.data_root):
race_folder = os.path.join(self.data_root, race)
races.add(race)

for gender in os.listdir(race_folder):
gender_filepath = os.path.join(race_folder, gender)
genders.add(gender)

with open(gender_filepath, 'r') as gender_file:
for name in gender_file.read().splitlines():
if len(name) > self.max_length:
name = name[:self.max_length-1] + '\0'
else:
name = name + '\0' * (self.max_length - len(name))
self.samples.append((race, gender, name))

self.race_codec.fit(list(races))
self.gender_codec.fit(list(genders))
self.char_codec.fit(list(self.charset))

def to_one_hot(self, codec, values):
''' Encodes a list of nominal values into a one-hot tensor.
Parameters
----------
codec: sklearn.preprocessing.LabelEncoder
Scikit-learn label encoder for the list of values.
values: list of str
List of values to be converted into numbers.
'''
values_idx = codec.transform(values)
return torch.eye(len(codec.classes_))[values_idx]

def one_hot_sample(self, race, gender, name):
''' Converts a single sample into its one-hot counterpart.
Calls the `to_one_hot` function for each of the value in a sample:
race, gender, and name. The race and gender gets converted into
a 1xR tensor, and 1xG tensor, respectively, where R is the number of
races in the dataset and G is the number of genders in the dataset.
The name gets converted into a tensor of 1xMxC where M is the maximum
length of the names (`self.max_length`) and C is the length of the
character set (after adding the terminationg, \0, character).
Parameters
----------
race: str
The race of the sample.
gender: str
The gender of the sample.
name: str
The name of the sample.
'''
t_race = self.to_one_hot(self.race_codec, [race])
t_gender = self.to_one_hot(self.gender_codec, [gender])
t_name = self.to_one_hot(self.char_codec, list(name))
return t_race, t_gender, t_name


if __name__ == '__main__':
import string
from torch.utils.data import DataLoader

data_root = '/home/syafiq/Data/tes-names/'
charset = string.ascii_letters + '\'- '
max_length = 30
dataset = TESNamesDataset(data_root, charset, max_length)
print(dataset[100])

dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
print(next(iter(dataloader)))
100 changes: 100 additions & 0 deletions generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from collections import deque
import string
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from datasets import TESNamesDataset
from models import TESLSTM


def generate(race, gender, char, dataset, model, device):
''' Generates a "novel" name given the parameters.
Given the desired race, gender and initial character, the trained model
will produce a new name by predicting what letter should come next and
feeding the predicted letter as an input to the model until it reaches
the maximum length or the terminating character is predicted.
Parameters
----------
race: str
Desired race for new name.
gender: str
Desired gender for new name.
char: str
Starting character of the new name.
dataset: torch.utils.data.Dataset
The dataset of Elder Scrolls names.
model: models.TESLSTM
The trained model used for prediction.
device: torch.device
The device on which to execute.
'''
name = char
model.eval()

t_race, t_gender, t_char = dataset.one_hot_sample(race, gender, char)
t_hidden, t_cell = model.init_hidden(1)

t_race = t_race.view(1, 1, -1).to(device)
t_gender = t_gender.view(1, 1, -1).to(device)
t_char = t_char.view(1, 1, -1).to(device)
t_hidden = t_hidden.to(device)
t_cell = t_cell.to(device)

for _ in range(dataset.max_length):
t_char, t_hidden, t_cell = \
model(t_race, t_gender, t_char, t_hidden, t_cell)

char_idx = t_char.argmax(dim=1).item()
new_char = dataset.char_codec.inverse_transform([char_idx])[0]

if new_char == '\0':
break
else:
name += new_char
t_char = dataset.to_one_hot(dataset.char_codec, [new_char])
t_char = t_char.view(1, 1, -1).to(device)

return name


if __name__ == '__main__':
data_root = '/home/syafiq/Data/tes-names/'
charset = string.ascii_letters + '\'- '
max_length = 30

# Prepare GPU.
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Prepare dataset.
dataset = TESNamesDataset(data_root, charset, max_length)

input_size = (
len(dataset.race_codec.classes_) +
len(dataset.gender_codec.classes_) +
len(dataset.char_codec.classes_)
)
hidden_size = 128
output_size = len(dataset.char_codec.classes_)

# Prepare model.
model = TESLSTM(input_size, hidden_size, output_size)
model.load_state_dict(torch.load('model.pt'))
model = model.to(device)

new_names = []

# Predict a name for all combinations.
for race in dataset.race_codec.classes_:
for gender in dataset.gender_codec.classes_:
for letter in string.ascii_uppercase:
name = generate(race, gender, letter, dataset, model, device)
print(race, gender, name)
new_names.append(name)

# See how many names are copied from the dataset, if any.
sample_names = [name.replace('\0', '') for _, _, name in dataset.samples]
intersection_set = set(new_names).intersection(set(sample_names))
print('%% of similar names: %.2f%%' % (len(intersection_set) / len(dataset)))
Loading

0 comments on commit 1f24a64

Please sign in to comment.