Skip to content

Commit

Permalink
Merge pull request #26 from Filter-Bubble/processor
Browse files Browse the repository at this point in the history
Stanza Processor
  • Loading branch information
Dafne van Kuppevelt authored Jan 28, 2021
2 parents cc3c019 + 38eb53c commit 072717f
Show file tree
Hide file tree
Showing 10 changed files with 283 additions and 30 deletions.
9 changes: 8 additions & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,16 @@ jobs:
- name: Install package
run: |
python setup.py install
- name: Set environment variables
run: |
echo "STANZA_RESOURCES_DIR=$HOME/data" >> $GITHUB_ENV
echo "E2E_HOME=$HOME/data/nl/coref" >> $GITHUB_ENV
- name: Download models
run: |
python -m e2edutch.download
echo "stanza home: $STANZA_RESOURCES_DIR"
echo "e2e home: $E2E_HOME"
python -c 'import stanza; stanza.download("nl")'
python -m e2edutch.download -v
- name: Test with pytest
run: |
pytest --cov=./e2edutch --cov-report=xml
Expand Down
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,25 @@ pip install .

The `setup_all` script downloads the word vector files to the `data` directories. It also builds the application-specific tensorflow kernels.

## Quick start - Stanza

e2edutch can be used as part of a Stanza pipeline.

Coreferences are added similarly to Stanza's entities:
* a ___Document___ has an attribute ___clusters___ that is a List of coreference clusters;
* a coreference cluster is a List of Stanza ___Spans___.

```
import stanza
import e2edutch.stanza
nlp = stanza.Pipeline(lang='nl', processors='tokenize,coref')
doc = nlp('Dit is een test document. Dit document bevat coreferenties.')
print ([[span.text for span in cluster] for cluster in doc.clusters])
```


