"""Evaluation of (discontinuous) parse trees.
Designed to behave like the reference implementation EVALB [1] for regular
parse trees, with a natural extension to the discontinuous case. Also provides
additional, alternative parse tree evaluation metrics (leaf ancestor, tree-edit
distance, unlabeled dependencies), as well as facilities for error analysis.
[1] http://nlp.cs.nyu.edu/evalb/"""
# pylint: disable=cell-var-from-loop
import io
import sys
from getopt import gnu_getopt, GetoptError
from decimal import Decimal, InvalidOperation
from collections import defaultdict, Counter # == multiset
from itertools import count, zip_longest
from . import grammar
from .tree import Tree, DrawTree, isdisc, bitfanout
from .treebank import READERS, dependencies, handlefunctions
from .treetransforms import getbits
from .treebanktransforms import functions
from .treedist import treedist, newtreedist
SHORTUSAGE = 'Usage: discodop eval <gold> <parses> [param] [options]'
HEADER = (
' Sentence Matched Brackets Corr POS\n'
' ID Length Recall Precis Bracket gold cand Words POS Accur.'
).splitlines()
[docs]class Evaluator(object):
"""Incremental evaluator for syntactic trees."""
def __init__(self, param, keylen=8):
"""Initialize evaluator object with given parameters.
:param param: a dictionary of parameters, as read by ``readparam``.
:param keylen: the length of the longest sentence ID, for padding
purposes."""
self.param = param
self.keylen = keylen
self.acc = EvalAccumulator(param['DISC_ONLY'])
self.acc40 = None
if param['CUTOFF_LEN'] is not None:
self.acc40 = EvalAccumulator(param['DISC_ONLY'])
if param['DEBUG'] >= 1:
print('Parameters:')
for a in param:
print('%s\t%s' % (a, param[a]))
for a in HEADER:
print(' ' * (self.keylen - 4) + a)
print('', '_' * ((self.keylen - 5) + len(HEADER[-1])))
[docs] def add(self, n, gtree, gsent, ctree, csent):
"""Add a pair of gold and candidate trees to the evaluation.
:param n: a unique identifier for this sentence.
:param gtree, ctree: ParentedTree objects (will be modified in-place)
:param gsent, csent: lists of tokens.
:returns: a ``TreePairResult`` object."""
treepair = TreePairResult(n, gtree, gsent, ctree, csent, self.param)
self.acc.add(treepair)
if (self.param['CUTOFF_LEN'] is not None
and treepair.lengpos <= self.param['CUTOFF_LEN']):
self.acc40.add(treepair)
if self.param['DEBUG'] > 1:
treepair.debug()
for a in HEADER:
print(' ' * (self.keylen - 4) + a)
print('', '_' * ((self.keylen - 5) + len(HEADER[-1])))
if self.param['DEBUG'] >= 1:
treepair.info('%%%ds ' % self.keylen)
return treepair
[docs] def breakdowns(self):
"""Print breakdowns for the most frequent rules, labels, tags."""
limit = 10 if self.param['DEBUG'] <= 0 else None
self.rulebreakdowns(limit)
self.catbreakdown(limit)
if self.acc.candfun:
self.funcbreakdown(limit)
try:
acc = accuracy(self.acc.goldpos, self.acc.candpos)
except InvalidOperation:
pass
else:
if acc != 1:
self.tagbreakdown(limit)
print()
[docs] def rulebreakdowns(self, limit=10):
"""Print breakdowns for the most frequent rule mismatches."""
acc = self.acc
# NB: unary nodes not handled properly
gmismatch = {(n, indices): rule
for n, indices, rule in acc.goldrule - acc.candrule}
wrong = Counter((rule, gmismatch[n, indices]) for n, indices, rule
in acc.candrule - acc.goldrule
if pyintbitcount(indices) > 1 and (n, indices) in gmismatch)
print('\n Rewrite rule mismatches (for given span)')
print(' count cand / gold rules')
for (crule, grule), cnt in wrong.most_common(limit):
print(' %7d %s' % (cnt, grammar.printrule(*crule)))
print(' %7s %s' % (' ', grammar.printrule(*grule)))
gspans = {(n, indices) for n, indices, _ in acc.goldrule}
wrong = Counter(rule for n, indices, rule
in acc.candrule - acc.goldrule
if pyintbitcount(indices) > 1 and (n, indices) not in gspans)
print('\n Rewrite rules (span not in gold trees)')
print(' count rule in candidate parses')
for crule, cnt in wrong.most_common(limit):
print(' %7d %s' % (cnt, grammar.printrule(*crule)))
cspans = {(n, indices) for n, indices, _ in acc.candrule}
wrong = Counter(rule for n, indices, rule
in acc.goldrule - acc.candrule
if pyintbitcount(indices) > 1 and (n, indices) not in cspans)
print('\n Rewrite rules (span missing from candidate parses)')
print(' count rule in gold standard set')
for grule, cnt in wrong.most_common(limit):
print(' %7d %s' % (cnt, grammar.printrule(*grule)))
[docs] def catbreakdown(self, limit=10):
"""Print breakdowns for the most frequent labels."""
acc = self.acc
print('\n Attachment errors (correct labeled bracketing, wrong parent)')
print(' label cand gold count')
print(' ' + 33 * '_')
gmismatch = dict(acc.goldbatt - acc.candbatt)
wrong = Counter((label, cparent, gmismatch[n, label, indices])
for (n, label, indices), cparent
in acc.candbatt - acc.goldbatt
if (n, label, indices) in gmismatch)
for (cat, gparent, cparent), cnt in wrong.most_common(limit):
print('%s %s %s %7d' % (cat.rjust(7), gparent.rjust(7),
cparent.rjust(7), cnt))
print('\n Category Statistics (%s categories / errors)' % (
('%d most frequent' % limit) if limit else 'all'))
print(' label % gold recall prec. F1',
' cand gold count')
print(' ' + 38 * '_' + 8 * ' ' + 24 * '_')
gmismatch = {(n, indices): label
for n, (label, indices) in acc.goldb - acc.candb}
wrong = Counter((label, gmismatch[n, indices])
for n, (label, indices) in acc.candb - acc.goldb
if (n, indices) in gmismatch)
freqcats = sorted(set(acc.goldbcat) | set(acc.candbcat),
key=lambda x: len(acc.goldbcat[x]), reverse=True)
for cat, mismatch in zip_longest(freqcats[:limit],
wrong.most_common(limit)):
if cat is None:
print(39 * ' ', end='')
else:
print('%s %6.2f %s %s %s' % (
cat.rjust(7),
100 * sum(acc.goldbcat[cat].values()) / len(acc.goldb),
nozerodiv(lambda: recall(
acc.goldbcat[cat], acc.candbcat[cat])),
nozerodiv(lambda: precision(
acc.goldbcat[cat], acc.candbcat[cat])),
nozerodiv(lambda: f_measure(
acc.goldbcat[cat], acc.candbcat[cat])),
), end='')
if mismatch is not None:
print(' %s %7d' % (' '.join((mismatch[0][0].rjust(8),
mismatch[0][1].ljust(8))), mismatch[1]), end='')
print()
[docs] def funcbreakdown(self, limit=10):
"""Print breakdowns for the most frequent function tags."""
acc = self.acc
print('\n Function Tag Statistics (%s tags / errors)' % (
('%d most frequent' % limit) if limit else 'all'))
print(' func. % gold recall prec. F1',
' cand gold count')
print(' ' + 38 * '_' + 8 * ' ' + 24 * '_')
gmismatch = {(n, span): tag
for n, (span, tag) in acc.goldfun - acc.candfun}
wrong = Counter((tag, gmismatch[n, span])
for n, (span, tag) in acc.candfun - acc.goldfun
if (n, span) in gmismatch)
freqcats = sorted(set(acc.goldbfunc) | set(acc.candbfunc),
key=lambda x: len(acc.goldbfunc[x]), reverse=True)
for cat, mismatch in zip_longest(freqcats[:limit],
wrong.most_common(limit)):
if cat is None:
print(39 * ' ', end='')
else:
print('%s %s %s %s %s' % (
cat.rjust(7),
nozerodiv(lambda: sum(acc.goldbfunc[cat].values())
/ len(acc.goldfun)),
nozerodiv(lambda: recall(
acc.goldbfunc[cat], acc.candbfunc[cat])),
nozerodiv(lambda: precision(
acc.goldbfunc[cat], acc.candbfunc[cat])),
nozerodiv(lambda: f_measure(
acc.goldbfunc[cat], acc.candbfunc[cat])),
), end='')
if mismatch is not None:
print(' %s %7d' % (' '.join((mismatch[0][0].rjust(8),
mismatch[0][1].ljust(8))), mismatch[1]), end='')
print()
[docs] def tagbreakdown(self, limit=10):
"""Print breakdowns for the most frequent tags."""
acc = self.acc
print('\n POS Statistics (%s tags / errors)' % (
('%d most frequent' % limit) if limit else 'all'), end='')
print('\n tag % gold recall prec. F1',
' cand gold count')
print(' ' + 38 * '_' + 12 * ' ' + 20 * '_')
tags = Counter(acc.goldpos)
wrong = Counter((c, g) for c, g
in zip(acc.candpos, acc.goldpos) if c != g)
for tag, mismatch in zip_longest(tags.most_common(limit),
wrong.most_common(limit)):
if tag is None:
print(''.rjust(40), end='')
else:
# only one tag per index may occur, but multiset is required by
# metrics
goldtag = Counter(n for n, t in enumerate(acc.goldpos)
if t == tag[0])
candtag = Counter(n for n, t in enumerate(acc.candpos)
if t == tag[0])
print('%s %6.2f %6.2f %6.2f %6.2f' % (
tag[0].rjust(7),
100 * len(goldtag) / len(acc.goldpos),
100 * recall(goldtag, candtag),
100 * precision(goldtag, candtag),
100 * f_measure(goldtag, candtag)), end='')
if mismatch is not None:
print(' %s %7d' % (' '.join((mismatch[0][0].rjust(8),
mismatch[0][1].ljust(8))).rjust(12), mismatch[1]),
end='')
print()
[docs] def summary(self):
""":returns: a string with an overview of scores for all sentences."""
acc = self.acc
acc40 = self.acc40
discbrackets = sum(1 for _, (_, a) in acc.candb.elements()
if bitfanout(a) > 1)
gdiscbrackets = sum(1 for _, (_, a) in acc.goldb.elements()
if bitfanout(a) > 1)
if acc.maxlenseen <= self.param['CUTOFF_LEN']:
msg = ['%s' % ' Summary (ALL) '.center(35, '_'),
'number of sentences: %6d' % (acc.sentcount),
'longest sentence: %6d' % (acc.maxlenseen)]
if gdiscbrackets or discbrackets:
msg.extend(['gold brackets (disc.): %6d (%d)' % (
len(acc.goldb), gdiscbrackets),
'cand. brackets (disc.): %6d (%d)' % (
len(acc.candb), discbrackets)])
else:
msg.extend(['gold brackets: %6d' % len(acc.goldb),
'cand. brackets: %6d' % len(acc.candb)])
msg.extend([
'labeled recall: %s' % (
nozerodiv(lambda: recall(acc.goldb, acc.candb))),
'labeled precision: %s' % (
nozerodiv(lambda: precision(acc.goldb, acc.candb))),
'labeled f-measure: %s' % (
nozerodiv(lambda: f_measure(acc.goldb, acc.candb))),
'exact match: %s' % (
nozerodiv(lambda: acc.exact / acc.sentcount))])
if self.param['LA']:
msg.append('leaf-ancestor: %s' % (
nozerodiv(lambda: mean(acc.lascores))))
if self.param['TED']:
msg.append('tree-dist (Dice micro avg) %s' % (
nozerodiv(lambda: 1 - acc.dicenoms / acc.dicedenoms)))
if self.param['DEP']:
msg.append('unlabeled dependencies: %s' % (
nozerodiv(lambda: accuracy(acc.golddep, acc.canddep))))
if acc.candfun:
msg.append('function tags: %s' %
nozerodiv(lambda: f_measure(acc.goldfun, acc.candfun)))
msg.append('pos accuracy: %s' % (
nozerodiv(lambda: accuracy(acc.goldpos, acc.candpos))))
return '\n'.join(msg)
discbrackets40 = sum(1 for _, (_, a) in acc40.candb.elements()
if bitfanout(a) > 1)
gdiscbrackets40 = sum(1 for _, (_, a) in acc40.goldb.elements()
if bitfanout(a) > 1)
msg = ['%s <= %d ______ ALL' % (
' Summary '.center(27, '_'), self.param['CUTOFF_LEN']),
'number of sentences: %6d %6d' % (
acc40.sentcount, acc.sentcount),
'longest sentence: %6d %6d' % (
acc40.maxlenseen, acc.maxlenseen),
'gold brackets: %6d %6d' % (
len(acc40.goldb), len(acc.goldb)),
'cand. brackets: %6d %6d' % (
len(acc40.candb), len(acc.candb))]
if gdiscbrackets or discbrackets:
msg.extend(['disc. gold brackets: %6d %6d' % (
gdiscbrackets40, gdiscbrackets),
'disc. cand. brackets: %6d %6d' % (
discbrackets40, discbrackets)])
msg.extend(['labeled recall: %s %s' % (
nozerodiv(lambda: recall(acc40.goldb, acc40.candb)),
nozerodiv(lambda: recall(acc.goldb, acc.candb))),
'labeled precision: %s %s' % (
nozerodiv(lambda: precision(acc40.goldb, acc40.candb)),
nozerodiv(lambda: precision(acc.goldb, acc.candb))),
'labeled f-measure: %s %s' % (
nozerodiv(lambda: f_measure(acc40.goldb, acc40.candb)),
nozerodiv(lambda: f_measure(acc.goldb, acc.candb))),
'exact match: %s %s' % (
nozerodiv(lambda: acc40.exact / acc40.sentcount),
nozerodiv(lambda: acc.exact / acc.sentcount))])
if self.param['LA']:
msg.append('leaf-ancestor: %s %s' % (
nozerodiv(lambda: mean(acc40.lascores)),
nozerodiv(lambda: mean(acc.lascores))))
if self.param['TED']:
msg.append('tree-dist (Dice micro avg) %s %s' % (
nozerodiv(lambda: (1 - acc40.dicenoms / acc40.dicedenoms)),
nozerodiv(lambda: (1 - acc.dicenoms / acc.dicedenoms))))
if self.param['DEP']:
msg.append('unlabeled dependencies: %s %s (%d / %d)' % (
nozerodiv(lambda: accuracy(acc40.golddep, acc40.canddep)),
nozerodiv(lambda: accuracy(acc.golddep, acc.canddep)),
sum(a[0] == a[1] for a in zip(acc.golddep, acc.canddep)),
len(acc.golddep)))
if acc.candfun:
msg.append('function tags: %s %s' % (
nozerodiv(lambda: f_measure(acc40.goldfun, acc40.candfun)),
nozerodiv(lambda: f_measure(acc.goldfun, acc.candfun))))
msg.append('pos accuracy: %s %s' % (
nozerodiv(lambda: accuracy(acc40.goldpos, acc40.candpos)),
nozerodiv(lambda: accuracy(acc.goldpos, acc.candpos))))
return '\n'.join(msg)
[docs]class TreePairResult(object):
"""Holds the evaluation result of a pair of trees."""
def __init__(self, n, gtree, gsent, ctree, csent, param):
"""Construct a pair of gold and candidate trees for evaluation."""
self.n = n
self.param = param
self.csentorig, self.gsentorig = csent, gsent
self.csent, self.gsent = csent[:], gsent[:]
self.cpos, self.gpos = sorted(ctree.pos()), sorted(gtree.pos())
self.lengpos = sum(1 for _, b in self.gpos
if b not in self.param['DELETE_LABEL_FOR_LENGTH'])
grootpos = {child[0] for child in gtree
if child and isinstance(child[0], int)}
# massage the data (in-place modifications)
self.ctree = transform(ctree, self.csent, self.cpos,
alignsent(self.csent, self.gsent, dict(self.gpos)),
self.param, grootpos)
self.gtree = transform(gtree, self.gsent, self.gpos,
dict(self.gpos), self.param, grootpos)
if len(self.csent) != len(self.gsent):
raise ValueError('sentence length mismatch. sents:\n%s\n%s' % (
' '.join(self.csent), ' '.join(self.gsent)))
if self.csent != self.gsent:
raise ValueError('candidate & gold sentences do not match:\n'
'%r // %r' % (' '.join(csent), ' '.join(gsent)))
self.cbrack = bracketings(self.ctree, self.param['LABELED'],
self.param['DELETE_LABEL'], self.param['DISC_ONLY'])
self.gbrack = bracketings(self.gtree, self.param['LABELED'],
self.param['DELETE_LABEL'], self.param['DISC_ONLY'])
self.lascore = self.ted = self.denom = Decimal('nan')
self.cdep = self.gdep = ()
self.pgbrack = Counter()
self.pcbrack = Counter()
self.grule = Counter()
self.crule = Counter()
# collect the function tags for correct bracketings & POS tags
self.candfun = Counter((bracketing(a), b)
for a in self.ctree.subtrees()
for b in functions(a)
if bracketing(a) in self.gbrack or (
a and isinstance(a[0], int)
and self.gpos[a[0]] == a.label))
self.goldfun = Counter((bracketing(a), b)
for a in self.gtree.subtrees()
for b in functions(a)
if bracketing(a) in self.cbrack or (
a and isinstance(a[0], int)
and self.cpos[a[0]] == a.label))
if not self.gpos:
return # avoid 'sentences' with only punctuation.
if self.param['LA']:
self.lascore = leafancestor(self.gtree, self.ctree,
self.param['DELETE_LABEL'])
if self.param['TED']:
self.ted, self.denom = treedisteval(self.gtree, self.ctree,
includeroot=self.gtree.label not in self.param['DELETE_LABEL'])
if self.param['DEP']:
self.cdep = dependencies(self.ctree)
self.gdep = dependencies(self.gtree)
assert self.lascore != 1 or self.gbrack == self.cbrack, (
'leaf ancestor score 1.0 but no exact match: (bug?)')
self.pgbrack = parentedbracketings(self.gtree, labeled=True,
dellabel=self.param['DELETE_LABEL'],
disconly=self.param['DISC_ONLY'])
self.pcbrack = parentedbracketings(self.ctree, labeled=True,
dellabel=self.param['DELETE_LABEL'],
disconly=self.param['DISC_ONLY'])
self.grule = Counter((node.bitset, rule)
for node, rule in zip(self.gtree.subtrees(),
grammar.lcfrsproductions(self.gtree, self.gsent)))
self.crule = Counter((node.bitset, rule)
for node, rule in zip(self.ctree.subtrees(),
grammar.lcfrsproductions(self.ctree, self.csent)))
[docs] def info(self, fmt='%8s '):
"""Print one line with evaluation results."""
print((fmt + '%5d %s %s %5d %5d %5d %5d %4d %s %s%s%s%s') % (
self.n, self.lengpos,
nozerodiv(lambda: recall(self.gbrack, self.cbrack)),
nozerodiv(lambda: precision(self.gbrack, self.cbrack)),
sum((self.gbrack & self.cbrack).values()),
sum(self.gbrack.values()), sum(self.cbrack.values()),
len(self.gpos),
sum(1 for a, b in zip(self.gpos, self.cpos) if a == b),
nozerodiv(lambda: accuracy(self.gpos, self.cpos)),
nozerodiv(lambda: f_measure(self.goldfun, self.candfun))
if self.candfun else '',
nozerodiv(lambda: 100 * self.lascore)
if self.param['LA'] else '',
nozerodiv(lambda: self.ted) if self.param['TED'] else '',
nozerodiv(lambda: accuracy(self.gdep, self.cdep))
if self.param['DEP'] else ''))
[docs] def debug(self):
"""Print detailed information."""
print('Sentence:', ' '.join(self.gsent))
print('Gold tree:\n%s\nCandidate tree:\n%s' % (
self.visualize(gold=True), self.visualize(gold=False)))
print('Gold brackets: ', strbracketings(self.gbrack))
print('Candidate brackets: ', strbracketings(self.cbrack))
print('Matched brackets: ',
strbracketings(self.gbrack & self.cbrack))
print('Unmatched brackets: ', strbracketings(
(self.cbrack - self.gbrack) | (self.gbrack - self.cbrack)))
goldpaths = leafancestorpaths(self.gtree, self.param['DELETE_LABEL'])
candpaths = leafancestorpaths(self.ctree, self.param['DELETE_LABEL'])
if self.candfun:
print('Function tags')
print('gold: ', strbracketings(
(a, b) for (_, b), a in self.goldfun))
print('candidate: ', strbracketings(
(a, b) for (_, b), a in self.candfun))
print('matched: ', strbracketings(
(a, b) for (_, b), a in self.candfun & self.goldfun))
print('unmatched: ', strbracketings(
(a, b) for (_, b), a in (self.candfun - self.goldfun)
| (self.goldfun - self.candfun)))
print('%15s %8s %8s | %10s %36s : %s' % (
'word', 'gold POS', 'cand POS',
'path score', 'gold path', 'cand path'))
for leaf in goldpaths:
print('%15s %8s %8s %6.3g %40s : %s' % (
self.gsent[leaf],
self.gpos[leaf],
self.cpos[leaf],
pathscore(goldpaths[leaf], candpaths[leaf]),
' '.join(goldpaths[leaf][::-1]),
' '.join(candpaths[leaf][::-1])))
if self.param['LA']:
print('leaf-ancestor score: %6.3g' % self.lascore)
if self.param['TED']:
print('Tree-dist: %g / %g = %g' % (
self.ted, self.denom, 1 - self.ted / Decimal(self.denom)))
newtreedist(self.gtree, self.ctree, True)
if self.param['DEP']:
print('Sentence:', ' '.join(self.gsent))
print('dependencies gold', ' ' * 35, 'cand')
for (a, _, b), (c, _, d) in zip(self.gdep, self.cdep):
# use original sentences because we don't delete
# punctuation for dependency evaluation
print('%15s -> %15s %15s -> %15s' % (
self.gsentorig[a - 1][0], self.gsentorig[b - 1][0],
self.csentorig[c - 1][0], self.csentorig[d - 1][0]))
print()
[docs] def scores(self):
"""Return precision, recall, f-measure for sentence pair."""
return dict(LP=nozerodiv(lambda: precision(self.gbrack, self.cbrack)),
LR=nozerodiv(lambda: recall(self.gbrack, self.cbrack)),
LF=nozerodiv(lambda: f_measure(self.gbrack, self.cbrack)),
POS=nozerodiv(lambda: accuracy(self.gpos, self.cpos)),
FUN=nozerodiv(lambda: f_measure(self.goldfun, self.candfun)))
[docs] def bracketings(self):
"""Return a string representation of bracketing errors."""
msg = ''
if self.cbrack - self.gbrack:
msg += 'cand-gold=%s ' % strbracketings(self.cbrack - self.gbrack)
if self.gbrack - self.cbrack:
msg += 'gold-cand=%s' % strbracketings(self.gbrack - self.cbrack)
return msg
[docs] def visualize(self, gold=False):
"""Visualize candidate parse, highlight matching POS, bracketings.
:param gold: by default, the candidate tree is visualized; if True,
visualize the gold tree instead."""
tree, brack, pos = self.ctree, self.gbrack, self.gpos
if gold:
tree, brack, pos = self.gtree, self.cbrack, self.cpos
if not tree: # avoid empty trees with just punctuation
return ''
if self.candfun:
tree = tree.copy(True)
highlight = list(tree.subtrees(lambda n: bracketing(n) in brack))
highlight.extend(tree.subtrees(lambda n: n and isinstance(n[0], int)
and n.label == pos[n[0]]))
highlight.extend(range(len(pos)))
highlightfunc = ()
if self.candfun:
highlightfunc = [a for a in tree.subtrees()
if all((bracketing(a), b) in self.candfun & self.goldfun
for b in functions(a))]
handlefunctions('add', tree)
return DrawTree(tree, self.csent, highlight=highlight,
highlightfunc=highlightfunc).text(
unicodelines=True, ansi=True,
funcsep='-' if self.candfun else None)
[docs]class EvalAccumulator(object):
"""Collect scores of evaluation."""
def __init__(self, disconly=False):
""":param disconly: if True, only collect discontinuous bracketings."""
self.disconly = disconly
self.maxlenseen, self.sentcount = Decimal(0), Decimal(0)
self.exact = Decimal(0)
self.dicenoms, self.dicedenoms = Decimal(0), Decimal(0)
self.goldb, self.candb = Counter(), Counter() # all brackets
self.goldfun, self.candfun = Counter(), Counter()
self.lascores = []
self.golddep, self.canddep = [], []
self.goldpos, self.candpos = [], []
# extra accounting for breakdowns:
self.goldbcat = defaultdict(Counter) # brackets per category
self.candbcat = defaultdict(Counter)
self.goldbfunc = defaultdict(Counter) # brackets by function tag
self.candbfunc = defaultdict(Counter)
self.goldbatt, self.candbatt = set(), set() # attachments per category
self.goldrule, self.candrule = Counter(), Counter()
[docs] def add(self, pair):
"""Add scores from given TreePairResult object."""
if not self.disconly or pair.cbrack or pair.gbrack:
self.sentcount += 1
if self.maxlenseen < pair.lengpos:
self.maxlenseen = pair.lengpos
self.candb.update((pair.n, a) for a in pair.cbrack.elements())
self.goldb.update((pair.n, a) for a in pair.gbrack.elements())
if pair.cbrack == pair.gbrack:
if not self.disconly or pair.cbrack or pair.gbrack:
self.exact += 1
self.goldpos.extend(pair.gpos)
self.candpos.extend(pair.cpos)
self.goldfun.update((pair.n, a) for a in pair.goldfun.elements())
self.candfun.update((pair.n, a) for a in pair.candfun.elements())
if pair.lascore is not None:
self.lascores.append(pair.lascore)
if pair.ted is not None:
self.dicenoms += pair.ted
self.dicedenoms += pair.denom
if pair.gdep is not None:
self.golddep.extend(pair.gdep)
self.canddep.extend(pair.cdep)
# extra bookkeeping for breakdowns
for a, n in pair.gbrack.items():
self.goldbcat[a[0]][(pair.n, a)] += n
for a, n in pair.cbrack.items():
self.candbcat[a[0]][(pair.n, a)] += n
for a, n in pair.goldfun.items():
self.goldbfunc[a[1]][(pair.n, a)] += n
for a, n in pair.candfun.items():
self.candbfunc[a[1]][(pair.n, a)] += n
for (label, indices), parent in pair.pgbrack:
self.goldbatt.add(((pair.n, label, indices), parent))
for (label, indices), parent in pair.pcbrack:
self.candbatt.add(((pair.n, label, indices), parent))
self.goldrule.update((pair.n, indices, rule)
for indices, rule in pair.grule.elements())
self.candrule.update((pair.n, indices, rule)
for indices, rule in pair.crule.elements())
[docs] def scores(self):
"""Return a dictionary with running scores for all added sentences."""
return dict(lr=nozerodiv(lambda: recall(self.goldb, self.candb)),
lp=nozerodiv(lambda: precision(self.goldb, self.candb)),
lf=nozerodiv(lambda: f_measure(self.goldb, self.candb)),
ex=nozerodiv(lambda: self.exact / self.sentcount),
tag=nozerodiv(lambda: accuracy(self.goldpos, self.candpos)),
fun=nozerodiv(lambda: f_measure(self.goldfun, self.candfun)))
[docs]def main():
"""Command line interface for evaluation."""
flags = {'help', 'verbose', 'debug', 'disconly', 'ted', 'la'}
options = {'goldenc=', 'parsesenc=', 'goldfmt=', 'parsesfmt=', 'fmt=',
'cutofflen=', 'headrules=', 'functions=', 'morphology='}
try:
opts, args = gnu_getopt(sys.argv[2:], 'h', flags | options)
except GetoptError as err:
print('error:', err, file=sys.stderr)
print(SHORTUSAGE)
sys.exit(2)
opts = dict(opts)
if len(args) < 2 or len(args) > 3:
print('error: Wrong number of arguments.', file=sys.stderr)
print(SHORTUSAGE)
sys.exit(2)
goldfile = args[0]
parsesfile = args[1]
param = readparam(args[2] if len(args) == 3 else None)
param['CUTOFF_LEN'] = int(opts.get('--cutofflen', param['CUTOFF_LEN']))
param['DISC_ONLY'] = '--disconly' in opts
param['DEBUG'] = max(param['DEBUG'],
'--verbose' in opts, 2 * ('--debug' in opts))
param['TED'] |= '--ted' in opts
param['LA'] |= '--la' in opts
param['DEP'] = '--headrules' in opts
if '--fmt' in opts:
opts['--goldfmt'] = opts['--parsesfmt'] = opts['--fmt']
goldreader = READERS[opts.get('--goldfmt', 'export')]
parsesreader = READERS[opts.get('--parsesfmt', 'export')]
gold = goldreader(goldfile,
encoding=opts.get('--goldenc', 'utf8'),
functions=opts.get('--functions', 'remove'),
morphology=opts.get('--morphology'),
headrules=opts.get('--headrules'))
parses = parsesreader(parsesfile,
encoding=opts.get('--parsesenc', 'utf8'),
functions=opts.get('--functions', 'remove'),
morphology=opts.get('--morphology'),
headrules=opts.get('--headrules'))
goldtrees, goldsents = gold.trees(), gold.sents()
candtrees, candsents = parses.trees(), parses.sents()
if not goldtrees:
raise ValueError('no trees in gold file')
if not candtrees:
raise ValueError('no trees in parses file')
if param['DEBUG'] >= 2:
print('gold:', goldfile)
print('parses:', parsesfile, '\n')
evaluator = Evaluator(param, max(len(str(key)) for key in candtrees))
for n, ctree in candtrees.items():
evaluator.add(n, goldtrees[n], goldsents[n], ctree, candsents[n])
if param['LABELED'] and param['DEBUG'] != -1:
evaluator.breakdowns()
print(evaluator.summary())
[docs]def readparam(filename):
"""Read an EVALB-style parameter file and return a dictionary."""
param = defaultdict(list)
# NB: we ignore MAX_ERROR, we abort immediately on error.
validkeysonce = ('DEBUG', 'MAX_ERROR', 'CUTOFF_LEN', 'LABELED',
'DISC_ONLY', 'LA', 'TED', 'DEP', 'DELETE_ROOT_PRETERMS')
param = {'DEBUG': 0, 'MAX_ERROR': 10, 'CUTOFF_LEN': 40,
'LABELED': 1, 'DELETE_LABEL_FOR_LENGTH': set(),
'DELETE_LABEL': set(), 'DELETE_WORD': set(),
'EQ_LABEL': set(), 'EQ_WORD': set(),
'DISC_ONLY': 0, 'LA': 0, 'TED': 0, 'DEP': 0,
'DELETE_ROOT_PRETERMS': 0}
seen = set()
lines = []
if filename:
with io.open(filename, encoding='utf8') as inp:
lines = inp.read().splitlines()
for line in lines:
if line and not line.startswith('#'):
key, val = line.split(None, 1)
if key in validkeysonce:
if key in seen:
raise ValueError('cannot declare %s twice' % key)
seen.add(key)
param[key] = int(val)
elif key in ('DELETE_LABEL', 'DELETE_LABEL_FOR_LENGTH',
'DELETE_WORD'):
param[key].add(val)
elif key in ('EQ_LABEL', 'EQ_WORD'):
# these are given as undirected pairs (A, B), (B, C), ...
# to be represented as equivalence classes A => {A, B, C, D}
try:
b, c = val.split()
except ValueError:
raise ValueError('%s requires two values' % key)
param[key].add((b, c))
else:
raise ValueError('unrecognized parameter key: %s' % key)
for key in ('EQ_LABEL', 'EQ_WORD'):
# from arbitrary pairs: [('A', 'B'), ('B', 'C')]
# to eq classes: {'A': {'A', 'B', 'C'}}
# to a mapping of all elements to their representative:
# {'A': 'A', 'B': 'A', 'C': 'A'}
param[key] = {x: k
for k, eqclass in transitiveclosure(param[key]).items()
for x in eqclass}
return param
[docs]def transitiveclosure(eqpairs):
"""Transitive closure of (undirected) EQ relations with DFS.
Given a sequence of pairs denoting an equivalence relation,
produce a dictionary with equivalence classes as values and
arbitrary members of those classes as keys.
>>> result = transitiveclosure({('A', 'B'), ('B', 'C')})
>>> len(result)
1
>>> k, v = result.popitem()
>>> k in ('A', 'B', 'C') and v == {'A', 'B', 'C'}
True"""
edges = defaultdict(set)
for a, b in eqpairs:
edges[a].add(b)
edges[b].add(a)
eqclasses = {}
seen = set()
for elem in set(edges):
if elem in seen:
continue
eqclasses[elem] = set()
agenda = edges.pop(elem)
while agenda:
eqelem = agenda.pop()
seen.add(eqelem)
eqclasses[elem].add(eqelem)
agenda.update(edges[eqelem] - seen)
return eqclasses
[docs]def alignsent(csent, gsent, gpos):
"""Map tokens of ``csent`` onto those of ``gsent``, and translate indices.
:returns: a copy of ``gpos`` with indices of ``csent`` as keys,
but tags from ``gpos``.
>>> gpos = {0: "``", 1: 'RB', 2: '.', 3: "''"}
>>> alignsent(['No'], ['``', 'No', '.', "''"], gpos) == {0: 'RB'}
True"""
n = m = 0
result = {}
while n < len(csent) and m < len(gsent):
if csent[n] == gsent[m]:
result[n] = gpos[m]
n += 1
m += 1
else:
m += 1
return result
[docs]def parentedbracketings(tree, labeled=True, dellabel=(), disconly=False):
"""Return the labeled bracketings with parents for a tree.
:returns:
multiset with items of the form ``((label, indices), parentlabel)``
"""
return Counter((bracketing(a, labeled), getattr(a.parent, 'label', ''))
for a in tree.subtrees()
if a and isinstance(a[0], Tree) # nonempty, not a preterminal
and a.label not in dellabel
and (not disconly or isdisc(a)))
[docs]def bracketings(tree, labeled=True, dellabel=(), disconly=False):
"""Return the labeled set of bracketings for a tree.
For each nonterminal node, the set will contain a tuple with the label and
the set of terminals which it dominates.
``tree`` must have been processed by ``transform()``.
The argument ``dellabel`` is only used to exclude the ROOT node from the
results (because it cannot be deleted by ``transform()`` when non-unary).
>>> tree = Tree('(S (NP 1) (VP (VB 0) (JJ 2)))')
>>> params = {'DELETE_LABEL': set(), 'DELETE_WORD': set(),
... 'EQ_LABEL': {}, 'EQ_WORD': {},
... 'DELETE_ROOT_PRETERMS': 0}
>>> tree = transform(tree, tree.leaves(), tree.pos(), dict(tree.pos()),
... params, set())
>>> for (label, span), cnt in sorted(bracketings(tree).items()):
... print(label, bin(span), cnt)
S 0b111 1
VP 0b101 1
>>> tree = Tree('(S (NP 1) (VP (VB 0) (JJ 2)))')
>>> params['DELETE_LABEL'] = {'VP'}
>>> tree = transform(tree, tree.leaves(), tree.pos(), dict(tree.pos()),
... params, set())
>>> for (label, span), cnt in sorted(bracketings(tree).items()):
... print(label, bin(span), cnt)
S 0b111 1"""
return Counter(bracketing(a, labeled) for a in tree.subtrees()
if a and isinstance(a[0], Tree) # nonempty, not a preterminal
and a.label not in dellabel and (not disconly or isdisc(a)))
[docs]def bracketing(node, labeled=True):
"""Generate bracketing ``(label, indices)`` for a given node."""
return (node.label if labeled else '', node.bitset)
[docs]def strbracketings(brackets):
"""Return a string with a concise representation of a bracketing.
>>> print(strbracketings({('S', 0b111), ('VP', 0b101)}))
S[0-2], VP[0,2]
"""
if not brackets:
return '{}'
return ', '.join('%s[%s]' % (a, ','.join(
'-'.join('%d' % y for y in sorted(set(x)))
for x in intervals(b))) for a, b in sorted(brackets))
[docs]def leafancestorpaths(tree, dellabel):
"""Generate a list of ancestors for each leaf node in a tree."""
# uses [] to mark components, and () to mark constituent boundaries
# deleted words/tags should not affect boundary detection
paths = {a: [] for a in getbits(tree.bitset)}
# do a top-down level-order traversal
thislevel = [tree]
while thislevel:
nextlevel = []
for n in thislevel:
leaves = list(getbits(n.bitset))
# skip empty nodes and POS tags
if not leaves or (not n or not isinstance(n[0], Tree)):
continue
first, last = min(leaves), max(leaves)
# skip root node if it is to be deleted
if n.label not in dellabel:
for b in leaves:
# mark end of constituents / components
if b + 1 not in leaves:
if b == last and ')' not in paths[b]:
paths[b].append(')')
elif b != last and ']' not in paths[b]:
paths[b].append(']')
# add this label to the lineage
paths[b].append(n.label)
# mark beginning of constituents / components
if b - 1 not in leaves:
if b == first and '(' not in paths[b]:
paths[b].append('(')
elif b != first and '[' not in paths[b]:
paths[b].append('[')
nextlevel.extend(n)
thislevel = nextlevel
return paths
[docs]def pathscore(gold, cand):
"""Get edit distance for two leaf-ancestor paths."""
return 1 - Decimal(editdistance(cand, gold)) / max(len(gold + cand), 1)
[docs]def leafancestor(goldtree, candtree, dellabel):
"""Sampson, Babarcz (2002): A test of the leaf-ancestor metric [...].
http://www.lrec-conf.org/proceedings/lrec2002/pdf/ws20.pdf p. 27;
2003 journal paper: https://doi.org/10.1017/S1351324903003243"""
gold = leafancestorpaths(goldtree, dellabel)
cand = leafancestorpaths(candtree, dellabel)
return mean([pathscore(gold[leaf], cand[leaf]) for leaf in gold])
[docs]def treedisteval(a, b, includeroot=False, debug=False):
"""Get tree-distance for two trees and compute the Dice normalization."""
ted = treedist(a, b, debug)
denom = len(list(a.subtrees()) + list(b.subtrees())) # Dice denominator
if not includeroot: # optionally discount ROOT nodes and preterminals
denom -= 2
return ted, denom
# If the goldfile contains n constituents for the same span, and the parsed
# file contains m constituents with that nonterminal, the scorer works as
# follows:
#
# i) If m>n, then the precision is n/m, recall is 100%
# ii) If n>m, then the precision is 100%, recall is m/n.
# iii) If n==m, recall and precision are both 100%.
[docs]def recall(reference, candidate):
"""Get recall score for two multisets."""
if not reference:
return Decimal('NaN')
return Decimal(sum(min(reference[a], candidate[a])
for a in reference & candidate)) / sum(reference.values())
[docs]def precision(reference, candidate):
"""Get precision score for two multisets."""
if not candidate:
return Decimal('NaN')
return Decimal(sum(min(reference[a], candidate[a])
for a in reference & candidate)) / sum(candidate.values())
[docs]def f_measure(reference, candidate, alpha=Decimal(0.5)):
"""Get F-measure of precision and recall for two multisets.
The default weight ``alpha=0.5`` corresponds to the F_1-measure."""
p = precision(reference, candidate)
r = recall(reference, candidate)
if p == 0 or r == 0:
return Decimal('NaN')
return Decimal(1) / (alpha / p + (1 - alpha) / r)
[docs]def accuracy(reference, candidate):
"""Compute fraction of equivalent pairs in two sequences.
In particular, return the fraction of indices
``0<i<=len(test)`` such that ``test[i] == reference[i]``."""
if len(reference) != len(candidate):
raise ValueError('Sequences must have the same length.')
return Decimal(sum(a == b for a, b in zip(reference, candidate))
) / len(reference)
[docs]def harmean(seq):
"""Compute harmonic mean of a sequence of numbers.
Returns NaN when ``seq`` contains zero."""
numerator = denominator = Decimal(0)
for a in seq:
if not a:
return Decimal('NaN')
numerator += 1
denominator += Decimal(1) / a
if not denominator:
return Decimal('NaN')
return numerator / denominator
[docs]def mean(seq):
"""Compute arithmetic mean of a sequence.
Returns NaN when ``seq`` is empty."""
numerator = denominator = Decimal(0)
for a in seq:
numerator += a
denominator += 1
if not denominator:
return Decimal('NaN')
return numerator / denominator
[docs]def intervals(bitset):
"""Return a sequence of intervals corresponding to contiguous ranges.
``seq`` is an integer representing a bitvector. An interval is a pair
``(a, b)``, with ``a <= b`` denoting a contiguous range of one bits ``x``
in ``seq`` such that ``a <= x <= b``.
>>> list(intervals(0b111011011)) # NB: read from right to left
[(0, 1), (3, 4), (6, 8)]"""
start = prev = None
for a in getbits(bitset):
if start is None:
start = prev = a
elif a == prev + 1:
prev = a
else:
yield start, prev
start = prev = a
if start is not None:
yield start, prev
[docs]def nozerodiv(func):
"""Return ``func()`` as 6-character string but catch zero division."""
try:
result = func()
except (ZeroDivisionError, InvalidOperation):
return ' 0DIV!'
return ' None' if result is None else '%6.2f' % (100 * result)
[docs]def editdistance(seq1, seq2):
"""Calculate the Levenshtein edit-distance between two strings.
The edit distance is the number of characters that need to be substituted,
inserted, or deleted, to transform seq1 into seq2. For example,
transforming 'rain' to 'shine' requires three steps, consisting of two
substitutions and one insertion: 'rain' -> 'sain' -> 'shin' -> 'shine'.
These operations could have been done in other orders, but at least three
steps are needed."""
# initialize 2-D array to zero
len1, len2 = len(seq1), len(seq2)
lev = [[0] * (len2 + 1) for _ in range(len1 + 1)]
for i in range(len1 + 1):
lev[i][0] = i # column 0: 0,1,2,3,4,...
for j in range(len2 + 1):
lev[0][j] = j # row 0: 0,1,2,3,4,...
# iterate over the array
for i in range(len1):
for j in range(len2):
a = lev[i][j + 1] + 1 # skip seq1[i]
b = lev[i][j] + (seq1[i] != seq2[j]) # match seq1[i] with seq2[j]
c = lev[i + 1][j] + 1 # skip seq2[j]
lev[i + 1][j + 1] = min(a, b, c) # pick the cheapest
return lev[len1][len2]
[docs]def pyintbitcount(a):
"""Return number of set bits (1s) in a Python integer.
>>> pyintbitcount(0b0011101)
4"""
cnt = 0
while a:
a &= a - 1
cnt += 1
return cnt
__all__ = ['Evaluator', 'TreePairResult', 'EvalAccumulator', 'main',
'readparam', 'transitiveclosure', 'alignsent', 'transform',
'parentedbracketings', 'bracketings', 'bracketing', 'strbracketings',
'leafancestorpaths', 'pathscore', 'leafancestor', 'treedisteval',
'recall', 'precision', 'f_measure', 'accuracy', 'harmean', 'mean',
'intervals', 'nozerodiv', 'editdistance', 'pyintbitcount']