diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index ca77f29..7dc6f8a 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -33,9 +33,16 @@ jobs: - name: Install package run: | python setup.py install + - name: Set environment variables + run: | + echo "STANZA_RESOURCES_DIR=$HOME/data" >> $GITHUB_ENV + echo "E2E_HOME=$HOME/data/nl/coref" >> $GITHUB_ENV - name: Download models run: | - python -m e2edutch.download + echo "stanza home: $STANZA_RESOURCES_DIR" + echo "e2e home: $E2E_HOME" + python -c 'import stanza; stanza.download("nl")' + python -m e2edutch.download -v - name: Test with pytest run: | pytest --cov=./e2edutch --cov-report=xml diff --git a/README.md b/README.md index 635f777..c567c9c 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,25 @@ pip install . The `setup_all` script downloads the word vector files to the `data` directories. It also builds the application-specific tensorflow kernels. +## Quick start - Stanza + +e2edutch can be used as part of a Stanza pipeline. + +Coreferences are added similarly to Stanza's entities: + * a ___Document___ has an attribute ___clusters___ that is a List of coreference clusters; + * a coreference cluster is a List of Stanza ___Spans___. + +``` +import stanza +import e2edutch.stanza + +nlp = stanza.Pipeline(lang='nl', processors='tokenize,coref') + +doc = nlp('Dit is een test document. Dit document bevat coreferenties.') +print ([[span.text for span in cluster] for cluster in doc.clusters]) +``` + + ## Quick start A pretrained model is available to download: ``` diff --git a/e2edutch/coref_model.py b/e2edutch/coref_model.py index 68b14a4..865a0e5 100644 --- a/e2edutch/coref_model.py +++ b/e2edutch/coref_model.py @@ -18,21 +18,41 @@ logger = logging.getLogger('e2edutch') +# When running as a Stanza Processor, we will be instantiating the CorefModel +# for each call to the pipeline, making interactive use impossible +# To speed things up, cache the slowest part here. +e2e_cached_embedding = None +e2e_cached_bert_model = None +e2e_cached_bert_tokenizer = None + class CorefModel(object): def __init__(self, config): + # these are variables in the outer scope + global e2e_cached_embedding + global e2e_cached_bert_tokenizer + global e2e_cached_bert_model + self.config = config + logger.info("Loading context embeddings..") self.context_embeddings = util.EmbeddingDictionary( - config["context_embeddings"], config['datapath']) + config["context_embeddings"], config['datapath'], + maybe_cache=e2e_cached_embedding) + + # cache the embeddings for reuse + e2e_cached_embedding = self.context_embeddings + logger.info("Loading head embeddings..") self.head_embeddings = util.EmbeddingDictionary( config["context_embeddings"], config['datapath'], - maybe_cache=self.context_embeddings) + maybe_cache=e2e_cached_embedding) + self.char_embedding_size = config["char_embedding_size"] self.char_dict = util.load_char_dict( os.path.join(config['datapath'], config["char_vocab_path"])) + self.max_span_width = config["max_span_width"] self.genres = {g: i for i, g in enumerate(config["genres"])} if config["lm_path"]: @@ -40,10 +60,21 @@ def __init__(self, config): self.config["lm_path"]), "r") else: self.lm_file = None + if config["lm_model_name"]: logger.info("Loading BERT model...") - self.bert_tokenizer, self.bert_model = bert.load_bert( - self.config["lm_model_name"]) + if e2e_cached_bert_model and e2e_cached_bert_tokenizer: + # reuse cached version + self.bert_tokenizer = e2e_cached_bert_tokenizer + self.bert_model = e2e_cached_bert_model + else: + # load the model... + self.bert_tokenizer, self.bert_model = bert.load_bert( + self.config["lm_model_name"]) + + # ...and cache for next time + e2e_cached_bert_tokenizer = self.bert_tokenizer + e2e_cached_bert_model = self.bert_model else: self.bert_tokenizer = None self.bert_model = None @@ -53,15 +84,19 @@ def __init__(self, config): input_props = [] input_props.append((tf.string, [None, None])) # Tokens. + # Context embeddings. input_props.append( (tf.float32, [None, None, self.context_embeddings.size])) + # Head embeddings. input_props.append( (tf.float32, [None, None, self.head_embeddings.size])) + # LM embeddings. input_props.append( (tf.float32, [None, None, self.lm_size, self.lm_layers])) + # Character indices. input_props.append((tf.int32, [None, None, None])) input_props.append((tf.int32, [None])) # Text lengths.. @@ -123,7 +158,8 @@ def restore(self, session): v for v in tf.global_variables() if "module/" not in v.name] saver = tf.train.Saver(vars_to_restore) checkpoint_path = os.path.join( - self.config["log_dir"], "model.max.ckpt") + self.config['log_root'], self.config['log_dir'], "model.max.ckpt") + logger.info("Restoring coref model from {}".format(checkpoint_path)) session.run(tf.global_variables_initializer()) saver.restore(session, checkpoint_path) diff --git a/e2edutch/download.py b/e2edutch/download.py index bb81d92..6cf5c96 100644 --- a/e2edutch/download.py +++ b/e2edutch/download.py @@ -1,11 +1,14 @@ -import os import requests import gzip import zipfile +import argparse +import logging from tqdm import tqdm from pathlib import Path from e2edutch import util +logger = logging.getLogger() + def download_file(url, path): """ @@ -28,11 +31,13 @@ def download_file(url, path): def download_data(config={}): - data_dir = Path(util.get_data_dir(config)) # Create the data directory if it doesn't exist yet + data_dir = Path(config['datapath']) + logger.info('Downloading to {}'.format(data_dir)) data_dir.mkdir(parents=True, exist_ok=True) # Download word vectors + logger.info('Download word vectors') url = "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nl.300.vec.gz" fname = data_dir / 'fasttext.300.vec' fname_gz = data_dir / 'fasttext.300.vec.gz' @@ -46,29 +51,54 @@ def download_data(config={}): fout.write(line) # Remove gz file fname_gz.unlink() + else: + logger.info('Word vectors file already exists') # Download e2e dutch model_ + logger.info('Download e2e model') url = "https://surfdrive.surf.nl/files/index.php/s/UnZMyDrBEFunmQZ/download" fname_zip = data_dir / 'model.zip' log_dir_name = data_dir / 'final' - if not fname_zip.exists() and not log_dir_name.exists(): + model_file = log_dir_name / 'model.max.ckpt.index' + if not fname_zip.exists() and not model_file.exists(): download_file(url, fname_zip) - if not log_dir_name.exists(): + if not model_file.exists(): with zipfile.ZipFile(fname_zip, 'r') as zfile: zfile.extractall(data_dir) Path(data_dir / 'logs' / 'final').rename(log_dir_name) Path(data_dir, 'logs').rmdir() + else: + logger.info('E2e model file already exists') # Download char_dict + logger.info('Download char dict') url = "https://github.com/Filter-Bubble/e2e-Dutch/raw/v0.2.0/data/char_vocab.dutch.txt" fname = data_dir / 'char_vocab.dutch.txt' if not fname.exists(): download_file(url, fname) + else: + logger.info('Char dict file already exists') + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument('-d', '--datapath', default=None) + parser.add_argument('-v', '--verbose', action='store_true') + return parser def main(): + parser = get_parser() + args = parser.parse_args() + if args.verbose: + # logger.setLevel(logging.INFO) + logging.basicConfig(level=logging.INFO) # To do: argparse for config file - download_data() + if args.datapath is None: + config = util.initialize_from_env(model_name='final') + else: + config = {'datapath': args.datapath} + download_data(config) if __name__ == "__main__": diff --git a/e2edutch/predict.py b/e2edutch/predict.py index b920277..651ac2a 100755 --- a/e2edutch/predict.py +++ b/e2edutch/predict.py @@ -18,11 +18,18 @@ class Predictor(object): - def __init__(self, model_name='final', cfg_file=None, verbose=False): + def __init__(self, model_name='final', config=None, verbose=False): if verbose: logger.setLevel(logging.INFO) - self.config = util.initialize_from_env(model_name, cfg_file) + + if config: + self.config = config + else: + # if no configuration is provided, try to get a default config. + self.config = util.initialize_from_env(model_name=model_name) + self.session = tf.compat.v1.Session() + try: self.model = cm.CorefModel(self.config) self.model.restore(self.session) @@ -93,7 +100,6 @@ def main(args=None): args = parser.parse_args() if args.verbose: logger.setLevel(logging.INFO) - # config = util.initialize_from_env(args.config, args.cfg_file) # Input file in .jsonlines format or .conll. input_filename = args.input_filename @@ -120,7 +126,10 @@ def main(args=None): docs = [util.create_example(text)] output_file = args.output_file - predictor = Predictor(args.config, args.cfg_file) + + config = util.initialize_from_env(cfg_file=args.cfg_file, model_cfg_file=args.config) + predictor = Predictor(config=config) + sentences = {} predictions = {} for example_num, example in enumerate(docs): diff --git a/e2edutch/stanza.py b/e2edutch/stanza.py new file mode 100755 index 0000000..85e6744 --- /dev/null +++ b/e2edutch/stanza.py @@ -0,0 +1,130 @@ +import os +import stanza +import logging + +from pathlib import Path + +from e2edutch import util +from e2edutch import coref_model as cm +from e2edutch.download import download_data +from e2edutch.predict import Predictor + +from stanza.pipeline.processor import Processor, register_processor +from stanza.models.common.doc import Document, Span + + +# Add a Clusters property to documents as a List of List of Span: +# Clusters is a List of cluster, cluster is a List of Span +def clusterSetter(self, value): + if isinstance(value, type([])): + self._clusters = value + else: + logger.error('Clusters must be a List') + +stanza.models.common.doc.Document.add_property('clusters', default='[]', setter=clusterSetter) + + +import tensorflow.compat.v1 as tf + +logger = logging.getLogger('e2edutch') +logger.setLevel(logging.INFO) +logger.addHandler(logging.StreamHandler()) + + +@register_processor('coref') +class CorefProcessor(Processor): + ''' Processor that appends coreference information ''' + _requires = set(['tokenize']) + _provides = set(['coref']) + + def __init__(self, config, pipeline, use_gpu): + # Make e2edutch follow Stanza's GPU settings: + # set the environment value for GPU, so that initialize_from_env picks it up. + #if use_gpu: + # os.environ['GPU'] = ' '.join(tf.config.experimental.list_physical_devices('GPU')) + #else: + # if 'GPU' in os.environ['GPU'] : + # os.environ.pop('GPU') + + self.e2econfig = util.initialize_from_env(model_name='final') + + # Override datapath and log_root: + # store e2edata with the Stanza resources, ie. a 'stanza_resources/nl/coref' directory + self.e2econfig['datapath'] = Path(config['model_path']).parent + self.e2econfig['log_root'] = Path(config['model_path']).parent + + # Download data files if not present + download_data(self.e2econfig) + + # Start and stop a session to cache all models + predictor = Predictor(config=self.e2econfig) + predictor.end_session() + + def _set_up_model(self, *args): + print ('_set_up_model') + pass + + def process(self, doc): + + predictor = Predictor(config=self.e2econfig) + + # build the example argument for predict: + # example (dict): dict with the following fields: + # sentences ([[str]]) + # doc_id (str) + # clusters ([[(int, int)]]) (optional) + example = {} + example['sentences'] = [] + example['doc_id'] = 'document_from_stanza' # TODO check what this should be + example['doc_key'] = 'undocumented' # TODO check what this should be + + for sentence in doc.sentences: + s = [] + for word in sentence.words: + s.append(word.text) + example['sentences'].append(s) + + predicted_clusters = predictor.predict(example) # a list of tuples + + # Add the predicted clusters back to the Stanza document + + clusters = [] + for predicted_cluster in predicted_clusters: # a tuple of entities + cluster = [] + for predicted_reference in predicted_cluster: # a tuple of (start, end) word + start, end = predicted_reference + + # find the sentence_id of the sentence containing this reference + sentence_id = 0 + sentence = doc.sentences[0] + sentence_start_word = 0 + sentence_end_word = len(sentence.words) - 1 + + while sentence_end_word < start: + sentence_start_word = sentence_end_word + 1 + + # move to the next sentence + sentence_id += 1 + sentence = doc.sentences[sentence_id] + + sentence_end_word = sentence_start_word + len(sentence.words) - 1 + + # start counting words from the start of this sentence + start -= sentence_start_word + end -= sentence_start_word + + span = Span( # a list of Tokens + tokens=[word.parent for word in sentence.words[start:end + 1]], + doc=doc, + type='COREF', + sent=doc.sentences[sentence_id] + ) + cluster.append(span) + + clusters.append(cluster) + + doc.clusters = clusters + + predictor.end_session() + + return doc diff --git a/e2edutch/train.py b/e2edutch/train.py index 437b62a..8a7370d 100755 --- a/e2edutch/train.py +++ b/e2edutch/train.py @@ -58,7 +58,7 @@ def main(args=None): model = cm.CorefModel(config) saver = tf.train.Saver() - log_dir = config["log_dir"] + log_dir = os.path.join(config['log_root'], config['log_dir']) writer = tf.summary.FileWriter(log_dir, flush_secs=20) max_f1 = 0 diff --git a/e2edutch/util.py b/e2edutch/util.py index de7333a..e5a4b3f 100644 --- a/e2edutch/util.py +++ b/e2edutch/util.py @@ -15,23 +15,25 @@ logger = logging.getLogger('e2edutch') -def get_data_dir(config): - if config.get('datapath', None) is not None: - path = config['datapath'] - elif os.environ.get('E2E_HOME', None) is not None: - path = os.environ['E2E_HOME'] - else: - path = Path(__file__).parent / "data" - return path +def initialize_from_env(model_name='final', cfg_file=None, model_cfg_file=None): + '''Read configuration files + + Read configuration files cfg_file and model_cfg_file from provided + filenames. If none given, use default config files provided by e2edutch: + cfg/defaults.conf for cfg_file, and + cfg/models.conf for model_cfg_file + + Configure Tensorflow to use a gpu or cpu based on the environment values of GPU. -def initialize_from_env(model_name, cfg_file=None, model_cfg_file=None): + Returns a config dict + ''' if "GPU" in os.environ: set_gpus(int(os.environ["GPU"])) else: set_gpus() - logger.info("Running model: {}".format(model_name)) + logger.info('Running model: {}'.format(model_name)) if cfg_file is None: cfg_file = pkg_resources.resource_filename( @@ -42,12 +44,21 @@ def initialize_from_env(model_name, cfg_file=None, model_cfg_file=None): config_base = pyhocon.ConfigFactory.parse_file(cfg_file) config_model = pyhocon.ConfigFactory.parse_file(model_cfg_file)[model_name] config = pyhocon.ConfigTree.merge_configs(config_model, config_base) - config['datapath'] = get_data_dir(config) + + # Override datapath from environment, if set + if os.environ.get('E2E_HOME', None) is not None: + config['datapath'] = os.environ['E2E_HOME'] + + # Finally, provide fallback for datapath + if config.get('datapath', None) is None: + config['datapath'] = Path(__file__).parent / "data" + config['log_root'] = config['datapath'] + config['log_dir'] = model_name - config["log_dir"] = mkdirs(os.path.join(config["log_root"], model_name)) + mkdirs(os.path.join(config['log_root'], config['log_dir'])) - logger.debug(pyhocon.HOCONConverter.convert(config, "hocon")) + logger.debug(pyhocon.HOCONConverter.convert(config, 'hocon')) return config diff --git a/setup.py b/setup.py index cb36085..58321c3 100644 --- a/setup.py +++ b/setup.py @@ -69,7 +69,8 @@ def run(self): "scikit-learn", "torch", "transformers<=3.5.1", - "KafNafParserPy" + "KafNafParserPy", + "stanza" ], tests_require=[ 'pytest', diff --git a/test/test_stanza.py b/test/test_stanza.py new file mode 100644 index 0000000..59e65cd --- /dev/null +++ b/test/test_stanza.py @@ -0,0 +1,10 @@ +import stanza +import e2edutch.stanza +import tensorflow.compat.v1 as tf + + +def test_processor(): + nlp = stanza.Pipeline(lang='nl', processors='tokenize,coref') + text = 'Dit is een tekst.' + doc = nlp(text) + # TODO: asserts about the doc having corefs