Skip to content

Commit

Permalink
Rename match to measure
Browse files Browse the repository at this point in the history
  • Loading branch information
jnothman committed Aug 8, 2014
1 parent 6cb084f commit 97c0bed
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 159 deletions.
6 changes: 3 additions & 3 deletions neleval/analyze.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .document import ENC
from .document import Reader
from .document import by_mention
from .evaluate import get_matcher
from .evaluate import get_measure
from collections import Counter
from collections import namedtuple

Expand Down Expand Up @@ -81,10 +81,10 @@ def _data():
def iter_errors(self):
system = list(Reader(open(self.system), group=by_mention))
gold = list(Reader(open(self.gold), group=by_mention))
matcher = get_matcher('strong_mention_match')
measure = get_measure('strong_mention_match')
for g, s in zip(gold, system):
assert g.id == s.id
tp, fp, fn = matcher.get_matches(g.annotations, s.annotations)
tp, fp, fn = measure.get_matches(g.annotations, s.annotations)
for g_m, s_m in tp:
if g_m.kbid == s_m.kbid and not self.with_correct:
#continue # Correct case.
Expand Down
14 changes: 7 additions & 7 deletions neleval/annotation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python
"Representation of link standoff annotation and matching over it"
"Representation of link standoff annotation and measures over it"

from collections import Sequence, defaultdict
import operator
Expand Down Expand Up @@ -141,7 +141,7 @@ def from_string(cls, s):
raise SyntaxError('Need id, score and type when >1 candidates')


class Matcher(object):
class Measure(object):
__slots__ = ['key', 'filter', 'filter_fn', 'agg']

def __init__(self, key, filter=None, agg='sets-micro'):
Expand Down Expand Up @@ -179,7 +179,7 @@ def __repr__(self):
NON_CLUSTERING_AGG = ('sets-micro',) # 'sets-macro')

@property
def is_clustering_match(self):
def is_clustering(self):
return self.agg not in self.NON_CLUSTERING_AGG

def build_index(self, annotations):
Expand Down Expand Up @@ -211,7 +211,7 @@ def build_clusters(self, annotations):
return out

