|
| 1 | +# _*_coding:utf-8_*_ |
| 2 | +from __future__ import print_function |
| 3 | + |
| 4 | +import glob |
| 5 | +import json |
| 6 | +import logging |
| 7 | +import os |
| 8 | +import re |
| 9 | +import socket |
| 10 | +import subprocess |
| 11 | +import sys |
| 12 | +import time |
| 13 | + |
| 14 | +import psutil |
| 15 | + |
| 16 | +try: |
| 17 | + from urlparse import urlparse |
| 18 | +except ImportError: |
| 19 | + from urllib.parse import urlparse |
| 20 | + |
| 21 | +import requests |
| 22 | + |
| 23 | + |
| 24 | +class StanfordCoreNLP: |
| 25 | + def __init__(self, path_or_host, port=None, memory='4g', lang='en', timeout=1500, quiet=True, |
| 26 | + logging_level=logging.WARNING, max_retries=5): |
| 27 | + self.path_or_host = path_or_host |
| 28 | + self.port = port |
| 29 | + self.memory = memory |
| 30 | + self.lang = lang |
| 31 | + self.timeout = timeout |
| 32 | + self.quiet = quiet |
| 33 | + self.logging_level = logging_level |
| 34 | + |
| 35 | + logging.basicConfig(level=self.logging_level) |
| 36 | + |
| 37 | + # Check args |
| 38 | + self._check_args() |
| 39 | + |
| 40 | + if path_or_host.startswith('http'): |
| 41 | + self.url = path_or_host + ':' + str(port) |
| 42 | + logging.info('Using an existing server {}'.format(self.url)) |
| 43 | + else: |
| 44 | + |
| 45 | + # Check Java |
| 46 | + if not subprocess.call(['java', '-version'], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) == 0: |
| 47 | + raise RuntimeError('Java not found.') |
| 48 | + |
| 49 | + # Check if the dir exists |
| 50 | + if not os.path.isdir(self.path_or_host): |
| 51 | + raise IOError(str(self.path_or_host) + ' is not a directory.') |
| 52 | + directory = os.path.normpath(self.path_or_host) + os.sep |
| 53 | + self.class_path_dir = directory |
| 54 | + |
| 55 | + # Check if the language specific model file exists |
| 56 | + switcher = { |
| 57 | + 'en': 'stanford-corenlp-[0-9].[0-9].[0-9]-models.jar', |
| 58 | + 'zh': 'stanford-chinese-corenlp-[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]-models.jar', |
| 59 | + 'ar': 'stanford-arabic-corenlp-[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]-models.jar', |
| 60 | + 'fr': 'stanford-french-corenlp-[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]-models.jar', |
| 61 | + 'de': 'stanford-german-corenlp-[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]-models.jar', |
| 62 | + 'es': 'stanford-spanish-corenlp-[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]-models.jar' |
| 63 | + } |
| 64 | + jars = { |
| 65 | + 'en': 'stanford-corenlp-x.x.x-models.jar', |
| 66 | + 'zh': 'stanford-chinese-corenlp-yyyy-MM-dd-models.jar', |
| 67 | + 'ar': 'stanford-arabic-corenlp-yyyy-MM-dd-models.jar', |
| 68 | + 'fr': 'stanford-french-corenlp-yyyy-MM-dd-models.jar', |
| 69 | + 'de': 'stanford-german-corenlp-yyyy-MM-dd-models.jar', |
| 70 | + 'es': 'stanford-spanish-corenlp-yyyy-MM-dd-models.jar' |
| 71 | + } |
| 72 | + if len(glob.glob(directory + switcher.get(self.lang))) <= 0: |
| 73 | + raise IOError(jars.get( |
| 74 | + self.lang) + ' not exists. You should download and place it in the ' + directory + ' first.') |
| 75 | + |
| 76 | + # If port not set, auto select |
| 77 | + if self.port is None: |
| 78 | + for port_candidate in range(9000, 65535): |
| 79 | + if port_candidate not in [conn.laddr[1] for conn in psutil.net_connections()]: |
| 80 | + self.port = port_candidate |
| 81 | + break |
| 82 | + |
| 83 | + # Check if the port is in use |
| 84 | + if self.port in [conn.laddr[1] for conn in psutil.net_connections()]: |
| 85 | + raise IOError('Port ' + str(self.port) + ' is already in use.') |
| 86 | + |
| 87 | + # Start native server |
| 88 | + logging.info('Initializing native server...') |
| 89 | + cmd = "java" |
| 90 | + java_args = "-Xmx{}".format(self.memory) |
| 91 | + java_class = "edu.stanford.nlp.pipeline.StanfordCoreNLPServer" |
| 92 | + class_path = '"{}*"'.format(directory) |
| 93 | + |
| 94 | + args = [cmd, java_args, '-cp', class_path, java_class, '-port', str(self.port)] |
| 95 | + |
| 96 | + args = ' '.join(args) |
| 97 | + |
| 98 | + logging.info(args) |
| 99 | + |
| 100 | + # Silence |
| 101 | + with open(os.devnull, 'w') as null_file: |
| 102 | + out_file = None |
| 103 | + if self.quiet: |
| 104 | + out_file = null_file |
| 105 | + |
| 106 | + self.p = subprocess.Popen(args, shell=True, stdout=out_file, stderr=subprocess.STDOUT) |
| 107 | + logging.info('Server shell PID: {}'.format(self.p.pid)) |
| 108 | + |
| 109 | + self.url = 'http://localhost:' + str(self.port) |
| 110 | + |
| 111 | + # Wait until server starts |
| 112 | + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| 113 | + host_name = urlparse(self.url).hostname |
| 114 | + time.sleep(1) # OSX, not tested |
| 115 | + trial = 1 |
| 116 | + while sock.connect_ex((host_name, self.port)): |
| 117 | + if trial > max_retries: |
| 118 | + raise ValueError('Corenlp server is not available') |
| 119 | + logging.info('Waiting until the server is available.') |
| 120 | + trial += 1 |
| 121 | + time.sleep(1) |
| 122 | + logging.info('The server is available.') |
| 123 | + |
| 124 | + def __enter__(self): |
| 125 | + return self |
| 126 | + |
| 127 | + def __exit__(self, exc_type, exc_val, exc_tb): |
| 128 | + self.close() |
| 129 | + |
| 130 | + def close(self): |
| 131 | + logging.info('Cleanup...') |
| 132 | + if hasattr(self, 'p'): |
| 133 | + try: |
| 134 | + parent = psutil.Process(self.p.pid) |
| 135 | + except psutil.NoSuchProcess: |
| 136 | + logging.info('No process: {}'.format(self.p.pid)) |
| 137 | + return |
| 138 | + |
| 139 | + if self.class_path_dir not in ' '.join(parent.cmdline()): |
| 140 | + logging.info('Process not in: {}'.format(parent.cmdline())) |
| 141 | + return |
| 142 | + |
| 143 | + children = parent.children(recursive=True) |
| 144 | + for process in children: |
| 145 | + logging.info('Killing pid: {}, cmdline: {}'.format(process.pid, process.cmdline())) |
| 146 | + # process.send_signal(signal.SIGTERM) |
| 147 | + process.kill() |
| 148 | + |
| 149 | + logging.info('Killing shell pid: {}, cmdline: {}'.format(parent.pid, parent.cmdline())) |
| 150 | + # parent.send_signal(signal.SIGTERM) |
| 151 | + parent.kill() |
| 152 | + |
| 153 | + def annotate(self, text, properties=None): |
| 154 | + if sys.version_info.major >= 3: |
| 155 | + text = text.encode('utf-8') |
| 156 | + |
| 157 | + r = requests.post(self.url, params={'properties': str(properties)}, data=text, |
| 158 | + headers={'Connection': 'close'}) |
| 159 | + return r.text |
| 160 | + |
| 161 | + def tregex(self, sentence, pattern): |
| 162 | + tregex_url = self.url + '/tregex' |
| 163 | + r_dict = self._request(tregex_url, "tokenize,ssplit,depparse,parse", sentence, pattern=pattern) |
| 164 | + return r_dict |
| 165 | + |
| 166 | + def tokensregex(self, sentence, pattern): |
| 167 | + tokensregex_url = self.url + '/tokensregex' |
| 168 | + r_dict = self._request(tokensregex_url, "tokenize,ssplit,depparse", sentence, pattern=pattern) |
| 169 | + return r_dict |
| 170 | + |
| 171 | + def semgrex(self, sentence, pattern): |
| 172 | + semgrex_url = self.url + '/semgrex' |
| 173 | + r_dict = self._request(semgrex_url, "tokenize,ssplit,depparse", sentence, pattern=pattern) |
| 174 | + return r_dict |
| 175 | + |
| 176 | + def word_tokenize(self, sentence, span=False): |
| 177 | + r_dict = self._request('ssplit,tokenize', sentence) |
| 178 | + tokens = [token['originalText'] for s in r_dict['sentences'] for token in s['tokens']] |
| 179 | + |
| 180 | + # Whether return token span |
| 181 | + if span: |
| 182 | + spans = [(token['characterOffsetBegin'], token['characterOffsetEnd']) for s in r_dict['sentences'] for token |
| 183 | + in s['tokens']] |
| 184 | + return tokens, spans |
| 185 | + else: |
| 186 | + return tokens |
| 187 | + |
| 188 | + def pos_tag(self, sentence): |
| 189 | + r_dict = self._request(self.url, 'pos', sentence) |
| 190 | + words = [] |
| 191 | + tags = [] |
| 192 | + for s in r_dict['sentences']: |
| 193 | + for token in s['tokens']: |
| 194 | + words.append(token['originalText']) |
| 195 | + tags.append(token['pos']) |
| 196 | + return list(zip(words, tags)) |
| 197 | + |
| 198 | + def ner(self, sentence): |
| 199 | + r_dict = self._request(self.url, 'ner', sentence) |
| 200 | + words = [] |
| 201 | + ner_tags = [] |
| 202 | + for s in r_dict['sentences']: |
| 203 | + for token in s['tokens']: |
| 204 | + words.append(token['originalText']) |
| 205 | + ner_tags.append(token['ner']) |
| 206 | + return list(zip(words, ner_tags)) |
| 207 | + |
| 208 | + def parse(self, sentence): |
| 209 | + r_dict = self._request(self.url, 'pos,parse', sentence) |
| 210 | + return [s['parse'] for s in r_dict['sentences']] |
| 211 | + |
| 212 | + def dependency_parse(self, sentence): |
| 213 | + r_dict = self._request(self.url, 'depparse', sentence) |
| 214 | + return [(dep['dep'], dep['governor'], dep['dependent']) for s in r_dict['sentences'] for dep in |
| 215 | + s['basicDependencies']] |
| 216 | + |
| 217 | + def coref(self, text): |
| 218 | + r_dict = self._request('coref', text) |
| 219 | + |
| 220 | + corefs = [] |
| 221 | + for k, mentions in r_dict['corefs'].items(): |
| 222 | + simplified_mentions = [] |
| 223 | + for m in mentions: |
| 224 | + simplified_mentions.append((m['sentNum'], m['startIndex'], m['endIndex'], m['text'])) |
| 225 | + corefs.append(simplified_mentions) |
| 226 | + return corefs |
| 227 | + |
| 228 | + def switch_language(self, language="en"): |
| 229 | + self._check_language(language) |
| 230 | + self.lang = language |
| 231 | + |
| 232 | + # def _request(self, url, annotators=None, data=None, *args, **kwargs): |
| 233 | + # if sys.version_info.major >= 3: |
| 234 | + # data = data.encode('utf-8') |
| 235 | + # |
| 236 | + # properties = {'annotators': annotators, 'outputFormat': 'json'} |
| 237 | + # params = {'properties': str(properties), 'pipelineLanguage': self.lang} |
| 238 | + # if 'pattern' in kwargs: |
| 239 | + # params = {"pattern": kwargs['pattern'], 'properties': str(properties), 'pipelineLanguage': self.lang} |
| 240 | + # |
| 241 | + # logging.info(params) |
| 242 | + # r = requests.post(url, params=params, data=data, headers={'Connection': 'close'}) |
| 243 | + # r_dict = json.loads(r.text) |
| 244 | + # |
| 245 | + # return r_dict |
| 246 | + |
| 247 | + def request(self, annotators=None, data=None, *args, **kwargs): |
| 248 | + # if sys.version_info.major >= 3: |
| 249 | + data = data.encode('utf-8') |
| 250 | + |
| 251 | + properties = {'annotators': annotators, 'outputFormat': 'json'} |
| 252 | + params = {'properties': str(properties), 'pipelineLanguage': self.lang, |
| 253 | + 'parse.model': 'edu/stanford/nlp/models/lexparser/chinesePCFG.ser.gz', |
| 254 | + 'parse.kbest': 3} |
| 255 | + if 'pattern' in kwargs: |
| 256 | + params = {"pattern": kwargs['pattern'], 'properties': str(properties), 'pipelineLanguage': self.lang} |
| 257 | + |
| 258 | + logging.info(params) |
| 259 | + r = requests.post(self.url, params=params, data=data, headers={'Connection': 'close'}) |
| 260 | + r_dict = json.loads(r.text) |
| 261 | + |
| 262 | + return r_dict |
| 263 | + |
| 264 | + def _check_args(self): |
| 265 | + self._check_language(self.lang) |
| 266 | + if not re.match('\dg', self.memory): |
| 267 | + raise ValueError('memory=' + self.memory + ' not supported. Use 4g, 6g, 8g and etc. ') |
| 268 | + |
| 269 | + def _check_language(self, lang): |
| 270 | + if lang not in ['en', 'zh', 'ar', 'fr', 'de', 'es']: |
| 271 | + raise ValueError('lang=' + self.lang + ' not supported. Use English(en), Chinese(zh), Arabic(ar), ' |
| 272 | + 'French(fr), German(de), Spanish(es).') |
0 commit comments