## Quick start
A pretrained model is available to download:
```
Expand Down
46 changes: 41 additions & 5 deletions e2edutch/coref_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,63 @@

logger = logging.getLogger('e2edutch')

# When running as a Stanza Processor, we will be instantiating the CorefModel
# for each call to the pipeline, making interactive use impossible
# To speed things up, cache the slowest part here.
e2e_cached_embedding = None
e2e_cached_bert_model = None
e2e_cached_bert_tokenizer = None


class CorefModel(object):
def __init__(self, config):
# these are variables in the outer scope
global e2e_cached_embedding
global e2e_cached_bert_tokenizer
global e2e_cached_bert_model

self.config = config

logger.info("Loading context embeddings..")
self.context_embeddings = util.EmbeddingDictionary(
config["context_embeddings"], config['datapath'])
config["context_embeddings"], config['datapath'],
maybe_cache=e2e_cached_embedding)

# cache the embeddings for reuse
e2e_cached_embedding = self.context_embeddings

logger.info("Loading head embeddings..")
self.head_embeddings = util.EmbeddingDictionary(
config["context_embeddings"], config['datapath'],
maybe_cache=self.context_embeddings)
maybe_cache=e2e_cached_embedding)

self.char_embedding_size = config["char_embedding_size"]
self.char_dict = util.load_char_dict(
os.path.join(config['datapath'],
config["char_vocab_path"]))

self.max_span_width = config["max_span_width"]
self.genres = {g: i for i, g in enumerate(config["genres"])}
if config["lm_path"]:
self.lm_file = h5py.File(os.path.join(config['datapath'],
self.config["lm_path"]), "r")
else:
self.lm_file = None

if config["lm_model_name"]:
logger.info("Loading BERT model...")
self.bert_tokenizer, self.bert_model = bert.load_bert(
self.config["lm_model_name"])
if e2e_cached_bert_model and e2e_cached_bert_tokenizer:
# reuse cached version
self.bert_tokenizer = e2e_cached_bert_tokenizer
self.bert_model = e2e_cached_bert_model
else:
# load the model...
self.bert_tokenizer, self.bert_model = bert.load_bert(
self.config["lm_model_name"])

# ...and cache for next time
e2e_cached_bert_tokenizer = self.bert_tokenizer
e2e_cached_bert_model = self.bert_model
else:
self.bert_tokenizer = None
self.bert_model = None
Expand All @@ -53,15 +84,19 @@ def __init__(self, config):

input_props = []
input_props.append((tf.string, [None, None])) # Tokens.

# Context embeddings.
input_props.append(
(tf.float32, [None, None, self.context_embeddings.size]))

# Head embeddings.
input_props.append(
(tf.float32, [None, None, self.head_embeddings.size]))

# LM embeddings.
input_props.append(
(tf.float32, [None, None, self.lm_size, self.lm_layers]))

# Character indices.
input_props.append((tf.int32, [None, None, None]))
input_props.append((tf.int32, [None])) # Text lengths..
Expand Down Expand Up @@ -123,7 +158,8 @@ def restore(self, session):
v for v in tf.global_variables() if "module/" not in v.name]
saver = tf.train.Saver(vars_to_restore)
checkpoint_path = os.path.join(
self.config["log_dir"], "model.max.ckpt")
self.config['log_root'], self.config['log_dir'], "model.max.ckpt")

logger.info("Restoring coref model from {}".format(checkpoint_path))
session.run(tf.global_variables_initializer())
saver.restore(session, checkpoint_path)
Expand Down
40 changes: 35 additions & 5 deletions e2edutch/download.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import os
import requests
import gzip
import zipfile
import argparse
import logging
from tqdm import tqdm
from pathlib import Path
from e2edutch import util

logger = logging.getLogger()


def download_file(url, path):
"""
Expand All @@ -28,11 +31,13 @@ def download_file(url, path):


def download_data(config={}):
data_dir = Path(util.get_data_dir(config))
# Create the data directory if it doesn't exist yet
data_dir = Path(config['datapath'])
logger.info('Downloading to {}'.format(data_dir))
data_dir.mkdir(parents=True, exist_ok=True)

# Download word vectors
logger.info('Download word vectors')
url = "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nl.300.vec.gz"
fname = data_dir / 'fasttext.300.vec'
fname_gz = data_dir / 'fasttext.300.vec.gz'
Expand All @@ -46,29 +51,54 @@ def download_data(config={}):
fout.write(line)
# Remove gz file
fname_gz.unlink()
else:
logger.info('Word vectors file already exists')

# Download e2e dutch model_
logger.info('Download e2e model')
url = "https://surfdrive.surf.nl/files/index.php/s/UnZMyDrBEFunmQZ/download"
fname_zip = data_dir / 'model.zip'
log_dir_name = data_dir / 'final'
if not fname_zip.exists() and not log_dir_name.exists():
model_file = log_dir_name / 'model.max.ckpt.index'
if not fname_zip.exists() and not model_file.exists():
download_file(url, fname_zip)
if not log_dir_name.exists():
if not model_file.exists():
with zipfile.ZipFile(fname_zip, 'r') as zfile:
zfile.extractall(data_dir)
Path(data_dir / 'logs' / 'final').rename(log_dir_name)
Path(data_dir, 'logs').rmdir()
else:
logger.info('E2e model file already exists')

# Download char_dict
logger.info('Download char dict')
url = "https://github.com/Filter-Bubble/e2e-Dutch/raw/v0.2.0/data/char_vocab.dutch.txt"
fname = data_dir / 'char_vocab.dutch.txt'
if not fname.exists():
download_file(url, fname)
else:
logger.info('Char dict file already exists')


def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--datapath', default=None)
parser.add_argument('-v', '--verbose', action='store_true')
return parser


def main():
parser = get_parser()
args = parser.parse_args()
if args.verbose:
# logger.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
# To do: argparse for config file
download_data()
if args.datapath is None:
config = util.initialize_from_env(model_name='final')
else:
config = {'datapath': args.datapath}
download_data(config)


if __name__ == "__main__":
Expand Down
17 changes: 13 additions & 4 deletions e2edutch/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,18 @@


class Predictor(object):
def __init__(self, model_name='final', cfg_file=None, verbose=False):
def __init__(self, model_name='final', config=None, verbose=False):
if verbose:
logger.setLevel(logging.INFO)
self.config = util.initialize_from_env(model_name, cfg_file)

if config:
self.config = config
else:
# if no configuration is provided, try to get a default config.
self.config = util.initialize_from_env(model_name=model_name)

self.session = tf.compat.v1.Session()

try:
self.model = cm.CorefModel(self.config)
self.model.restore(self.session)
Expand Down Expand Up @@ -93,7 +100,6 @@ def main(args=None):
args = parser.parse_args()
if args.verbose:
logger.setLevel(logging.INFO)
# config = util.initialize_from_env(args.config, args.cfg_file)

# Input file in .jsonlines format or .conll.
input_filename = args.input_filename
Expand All @@ -120,7 +126,10 @@ def main(args=None):
docs = [util.create_example(text)]

output_file = args.output_file
predictor = Predictor(args.config, args.cfg_file)

config = util.initialize_from_env(cfg_file=args.cfg_file, model_cfg_file=args.config)
predictor = Predictor(config=config)

sentences = {}
predictions = {}
for example_num, example in enumerate(docs):
Expand Down
Loading

0 comments on commit 072717f

Please sign in to comment.