"""Extract recurring tree fragments from constituency treebanks.
NB: there is a known bug in multiprocessing which makes it impossible to detect
Ctrl-C or fatal errors like segmentation faults in child processes which causes
the master program to wait forever for output from its children. Therefore if
you want to abort, kill the program manually (e.g., press Ctrl-Z and issue
'kill %1'). If the program seems stuck, re-run without multiprocessing
(pass --numproc 1) to see if there might be a bug."""
import io
import os
import re
import sys
import codecs
import logging
import tempfile
if sys.version_info[0] == 2:
from itertools import imap as map # pylint: disable=E0611,W0622
import multiprocessing
from collections import defaultdict
from getopt import gnu_getopt, GetoptError
from .tree import brackettree, discbrackettree
from .treebank import writetree
from .treetransforms import unbinarize
from . import _fragments
from .util import workerfunc
from .containers import Vocabulary
SHORTUSAGE = '''Usage: discodop fragments <treebank1> [treebank2] [options]
or: discodop fragments --batch=<dir> <treebank1> <treebank2>... [options]'''
FLAGS = ('approx', 'indices', 'nofreq', 'complete', 'alt',
'relfreq', 'adjacent', 'debin', 'debug', 'quiet', 'help')
OPTIONS = ('fmt=', 'numproc=', 'numtrees=', 'encoding=', 'batch=', 'cover=',
'twoterms=')
PARAMS = {}
FRONTIERRE = re.compile(r'\(([^ ()]+) \)') # for altrepr()
TERMRE = re.compile(r'\(([^ ()]+) ([^ ()]+)\)') # for altrepr()
[docs]def main(argv=None):
"""Command line interface to fragment extraction."""
if argv is None:
argv = sys.argv[2:]
try:
opts, args = gnu_getopt(argv, 'ho:', FLAGS + OPTIONS)
except GetoptError as err:
print('error:', err, file=sys.stderr)
print(SHORTUSAGE)
sys.exit(2)
opts = dict(opts)
for flag in FLAGS:
PARAMS[flag] = '--' + flag in opts
PARAMS['disc'] = opts.get('--fmt', 'bracket') != 'bracket'
PARAMS['fmt'] = opts.get('--fmt', 'bracket')
numproc = int(opts.get('--numproc', 1))
if numproc == 0:
numproc = cpu_count()
if not numproc:
raise ValueError('numproc should be an integer > 0. got: %r' % numproc)
limit = int(opts.get('--numtrees', 0)) or None
PARAMS['cover'] = None
if '--cover' in opts and ',' in opts['--cover']:
a, b = opts['--cover'].split(',')
PARAMS['cover'] = int(a), int(b)
elif '--cover' in opts:
PARAMS['cover'] = int(opts.get('--cover', 0)), 999
PARAMS['twoterms'] = opts.get('--twoterms')
encoding = opts.get('--encoding', 'utf8')
batchdir = opts.get('--batch')
if len(args) < 1:
print('missing treebank argument')
if batchdir is None and len(args) not in (1, 2):
print('incorrect number of arguments:', args, file=sys.stderr)
print(SHORTUSAGE)
sys.exit(2)
if batchdir:
if numproc != 1:
raise ValueError('Batch mode only supported in single-process '
'mode. Use the xargs command for multi-processing.')
readstdin = None
for n, fname in enumerate(args):
if fname == '-':
if numproc != 1:
# write to temp file so that contents can be read
# in multiple processes
if readstdin is not None:
raise ValueError('can only read from stdin once.')
with tempfile.NamedTemporaryFile(delete=False) as tmp:
tmp.write(open(sys.stdin.fileno(), 'rb').read())
args[n] = tmp.name
readstdin = n
elif not os.path.exists(fname):
raise ValueError('not found: %r' % fname)
if PARAMS['complete']:
if len(args) < 2:
raise ValueError('need at least two treebanks with --complete.')
if PARAMS['twoterms'] or PARAMS['adjacent']:
raise ValueError('--twoterms and --adjacent are incompatible '
'with --complete.')
if PARAMS['approx'] or PARAMS['nofreq']:
raise ValueError('--complete is incompatible with --nofreq '
'and --approx')
level = logging.WARNING if PARAMS['quiet'] else logging.DEBUG
logging.basicConfig(level=level, format='%(message)s')
if PARAMS['debug'] and numproc > 1:
logger = multiprocessing.log_to_stderr()
logger.setLevel(multiprocessing.SUBDEBUG)
logging.info('Disco-DOP Fragment Extractor')
logging.info('parameters:\n%s', '\n'.join(' %s:\t%r' % kv
for kv in sorted(PARAMS.items())))
logging.info('\n'.join('treebank%d: %s' % (n + 1, a)
for n, a in enumerate(args)))
if numproc == 1 and batchdir:
batch(batchdir, args, limit, encoding, '--debin' in opts)
else:
fragmentkeys, counts = regular(args, numproc, limit, encoding)
out = (io.open(opts['-o'], 'w', encoding=encoding)
if '-o' in opts else None)
if '--debin' in opts:
fragmentkeys = debinarize(fragmentkeys)
printfragments(fragmentkeys, counts, out=out)
if readstdin is not None:
os.unlink(args[readstdin])
[docs]def regular(filenames, numproc, limit, encoding):
"""non-batch processing. multiprocessing optional."""
mult = 1
if PARAMS['approx']:
fragments = defaultdict(int)
else:
fragments = {}
# detect corpus reading errors in this process (e.g., wrong encoding)
initworker(
filenames[0],
filenames[1] if len(filenames) == 2 else None,
limit, encoding)
if numproc == 1:
mymap, myworker = map, worker
else: # multiprocessing, start worker processes
pool = multiprocessing.Pool(
processes=numproc, initializer=initworker,
initargs=(filenames[0], filenames[1] if len(filenames) == 2
else None, limit, encoding))
mymap, myworker = pool.imap, mpworker
numtrees = (PARAMS['trees1'].len if limit is None
else min(PARAMS['trees1'].len, limit))
if PARAMS['complete']:
trees1, trees2 = PARAMS['trees1'], PARAMS['trees2']
fragmentkeys, bitsets = _fragments.completebitsets(
trees1, PARAMS['vocab'],
max(trees1.maxnodes, trees2.maxnodes), PARAMS['disc'])
else:
if len(filenames) == 1:
work = workload(numtrees, mult, numproc)
else:
chunk = numtrees // (mult * numproc) + 1
work = [(a, a + chunk) for a in range(0, numtrees, chunk)]
if numproc != 1:
logging.info('work division:\n%s', '\n'.join(' %s:\t%r' % kv
for kv in sorted(dict(numchunks=len(work), mult=mult).items())))
dowork = mymap(myworker, work)
for results in dowork:
if PARAMS['approx']:
for frag, x in results.items():
fragments[frag] += x
else:
fragments.update(results)
fragmentkeys = list(fragments)
bitsets = [fragments[a] for a in fragmentkeys]
if PARAMS['nofreq']:
counts = None
elif PARAMS['approx']:
counts = [fragments[a] for a in fragmentkeys]
else:
task = 'indices' if PARAMS['indices'] else 'counts'
logging.info('dividing work for exact %s', task)
countchunk = len(bitsets) // numproc + 1
work = list(range(0, len(bitsets), countchunk))
work = [(n, len(work), bitsets[a:a + countchunk])
for n, a in enumerate(work)]
counts = []
logging.info('getting exact %s', task)
for a in mymap(
exactcountworker if numproc == 1 else mpexactcountworker, work):
counts.extend(a)
if PARAMS['cover']:
maxdepth, maxfrontier = PARAMS['cover']
before = len(fragmentkeys)
cover = _fragments.allfragments(PARAMS['trees1'], PARAMS['vocab'],
maxdepth, maxfrontier, PARAMS['disc'], PARAMS['indices'])
for a in cover:
if a not in fragments:
fragmentkeys.append(a)
counts.append(cover[a])
logging.info('merged %d cover fragments '
'up to depth %d with max %d frontier non-terminals.',
len(fragmentkeys) - before, maxdepth, maxfrontier)
if numproc != 1:
pool.close()
pool.join()
del dowork, pool
return fragmentkeys, counts
[docs]def batch(outputdir, filenames, limit, encoding, debin):
"""batch processing: three or more treebanks specified.
Compares the first treebank to all others, and writes the results
to ``outputdir/A_B`` where ``A`` and ``B`` are the respective filenames.
Counts/indices are from the other (B) treebanks.
There are at least 2 use cases for this:
1. Comparing one treebank to a series of others. The first treebank will
only be loaded once.
2. In combination with ``--complete``, the first treebank is a set of
fragments used as queries on the other treebanks specified."""
initworker(filenames[0], None, limit, encoding)
trees1 = PARAMS['trees1']
maxnodes = trees1.maxnodes
if PARAMS['complete']:
fragmentkeys, bitsets = _fragments.completebitsets(
trees1, PARAMS['vocab'],
maxnodes, PARAMS['disc'])
fragments = True
elif PARAMS['approx']:
fragments = defaultdict(int)
else:
fragments = {}
for filename in filenames[1:]:
PARAMS.update(read2ndtreebank(filename, PARAMS['vocab'],
PARAMS['fmt'], limit, encoding))
trees2 = PARAMS['trees2']
if not PARAMS['complete']:
fragments = _fragments.extractfragments(trees1, 0, 0,
PARAMS['vocab'], trees2, disc=PARAMS['disc'],
debug=PARAMS['debug'], approx=PARAMS['approx'],
twoterms=PARAMS['twoterms'], adjacent=PARAMS['adjacent'])
fragmentkeys = list(fragments)
bitsets = [fragments[a] for a in fragmentkeys]
maxnodes = max(trees1.maxnodes, trees2.maxnodes)
counts = None
if PARAMS['approx'] or not fragments:
counts = fragments.values()
elif not PARAMS['nofreq']:
logging.info('getting %s for %d fragments',
'indices of occurrence' if PARAMS['indices']
else 'exact counts', len(bitsets))
counts = _fragments.exactcounts(bitsets, trees1, trees2,
indices=PARAMS['indices'],
maxnodes=maxnodes)
outputfilename = '%s/%s_%s' % (outputdir,
os.path.basename(filenames[0]), os.path.basename(filename))
out = io.open(outputfilename, 'w', encoding=encoding)
if debin:
fragmentkeys = debinarize(fragmentkeys)
printfragments(fragmentkeys, counts, out=out)
logging.info('wrote to %s', outputfilename)
[docs]def readtreebanks(filename1, filename2=None, fmt='bracket',
limit=None, encoding='utf8'):
"""Read one or two treebanks."""
vocab = Vocabulary()
trees1 = _fragments.readtreebank(filename1, vocab,
fmt, limit, encoding)
trees2 = _fragments.readtreebank(filename2, vocab,
fmt, limit, encoding)
trees1.indextrees(vocab)
if trees2:
trees2.indextrees(vocab)
return dict(trees1=trees1, trees2=trees2, vocab=vocab)
[docs]def read2ndtreebank(filename2, vocab, fmt='bracket',
limit=None, encoding='utf8'):
"""Read a second treebank."""
trees2 = _fragments.readtreebank(filename2, vocab,
fmt, limit, encoding)
trees2.indextrees(vocab)
logging.info('%r: %d trees; %d nodes (max %d); '
'word tokens: %d\n%r',
filename2, len(trees2), trees2.numnodes, trees2.maxnodes,
trees2.numwords, PARAMS['vocab'])
return dict(trees2=trees2, vocab=vocab)
[docs]def initworker(filename1, filename2, limit, encoding):
"""Read treebanks for this worker.
We do this separately for each process under the assumption that this is
advantageous with a NUMA architecture."""
PARAMS.update(readtreebanks(filename1, filename2,
limit=limit, fmt=PARAMS['fmt'], encoding=encoding))
trees1 = PARAMS['trees1']
if PARAMS['debug']:
print('\nproductions:')
for a, b in sorted([(PARAMS['vocab'].prodrepr(n), n)
for n in range(len(PARAMS['vocab'].prods))],
key=lambda x: x[1]):
print('%d. %s' % (b, a))
print('treebank 1:')
for n in range(trees1.len):
trees1.printrepr(n, PARAMS['vocab'])
if not trees1:
raise ValueError('treebank1 empty.')
m = 'treebank1: %d trees; %d nodes (max: %d); %d word tokens.\n' % (
trees1.len, trees1.numnodes, trees1.maxnodes, trees1.numwords)
if filename2:
trees2 = PARAMS['trees2']
if PARAMS['debug']:
print('treebank 2:')
for n in range(trees2.len):
trees2.printrepr(n, PARAMS['vocab'])
if not trees2:
raise ValueError('treebank2 empty.')
m += 'treebank2: %d trees; %d nodes (max %d); %d word tokens.\n' % (
trees2.len, trees2.numnodes, trees2.maxnodes, trees2.numwords)
logging.info('%s%r', m, PARAMS['vocab'])
[docs]def initworkersimple(trees, sents, trees2=None, sents2=None):
"""Initialization for a worker in which a treebank was already loaded."""
PARAMS.update(_fragments.getctrees(zip(trees, sents),
None if trees2 is None else zip(trees2, sents2)))
assert PARAMS['trees1'], PARAMS['trees1']
@workerfunc
def mpworker(interval):
"""Worker function for fragment extraction (multiprocessing wrapper)."""
return worker(interval)
[docs]def worker(interval):
"""Worker function for fragment extraction."""
offset, end = interval
trees1 = PARAMS['trees1']
trees2 = PARAMS['trees2']
assert offset < trees1.len
result = {}
result = _fragments.extractfragments(trees1, offset, end,
PARAMS['vocab'], trees2, approx=PARAMS['approx'],
disc=PARAMS['disc'],
debug=PARAMS['debug'], twoterms=PARAMS['twoterms'],
adjacent=PARAMS['adjacent'])
logging.debug('finished %d--%d', offset, end)
return result
@workerfunc
def mpexactcountworker(args):
"""Worker function for counts (multiprocessing wrapper)."""
return exactcountworker(args)
[docs]def exactcountworker(args):
"""Worker function for counting of fragments."""
n, m, bitsets = args
trees1 = PARAMS['trees1']
if PARAMS['complete']:
results = _fragments.exactcounts(bitsets, trees1, PARAMS['trees2'],
indices=PARAMS['indices'])
logging.debug('complete matches chunk %d of %d', n + 1, m)
return results
results = _fragments.exactcounts(
bitsets, trees1, trees1, indices=PARAMS['indices'])
if PARAMS['indices']:
logging.debug('exact indices chunk %d of %d', n + 1, m)
else:
logging.debug('exact counts chunk %d of %d', n + 1, m)
return results
[docs]def workload(numtrees, mult, numproc):
"""Calculate an even workload.
When *n* trees are compared against themselves, ``n * (n - 1)`` total
comparisons are made. Each tree ``m`` has to be compared to all trees ``x``
such that ``m < x <= n``
(meaning there are more comparisons for lower *n*).
:returns: a sequence of ``(start, end)`` intervals such that
the number of comparisons is approximately balanced."""
# could base on number of nodes as well.
if numproc == 1:
return [(0, numtrees)]
# here chunk is the number of tree pairs that will be compared
goal = togo = total = 0.5 * numtrees * (numtrees - 1)
chunk = total // (mult * numproc) + 1
goal -= chunk
result = []
last = 0
for n in range(1, numtrees):
togo -= numtrees - n
if togo <= goal:
goal -= chunk
result.append((last, n))
last = n
if last < numtrees:
result.append((last, numtrees))
return result
[docs]def recurringfragments(trees, sents, numproc=1, disc=True,
indices=True, maxdepth=1,
maxfrontier=999):
"""Get recurring fragments with exact counts in a single treebank.
:returns: a dictionary whose keys are fragments as strings, and
indices as values. When ``disc`` is ``True``, keys are of the form
``(frag, sent)`` where ``frag`` is a unicode string, and ``sent``
is a list of words as unicode strings; when ``disc`` is ``False``, keys
are of the form ``frag`` where ``frag`` is a unicode string.
:param trees: a sequence of binarized Tree objects, with indices as leaves.
:param sents: the corresponding sentences (lists of strings).
:param numproc: number of processes to use; pass 0 to use detected # CPUs.
:param disc: when disc=True, assume trees with discontinuous constituents;
resulting fragments will be of the form (frag, sent);
otherwise fragments will be strings with words as leaves.
:param indices: when False, return integer counts instead of indices.
:param maxdepth: when > 0, add 'cover' fragments to result, corresponding
to all fragments up to given depth; pass 0 to disable.
:param maxfrontier: maximum number of frontier non-terminals (substitution
sites) in cover fragments; a limit of 0 only gives fragments that
bottom out in terminals; the default 999 is unlimited for practical
purposes."""
if numproc == 0:
numproc = cpu_count()
numtrees = len(trees)
if not numtrees:
raise ValueError('no trees.')
mult = 1 # 3 if numproc > 1 else 1
fragments = {}
trees = trees[:]
work = workload(numtrees, mult, numproc)
PARAMS.update(disc=disc, indices=indices, approx=False, complete=False,
debug=False, adjacent=False, twoterms=None)
initworkersimple(trees, list(sents))
if numproc == 1:
mymap, myworker = map, worker
else:
logging.info('work division:\n%s', '\n'.join(' %s: %r' % kv
for kv in sorted(dict(numchunks=len(work),
numproc=numproc).items())))
# start worker processes
pool = multiprocessing.Pool(
processes=numproc, initializer=initworkersimple,
initargs=(trees, list(sents)))
mymap, myworker = pool.map, mpworker
# collect recurring fragments
logging.info('extracting recurring fragments')
for a in mymap(myworker, work):
fragments.update(a)
fragmentkeys = list(fragments)
bitsets = [fragments[a] for a in fragmentkeys]
countchunk = len(bitsets) // numproc + 1
work = list(range(0, len(bitsets), countchunk))
work = [(n, len(work), bitsets[a:a + countchunk])
for n, a in enumerate(work)]
logging.info('getting exact counts for %d fragments', len(bitsets))
counts = []
for a in mymap(
exactcountworker if numproc == 1 else mpexactcountworker, work):
counts.extend(a)
# add all fragments up to a given depth
if maxdepth:
cover = _fragments.allfragments(PARAMS['trees1'], PARAMS['vocab'],
maxdepth, maxfrontier, disc, indices)
before = len(fragmentkeys)
for a in cover:
if a not in fragments:
fragmentkeys.append(a)
counts.append(cover[a])
logging.info('merged %d cover fragments '
'up to depth %d with max %d frontier non-terminals.',
len(fragmentkeys) - before, maxdepth, maxfrontier)
if numproc != 1:
pool.close()
pool.join()
del pool
logging.info('found %d fragments', len(fragmentkeys))
return dict(zip(fragmentkeys, counts))
[docs]def allfragments(trees, sents, maxdepth, maxfrontier=999):
"""Return all fragments up to a certain depth, # frontiers."""
PARAMS.update(disc=True, indices=True, approx=False, complete=False,
debug=False, adjacent=False, twoterms=None)
initworkersimple(trees, list(sents))
return _fragments.allfragments(PARAMS['trees1'],
PARAMS['vocab'], maxdepth, maxfrontier,
disc=PARAMS['disc'], indices=PARAMS['indices'])
[docs]def altrepr(a):
"""Rewrite bracketed tree to alternative format.
Replace double quotes with double single quotes: " -> ''
Quote terminals with double quotes terminal: -> "terminal"
Remove parentheses around frontier nodes: (NN ) -> NN
>>> print(altrepr('(NP (DT a) (NN ))'))
(NP (DT "a") NN)
"""
return FRONTIERRE.sub(r'\1', TERMRE.sub(r'(\1 "\2")', a.replace('"', "''")))
[docs]def debinarize(fragments):
"""Debinarize fragments; fragments that fail to debinarize left as-is."""
result = []
for origfrag in fragments:
frag, sent = (discbrackettree(origfrag) if PARAMS['disc']
else brackettree(origfrag, detectdisc=False))
try:
frag = writetree(unbinarize(frag), sent, 0,
'discbracket' if PARAMS['disc'] else 'bracket').strip()
except Exception: # pylint: disable=broad-except
result.append(origfrag)
else:
result.append(frag)
return result
[docs]def printfragments(fragments, counts, out=None):
"""Dump fragments to standard output or some other file object."""
if out is None:
out = sys.stdout
if sys.stdout.encoding is None:
out = codecs.getwriter('utf8')(out)
if PARAMS['alt']:
for n, a in enumerate(fragments):
fragments[n] = altrepr(a)
if PARAMS['complete']:
logging.info('total number of matches: %d',
sum(sum(a) for a in counts)
if PARAMS['indices'] else sum(counts))
else:
logging.info('number of fragments: %d', len(fragments))
if PARAMS['nofreq']:
for a in fragments:
out.write(a + '\n')
return
# a frequency of 0 is normal when counting occurrences of given fragments
# in a second treebank
if PARAMS['complete']:
threshold = 0
zeroinvalid = False
# a frequency of 1 is normal when comparing two treebanks
# or when non-recurring fragments are added
elif PARAMS.get('trees2') or PARAMS['cover'] or PARAMS['approx']:
threshold = 0
zeroinvalid = True
else: # otherwise, raise alarm.
threshold = 1
zeroinvalid = True
if PARAMS['indices']:
for a, theindices in zip(fragments, counts):
if len(theindices) > threshold:
out.write('%s\t%s\n' % (a,
[n for n in theindices
if n - 1 in theindices or n + 1 in theindices]
if PARAMS['adjacent'] else
str(theindices)[len("array('I', "):-len(')')]))
elif zeroinvalid:
raise ValueError('invalid fragment--frequency=1: %r' % a)
elif PARAMS['relfreq']:
sums = defaultdict(int)
for a, freq in zip(fragments, counts):
if freq > threshold:
sums[a[1:a.index(' ')]] += freq
elif zeroinvalid:
raise ValueError('invalid fragment--frequency=%d: %r' % (
freq, a))
for a, freq in zip(fragments, counts):
out.write('%s\t%d/%d\n' % (
a, freq, sums[a[1:a.index(' ')]]))
else:
for a, freq in zip(fragments, counts):
if freq > threshold:
out.write('%s\t%d\n' % (a, freq))
elif zeroinvalid:
raise ValueError('invalid fragment--frequency=1: %r' % a)
[docs]def cpu_count():
"""Return number of CPUs or 1."""
try:
return multiprocessing.cpu_count()
except NotImplementedError:
return 1
def test():
"""Demonstration of fragment extractor."""
main('--fmt=export alpinosample.export'.split())
__all__ = ['main', 'regular', 'batch', 'readtreebanks', 'read2ndtreebank',
'initworker', 'initworkersimple', 'worker', 'exactcountworker',
'workload', 'recurringfragments', 'allfragments', 'debinarize',
'printfragments', 'altrepr', 'cpu_count']