"""Treebank-indenpendent tree transformations.
This file contains three main transformations:
- A straightforward binarization: binarize(), based on NLTK code.
Provides some additional Markovization options.
- An optimal binarization for LCFRS: optimalbinarize()
Cf. Gildea (2010): Optimal parsing strategies for linear
context-free rewriting systems. http://aclweb.org/anthology/N10-1118
- Converting discontinuous trees to continuous trees and back:
splitdiscnodes(). Cf. Boyd (2007): Discontinuity revisited.
http://aclweb.org/anthology/W07-1506"""
# Original notice:
# Natural Language Toolkit: Tree Transformations
#
# Copyright (C) 2005-2007 Oregon Graduate Institute
# Author: Nathan Bodenstab <bodenstab@cslu.ogi.edu>
# URL: <http://www.nltk.org/>
# For license information, see LICENSE.TXT
import re
from operator import attrgetter
from itertools import islice
from collections import defaultdict, Counter
from .tree import Tree, ImmutableTree, isdisc, bitfanout, HEAD
from .util import OrderedSet, PyAgenda
# e.g., 'VP_2*0' group 1: 'VP_2'; group 2: '0'; group 3: ''
SPLITLABELRE = re.compile(r'(.*)\*(?:([0-9]+)([^!]+![^!]+)?)?$')
MARKOVRE = re.compile(r'^(.*)\|<(.*;)?(.*)>(\^<.*>)?(.*)?$')
[docs]def binarize(tree, factor='right', horzmarkov=999, vertmarkov=1,
revhorzmarkov=0, markhead=False, headoutward=False,
childchar='|', parentchar='^', tailmarker='',
leftmostunary=False, rightmostunary=False, threshold=2,
artpa=True, ids=None, filterlabels=(),
labelfun=None, dot=False, direction=False):
"""
Binarize a Tree object.
:param factor: "left" or "right". Determines whether binarization proceeds
from left to right or vice versa.
:param horzmarkov: amount of horizontal context in labels. Default is
infinity, such that now new generalization are introduced by the
binarization.
:param vertmarkov: number of ancestors to include in labels.
NB: 1 means only the direct parent, as in a normal tree.
:param revhorzmarkov: like ``horzmarkov``, but looks backwards.
:param headoutward: nodes are marked as head in their function tags;
the direction of binarization will be switched when it is
encountered, to enable a head-outward binarization.
:param markhead: include label of the head child in all auxiliary labels.
:param leftmostunary, rightmostunary: introduce a unary production for the
first/last child. When h=1, this enables the same generalizations
for the first & last non-terminals as for other siblings.
:param tailmarker: when given a non-empty string, add this to artificial
nodes introducing the last symbol. This is useful when the last
symbol is the head node, ensuring that it is not exchangeable with
other non-terminals.
:param dot: if True, horizontal context will include all siblings not yet
generated, separated with a dot from the siblings that have been.
:param artpa: whether to add parent annotation to the artificial nodes
introduced by the binarization.
:param ids: abbreviate artificial node labels using numeric IDs from this
object; must have dictionary-like interface.
:param threshold: constituents with more than this number of children are
factored; i.e., for a value of 2, do a normal binarization; for a
value of 1, also factor binary productions to include an artificial
node, etc.
:param filterlabels: filter any labels matching this sequence from the
horizontal markovization context. If labels are of the form ``A/B``,
only A is used to match against this sequence. Also, ``labelfun`` is
first applied to the label, if given. Can be used to filter out
modifiers, s.t. the context contains only required elements.
:param labelfun: a function to derive a label from a node to be used for
the horizontal markovization context; the default is to use
``child.label`` for a given child node.
:param direction: if True, mark the the direction of the binarization with
l, r, or m; l is everything before the head, r to the right, and m
just before introducing the head.
>>> tree = Tree('(S (VP (PDS 0) (ADV 3) (VVINF 4)) (VMFIN 1) (PIS 2))')
>>> tree[1].type = HEAD
>>> sent = 'das muss man jetzt machen'.split()
>>> print(binarize(tree, horzmarkov=1, headoutward=True))
(S (VP (PDS 0) (VP|<ADV> (ADV 3) (VVINF 4))) (S|<VMFIN> (VMFIN 1) (PIS 2)))
>>> tree = Tree('(S (X (A 0) (B 3) (C 4)) (D 1) (E 2))')
>>> tree[1].type = HEAD
>>> print(binarize(tree, headoutward=True, leftmostunary=True,
... rightmostunary=True)) # doctest: +NORMALIZE_WHITESPACE
(S (S|<X,D,E> (X (X|<A,B,C> (A 0) (X|<B,C> (B 3) (X|<C> (C 4)))))
(S|<D,E> (S|<D,E> (D 1)) (E 2))))"""
# FIXME: combination of factor='left' and headoutward=True is broken.
# FIXME: horiz. markov label is wrong when direction switches
# assume all nodes have homogeneous children, terminals have no siblings
if factor not in ('left', 'right'):
raise ValueError("factor should be 'left' or 'right'.")
if labelfun is None:
labelfun = attrgetter('label')
treeclass = tree.__class__
origfactor = factor
# Traverse tree depth-first keeping a list of ancestor nodes to the root.
agenda = [(tree, [tree.label])]
while agenda:
node, parent = agenda.pop()
if not isinstance(node, Tree):
continue
# parent annotation
parents = ''
origlabel = node.label if vertmarkov else '_'
factor = origfactor
if vertmarkov > 1 and node is not tree and isinstance(node[0], Tree):
parents = '%s<%s>' % (parentchar, ','.join(parent))
node.label += parents
parent = [origlabel] + parent[:vertmarkov - 2]
if not artpa:
parents = ''
# add children to the agenda before we mess with them
agenda.extend((child, parent) for child in node)
headidx = None
if headoutward or markhead:
for i, child in enumerate(node):
if isinstance(child, Tree) and child.type == HEAD:
headidx = i
break
# binary form factorization
if len(node) <= threshold:
continue
elif 1 <= len(node) <= 2:
# insert an initial artificial nonterminal
siblings = ''
if isinstance(node[0], Tree):
if direction and factor == 'left':
siblings += 'r:'
elif direction and factor == 'right':
siblings += 'l:'
if markhead and headidx is not None:
siblings += node[headidx].label + ';'
siblings += ','.join(labelfun(child)
for child in node[:horzmarkov]
if labelfun(child).split('/', 1)[0] not in filterlabels)
if dot:
siblings += '.'
mark = '<%s>%s' % (siblings, parents)
if ids is not None: # numeric identifier
mark = '<%s>' % ids[mark]
offset = 0 if factor == 'left' else 1
childnodes = list(node[offset:offset + 1])
node[offset:offset + 1] = []
newnode = treeclass(
'%s%s%s' % (origlabel, childchar, mark), childnodes)
offset = -1 if factor == 'left' else 1
node[offset:offset] = [newnode]
else:
if isinstance(node[0], Tree):
childlabels = [labelfun(child) for child in node]
else:
childlabels = []
childnodes = list(node)
numchildren = len(childnodes)
# insert an initial artificial nonterminal
node[:] = []
i = 0
if headoutward and i == headidx:
factor = 'right' if factor == 'left' else 'left'
if leftmostunary:
if factor == 'right':
start = i
end = i + horzmarkov
else: # factor == 'left'
start = max(numchildren - i - horzmarkov + (headidx or 0),
0)
end = min(numchildren - i + (headidx or 0), numchildren)
siblings = ''
if direction and factor == 'left':
siblings += 'r:'
elif direction and factor == 'right':
siblings += 'l:'
if markhead and headidx is not None:
siblings += childlabels[headidx] + ';'
if dot:
siblings += '.'
siblings += ','.join(a for a in childlabels[start:end]
if a not in filterlabels)
mark = '<%s>%s' % (siblings, parents)
if ids is not None: # numeric identifier
mark = '<%s>' % ids[mark]
newnode = treeclass('%s%s%s' % (origlabel, childchar, mark), [])
node.append(newnode)
node = newnode
curnode = node
for i in range(1, numchildren - (not rightmostunary)):
marktail = tailmarker if i + 1 == numchildren else ''
newnode = treeclass('', [])
if factor == 'right':
start = i
end = i + horzmarkov
else: # factor == 'left':
start = max(numchildren - i - horzmarkov + (headidx or 0),
(headidx or 0))
end = min(numchildren - i + (headidx or 0),
numchildren)
if factor == 'right':
curnode[:] = [childnodes.pop(0), newnode]
else: # factor == 'left':
curnode[:] = [newnode, childnodes.pop()]
siblings = ''
if direction and factor == 'left':
siblings += 'r:'
elif direction and factor == 'right':
siblings += 'l:'
if markhead and headidx is not None:
siblings += childlabels[headidx] + ';'
if dot:
siblings += ','.join(a for a in childlabels[:start]
if a not in filterlabels) + '.'
if revhorzmarkov:
if factor == 'right':
siblings += ','.join(a for a in childlabels[
max(start - revhorzmarkov, 0):start]
if a not in filterlabels) + ';'
else: # factor == 'left':
siblings += ','.join(a for a in childlabels[
end:end + revhorzmarkov]
if a not in filterlabels) + ';'
siblings += ','.join([a for a in childlabels[start:]
if a not in filterlabels][:horzmarkov])
mark = '<%s>%s' % (siblings, parents)
if ids is not None: # numeric identifier
mark = '<%s>' % ids[mark]
newnode.label = ''.join((origlabel, childchar, marktail, mark))
curnode = newnode
# switch direction upon encountering the head
if headoutward and i == headidx:
factor = 'right' if factor == 'left' else 'left'
if (headoutward and direction and i == headidx
and i + 1 != numchildren):
# insert unary for switch of direction
newnode = treeclass(curnode.label, curnode[:])
curnode[:] = [newnode]
if dot:
siblings = 'm' + siblings[1:]
else:
siblings = 'm:' if direction else ''
if markhead and headidx is not None:
siblings += childlabels[headidx] + ';'
newnode.label = ''.join((
origlabel, childchar, '<', siblings, '>', parents))
curnode = newnode
assert len(childnodes) == 1 + (not rightmostunary)
curnode.extend(childnodes)
return tree
[docs]def unbinarize(tree, _sent=None, expandunary=True,
childchar='|', parentchar='^', unarychar='+'):
"""Restore a binarized tree to the original n-ary tree.
Modifies tree in-place.
NB: a malformed node such as ``(X|<Y> )`` which is not supposed to be empty
will be silently discarded."""
# increase robustness
childchar += '<'
parentchar += '<'
treeclass = tree.__class__
# Traverse the tree depth-first keeping a pointer to the parent for
# modification purposes.
agenda = [(tree, [])]
while agenda:
node, parent = agenda.pop()
if isinstance(node, Tree):
# if the node contains the 'childchar' character it means that it
# is an artificial node and can be removed, although we still
# need to move its children to its parent
childindex = node.label.find(childchar)
if childindex != -1 and node is not tree:
# go by identity instead of equality
for n, a in enumerate(parent):
if a is node:
# convert node to list so that its children may
# get new parents.
tmp = node[:]
node[:] = []
node = tmp
parent[n:n + 1] = node
break
else:
raise IndexError
else:
parentindex = node.label.find(parentchar)
if parentindex != -1:
# strip the node name of the parent annotation
node.label = node.label[:parentindex]
# expand collapsed unary productions
if expandunary:
unaryindex = node.label.find(unarychar)
if unaryindex != -1:
newnode = treeclass(
node.label[unaryindex + 1:], node[:])
node.label = node.label[:unaryindex]
node[:] = [newnode]
parent = node # non-binarized node, move on to next parent
agenda.extend((child, parent) for child in node)
return tree
[docs]def collapseunary(tree, collapsepos=False, collapseroot=False, joinchar='+'):
"""Collapse unary nodes into a new node indicated by 'joinchar'.
For example``(NP (NN John))`` becomes ``(NP+NN John)``.
The tree is modified in-place.
:param collapsepos: when False (default), do not collapse preterminals
:param collapseroot: when False (default) do not modify the root production
if it is unary; e.g., TOP -> productions for the Penn WSJ treebank
:param joinchar: A string used to connect collapsed node values"""
agenda = [tree]
if not collapseroot and isinstance(tree, Tree) and len(tree) == 1:
agenda = [tree[0]]
# depth-first traversal of tree
while agenda:
node = agenda.pop()
if isinstance(node, Tree):
if (len(node) == 1 and isinstance(node[0], Tree)
and (collapsepos or isinstance(node[0, 0], Tree))):
node.label += joinchar + node[0].label
node[0:] = [child for child in node[0]]
# since we assigned the child's children to the current node,
# evaluate the current node again
agenda.append(node)
else:
agenda.extend(node)
return tree
[docs]def introducepreterminals(tree, sent, ids=None):
"""Add preterminals with artificial POS-tags for terminals with siblings.
:param ids: by default, artificial labels have the form
``parent_label/terminal``. When an iterator is passed, its values are
used in place of ``terminal``.
>>> tree = Tree('(S (X 0 1 (CD 2 3) 4))')
>>> print(introducepreterminals(tree, ['a', 'b', 'c', 'd', 'e']))
(S (X (X/a 0) (X/b 1) (CD (CD/c 2) (CD/d 3)) (X/e 4)))
>>> tree = Tree('(S (X 0 1 2))')
>>> print(introducepreterminals(tree, [None, None, None], ids=iter('abc')))
(S (X (X/a 0) (X/b 1) (X/c 2)))
"""
assert isinstance(tree, Tree)
treeclass = tree.__class__
agenda = [tree]
while agenda:
node = agenda.pop()
hassiblings = len(node) > 1
for n, child in enumerate(node):
if isinstance(child, Tree):
agenda.append(child)
elif hassiblings:
node[n] = treeclass('%s/%s' % (
node.label,
(sent[child] or '') if ids is None else next(ids)),
[child])
return tree
[docs]def handledisc(tree):
"""Binarize discontinuous substitution sites.
>>> print(handledisc(Tree('(S (X 0 2 4))')))
(S (X 0 (X|<> 2 (X|<> 4))))
>>> print(handledisc(Tree('(S (X 0 2))')))
(S (X 0 (X|<> 2)))
"""
for a in tree.postorder(lambda n: len(n) > 1 and isinstance(n[0], int)):
binarize(a, rightmostunary=True, threshold=1)
return tree
[docs]def factorconstituent(node, sep='|', h=999, factor='right',
markfanout=False, markyf=False, ids=None, threshold=2,
filterlabels=(), labelfun=attrgetter('label')):
"""Binarize one constituent with a left/right factored binarization.
Children remain unmodified. Nodes must be immutable and contain bitsets;
use ``addbitsets()``. By default construct artificial labels using labels
of child nodes. When markyf is True, each artificial label will include the
yield function; this is necessary for a 'normal form' binarization that is
equivalent to the original. When ids is given, it is used both as an
interator (for new unique labels) and as a dictionary (to re-use labels).
The first ID in a binarization will always be unique, while the others will
be re-used for the same combination of labels and yield function."""
if len(node) <= threshold:
return node
elif 1 <= len(node) <= 2:
if ids is None:
key = '%s%s' % (','.join(labelfun(child) for child in node[:h]
if labelfun(child) not in filterlabels),
getyf(node[0], node[1] if len(node) > 1 else None)
if markyf else '')
else:
key = next(ids)
newlabel = '%s%s<%s>' % (node.label, sep, key)
result = ImmutableTree(node.label, [ImmutableTree(newlabel, node)])
result.bitset = node.bitset
else:
if factor == 'right':
prev = node[-1]
rng = range(len(node) - 2, 0, -1)
elif factor == 'left':
prev = node[0]
rng = range(1, len(node) - 1)
else:
raise ValueError("factor should be 'left' or 'right'.")
for i in rng:
newbitset = node[i].bitset | prev.bitset
if factor == 'right' and (ids is None or i > 1):
key = ','.join(labelfun(child) for child in node[i:i + h]
if labelfun(child) not in filterlabels)
if markyf:
key += getyf(node[i], prev)
if ids is not None:
key = ids[key]
elif factor == 'left' and (ids is None or i < len(node) - 2):
key = ','.join(labelfun(child)
for child in node[max(0, i - h + 1):i + 1]
if labelfun(child) not in filterlabels)
if markyf:
key += getyf(prev, node[i])
if ids is not None:
key = ids[key]
else:
key = next(ids)
newlabel = '%s%s<%s>' % (node.label, sep, key)
if markfanout:
nodefanout = bitfanout(newbitset)
if nodefanout > 1:
newlabel += '_%d' % nodefanout
prev = ImmutableTree(newlabel,
[node[i], prev] if factor == 'right' else [prev, node[i]])
prev.bitset = newbitset
result = ImmutableTree(node.label,
[node[0], prev] if factor == 'right' else [prev, node[-1]])
result.bitset = (node[0].bitset if factor == 'right'
else node[-1].bitset) | prev.bitset
return result
[docs]def markovthreshold(trees, n, horzmarkov, vertmarkov):
"""Reduce Markov order of binarization labels occurring < n times."""
freqs = Counter(node.label for tree in trees
for node in tree.subtrees()
if MARKOVRE.match(node.label))
newlabels = {}
for label, freq in freqs.items():
if freq < n:
match = MARKOVRE.match(label)
if not match:
continue
newlabel = '%s|<%s%s,>' % (
match.group(1),
match.group(2) or '',
','.join(match.group(3).split(',')[:horzmarkov]))
if match.group(4):
newlabel += '^<%s>' % ','.join(
match.group(4).split(',')[:vertmarkov])
newlabels[label] = newlabel + match.group(5)
for tree in trees:
for node in tree.subtrees(lambda n: n.label in newlabels):
node.label = newlabels[node.label]
return ('markovization for labels with freq < %d reduced to h=%d v=%d.\n'
'# labels before %d, after %d. %s' % (n, horzmarkov, vertmarkov,
len(newlabels), len(set(newlabels.values())),
', '.join('%s -> %s' % a for a in islice(newlabels.items(), 5))))
[docs]def splitdiscnodes(tree, markorigin=False):
"""Return a continuous version of tree by splitting discontinuous nodes.
Boyd (2007): Discontinuity revisited. http://aclweb.org/anthology/W07-1506
:param markorigin:
:False: VP* (bare label)
:True: VP*1 (add index)
>>> tree = Tree('(S (VP (VP (PP (APPR 0) (ART 1) (NN 2)) (CARD 4)'
... '(VVPP 5)) (VAINF 6)) (VMFIN 3))')
>>> print(splitdiscnodes(tree.copy(True)))
... # doctest: +NORMALIZE_WHITESPACE
(S (VP* (VP* (PP (APPR 0) (ART 1) (NN 2)))) (VMFIN 3) (VP* (VP* (CARD 4)
(VVPP 5)) (VAINF 6)))
>>> print(splitdiscnodes(tree, markorigin=True))
... # doctest: +NORMALIZE_WHITESPACE
(S (VP*0 (VP*0 (PP (APPR 0) (ART 1) (NN 2)))) (VMFIN 3) (VP*1 (VP*1
(CARD 4) (VVPP 5)) (VAINF 6)))"""
treeclass = tree.__class__
for node in tree.postorder():
nodes = list(node)
node[:] = []
for child in nodes:
if isdisc(child):
childnodes = list(child)
child[:] = []
for n, childsubset in enumerate(contsets(childnodes)):
newlabel = ('%s*%d' % (child.label, n) if markorigin
else '%s*' % child.label)
newchild = treeclass(newlabel, childsubset)
newchild.source = child.source
if (child.type == HEAD
and any(a.type == HEAD for a in child)):
newchild.type = HEAD
node.append(newchild)
else:
node.append(child)
return canonicalize(tree)
[docs]def mergediscnodes(tree):
"""Reverse transformation of ``splitdiscnodes()``."""
treeclass = tree.__class__
for node in tree.subtrees():
merge = defaultdict(list) # a series of queues of nodes
# e.g. merge['VP_2*'] = [Tree('VP_2', []), ...]
# when origin is present (index after *), the node is moved to where
# the next one is expected, e.g., VP_2*1 after VP_2*0 is added.
nodes = list(node) # the original, unmerged children
node[:] = [] # the new, merged children
for child in nodes:
if not isinstance(child, Tree):
node.append(child)
continue
match = SPLITLABELRE.search(child.label)
if not match:
node.append(child)
continue
label, part, _ = match.groups()
grandchildren = list(child)
child[:] = []
if not merge[child.label]:
newchild = treeclass(label, [])
newchild.source = child.source
merge[child.label].append(newchild)
node.append(newchild)
merge[child.label][0].extend(grandchildren)
if child.type == HEAD:
merge[child.label][0].type = HEAD
if part:
nextlabel = '%s*%d' % (label, int(part) + 1)
merge[nextlabel].append(merge[child.label].pop(0))
return tree
[docs]def raisediscnodes(tree):
"""Return a continuous version of tree by raising discontinuous components.
Requires head information. This transformation cannot be reversed.
>>> tree = Tree('(S (VP (VP (PP (APPR 0) (ART 1) (NN 2)) (CARD 4)'
... '(VVPP 5)) (VAINF 6)) (VMFIN 3))')
>>> tree[1].type = tree[0, 1].type = tree[0, 0, 2].type = HEAD
>>> tree[0, 0, 0, 0].type = HEAD
>>> print(raisediscnodes(tree)) # doctest: +NORMALIZE_WHITESPACE
(S (PP (APPR 0) (ART 1) (NN 2)) (VMFIN 3)
(VP (VP (CARD 4) (VVPP 5)) (VAINF 6)))"""
for node in tree.subtrees():
if all(isinstance(a, Tree) and a.type != HEAD for a in node):
raise ValueError('No head information. Apply head rules.')
splitdiscnodes(tree)
for node in tree.postorder():
nodes = list(node) # the original, unmerged children
node[:] = [] # the new, merged children
for child in nodes:
if not isinstance(child, Tree):
node.append(child)
continue
match = SPLITLABELRE.search(child.label)
if match is None: # not discontinuous
node.append(child)
elif any(a.type == HEAD for a in child):
child.label = match.group(1)
node.append(child)
else: # non-head disc component
grandchildren = list(child)
child[:] = []
node.extend(grandchildren)
return tree
[docs]def canonicallyorderedtree(tree, sent):
"""Return a continuous version of tree by re-ordering the sentence.
This transformation cannot be reversed.
>>> tree = ('(top (smain (noun 0=Het) (verb 1=had) (inf (verb 5=kunnen)'
... ' (inf (np (det 2=een) (adj 3=prachtige) (noun 4=dag))'
... ' (verb 6=zijn) (pp (prep 7=in) (noun 8=Londen))))) (punct 9=.))')
>>> from .tree import discbrackettree, writebrackettree
>>> tree, sent = discbrackettree(tree)
>>> tree = canonicallyorderedtree(tree, sent)
>>> print(writebrackettree(tree, sent)) # doctest: +NORMALIZE_WHITESPACE
(top (smain (noun Het) (verb had) (inf (verb kunnen) (inf (np (det een)
(adj prachtige) (noun dag)) (verb zijn) (pp (prep in) (noun Londen)))))
(punct .))
"""
newsent = [None] * len(sent)
for n, node in enumerate(canonicalizealpino(tree).subtrees(
lambda n: n and isinstance(n[0], int))):
newsent[n] = sent[node[0]]
node[0] = n
sent[:] = newsent
return tree
[docs]def addfanoutmarkers(tree):
"""Mark discontinuous constituents with '_n' where n = # gaps + 1."""
for st in tree.subtrees():
leaves = set(st.leaves())
thisfanout = len([a for a in sorted(leaves) if a - 1 not in leaves])
if thisfanout > 1 and not st.label.endswith('_%d' % thisfanout):
st.label += '_%d' % thisfanout
return tree
[docs]def removefanoutmarkers(tree):
"""Remove fanout marks."""
for a in tree.subtrees(lambda x: '_' in x.label):
a.label = a.label.rsplit('_', 1)[0]
return tree
[docs]def removeterminals(tree, sent, func):
"""Remove any terminal for which ``func`` is True, and any empty ancestors.
:param tree: a ParentedTree.
:param func: a function with the signature (word, node) -> bool."""
agenda = [tree]
preterms = {} # index => Tree object
while agenda:
node = agenda.pop()
if not node:
continue
for n in range(len(node) - 1, -1, -1):
child = node[n]
if not child:
continue
elif isinstance(child[0], Tree):
agenda.append(child)
elif func(sent[child[0]], child):
del node[n]
# delete empty ancestors
while not node and node is not tree:
node, child = node.parent, node
del node[child.parent_index]
else:
preterms[child[0]] = child
# renumber
oldindices = sorted(preterms)
newindices = {a: n for n, a in enumerate(oldindices)}
for a, node in preterms.items():
node[0] = newindices[a]
sent[:] = [sent[a] for a in oldindices]
[docs]def removeemptynodes(tree, sent):
"""Remove any empty nodes, and any empty ancestors."""
removeterminals(tree, sent,
lambda word, node: word in (None, '') or node.label == '-NONE-')
[docs]def treebankfanout(trees):
"""Get maximal fan-out of a list of trees."""
try: # avoid max over empty sequence: 'treebank' may only have unary prods
return max((fanout(a), n) for n, tree in enumerate(trees)
for a in addbitsets(tree).subtrees(lambda x: len(x) > 1))
except ValueError:
return 1, 0
[docs]def canonicalize(tree):
"""Restore canonical linear precedence order; tree is modified in-place."""
for a in tree.postorder(lambda n: len(n) > 1 and isinstance(n[0], Tree)):
a.children.sort(key=lambda n: n.leaves())
return tree
[docs]def canonicalizealpino(tree):
"""Canonicalize tree using Alpino method.
cf. https://github.com/rug-compling/Alpino/
TreebankTools/bin/dtcanonicalize.py"""
from .treebanktransforms import function
for a in tree.postorder(lambda n: len(n) > 1 and isinstance(n[0], Tree)):
a.children.sort(
key=lambda n: (n.leaves()[-1], n.leaves(), function(n)))
return tree
[docs]def binarizetree(tree, binarization, relationalrealizational):
"""Binarize a single tree."""
if binarization.method == 'default':
return binarize(tree, factor=binarization.factor,
tailmarker=binarization.tailmarker,
horzmarkov=binarization.h, vertmarkov=binarization.v,
revhorzmarkov=binarization.revh,
leftmostunary=binarization.leftmostunary,
rightmostunary=binarization.rightmostunary,
markhead=binarization.markhead,
headoutward=binarization.headrules is not None,
direction=binarization.headrules is not None
and binarization.direction,
dot=binarization.dot,
filterlabels=(relationalrealizational['ignorefunctions']
+ (relationalrealizational['adjunctionlabel'], ))
if relationalrealizational else binarization.filterlabels,
labelfun=binarization.labelfun)
elif binarization.method == 'optimal':
return Tree.convert(optimalbinarize(tree))
elif binarization.method == 'optimalhead':
return Tree.convert(optimalbinarize(
tree, headdriven=True, h=binarization.h, v=binarization.v))
return tree
[docs]def optimalbinarize(tree, sep='|', headdriven=False,
h=None, v=1, fun=None):
"""Recursively binarize a tree, optimizing for given function.
``v=0`` is not implemented. Setting h to a nonzero integer restricts the
possible binarizations to head driven binarizations."""
if h is None:
tree = canonicalize(Tree.convert(tree))
return _optimalbinarize(addbitsets(tree), fun or complexityfanout, sep,
headdriven, h or 999, v, ())
def _optimalbinarize(tree, fun, sep, headdriven, h, v, ancestors):
"""Helper function for postorder / bottom-up binarization."""
if not isinstance(tree, Tree):
return tree
parentstr = '^<%s>' % (','.join(ancestors[:v - 1])) if v > 1 else ''
newtree = ImmutableTree(tree.label + parentstr,
[_optimalbinarize(t, fun, sep, headdriven, h, v,
(tree.label,) + ancestors) for t in tree])
newtree.bitset = tree.bitset
return minimalbinarization(newtree, fun, sep, parentstr=parentstr, h=h,
head=(len(tree) - 1) if headdriven else None)
[docs]def minimalbinarization(tree, score, sep='|', head=None, parentstr='', h=999):
"""Find optimal binarization according to a scoring function.
Implementation of Gildea (2010): Optimal parsing strategies for linear
context-free rewriting systems. http://aclweb.org/anthology/N10-1118
:param tree: ImmutableTree for which the optimal binarization of its top
production will be searched. Nodes need to have a .bitset attribute,
as produced by ``addbitsets()``.
:param score: a function from binarized trees to scores, where lower is
better (the scores can be anything else which supports comparisons).
:param head: an optional index of the head node, specifying it enables
head-driven binarization (which constrains the possible binarizations).
>>> tree = '(X (A 0) (B 1) (C 2) (D 3) (E 4))'
>>> tree2 = binarize(Tree(tree))
>>> minimalbinarization(addbitsets(tree), complexityfanout, head=2) == tree2
True
>>> tree = addbitsets('(A (B1 (t 6) (t 13)) (B2 (t 3) (t 7) (t 10)) '
... '(B3 (t 1) (t 9) (t 11) (t 14) (t 16)) (B4 (t 0) (t 5) (t 8)))')
>>> a = minimalbinarization(tree, complexityfanout)
>>> b = minimalbinarization(tree, fanoutcomplexity)
>>> print(max(map(complexityfanout, a.subtrees())))
(14, 6)
>>> print(max(map(complexityfanout, b.subtrees())))
(15, 5)"""
def newproduction(a, b):
"""Return a new 'production' (here a tree) combining a and b."""
if head is not None:
siblings = (nonterms[a] | nonterms[b])[:h]
else:
siblings = getbits(nonterms[a] | nonterms[b])
newlabel = '%s%s<%s>%s' % (tree.label, sep,
','.join(labels[x] for x in siblings), parentstr)
new = ImmutableTree(newlabel, [a, b])
new.bitset = a.bitset | b.bitset
return new
if len(tree) <= 2:
return tree
# don't bother with optimality if this particular node is not discontinuous
# do default right factored binarization instead
elif fanout(tree) == 1 and all(fanout(a) == 1 for a in tree):
return factorconstituent(tree, sep=sep, h=h)
labels = [a.label for a in tree]
# the four main datastructures:
# the agenda is a priority queue of partial binarizations to explore
# the first complete binarization that is dequeued is the optimal one
agenda = PyAgenda()
# the working set contains all the optimal partial binarizations
# keys are binarizations, values are their scores
workingset = {}
# for each of the optimal partial binarizations, this dictionary has
# a bitset that describes which non-terminals from the input it covers
nonterms = {}
# reverse lookup table for nonterms (from bitsets to binarizations)
revnonterms = {}
# the goal is a bitset that covers all non-terminals of the input
goal = (1 << len(tree)) - 1
if head is None:
for n, a in enumerate(tree):
nonterms[a] = 1 << n
revnonterms[nonterms[a]] = a
workingset[a] = score(a) + (0,)
agenda[a] = workingset[a]
else:
# head driven binarization:
# add all non-head nodes to the working set,
# add all combinations of non-head nodes with head to agenda
# caveat: Crescenzi et al. (2011) show that this problem is NP hard.
# http://aclweb.org/anthology/P11-1046
hd = tree[head]
goal = OrderedSet(range(len(tree)))
for n, a in enumerate(tree):
nonterms[a] = OrderedSet([n])
revnonterms[nonterms[a]] = a
if n != head:
workingset[a] = score(a) + (0,)
for n, a in enumerate(tree):
if n == head:
continue
# (add initial unary here)
p = newproduction(a, hd)
x = score(p)
agenda[p] = workingset[p] = x + (x[0],)
nonterms[p] = nonterms[a] | nonterms[hd]
revnonterms[nonterms[p]] = p
while agenda:
p, x = agenda.popitem()
if nonterms[p] == goal:
# (add final unary here)
p = ImmutableTree(tree.label, p[:])
p.bitset = tree.bitset
return p
for p1, y in list(workingset.items()):
if p1 not in workingset:
continue
# this is inefficient. we should have a single query for all
# items not overlapping with p
elif nonterms[p] & nonterms[p1]:
continue
# if we do head-driven binarization, add one nonterminal at a time
if head is None:
p2 = newproduction(p, p1)
p2nonterms = nonterms[p] | nonterms[p1]
elif len(nonterms[p1]) == 1:
p2 = newproduction(p1, p)
p2nonterms = nonterms[p1] | nonterms[p]
elif len(nonterms[p]) == 1:
p2 = newproduction(p, p1)
p2nonterms = nonterms[p] | nonterms[p1]
else:
continue
scorep2 = score(p2)
# important: the score is the maximum score up till now
x2 = max((scorep2, y[:-1], x[:-1]))
# add the sum of all previous parsing complexities as last item
x2 += (scorep2[0] + x[-1] + y[-1],)
# if new or better:
# should we allow item when score is equal?
if (p2nonterms not in revnonterms
or workingset[revnonterms[p2nonterms]] > x2):
if p2nonterms in revnonterms:
a = revnonterms[p2nonterms]
del nonterms[a], workingset[a]
if a in agenda:
del agenda[a]
nonterms[p2] = p2nonterms
revnonterms[p2nonterms] = p2
agenda[p2] = workingset[p2] = x2
raise ValueError('agenda exhausted without finding binarization.')
[docs]def fanout(tree):
"""Return fan-out of constituent. Requires ``bitset`` attribute."""
return bitfanout(tree.bitset) if isinstance(tree, Tree) else 1
[docs]def complexity(tree):
"""The degree of the time complexity of parsing with this rule.
Cf. Gildea (2010). http://aclweb.org/anthology/N10-1118"""
return fanout(tree) + sum(map(fanout, tree))
[docs]def complexityfanout(tree):
"""Return a tuple with the complexity and fan-out of a subtree."""
return (fanout(tree) + sum(map(fanout, tree)), fanout(tree))
[docs]def fanoutcomplexity(tree):
"""Return a tuple with the fan-out and complexity of a subtree."""
return (fanout(tree), fanout(tree) + sum(map(fanout, tree)))
[docs]def contsets(nodes):
"""Partition children into continuous subsets.
>>> tree = Tree('(VP (PP (APPR 0) (ART 1) (NN 2)) (CARD 4) (VVPP 5))')
>>> for a in contsets(tree):
... print(' / '.join('%s' % b for b in a))
(PP (APPR 0) (ART 1) (NN 2))
(CARD 4) / (VVPP 5)"""
rng, subset = -1, []
mins = {min(a.leaves()) if isinstance(a, Tree) else a: a for a in nodes}
leaves = [a for child in nodes for a in child.leaves()]
for a in sorted(leaves):
if rng >= 0 and a != rng + 1:
yield subset
subset = []
if a in mins:
subset.append(mins[a])
rng = a
if subset:
yield subset
[docs]def getbits(bitset):
"""Iterate over the indices of set bits in a bitset."""
n = 0
while bitset:
if bitset & 1:
yield n
elif not bitset:
break
bitset >>= 1
n += 1
[docs]def addbitsets(tree):
"""Turn tree into an ImmutableTree and add bitset attribute.
The bitset attribute is a Python integer corresponding to the information
that leaves() would return for that node."""
if isinstance(tree, str):
result = ImmutableTree(tree)
elif isinstance(tree, ImmutableTree):
result = tree
elif isinstance(tree, Tree):
result = tree.freeze()
else:
raise ValueError('expected string or tree object')
for a in result.subtrees():
a.bitset = sum(1 << n for n in a.leaves())
return result
[docs]def getyf(left, right):
"""Return the yield function for two subtrees with bitsets.
:returns: string representation of yield function; e.g., ';01,10'."""
result = [';']
cur = ','
bits = left.bitset.bit_length()
if right is not None:
bits = max(bits, right.bitset.bit_length())
for n in range(bits):
mask = 1 << n
if left.bitset & mask:
if cur != '0':
cur = '0'
result.append(cur)
elif right.bitset & mask:
if cur != '1':
cur = '1'
result.append(cur)
elif cur != ',':
cur = ','
result.append(cur)
return ''.join(result)
__all__ = ['binarize', 'unbinarize', 'collapseunary', 'introducepreterminals',
'handledisc', 'factorconstituent', 'markovthreshold', 'splitdiscnodes',
'mergediscnodes', 'raisediscnodes', 'canonicallyorderedtree',
'addfanoutmarkers', 'removefanoutmarkers', 'removeterminals',
'removeemptynodes', 'treebankfanout', 'canonicalize',
'canonicalizealpino', 'binarizetree', 'optimalbinarize',
'minimalbinarization', 'fanout', 'complexity', 'complexityfanout',
'fanoutcomplexity', 'contsets', 'getbits', 'addbitsets', 'getyf']