from __future__ import division

import sys
import json
from itertools import groupby
from pprint import pprint

def crossprod_noself_uniquepairs(xs):
    assert isinstance(xs,list)
    for i in range(len(xs)):
        for j in range(i+1, len(xs)):
            yield xs[i],xs[j]


def loadSingleDocFromMultilineJson(filename):
    # Assume the file is a single json doc
    jdoc = json.loads(open(filename).read())
    return convertJsonDocIntoMyDoc(jdoc)

def loadJson(file,numTake=None):
    documents = []
    lines = open(file).readlines()
    lines = lines[:numTake] if numTake is not None else lines
    try:
        for line in lines:
            doc = line.split('\t')[-1]
            doc = json.loads(doc)
            document = convertJsonDocIntoMyDoc(doc)
            documents.append(document)
        return documents
    except ValueError:
        return [loadSingleDocFromMultilineJson(file)]


def convertJsonDocIntoMyDoc(doc):
    # This takes in a doc data structure from the file
    # And turns it into a document data structure used for the coref code.
    # Specifically, it contains a list of mention objects, each of which has a
    # bunch of information like pointers to the sentence it's in and the tokens
    # within the mention.

    mentions = []
    sentPositionInDocument = 0
    seenEntids = set([])
    mentionCount = 0
    for sentid,sent in enumerate(doc['sentences']):      
        tokens = sent['tokens']
        sent2 = {}
        sent2['pos'] = sent['pos']
        sent2['tokens'] = sent['tokens']
        for mention in sent['mentions']:
            entid,(start,end) = mention
            ment = {}
            ment['mentionId'] = mentionCount
            mentionCount += 1
            ment['entid'] = entid
            ment['firstInEntity'] = not entid in seenEntids 
            seenEntids.add(entid)
            
            ment['start'] = start
            ment['end'] = end
            ment['sentence'] = sent2
            ment['sentenceIndex'] = sentid
            ment['tokens'] = tokens[start:end]
            ment['pos'] = sent['pos'][start:end]
            # Note: this offset is w.r.t. the mention, and not the sentence   
            # TODO: this should use the parse. it makes really bad errors on
            # things like N P N, like "soldiers of fortune" having its headword
            # as "fortune"
            ment['headTokenIndex'] = getLastNoun( sent['pos'][start:end], ment) 
            mentions.append(ment)
            ment['positionInDoc'] = ment['start'] + sentPositionInDocument
        sentPositionInDocument += len(sent['tokens'])
    document = {}
    document['mentions'] = mentions
    document['sentences'] = doc['sentences']
    return document

def assignGroundTruthEntities(documents):
    #This groups mentions by their ground truth entity id
    for doc in documents:
        mentions = doc['mentions']        
        entities = groupby(sorted(mentions, key=lambda m: m['entid']) ,lambda m: m['entid'])
	doc['groundTruthEntities'] = {k: sortMentionsByPosition(v) for k, v in entities} 

def assignPredictedEntities(documents):
    #the coref functions in coref.py add the 'coref' field to the document. These just store links.
    #This function groups things into entities, which is used for pretty printing. 
    for doc in documents:
        entityCount = 0
        mentions = doc['mentions']
        links = doc['coref']
        for i in range(len(mentions)):
            if(i in links.keys()):
                assert(links[i] < i)
                #                assert(mentions[i]['positionInDoc'] >=  mentions[links[i]]['positionInDoc'])
                mentions[i]['predictedEntity'] = mentions[links[i]]['predictedEntity']               
            else:
                mentions[i]['predictedEntity'] = entityCount
                entityCount +=1

        entities = groupby(sorted(mentions, key=lambda m: m['predictedEntity']) ,lambda m: m['predictedEntity'])
        # 
        doc['predictedEntities'] = {k: sortMentionsByPosition(v) for k, v in entities}

def sortMentionsByPosition(mentions):
    #this is used for pretty-printing
    return sorted(mentions,key = lambda m: m['positionInDoc'])

def mentionSetOverlap(ments1,ments2):
	set1 = set(map(lambda m: m['mentionId'],ments1))
	set2 = set(map(lambda m: m['mentionId'],ments2))
	return len(set1.intersection(set2))

