import glob,os,json,sys
from collections import defaultdict
import nltk

class NaiveBayesModel:

    def __init__(self):
        self.vocabulary = set()
        self.class_doc_counts = {}   # {label: numdocs}
        self.class_word_counts = {}  # {label: {word: numtokens}}
        self.class_total_tokens = {}  # {label: num tokens summed across all docs in this class}

    def train(self, docs, doclabels):
        # docs: list of document dicts
        # doclabels: dictionary {docid: label}
        print "Training on %d documents" % len(docs)
        for doc in docs:
            label = doclabels[doc['docid']]
            if label not in self.class_doc_counts:
                self.class_doc_counts[label] = 0
                self.class_word_counts[label] = defaultdict(int)
                self.class_total_tokens[label] = 0
            self.class_doc_counts[label] += 1
            for w in doc['tokens']:
                self.class_word_counts[label][w] += 1
                self.class_total_tokens[label] += 1
                self.vocabulary.add(w)
        # turn the defaultdicts into normal dicts
        self.class_word_counts = {label: dict(wc) for label,wc in self.class_word_counts.items()}

def dict_argmax(dct):
    # Returns the key out of dct that has the largest value
    return max(dct.keys(), key=lambda k: dct[k])

def read_keyfile(filename):
    filename2label = dict(L.strip().split() for L in open(filename))
    return filename2label

def read_jsons(filename):
    return [json.loads(line) for line in open(filename)]

def read_directory(dirpath):
    filenames = glob.glob(os.path.join(dirpath, "*.txt"))
    print "Reading %s files from %s" % (len(filenames), dirpath)
    sys.stdout.flush()
    def gen():
        for f in filenames:
            text = open(f).read().strip().decode('utf8')
            tokens = nltk.word_tokenize(text)
            tokens = [w.lower() for w in tokens]
            # label = filename2label[f]
            d = {'text':text, 'tokens':tokens, 'docid':f}
            yield d
    return list(gen())

def files_to_jsons(dirpath):
    for doc in read_directory(dirpath):
        print json.dumps(doc)

if __name__=='__main__':
    eval(sys.argv[1])(*sys.argv[2:])