def count_matches(self, system, gold):
if self.is_clustering_match:
if self.is_clustering:
raise ValueError('count_matches is inappropriate '
'for {}'.format(self.agg))
gold_index = self.build_index(gold)
Expand All @@ -229,7 +229,7 @@ def get_matches(self, system, gold):
* fp [(None, other_item), ...]
* fn [(item, None), ...]
"""
if self.is_clustering_match:
if self.is_clustering:
raise ValueError('get_matches is inappropriate '
'for {}'.format(self.agg))
gold_index = self.build_index(gold)
Expand All @@ -244,7 +244,7 @@ def get_matches(self, system, gold):

def count_clustering(self, system, gold):
from . import coref_metrics
if not self.is_clustering_match:
if not self.is_clustering:
raise ValueError('evaluate_clustering is inappropriate '
'for {}'.format(self.agg))
try:
Expand All @@ -258,7 +258,7 @@ def count_clustering(self, system, gold):
return fn(gold_clusters, pred_clusters)

def contingency(self, system, gold):
if self.is_clustering_match:
if self.is_clustering:
p_num, p_den, r_num, r_den = self.count_clustering(system, gold)
ptp = p_num
fp = p_den - p_num
Expand Down
157 changes: 79 additions & 78 deletions neleval/configs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import textwrap
from collections import defaultdict
from .annotation import Matcher
from .annotation import Measure

try:
keys = dict.viewkeys
Expand All @@ -9,38 +9,38 @@
keys = dict.keys


MATCHERS = {
'strong_mention_match': Matcher(['span']),
'strong_linked_mention_match': Matcher(['span'], 'is_linked'),
'strong_link_match': Matcher(['span', 'kbid'], 'is_linked'),
'strong_nil_match': Matcher(['span'], 'is_nil'),
'strong_all_match': Matcher(['span', 'kbid']),
'strong_typed_all_match': Matcher(['span', 'type', 'kbid']),
'entity_match': Matcher(['docid', 'kbid'], 'is_linked'),
MEASURES = {
'strong_mention_match': Measure(['span']),
'strong_linked_mention_match': Measure(['span'], 'is_linked'),
'strong_link_match': Measure(['span', 'kbid'], 'is_linked'),
'strong_nil_match': Measure(['span'], 'is_nil'),
'strong_all_match': Measure(['span', 'kbid']),
'strong_typed_all_match': Measure(['span', 'type', 'kbid']),
'entity_match': Measure(['docid', 'kbid'], 'is_linked'),

'b_cubed_plus': Matcher(['span', 'kbid'], agg='b_cubed'),
'b_cubed_plus': Measure(['span', 'kbid'], agg='b_cubed'),
}

for name in ['muc', 'b_cubed', 'entity_ceaf', 'mention_ceaf', 'pairwise',
#'cs_b_cubed', 'entity_cs_ceaf', 'mention_cs_ceaf']:
]:
MATCHERS[name] = Matcher(['span'], agg=name)
MEASURES[name] = Measure(['span'], agg=name)


# Configuration constants
ALL_MATCHES = 'all'
ALL_MEASURES = 'all'
ALL_TAGGING = 'all-tagging'
ALL_COREF = 'all-coref'
TAC_MATCHES = 'tac'
TAC14_MATCHES = 'tac14'
TMP_MATCHES = 'tmp'
CORNOLTI_WWW13_MATCHES = 'cornolti'
HACHEY_ACL14_MATCHES = 'hachey'
LUO_MATCHES = 'luo'
CAI_STRUBE_MATCHES = 'cai'

MATCH_SETS = {
ALL_MATCHES: [
TAC_MEASURES = 'tac'
TAC14_MEASURES = 'tac14'
TMP_MEASURES = 'tmp'
CORNOLTI_WWW13_MEASURES = 'cornolti'
HACHEY_ACL14_MEASURES = 'hachey'
LUO_MEASURES = 'luo'
CAI_STRUBE_MEASURES = 'cai'

MEASURE_SETS = {
ALL_MEASURES: [
'all-tagging',
'all-coref',
],
Expand All @@ -64,29 +64,29 @@
#'cs_b_cubed',
'b_cubed_plus',
},
CORNOLTI_WWW13_MATCHES: [
CORNOLTI_WWW13_MEASURES: [
'strong_linked_mention_match',
'strong_link_match',
'entity_match',
],
HACHEY_ACL14_MATCHES: [
HACHEY_ACL14_MEASURES: [
'strong_mention_match', # full ner
'strong_linked_mention_match',
'strong_link_match',
'entity_match',
],
LUO_MATCHES: [
LUO_MEASURES: [
'muc',
'b_cubed',
'mention_ceaf',
'entity_ceaf',
],
#CAI_STRUBE_MATCHES: [
#CAI_STRUBE_MEASURES: [
# 'cs_b_cubed',
# 'entity_cs_ceaf',
# 'mention_cs_ceaf',
#],
TAC_MATCHES: [
TAC_MEASURES: [
'strong_link_match', # recall equivalent to kb accuracy before 2014
'strong_nil_match', # recall equivalent to nil accuracy before 2014
'strong_all_match', # equivalent to overall accuracy before 2014
Expand All @@ -96,118 +96,119 @@
'b_cubed',
'b_cubed_plus',
],
TAC14_MATCHES: [
TAC14_MEASURES: [
'strong_typed_all_match', # wikification f-score for TAC 2014
],
TMP_MATCHES: [
TMP_MEASURES: [
'mention_ceaf',
'entity_ceaf',
'pairwise',
],
}

DEFAULT_MATCH_SET = ALL_MATCHES
DEFAULT_MATCH = 'strong_all_match'
DEFAULT_MEASURE_SET = ALL_MEASURES
DEFAULT_MEASURE = 'strong_all_match'


def _expand(matches):
if isinstance(matches, str):
if matches in MATCH_SETS:
matches = MATCH_SETS[matches]
def _expand(measures):
if isinstance(measures, str):
if measures in MEASURE_SETS:
measures = MEASURE_SETS[measures]
else:
return [matches]
if isinstance(matches, Matcher):
return [Matcher]
if len(matches) == 1:
return _expand(matches[0])
return [m for group in matches for m in _expand(group)]
return [measures]
if isinstance(measures, Measure):
return [measures]
if len(measures) == 1:
return _expand(measures[0])
return [m for group in measures for m in _expand(group)]


def parse_matches(in_matches, incl_clustering=True):
def parse_measures(in_measures, incl_clustering=True):
# flatten nested sequences and expand group names
matches = _expand(in_matches)
measures = _expand(in_measures)
# remove duplicates while maintaining order
seen = set()
matches = [seen.add(m) or m
for m in matches if m not in seen]
measures = [seen.add(m) or m
for m in measures if m not in seen]

# TODO: make sure resolve to valid matchers
not_found = set(matches) - keys(MATCHERS)
# TODO: make sure resolve to valid measures
not_found = set(measures) - keys(MEASURES)
invalid = []
for m in not_found:
try:
get_matcher(m)
get_measure(m)
except Exception:
raise
invalid.append(m)
if invalid:
raise ValueError('Could not resolve matchers: {}'.format(sorted(not_found)))
raise ValueError('Could not resolve measures: '
'{}'.format(sorted(not_found)))

if not incl_clustering:
matches = [m for m in matches
if not get_matcher(m).is_clustering_match]
measures = [m for m in measures
if not get_measure(m).is_clustering]
# TODO: remove clustering metrics given flag
# raise error if empty
if not matches:
msg = 'Could not resolve {!r} to any matches.'.format(in_matches)
if not measures:
msg = 'Could not resolve {!r} to any measures.'.format(in_measures)
if not incl_clustering:
msg += ' Clustering measures have been excluded.'
raise ValueError(msg)
return matches
return measures


def get_matcher(name):
if isinstance(name, Matcher):
def get_measure(name):
if isinstance(name, Measure):
return name
if name.count(':') == 2:
return Matcher.from_string(name)
return MATCHERS[name]
return Measure.from_string(name)
return MEASURES[name]


def get_match_choices():
return sorted(MATCH_SETS.keys()) + sorted(MATCHERS.keys())
def get_measure_choices():
return sorted(MEASURE_SETS.keys()) + sorted(MEASURES.keys())


MATCH_HELP = ('Which metrics to use: specify a name (or group name) from the '
'list-metrics command. This flag may be repeated.')
MEASURE_HELP = ('Which measures to use: specify a name (or group name) from '
'the list-measures command. This flag may be repeated.')


def _wrap(text):
return '\n'.join(textwrap.wrap(text))


class ListMeasures(object):
"""List matching schemes available for evaluation"""
"""List measures schemes available for evaluation"""

def __init__(self, matches=None):
self.matches = matches
def __init__(self, measures=None):
self.measures = measures

def __call__(self):
matches = parse_matches(self.matches or get_match_choices())
measures = parse_measures(self.measures or get_measure_choices())
header = ['Name', 'Aggregate', 'Filter', 'Key Fields', 'In groups']
rows = [header]

set_membership = defaultdict(list)
for set_name, match_set in sorted(MATCH_SETS.items()):
for name in parse_matches(match_set):
for set_name, measure_set in sorted(MEASURE_SETS.items()):
for name in parse_measures(measure_set):
set_membership[name].append(set_name)

for name in sorted(matches):
matcher = get_matcher(name)
rows.append((name, matcher.agg, str(matcher.filter),
'+'.join(matcher.key),
for name in sorted(measures):
measure = get_measure(name)
rows.append((name, measure.agg, str(measure.filter),
'+'.join(measure.key),
', '.join(set_membership[name])))

col_widths = [max(len(row[i]) for row in rows)
for i in range(len(header))]
rows.insert(1, ['=' * w for w in col_widths])
fmt = '\t'.join('{:%ds}' % w for w in col_widths[:-1]) + '\t{}'
ret = _wrap('The following lists possible values for --match (-m) in '
'evaluate, confidence and significance. The name from '
ret = _wrap('The following lists possible values for --measure (-m) '
'in evaluate, confidence and significance. The name from '
'each row or the name of a group may be used. ') + '\n\n'
ret = '\n'.join(textwrap.wrap(ret)) + '\n\n'
ret += '\n'.join(fmt.format(*row) for row in rows)
ret += '\n\nDefault evaluation group: {}'.format(DEFAULT_MATCH_SET)
ret += '\n\nDefault evaluation group: {}'.format(DEFAULT_MEASURE_SET)
ret += '\n\n'
ret += _wrap('In all measures, a set of tuples corresponding to Key '
'Fields is produced from annotations matching Filter. '
Expand All @@ -217,12 +218,12 @@ def __call__(self):
ret += '\n\n'
ret += ('A measure may be specified explicitly. Thus:\n'
' {}\nmay be entered as\n {}'
''.format(DEFAULT_MATCH, get_matcher(DEFAULT_MATCH)))
''.format(DEFAULT_MEASURE, get_measure(DEFAULT_MEASURE)))
return ret

@classmethod
def add_arguments(cls, p):
p.add_argument('-m', '--match', dest='matches', action='append',
metavar='NAME', help=MATCH_HELP)
p.add_argument('-m', '--measure', dest='measures', action='append',
metavar='NAME', help=MEASURE_HELP)
p.set_defaults(cls=cls)
return p
Loading

0 comments on commit 97c0bed

Please sign in to comment.