Source code for discodop.fragments

"""Extract recurring tree fragments from constituency treebanks.

NB: there is a known bug in multiprocessing which makes it impossible to detect
Ctrl-C or fatal errors like segmentation faults in child processes which causes
the master program to wait forever for output from its children. Therefore if
you want to abort, kill the program manually (e.g., press Ctrl-Z and issue
'kill %1'). If the program seems stuck, re-run without multiprocessing
(pass --numproc 1) to see if there might be a bug."""

import io
import os
import re
import sys
import codecs
import logging
import tempfile
if sys.version_info[0] == 2:
	from itertools import imap as map  # pylint: disable=E0611,W0622
import multiprocessing
from collections import defaultdict
from getopt import gnu_getopt, GetoptError
from .tree import brackettree, discbrackettree
from .treebank import writetree
from .treetransforms import unbinarize
from . import _fragments
from .util import workerfunc
from .containers import Vocabulary

SHORTUSAGE = '''Usage: discodop fragments <treebank1> [treebank2] [options]
  or: discodop fragments --batch=<dir> <treebank1> <treebank2>... [options]'''
FLAGS = ('approx', 'indices', 'nofreq', 'complete', 'alt',
		'relfreq', 'adjacent', 'debin', 'debug', 'quiet', 'help')
OPTIONS = ('fmt=', 'numproc=', 'numtrees=', 'encoding=', 'batch=', 'cover=',
		'twoterms=')
PARAMS = {}
FRONTIERRE = re.compile(r'\(([^ ()]+) \)')  # for altrepr()
TERMRE = re.compile(r'\(([^ ()]+) ([^ ()]+)\)')  # for altrepr()


