from __future__ import division
import math,random,cgi

######## The IPython Notebook HTML-based visualization function #############

def show_alignment_posterior(alignments_matrix, ftoks=None, etoks=None):
    # INPUTS
    #  - alignmentx_matrix: a list-of-lists of the posterior alignments, as
    #    returned by calc_alignment_posterior().
    #  - ftoks, etoks: these are optional arguments.  If you give them, this
    #    function will add them as labels to the matrix visualization.
    from IPython.display  import HTML
    size = 3
    td_style = """
        width:{size}em; height:{size}em;
        border: 1px solid black; padding:0; margin:0""".format(**locals())
    word_style = """
        text-align:center; padding:0; """
    filldiv_style = """
        width:{mysize}em; height:{mysize}em;
        background:blue;
        margin-left:auto; margin-right:auto; padding:0; margin-top:auto; margin-bottom:auto"""

    nrow = len(alignments_matrix)
    ncol = len(alignments_matrix[0])
    if ftoks: assert nrow==len(ftoks), "number of foreign tokens doesn't match matrix"
    if etoks: assert ncol==len(etoks), "number of english tokens doesn't match matrix"

    if not ftoks: ftoks = ['']*nrow
    if not etoks: etoks = ['']*ncol

    html_cells = [ ["<td>"]*(ncol+1)
                        for i in range(nrow+1) ]
    for i in range(nrow):
        html_cells[i+1][0] = u"""<td style="{}"><div style="{}">{}</div></td>""".format("",word_style, cgi.escape(ftoks[i]))
    for j in range(ncol):
        html_cells[0][j+1] = u"""<td style="{}"><div style="{}">{}</div></td>""".format("",word_style, cgi.escape(etoks[j]))

    for i in range(nrow):
        for j in range(ncol):
            side_length =  size * math.sqrt(alignments_matrix[i][j])
            html_cells[i+1][j+1] = u"""
            <td style="{td_style}"><div style="{divstyle}"></div></td>
            """.format(td_style=td_style,
                    divstyle = filldiv_style.format(mysize= side_length))
    html = "<table cellpadding='0'>"
    for row in html_cells:
        html += "<tr>" + u"".join(row)
    return HTML(html)




######## Utility functions below, that you may use if you like. #############

def argmax(vec):
    """Return the index of the largest element, where vec is a list of numbers.
    If there are multiple largest elements, I think the first might get
    returned -- see Python's documentation for max() to be sure."""
    indexes = range(len(vec))
    return max(indexes, key=lambda i: vec[i])

def normalized_vector(vec):
    """
    vec is a Python list of numbers.  Return a new normalized version that sums to 1.
    [1,2,7] ==> [0.1, 0.2, 0.7]
    """
    s = sum(vec)
    return [x/s for x in vec]

def normalized_dict(dct):
    """
    Assume dct is a string-to-number map.  Return a normalized version where the values sum to 1.
    {"a":4.0, "b":2.0} ==> {"a":0.6666, "b":0.3333}
    """
    s = sum(dct.values())
    new_dct = {key: value/s for key,value in dct.items()}
    return new_dct

def weighted_draw_from_dict(choice_dict):
    """Randomly choose a key from a dict, where the values are the relative probability weights."""
    # http://stackoverflow.com/a/3679747/86684
    choice_items = choice_dict.items()
    total = sum(w for c, w in choice_items)
    r = random.uniform(0, total)
    upto = 0
    for c, w in choice_items:
       if upto + w > r:
          return c
       upto += w
    assert False, "Shouldn't get here"