def evalCorefPairwise(documents):
    tp1,tp2, fn,fp = 0,0,0,0
    for doc in documents:
	assignPredictedEntities([doc])
	assignGroundTruthEntities([doc])
        # pprint(doc)

        for _,mentions in doc['groundTruthEntities'].items():
            # if a pair within the same goldcluster are predicted to be different clusters
            # that's a false negative pair: shoulda been a link.
            for m1,m2 in crossprod_noself_uniquepairs(mentions):
                tp1 += int(m1['predictedEntity'] == m2['predictedEntity'])
                fn  += int(m1['predictedEntity'] != m2['predictedEntity'])

        for _,mentions in doc['predictedEntities'].items():
            # if a pair within the same predcluster are ground-truth different
            # that's a false positive link
            for m1,m2 in crossprod_noself_uniquepairs(mentions):
                tp2 += int(m1['entid'] == m2['entid'])
                fp  += int(m1['entid'] != m2['entid'])

        assert tp1==tp2, "bug in pairwise eval. the two different truepos calculations should give the same answer"

    prec = tp1/(tp1+fp) if tp1+fp > 0 else 0
    rec  = tp1/(tp1+fn) if tp1+fn > 0 else 0
    f1 = 2*prec*rec/(prec+rec) if prec+rec>0 else 0
    print "Pairwise Prec = %.3f (%s/%s), Rec = %.3f (%s/%s), F1 = %.3f" % (
            prec, tp1,tp1+fp,  rec, tp1,tp1+fn,  f1)
    return prec,rec,f1

def evalCorefAccuracy(documents):
    #this computes linking accuracy. A mention is linked correctly if it is assigned to a mention that is in its entity.
    #The implementation below is specific the fact that singleton mentions are not annotated in the data. 

    correct = 0.0
    total = 0
    for doc in documents:
        mentions = doc['mentions']
        links = doc['coref']
        for i in range(len(mentions)):
            ment = mentions[i]
            parentIndex = links.get(i,None)
            if parentIndex is None:
                if ment['firstInEntity']:
                    correct += 1 
            else:
                if(entityIdForMention(ment) ==  entityIdForMention(mentions[parentIndex])):
                    correct += 1
        total += len(mentions)

    return correct/total

def isPrep(tag):
    return tag in set(['IN','TO'])

def getLastNoun(tags,ment):
    #this is a heuristic for finding the head token index of a mention
    try:
     return (idx for idx in reversed(range(0,len(tags))) if isNoun(tags[idx])).next()     
    except StopIteration:
        #print ment
        #print 'No noun in ',tags
     return 0

Nouns = set(['NNP','NN','PRP','NNS','PRP$','NNPS','DT'])
def isNoun(tag):
    #this is used only by getLastNoun, which you shouldn't have to use
    return tag in Nouns

def printEntities(entities):
    #use this to print entities
    for e in entities.values():
        print " ".join(map(lambda m: "[" + " ".join(m['tokens'])+  "]",e))

def entityIdForMention(ment):
    return ment['entid']


#---------Below here are functions that you might find useful for printing and for designing features------------#

def printDocument(doc):
    #use this for debugging and feature engineering
    assignPredictedEntities([doc])
    assignGroundTruthEntities([doc])
    print "\n\nDocument Text:"
    for sent in doc['sentences']:
        sys.stdout.write(" ".join(sent['tokens']) + " ")
    print "\n"
    print "\n===Predicted Entities==="
    printEntities(doc['predictedEntities'])
    print "\n===Ground Truth Entities==="
    printEntities(doc['groundTruthEntities'])

def mentionPositionInDocument(ment):
    #this tells you the index of the mention head token in the document. Useful for comparing distances between mentions.
    return ment['positionInDoc']

def mentionString(m):
    #this gives you the full string of the mention
    return " ".join(m['tokens'])

def sentenceIndexInDocument(ment):
    # this tells you the index of the sentence in the document for the sentence
    # that contains the mention
    return ment['sentenceIndex']

def headTokenPOSTag(m):
    #this tells you the POS tag of the head token of a mention
    # see warnings about headTokenIndex above
    return m['pos'][m['headTokenIndex']]
    
def headToken(m):
    #this returns the string of the head token of a mention
    # see warnings about headTokenIndex above
    return m['tokens'][m['headTokenIndex']].lower()