[docs]def main(argv=None): """Command line interface to fragment extraction.""" if argv is None: argv = sys.argv[2:] try: opts, args = gnu_getopt(argv, 'ho:', FLAGS + OPTIONS) except GetoptError as err: print('error:', err, file=sys.stderr) print(SHORTUSAGE) sys.exit(2) opts = dict(opts) for flag in FLAGS: PARAMS[flag] = '--' + flag in opts PARAMS['disc'] = opts.get('--fmt', 'bracket') != 'bracket' PARAMS['fmt'] = opts.get('--fmt', 'bracket') numproc = int(opts.get('--numproc', 1)) if numproc == 0: numproc = cpu_count() if not numproc: raise ValueError('numproc should be an integer > 0. got: %r' % numproc) limit = int(opts.get('--numtrees', 0)) or None PARAMS['cover'] = None if '--cover' in opts and ',' in opts['--cover']: a, b = opts['--cover'].split(',') PARAMS['cover'] = int(a), int(b) elif '--cover' in opts: PARAMS['cover'] = int(opts.get('--cover', 0)), 999 PARAMS['twoterms'] = opts.get('--twoterms') encoding = opts.get('--encoding', 'utf8') batchdir = opts.get('--batch') if len(args) < 1: print('missing treebank argument') if batchdir is None and len(args) not in (1, 2): print('incorrect number of arguments:', args, file=sys.stderr) print(SHORTUSAGE) sys.exit(2) if batchdir: if numproc != 1: raise ValueError('Batch mode only supported in single-process ' 'mode. Use the xargs command for multi-processing.') readstdin = None for n, fname in enumerate(args): if fname == '-': if numproc != 1: # write to temp file so that contents can be read # in multiple processes if readstdin is not None: raise ValueError('can only read from stdin once.') with tempfile.NamedTemporaryFile(delete=False) as tmp: tmp.write(open(sys.stdin.fileno(), 'rb').read()) args[n] = tmp.name readstdin = n elif not os.path.exists(fname): raise ValueError('not found: %r' % fname) if PARAMS['complete']: if len(args) < 2: raise ValueError('need at least two treebanks with --complete.') if PARAMS['twoterms'] or PARAMS['adjacent']: raise ValueError('--twoterms and --adjacent are incompatible ' 'with --complete.') if PARAMS['approx'] or PARAMS['nofreq']: raise ValueError('--complete is incompatible with --nofreq ' 'and --approx') level = logging.WARNING if PARAMS['quiet'] else logging.DEBUG logging.basicConfig(level=level, format='%(message)s') if PARAMS['debug'] and numproc > 1: logger = multiprocessing.log_to_stderr() logger.setLevel(multiprocessing.SUBDEBUG) logging.info('Disco-DOP Fragment Extractor') logging.info('parameters:\n%s', '\n'.join(' %s:\t%r' % kv for kv in sorted(PARAMS.items()))) logging.info('\n'.join('treebank%d: %s' % (n + 1, a) for n, a in enumerate(args))) if numproc == 1 and batchdir: batch(batchdir, args, limit, encoding, '--debin' in opts) else: fragmentkeys, counts = regular(args, numproc, limit, encoding) out = (io.open(opts['-o'], 'w', encoding=encoding) if '-o' in opts else None) if '--debin' in opts: fragmentkeys = debinarize(fragmentkeys) printfragments(fragmentkeys, counts, out=out) if readstdin is not None: os.unlink(args[readstdin])
[docs]def regular(filenames, numproc, limit, encoding): """non-batch processing. multiprocessing optional.""" mult = 1 if PARAMS['approx']: fragments = defaultdict(int) else: fragments = {} # detect corpus reading errors in this process (e.g., wrong encoding) initworker( filenames[0], filenames[1] if len(filenames) == 2 else None, limit, encoding) if numproc == 1: mymap, myworker = map, worker else: # multiprocessing, start worker processes pool = multiprocessing.Pool( processes=numproc, initializer=initworker, initargs=(filenames[0], filenames[1] if len(filenames) == 2 else None, limit, encoding)) mymap, myworker = pool.imap, mpworker numtrees = (PARAMS['trees1'].len if limit is None else min(PARAMS['trees1'].len, limit)) if PARAMS['complete']: trees1, trees2 = PARAMS['trees1'], PARAMS['trees2'] fragmentkeys, bitsets = _fragments.completebitsets( trees1, PARAMS['vocab'], max(trees1.maxnodes, trees2.maxnodes), PARAMS['disc']) else: if len(filenames) == 1: work = workload(numtrees, mult, numproc) else: chunk = numtrees // (mult * numproc) + 1 work = [(a, a + chunk) for a in range(0, numtrees, chunk)] if numproc != 1: logging.info('work division:\n%s', '\n'.join(' %s:\t%r' % kv for kv in sorted(dict(numchunks=len(work), mult=mult).items()))) dowork = mymap(myworker, work) for results in dowork: if PARAMS['approx']: for frag, x in results.items(): fragments[frag] += x else: fragments.update(results) fragmentkeys = list(fragments) bitsets = [fragments[a] for a in fragmentkeys] if PARAMS['nofreq']: counts = None elif PARAMS['approx']: counts = [fragments[a] for a in fragmentkeys] else: task = 'indices' if PARAMS['indices'] else 'counts' logging.info('dividing work for exact %s', task) countchunk = len(bitsets) // numproc + 1 work = list(range(0, len(bitsets), countchunk)) work = [(n, len(work), bitsets[a:a + countchunk]) for n, a in enumerate(work)] counts = [] logging.info('getting exact %s', task) for a in mymap( exactcountworker if numproc == 1 else mpexactcountworker, work): counts.extend(a) if PARAMS['cover']: maxdepth, maxfrontier = PARAMS['cover'] before = len(fragmentkeys) cover = _fragments.allfragments(PARAMS['trees1'], PARAMS['vocab'], maxdepth, maxfrontier, PARAMS['disc'], PARAMS['indices']) for a in cover: if a not in fragments: fragmentkeys.append(a) counts.append(cover[a]) logging.info('merged %d cover fragments ' 'up to depth %d with max %d frontier non-terminals.', len(fragmentkeys) - before, maxdepth, maxfrontier) if numproc != 1: pool.close() pool.join() del dowork, pool return fragmentkeys, counts
[docs]def batch(outputdir, filenames, limit, encoding, debin): """batch processing: three or more treebanks specified. Compares the first treebank to all others, and writes the results to ``outputdir/A_B`` where ``A`` and ``B`` are the respective filenames. Counts/indices are from the other (B) treebanks. There are at least 2 use cases for this: 1. Comparing one treebank to a series of others. The first treebank will only be loaded once. 2. In combination with ``--complete``, the first treebank is a set of fragments used as queries on the other treebanks specified.""" initworker(filenames[0], None, limit, encoding) trees1 = PARAMS['trees1'] maxnodes = trees1.maxnodes if PARAMS['complete']: fragmentkeys, bitsets = _fragments.completebitsets( trees1, PARAMS['vocab'], maxnodes, PARAMS['disc']) fragments = True elif PARAMS['approx']: fragments = defaultdict(int) else: fragments = {} for filename in filenames[1:]: PARAMS.update(read2ndtreebank(filename, PARAMS['vocab'], PARAMS['fmt'], limit, encoding)) trees2 = PARAMS['trees2'] if not PARAMS['complete']: fragments = _fragments.extractfragments(trees1, 0, 0, PARAMS['vocab'], trees2, disc=PARAMS['disc'], debug=PARAMS['debug'], approx=PARAMS['approx'], twoterms=PARAMS['twoterms'], adjacent=PARAMS['adjacent']) fragmentkeys = list(fragments) bitsets = [fragments[a] for a in fragmentkeys] maxnodes = max(trees1.maxnodes, trees2.maxnodes) counts = None if PARAMS['approx'] or not fragments: counts = fragments.values() elif not PARAMS['nofreq']: logging.info('getting %s for %d fragments', 'indices of occurrence' if PARAMS['indices'] else 'exact counts', len(bitsets)) counts = _fragments.exactcounts(bitsets, trees1, trees2, indices=PARAMS['indices'], maxnodes=maxnodes) outputfilename = '%s/%s_%s' % (outputdir, os.path.basename(filenames[0]), os.path.basename(filename)) out = io.open(outputfilename, 'w', encoding=encoding) if debin: fragmentkeys = debinarize(fragmentkeys) printfragments(fragmentkeys, counts, out=out) logging.info('wrote to %s', outputfilename)
[docs]def readtreebanks(filename1, filename2=None, fmt='bracket', limit=None, encoding='utf8'): """Read one or two treebanks.""" vocab = Vocabulary() trees1 = _fragments.readtreebank(filename1, vocab, fmt, limit, encoding) trees2 = _fragments.readtreebank(filename2, vocab, fmt, limit, encoding) trees1.indextrees(vocab) if trees2: trees2.indextrees(vocab) return dict(trees1=trees1, trees2=trees2, vocab=vocab)
[docs]def read2ndtreebank(filename2, vocab, fmt='bracket', limit=None, encoding='utf8'): """Read a second treebank.""" trees2 = _fragments.readtreebank(filename2, vocab, fmt, limit, encoding) trees2.indextrees(vocab) logging.info('%r: %d trees; %d nodes (max %d); ' 'word tokens: %d\n%r', filename2, len(trees2), trees2.numnodes, trees2.maxnodes, trees2.numwords, PARAMS['vocab']) return dict(trees2=trees2, vocab=vocab)
[docs]def initworker(filename1, filename2, limit, encoding): """Read treebanks for this worker. We do this separately for each process under the assumption that this is advantageous with a NUMA architecture.""" PARAMS.update(readtreebanks(filename1, filename2, limit=limit, fmt=PARAMS['fmt'], encoding=encoding)) trees1 = PARAMS['trees1'] if PARAMS['debug']: print('\nproductions:') for a, b in sorted([(PARAMS['vocab'].prodrepr(n), n) for n in range(len(PARAMS['vocab'].prods))], key=lambda x: x[1]): print('%d. %s' % (b, a)) print('treebank 1:') for n in range(trees1.len): trees1.printrepr(n, PARAMS['vocab']) if not trees1: raise ValueError('treebank1 empty.') m = 'treebank1: %d trees; %d nodes (max: %d); %d word tokens.\n' % ( trees1.len, trees1.numnodes, trees1.maxnodes, trees1.numwords) if filename2: trees2 = PARAMS['trees2'] if PARAMS['debug']: print('treebank 2:') for n in range(trees2.len): trees2.printrepr(n, PARAMS['vocab']) if not trees2: raise ValueError('treebank2 empty.') m += 'treebank2: %d trees; %d nodes (max %d); %d word tokens.\n' % ( trees2.len, trees2.numnodes, trees2.maxnodes, trees2.numwords) logging.info('%s%r', m, PARAMS['vocab'])
[docs]def initworkersimple(trees, sents, trees2=None, sents2=None): """Initialization for a worker in which a treebank was already loaded.""" PARAMS.update(_fragments.getctrees(zip(trees, sents), None if trees2 is None else zip(trees2, sents2))) assert PARAMS['trees1'], PARAMS['trees1']
@workerfunc def mpworker(interval): """Worker function for fragment extraction (multiprocessing wrapper).""" return worker(interval)
[docs]def worker(interval): """Worker function for fragment extraction.""" offset, end = interval trees1 = PARAMS['trees1'] trees2 = PARAMS['trees2'] assert offset < trees1.len result = {} result = _fragments.extractfragments(trees1, offset, end, PARAMS['vocab'], trees2, approx=PARAMS['approx'], disc=PARAMS['disc'], debug=PARAMS['debug'], twoterms=PARAMS['twoterms'], adjacent=PARAMS['adjacent']) logging.debug('finished %d--%d', offset, end) return result
@workerfunc def mpexactcountworker(args): """Worker function for counts (multiprocessing wrapper).""" return exactcountworker(args)
[docs]def exactcountworker(args): """Worker function for counting of fragments.""" n, m, bitsets = args trees1 = PARAMS['trees1'] if PARAMS['complete']: results = _fragments.exactcounts(bitsets, trees1, PARAMS['trees2'], indices=PARAMS['indices']) logging.debug('complete matches chunk %d of %d', n + 1, m) return results results = _fragments.exactcounts( bitsets, trees1, trees1, indices=PARAMS['indices']) if PARAMS['indices']: logging.debug('exact indices chunk %d of %d', n + 1, m) else: logging.debug('exact counts chunk %d of %d', n + 1, m) return results
[docs]def workload(numtrees, mult, numproc): """Calculate an even workload. When *n* trees are compared against themselves, ``n * (n - 1)`` total comparisons are made. Each tree ``m`` has to be compared to all trees ``x`` such that ``m < x <= n`` (meaning there are more comparisons for lower *n*). :returns: a sequence of ``(start, end)`` intervals such that the number of comparisons is approximately balanced.""" # could base on number of nodes as well. if numproc == 1: return [(0, numtrees)] # here chunk is the number of tree pairs that will be compared goal = togo = total = 0.5 * numtrees * (numtrees - 1) chunk = total // (mult * numproc) + 1 goal -= chunk result = [] last = 0 for n in range(1, numtrees): togo -= numtrees - n if togo <= goal: goal -= chunk result.append((last, n)) last = n if last < numtrees: result.append((last, numtrees)) return result
[docs]def recurringfragments(trees, sents, numproc=1, disc=True, indices=True, maxdepth=1, maxfrontier=999): """Get recurring fragments with exact counts in a single treebank. :returns: a dictionary whose keys are fragments as strings, and indices as values. When ``disc`` is ``True``, keys are of the form ``(frag, sent)`` where ``frag`` is a unicode string, and ``sent`` is a list of words as unicode strings; when ``disc`` is ``False``, keys are of the form ``frag`` where ``frag`` is a unicode string. :param trees: a sequence of binarized Tree objects, with indices as leaves. :param sents: the corresponding sentences (lists of strings). :param numproc: number of processes to use; pass 0 to use detected # CPUs. :param disc: when disc=True, assume trees with discontinuous constituents; resulting fragments will be of the form (frag, sent); otherwise fragments will be strings with words as leaves. :param indices: when False, return integer counts instead of indices. :param maxdepth: when > 0, add 'cover' fragments to result, corresponding to all fragments up to given depth; pass 0 to disable. :param maxfrontier: maximum number of frontier non-terminals (substitution sites) in cover fragments; a limit of 0 only gives fragments that bottom out in terminals; the default 999 is unlimited for practical purposes.""" if numproc == 0: numproc = cpu_count() numtrees = len(trees) if not numtrees: raise ValueError('no trees.') mult = 1 # 3 if numproc > 1 else 1 fragments = {} trees = trees[:] work = workload(numtrees, mult, numproc) PARAMS.update(disc=disc, indices=indices, approx=False, complete=False, debug=False, adjacent=False, twoterms=None) initworkersimple(trees, list(sents)) if numproc == 1: mymap, myworker = map, worker else: logging.info('work division:\n%s', '\n'.join(' %s: %r' % kv for kv in sorted(dict(numchunks=len(work), numproc=numproc).items()))) # start worker processes pool = multiprocessing.Pool( processes=numproc, initializer=initworkersimple, initargs=(trees, list(sents))) mymap, myworker = pool.map, mpworker # collect recurring fragments logging.info('extracting recurring fragments') for a in mymap(myworker, work): fragments.update(a) fragmentkeys = list(fragments) bitsets = [fragments[a] for a in fragmentkeys] countchunk = len(bitsets) // numproc + 1 work = list(range(0, len(bitsets), countchunk)) work = [(n, len(work), bitsets[a:a + countchunk]) for n, a in enumerate(work)] logging.info('getting exact counts for %d fragments', len(bitsets)) counts = [] for a in mymap( exactcountworker if numproc == 1 else mpexactcountworker, work): counts.extend(a) # add all fragments up to a given depth if maxdepth: cover = _fragments.allfragments(PARAMS['trees1'], PARAMS['vocab'], maxdepth, maxfrontier, disc, indices) before = len(fragmentkeys) for a in cover: if a not in fragments: fragmentkeys.append(a) counts.append(cover[a]) logging.info('merged %d cover fragments ' 'up to depth %d with max %d frontier non-terminals.', len(fragmentkeys) - before, maxdepth, maxfrontier) if numproc != 1: pool.close() pool.join() del pool logging.info('found %d fragments', len(fragmentkeys)) return dict(zip(fragmentkeys, counts))
[docs]def allfragments(trees, sents, maxdepth, maxfrontier=999): """Return all fragments up to a certain depth, # frontiers.""" PARAMS.update(disc=True, indices=True, approx=False, complete=False, debug=False, adjacent=False, twoterms=None) initworkersimple(trees, list(sents)) return _fragments.allfragments(PARAMS['trees1'], PARAMS['vocab'], maxdepth, maxfrontier, disc=PARAMS['disc'], indices=PARAMS['indices'])
[docs]def altrepr(a): """Rewrite bracketed tree to alternative format. Replace double quotes with double single quotes: " -> '' Quote terminals with double quotes terminal: -> "terminal" Remove parentheses around frontier nodes: (NN ) -> NN >>> print(altrepr('(NP (DT a) (NN ))')) (NP (DT "a") NN) """ return FRONTIERRE.sub(r'\1', TERMRE.sub(r'(\1 "\2")', a.replace('"', "''")))
[docs]def debinarize(fragments): """Debinarize fragments; fragments that fail to debinarize left as-is.""" result = [] for origfrag in fragments: frag, sent = (discbrackettree(origfrag) if PARAMS['disc'] else brackettree(origfrag, detectdisc=False)) try: frag = writetree(unbinarize(frag), sent, 0, 'discbracket' if PARAMS['disc'] else 'bracket').strip() except Exception: # pylint: disable=broad-except result.append(origfrag) else: result.append(frag) return result
[docs]def printfragments(fragments, counts, out=None): """Dump fragments to standard output or some other file object.""" if out is None: out = sys.stdout if sys.stdout.encoding is None: out = codecs.getwriter('utf8')(out) if PARAMS['alt']: for n, a in enumerate(fragments): fragments[n] = altrepr(a) if PARAMS['complete']: logging.info('total number of matches: %d', sum(sum(a) for a in counts) if PARAMS['indices'] else sum(counts)) else: logging.info('number of fragments: %d', len(fragments)) if PARAMS['nofreq']: for a in fragments: out.write(a + '\n') return # a frequency of 0 is normal when counting occurrences of given fragments # in a second treebank if PARAMS['complete']: threshold = 0 zeroinvalid = False # a frequency of 1 is normal when comparing two treebanks # or when non-recurring fragments are added elif PARAMS.get('trees2') or PARAMS['cover'] or PARAMS['approx']: threshold = 0 zeroinvalid = True else: # otherwise, raise alarm. threshold = 1 zeroinvalid = True if PARAMS['indices']: for a, theindices in zip(fragments, counts): if len(theindices) > threshold: out.write('%s\t%s\n' % (a, [n for n in theindices if n - 1 in theindices or n + 1 in theindices] if PARAMS['adjacent'] else str(theindices)[len("array('I', "):-len(')')])) elif zeroinvalid: raise ValueError('invalid fragment--frequency=1: %r' % a) elif PARAMS['relfreq']: sums = defaultdict(int) for a, freq in zip(fragments, counts): if freq > threshold: sums[a[1:a.index(' ')]] += freq elif zeroinvalid: raise ValueError('invalid fragment--frequency=%d: %r' % ( freq, a)) for a, freq in zip(fragments, counts): out.write('%s\t%d/%d\n' % ( a, freq, sums[a[1:a.index(' ')]])) else: for a, freq in zip(fragments, counts): if freq > threshold: out.write('%s\t%d\n' % (a, freq)) elif zeroinvalid: raise ValueError('invalid fragment--frequency=1: %r' % a)
[docs]def cpu_count(): """Return number of CPUs or 1.""" try: return multiprocessing.cpu_count() except NotImplementedError: return 1
def test(): """Demonstration of fragment extractor.""" main('--fmt=export alpinosample.export'.split()) __all__ = ['main', 'regular', 'batch', 'readtreebanks', 'read2ndtreebank', 'initworker', 'initworkersimple', 'worker', 'exactcountworker', 'workload', 'recurringfragments', 'allfragments', 'debinarize', 'printfragments', 'altrepr', 'cpu_count']