Deep Transition Dependency Parser in PyTorch

(Adapted from gt-nlp-class)

In this problem set, you will implement a deep transition dependency parser in PyTorch. PyTorch is a popular deep learning framework providing a variety of components for constructing neural networks. You will see how more complicated network architectures than simple feed-forward networks that you have learned in earlier classes can be used to solve a structured prediction problem.

In [1]:
import gtnlplib.parsing as parsing
import gtnlplib.data_tools as data_tools
import gtnlplib.constants as consts
import gtnlplib.evaluation as evaluation
import gtnlplib.utils as utils
import gtnlplib.feat_extractors as feat_extractors
import gtnlplib.neural_net as neural_net

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as ag

from collections import defaultdict
In [2]:
# Read in the dataset
dataset = data_tools.Dataset(consts.TRAIN_FILE, consts.DEV_FILE, consts.TEST_FILE)

# Assign each word a unique index, including the two special tokens
word_to_ix = { word: i for i, word in enumerate(dataset.vocab) }
In [3]:
# Some constants to keep around
LSTM_NUM_LAYERS = 1
TEST_EMBEDDING_DIM = 5
WORD_EMBEDDING_DIM = 64
STACK_EMBEDDING_DIM = 100
NUM_FEATURES = 3

# Hyperparameters
ETA_0 = 0.01
DROPOUT = 0.0
In [4]:
def make_dummy_parser_state(sentence):
    dummy_embeds = [ w + "-EMBEDDING" for w in sentence ] + [consts.END_OF_INPUT_TOK + "-EMBEDDING"]
    return parsing.ParserState(sentence + [consts.END_OF_INPUT_TOK], dummy_embeds, utils.DummyCombiner())

High-Level Overview of the Parser

Be sure that you have reviewed the notes on transition-based dependency parsing, and are familiar with the relevant terminology. One small difference is that the text describes arc-left and arc-right actions, which create arcs between the top of the stack and the front of the buffer; in contrast, the parser you will implement here uses reduce-left and reduce-right actions, which create arcs between the top two items on the stack.

Parsing will proceed as follows:

  • Initialize your parsing stack and input buffer.
  • At each step, extract some features. These can be anything: words in the sentence, the configuration of the stack, the configuration of the input buffer, the previous action, etc.
  • Send these features through a multi-layer perceptron (MLP) to get a probability distribution over actions (SHIFT, REDUCE_L, REDUCE_R). The next action you choose is the one with the highest probability.
  • If the action is either reduce left or reduce right, you use a neural network to combine the items being reduced and get a dense output to place back on the stack.

The key classes you will fill in code for are

  • Feature extraction in feat_extractors.py
  • The ParserState class, which keeps track of the input buffer and parse stack, and offers a public interface for doing the parsing actions to update the state
  • The TransitionParser class, which is a PyTorch module where the core parsing logic resides in parsing.py.
  • The neural network components in neural_net.py

The network components are compartmentalized as follows:

  • Parsing: TransitionParser is the base component that contains and coordinates the other substitutable components.
  • Embedding Lookup: You will implement two flavors of getting embeddings. These embeddings are used to initialize the input buffer, and will be shifted on the stack / serve as inputs to the combiner networks (explained below).
    • VanillaWordEmbeddingLookup just gets embeddings from a lookup table, one per word in the sentence.
    • BiLSTMWordEmbeddingLookup is more fancy, running a sequence model in both directions over the sentence. The hidden state at step t is the embedding for the t'th word of the sentence.
  • Action Choosing: This is a simple multilayer perceptron (MLP) that outputs log probabilities over actions
  • Combiners: These networks take the two embeddings of the items being reduced, and combine them into a single embedding. You will create two version of this:
    • MLPCombinerNetwork takes the two input embeddings and gives a dense output
    • LSTMCombinerNetwork does a sequence model, where the output embedding is the hidden state of the next timestep.

Example

The following is how the input buffer and stack look at each step of a parse, up to the first reduction. The input sentence is "the dog ran away". Our action chooser network takes the top two elements of the stack plus a one-token lookahead in the input buffer. $C(x,y)$ refers to calling our combiner network on arguments $x, y$. Also let $A$ be the set of actions: $\{ \text{SHIFT}, \text{REDUCE-L}, \text{REDUCE-R} \}$, and let $q_w$ be the embedding for word $w$.

Step 1.

  • Input Buffer: $\left[ q_\text{the}, q_\text{dog}, q_\text{ran}, q_\text{away}, q_\text{END-INPUT} \right]$
  • Stack: $\left[ q_\text{NULL-STACK}, q_\text{NULL-STACK} \right]$
  • Action: $ \text{argmax}_{a \in A} \ \text{ActionChooser}(q_\text{NULL-STACK}, q_\text{NULL-STACK}, \overbrace{q_\text{the}}^\text{lookahead}) \Rightarrow \text{SHIFT}$

Step 2

  • Input Buffer: $\left[ q_\text{dog}, q_\text{ran}, q_\text{away}, q_\text{END-INPUT} \right]$
  • Stack: $\left[ q_\text{NULL-STACK}, q_\text{NULL-STACK}, q_\text{the} \right]$
  • Action: $ \text{argmax}_{a \in A} \ \text{ActionChooser}(q_\text{NULL-STACK}, q_\text{the}, q_\text{dog}) \Rightarrow \text{SHIFT}$

Step 3

  • Input Buffer: $\left[ q_\text{ran}, q_\text{away}, q_\text{END-INPUT} \right]$
  • Stack: $\left[ q_\text{NULL-STACK}, q_\text{NULL-STACK}, q_\text{the}, q_\text{dog} \right]$
  • Action: $ \text{argmax}_{a \in A} \ \text{ActionChooser}(q_\text{the}, q_\text{dog}, q_\text{ran}) \Rightarrow \text{REDUCE-L}$

Step 4

  • Input Buffer: $\left[ q_\text{ran}, q_\text{away}, q_\text{END-INPUT} \right]$
  • Stack: $\left[ q_\text{NULL-STACK}, q_\text{NULL-STACK}, C(q_\text{dog}, q_\text{the}) \right]$

For each word $w_m$, the parser keeps track of: the embedding $q_{w_m}$, the word itself $w_m$, and the word's position in the sentence $m$. The combination action should store the word and the index for the head word in the relation. The combined embedding may be a function of the embeddings for both the head and modifier words.

Before beginning, I recommend completing the parse, drawing the input buffer and stack at each step, and explicity listing the arguments to the action chooser.

1. Managing and Updating the Parser State (15 points)

In this part of the assignment, you will work with the ParserState class, that keeps track of the parsers input buffer and stack.

Deliverable 1.1: Implementing Reduce (10 points)

Implement the reduction operation of the ParserState in parsing.py, in the function _reduce.

The way reduction is done is slightly different from the notes. In the notes, reduction takes place between the top element of the stack and the first element of the input buffer. Here, reduction takes place between the top two elements of the stack.

At this step, there are no embeddings, but don't forget to make the call to the combiner network component.

Hints:

  • Before starting, read the comments in _reduce, and look at the __init__ function of ParserState to see how it represents the stack and input buffer.
  • The StackEntry and DepGraphEdge tuples will be part of your solution, so take a look at how these are used elsewhere in the source.
  • In particular, you will want to push a new StackEntry onto the stack, and return a DepGraphEdge.
  • If you have trouble understanding the representation, print parser_state.stack or parser_state.input_buffer directly. (If you just print parser_state, it will output a pretty-printed version).
In [6]:
test_sentence = "They can fish".split()+[consts.END_OF_INPUT_TOK]
parser_state = parsing.ParserState(test_sentence, [None] * len(test_sentence), utils.DummyCombiner())

print parser_state

parser_state.shift()
parser_state.shift()
print parser_state

reduction = parser_state.reduce_left()
print "Reduction Made Edge: Head: {}, Modifier: {}".format(reduction[0], reduction[1]), "\n"
print parser_state
Stack: []
Input Buffer: ['They', 'can', 'fish', '<END-OF-INPUT>']

Stack: ['They', 'can']
Input Buffer: ['fish', '<END-OF-INPUT>']

Reduction Made Edge: Head: ('can', 1), Modifier: ('They', 0) 

Stack: ['can']
Input Buffer: ['fish', '<END-OF-INPUT>']

Deliverable 1.2: Parser Terminating Condition (5 points)

In this short (one line) deliverable, implement done_parsing() in ParserState. Note we add an END_INPUT_TOKEN to the end of the sentence (this token could be a helpful feature). Think about what the input buffer and stack look like at the end of a parse.

In [7]:
print parser_state, parser_state.done_parsing(),'\n'
parser_state.shift()
print parser_state, parser_state.done_parsing(),'\n'
parser_state.reduce_right()
print parser_state, parser_state.done_parsing(),'\n'
Stack: ['can']
Input Buffer: ['fish', '<END-OF-INPUT>']
 False 

Stack: ['can', 'fish']
Input Buffer: ['<END-OF-INPUT>']
 False 

Stack: ['can']
Input Buffer: ['<END-OF-INPUT>']
 True 

2. Neural Network for Action Decisions (35 points)

In this part of the assignment, you will use PyTorch to create a neural network which examines the current state of the parse and makes the decision to either shift, reduce left, or reduce right.

Deliverable 2.1: Word Embedding Lookup (10 points)

Implement the class VanillaWordEmbeddingLookup in neural_net.py.

This involves adding code to the __init__ and forward methods.

  • In the __init__ method, you want make sure that instances of the class can store the embeddings
  • In the forward method, you should return a list of Torch variables, representing the looked up embeddings for each word in the sequence

If you didn't do the tutorial, you will want to read the docs on how to create a lookup table for your word embeddings.

Hint: You will have to turn the input, which is a list of strings (the words in the sentence), into a format that your embedding lookup table can take, which is a torch.LongTensor. So that we can automatically backprop, it is wrapped in a Variable. utils.sequence_to_variable takes care of this for you.

In [52]:
torch.manual_seed(1) # DO NOT CHANGE
reload(neural_net)
test_sentence = "William Faulkner".split()
test_word_to_ix = { "William": 0, "Faulkner": 1 }

word_embedder = neural_net.VanillaWordEmbeddingLookup(test_word_to_ix, TEST_EMBEDDING_DIM)
embeds = word_embedder(test_sentence)
print type(embeds)
print len(embeds), "\n"
print "Embedding for William:\n {}".format(embeds[0])
<type 'list'>
2 

Embedding for William:
 Variable containing:
-2.9718  1.7070 -0.4305 -2.2820  0.5237
[torch.FloatTensor of size 1x5]

Deliverable 2.2: Feature Extraction (5 points)

Fill in the SimpleFeatureExtractor class in feat_extractors.py to give the following 3 features

  • The embedding of the 2nd to top of the stack
  • The embedding of the top of the stack
  • The embedding of the next token in the input buffer (one-token lookahead)

If at this point you have not poked around ParserState to see how it stores the state, now would be a good time.

In [53]:
torch.manual_seed(1)
test_sentence = "The Sound and the Fury".split()
test_word_to_ix = { word: i for i, word in enumerate(set(test_sentence)) }

embedder = neural_net.VanillaWordEmbeddingLookup(test_word_to_ix, TEST_EMBEDDING_DIM)
embeds = embedder(test_sentence)

state = parsing.ParserState(test_sentence, embeds, utils.DummyCombiner())

state.shift()
state.shift()
feat_extractor = feat_extractors.SimpleFeatureExtractor()
feats = feat_extractor.get_features(state)

print "Embedding for 'The':\n {}".format(feats[0])
print "Embedding for 'Sound':\n {}".format(feats[1])
print "Embedding for 'and' (from buffer lookahead):\n {}".format(feats[2])
Embedding for 'The':
 Variable containing:
 0.8407  0.5510  0.3863  0.9124 -0.8410
[torch.FloatTensor of size 1x5]

Embedding for 'Sound':
 Variable containing:
-2.9718  1.7070 -0.4305 -2.2820  0.5237
[torch.FloatTensor of size 1x5]

Embedding for 'and' (from buffer lookahead):
 Variable containing:
 0.0004 -1.2039  3.5283  0.4434  0.5848
[torch.FloatTensor of size 1x5]

Deliverable 2.3: MLP for Choosing Actions (10 points)

Implement the class neural_net.ActionChooserNetwork according to the specification in neural_net.py.

You will want to use the utils.concat_and_flatten function. We provide this function because the Tensor reshaping code can get somewhat terse. It takes the list of embeddings passed in (that come from your feature extractor) and concatenates them to one long row vector.

This network takes as input the features from your feature extractor, concatenates them, runs them through an MLP and outputs log probabilities over actions.

Hint:

In [54]:
torch.manual_seed(1) # DO NOT CHANGE, you can compare my output below to yours
act_chooser = neural_net.ActionChooserNetwork(TEST_EMBEDDING_DIM * NUM_FEATURES)
feats = [ ag.Variable(torch.randn(1, TEST_EMBEDDING_DIM)) for _ in xrange(NUM_FEATURES) ] # make some dummy feature embeddings
log_probs = act_chooser(feats)
print log_probs
Variable containing:
-1.5347 -1.3445 -0.6466
[torch.FloatTensor of size 1x3]

Deliverable 2.4: Network for Combining Stack Items (10 points)

Implement the class neural_net.MLPCombinerNetwork according to the specification in neural_net.py. Again, utils.concat_and_flatten will come in handy.

Recall that what this component does is take two embeddings, the head and modifier, during a reduction and output a combined embedding, which is then pushed back onto the stack during parsing.

In [55]:
torch.manual_seed(1) # DO NOT CHANGE
combiner = neural_net.MLPCombinerNetwork(TEST_EMBEDDING_DIM)

# Again, make dummy inputs
head_feat = ag.Variable(torch.randn(1, TEST_EMBEDDING_DIM))
modifier_feat = ag.Variable(torch.randn(1, TEST_EMBEDDING_DIM))
combined = combiner(head_feat, modifier_feat)
print combined
Variable containing:
 0.6063 -0.0110  0.6530 -0.6196 -0.1051
[torch.FloatTensor of size 1x5]

3. Return of the Parser (20 points)

Deliverable 3.1: Parser Training Code (15 points)

Note: There are two unit tests for this deliverable, one worth 1 point, one worth 0.5.

You will implement the forward() function in gtnlplib.parsing.TransitionParser. It is important to understand the difference between the following tasks:

  • Training: Training the model involves passing it sentences along with the correct sequence of actions, and updating weights.
  • Evaluation: We can evaluate the parser by passing it sentences along with the correct sequence of actions, and see how many actions it predicts correctly. This is identical to training, except the weights are not updated after making a prediction.
  • Prediction: After setting the weights, we give it a raw sentence (no gold-standard actions) and ask it for the correct dependency graph.

At this point, it is necessary to have all of the components in place for constructing the parser.

The parsing logic is roughly as follows:

  • Loop until parsing state is in its terminating state (deliverable 1.2)
  • Get the features from the parsing state (deliverable 2.1)
  • Send them through your action chooser network to get log probabilities over actions (deliverable 2.3)
  • If you have gold_actions, do that. Otherwise (when predicting), take the argmax of your log probabilities and do that.
    • Argmax is gross in PyTorch, so a function is provided for you in utils.argmax.
    • While the gold actions will always be valid, if you are not provided gold actions, you must make sure that any action you do is legal. You cannot shift when the input buffer contains only END_OF_INPUT_TOK (this token should NOT be shifted onto the stack) and you cannot reduce when the stack contains fewer than 2 elements. If your network chooses SHIFT when it is not legal, just do REDUCE_R

Make sure to keep track of the things that the function wants to keep track of

  • Do all of your actions by calling the appropriate function on your parser_state
  • Append each output Variable from your action_chooser to the outputs list
  • Append each action you do to actions_done
In [8]:
test_sentence = "The man ran away".split()
test_word_to_ix = { word: i for i, word in enumerate(set(test_sentence)) }
test_word_to_ix[consts.END_OF_INPUT_TOK] = len(test_word_to_ix)
test_sentence_vocab = set(test_sentence)
gold_actions = ["SHIFT", "SHIFT", "REDUCE_L", "SHIFT", "REDUCE_L", "SHIFT", "REDUCE_R"]
In [9]:
feat_extractor = feat_extractors.SimpleFeatureExtractor()
word_embedding_lookup = neural_net.VanillaWordEmbeddingLookup(test_word_to_ix, STACK_EMBEDDING_DIM)
action_chooser = neural_net.ActionChooserNetwork(STACK_EMBEDDING_DIM * NUM_FEATURES)
combiner_network = neural_net.MLPCombinerNetwork(STACK_EMBEDDING_DIM)
parser = parsing.TransitionParser(feat_extractor, word_embedding_lookup,
                                     action_chooser, combiner_network)
In [10]:
output, depgraph, actions_done = parser(test_sentence, gold_actions)
print depgraph
print actions_done
set([DepGraphEdge(head=('ran', 2), modifier=('away', 3)), DepGraphEdge(head=('ran', 2), modifier=('man', 1)), DepGraphEdge(head=('<ROOT>', -1), modifier=('ran', 2)), DepGraphEdge(head=('man', 1), modifier=('The', 0))])
[0, 0, 1, 0, 1, 0, 2]

Now Train the Parser!

Training your parser may take some time. On the test below, I get about 5 seconds per loop (i7 6700k).

  • There are 10,000 training sentences, so multiply this measurement by 100 to get your training time.
  • One optimization trick is to that if you can do several things with a single PyTorch call, this will probably be faster than writing a PyTorch call that does one thing, and then calling it several times.
In [11]:
torch.manual_seed(1)
feat_extractor = feat_extractors.SimpleFeatureExtractor()
word_embedding_lookup = neural_net.VanillaWordEmbeddingLookup(word_to_ix, STACK_EMBEDDING_DIM)
action_chooser = neural_net.ActionChooserNetwork(STACK_EMBEDDING_DIM * NUM_FEATURES)
combiner_network = neural_net.MLPCombinerNetwork(STACK_EMBEDDING_DIM)
parser = parsing.TransitionParser(feat_extractor, word_embedding_lookup,
                                     action_chooser, combiner_network)
optimizer = optim.SGD(parser.parameters(), lr=ETA_0)
In [12]:
%%timeit
parsing.train(dataset.training_data[:100], parser, optimizer)
Number of instances: 100    Number of network actions: 3898
Acc: 0.687275525911  Loss: 27.0280368376
Number of instances: 100    Number of network actions: 3898
Acc: 0.837609030272  Loss: 15.3356624376
Number of instances: 100    Number of network actions: 3898
Acc: 0.91713699333  Loss: 8.99367914472
Number of instances: 100    Number of network actions: 3898
Acc: 0.956900974859  Loss: 4.93015527138
1 loop, best of 3: 3.96 s per loop
In [15]:
# if this call doesn't work, something is wrong with your parser's behavior when gold labels aren't provided
parser.predict(dataset.dev_data[0].sentence)
Out[15]:
{DepGraphEdge(head=('<ROOT>', -1), modifier=('restrict', 4)),
 DepGraphEdge(head=('RTC', 6), modifier=('the', 5)),
 DepGraphEdge(head=('Treasury', 8), modifier=(',', 11)),
 DepGraphEdge(head=('Treasury', 8), modifier=('RTC', 6)),
 DepGraphEdge(head=('Treasury', 8), modifier=('only', 10)),
 DepGraphEdge(head=('Treasury', 8), modifier=('to', 7)),
 DepGraphEdge(head=('agency', 14), modifier=('the', 13)),
 DepGraphEdge(head=('authorization', 18), modifier=('congressional', 17)),
 DepGraphEdge(head=('authorization', 18), modifier=('specific', 16)),
 DepGraphEdge(head=('intends', 2), modifier=('The', 0)),
 DepGraphEdge(head=('intends', 2), modifier=('bill', 1)),
 DepGraphEdge(head=('only', 10), modifier=('borrowings', 9)),
 DepGraphEdge(head=('receives', 15), modifier=('agency', 14)),
 DepGraphEdge(head=('receives', 15), modifier=('authorization', 18)),
 DepGraphEdge(head=('restrict', 4), modifier=('.', 19)),
 DepGraphEdge(head=('restrict', 4), modifier=('intends', 2)),
 DepGraphEdge(head=('restrict', 4), modifier=('to', 3)),
 DepGraphEdge(head=('restrict', 4), modifier=('unless', 12)),
 DepGraphEdge(head=('unless', 12), modifier=('Treasury', 8)),
 DepGraphEdge(head=('unless', 12), modifier=('receives', 15))}
In [7]:
# train the thing for a while here.
# Shouldn't take too long, even on a laptop
for epoch in xrange(1):
    print "Epoch {}".format(epoch+1)
    parsing.train(dataset.training_data[:1000], parser, optimizer, verbose=True)
    
    print "Dev Evaluation"
    parsing.evaluate(dataset.dev_data, parser, verbose=True)
    print "F-Score: {}".format(evaluation.compute_metric(parser, dataset.dev_data, evaluation.fscore))
    print "Attachment Score: {}".format(evaluation.compute_attachment(parser, dataset.dev_data))
    print "\n"
Epoch 1
Number of instances: 997    Number of network actions: 39025
Acc: 0.825035233824  Loss: 17.4108023486
Dev Evaluation
Number of instances: 399    Number of network actions: 15719
Acc: 0.823843755964  Loss: 16.9211852149
F-Score: 0.500566790988
Attachment Score: 0.486784960913


Deliverable 3.2: Test Data Predictions (5 points)

Run the code below to output your predictions on the test data and dev data. You can run the dev test to verify you are correct up to this point. The test data evaluation is for us.

In [8]:
dev_sentences = [ sentence for sentence, _ in dataset.dev_data ]
evaluation.output_preds(consts.D3_2_DEV_FILENAME, parser, dev_sentences)
In [9]:
evaluation.output_preds(consts.D3_2_TEST_FILENAME, parser, dataset.test_data)

4. Evaluation and Training Improvements (30 points)

Deliverable 4.1: Better Word Embeddings (10 points)

Implement the class BiLSTMWordEmbeddingLookup in neural_net.py. This class can replace your VanillaWordEmbeddingLookup. This class implements a sequence model over the sentence, where the t'th word's embedding is the hidden state at timestep t. This means that, rather than have our embeddings on the stack only include the semantics of a single word, our embeddings will contain information from all parts of the sentence (the LSTM will, in principle, learn what information is relevant).

In [63]:
torch.manual_seed(1) # DO NOT CHANGE
test_sentence = "Michael Collins".split()
test_word_to_ix = { "Michael": 0, "Collins": 1 }

lstm_word_embedder = neural_net.BiLSTMWordEmbeddingLookup(test_word_to_ix,
                                                          WORD_EMBEDDING_DIM,
                                                          STACK_EMBEDDING_DIM,
                                                          num_layers=LSTM_NUM_LAYERS,
                                                          dropout=DROPOUT)
    
lstm_embeds = lstm_word_embedder(test_sentence)
print type(lstm_embeds)
print len(lstm_embeds), "\n"
print "Embedding for Michael:\n {}".format(lstm_embeds[0])
<type 'list'>
2 

Embedding for Michael:
 Variable containing:

Columns 0 to 9 
-0.0134 -0.0766 -0.0746  0.0530 -0.0202  0.1845 -0.1455 -0.0734 -0.0072  0.0781

Columns 10 to 19 
 0.0354 -0.0723  0.0160  0.0915 -0.0200  0.1126  0.1395  0.0041  0.0919  0.0251

Columns 20 to 29 
 0.3126  0.0233  0.1408  0.1407 -0.2879 -0.1591 -0.0579  0.0207  0.0364 -0.3148

Columns 30 to 39 
-0.4017  0.1126  0.2589  0.0505 -0.1529 -0.0149  0.0705  0.0419 -0.1842  0.1084

Columns 40 to 49 
-0.1632 -0.0252 -0.0965 -0.0090  0.1427  0.1717  0.1267 -0.0724  0.3383 -0.0991

Columns 50 to 59 
 0.2505 -0.1585 -0.0338  0.2543  0.1364  0.1747 -0.0128  0.0472 -0.0284 -0.1095

Columns 60 to 69 
-0.2905  0.1631  0.0890  0.1824  0.0406  0.0039 -0.0506 -0.0266  0.0073  0.1715

Columns 70 to 79 
 0.0092 -0.3738 -0.0689  0.0460  0.1567 -0.0565  0.1381  0.0503 -0.0933  0.1842

Columns 80 to 89 
-0.0477  0.1206  0.0543  0.0678 -0.0886  0.0467 -0.2502  0.0426 -0.0566 -0.0431

Columns 90 to 99 
 0.0637 -0.0667  0.0312 -0.1330 -0.1285 -0.0477  0.0292 -0.1092 -0.0594  0.0528
[torch.FloatTensor of size 1x100]

Deliverable 4.2: Pretrained Embeddings (5 points)

Fill in the function initialize_with_pretrained in utils.py.

It will take a word embedding lookup component and initialize its lookup table with pretrained embeddings.

Note that you can create a Torch variable from a list of floats using torch.Tensor(). Googling for more information about how Torch stores parameters is allowed, I don't think you'll find the exact answer online (corollary: do not post the answer online).

In [21]:
import cPickle
pretrained_embeds = cPickle.load(open(consts.PRETRAINED_EMBEDS_FILE))
print pretrained_embeds['four'][:5]
[0.12429751455783844, -0.11472601443529129, -0.5684014558792114, -0.396965891122818, 0.22938089072704315]
In [22]:
embedder = neural_net.VanillaWordEmbeddingLookup(word_to_ix,64)
In [23]:
embedder.forward(['four'])[0][0,:5]
Out[23]:
Variable containing:
 0.6730
-1.7911
 1.4701
 1.5589
-2.0735
[torch.FloatTensor of size 5]
In [24]:
reload(utils);
utils.initialize_with_pretrained(pretrained_embeds,embedder)
print embedder.forward(['four'])[0][0,:5]
Variable containing:
 0.1243
-0.1147
-0.5684
-0.3970
 0.2294
[torch.FloatTensor of size 5]

Deliverable 4.3: Better Reduction Combination (10 points)

Before, in order to combine two embeddings during a reduction, we just passed them through an MLP and got a dense output. Now, we will instead use a sequence model of the stack. The combined embedding from a reduction is the next time step of an LSTM. Implement LSTMCombinerNetwork in neural_network.py.

In [25]:
reload(neural_net);
TEST_EMBEDDING_DIM = 5
combiner = neural_net.LSTMCombinerNetwork(TEST_EMBEDDING_DIM, 1, 0.0)
head_feat = ag.Variable(torch.randn(1, TEST_EMBEDDING_DIM))
modifier_feat = ag.Variable(torch.randn(1, TEST_EMBEDDING_DIM))
In [26]:
utils.concat_and_flatten([head_feat,modifier_feat]).view(1,1,-1)
Out[26]:
Variable containing:
(0 ,.,.) = 

Columns 0 to 8 
  -0.7641 -0.2784  0.6002  0.0032 -1.3923 -0.5975  0.2761  1.4585  1.2168

Columns 9 to 9 
  -0.0510
[torch.FloatTensor of size 1x1x10]
In [27]:
# note that the output keeps changing, because of the recurrent update
for _ in xrange(3):
    print combiner(head_feat,modifier_feat)
Variable containing:
 0.2730  0.0426 -0.1535  0.0089  0.0604
[torch.FloatTensor of size 1x5]

Variable containing:
 0.3838  0.0589 -0.1849  0.0171  0.0936
[torch.FloatTensor of size 1x5]

Variable containing:
 0.4329  0.0679 -0.1930  0.0236  0.1142
[torch.FloatTensor of size 1x5]

Retrain with the new components

The code below retrains your parser using all the new components that you just wrote.

In [16]:
feat_extractor = feat_extractors.SimpleFeatureExtractor()

# BiLSTM over word embeddings
word_embedding_lookup = neural_net.BiLSTMWordEmbeddingLookup(word_to_ix,
                                                             WORD_EMBEDDING_DIM,
                                                             STACK_EMBEDDING_DIM,
                                                             num_layers=LSTM_NUM_LAYERS,
                                                             dropout=DROPOUT)
# pretrained inputs
utils.initialize_with_pretrained(pretrained_embeds, word_embedding_lookup)

action_chooser = neural_net.ActionChooserNetwork(STACK_EMBEDDING_DIM * NUM_FEATURES)

# LSTM reduction operations
combiner = neural_net.LSTMCombinerNetwork(STACK_EMBEDDING_DIM,
                                          num_layers=LSTM_NUM_LAYERS,
                                          dropout=DROPOUT)

parser = parsing.TransitionParser(feat_extractor, word_embedding_lookup,
                                  action_chooser, combiner)

optimizer = optim.SGD(parser.parameters(), lr=ETA_0)
In [17]:
%%timeit
# The LSTMs will make this take longer
parsing.train(dataset.training_data[:100], parser, optimizer)
Number of instances: 100    Number of network actions: 3898
Acc: 0.668291431503  Loss: 28.3679159951
Number of instances: 100    Number of network actions: 3898
Acc: 0.806824012314  Loss: 18.1205399126
Number of instances: 100    Number of network actions: 3898
Acc: 0.855310415598  Loss: 13.5849320753
Number of instances: 100    Number of network actions: 3898
Acc: 0.880708055413  Loss: 10.9105038324
1 loop, best of 3: 3.9 s per loop
In [18]:
for epoch in xrange(1):
    print "Epoch {}".format(epoch+1)
    
    parser.train() # turn on dropout layers if they are there
    parsing.train(dataset.training_data[:1000], parser, optimizer, verbose=True)
    
    print "Dev Evaluation"
    parser.eval() # turn them off for evaluation
    parsing.evaluate(dataset.dev_data, parser, verbose=True)
    print "F-Score: {}".format(evaluation.compute_metric(parser, dataset.dev_data, evaluation.fscore))
    print "Attachment Score: {}".format(evaluation.compute_attachment(parser, dataset.dev_data))
    print "\n"
Epoch 1
Number of instances: 997    Number of network actions: 39025
Acc: 0.896809737348  Loss: 10.2798197525
Dev Evaluation
Number of instances: 399    Number of network actions: 15719
Acc: 0.898403206311  Loss: 9.99048229483
F-Score: 0.698743236193
Attachment Score: 0.691897257724


Deliverable 4.4: Test Predictions (5 points)

Run the code below to generate test predictions

In [19]:
dev_sentences = [ sentence for sentence, _ in dataset.dev_data ]
evaluation.output_preds(consts.D4_4_DEV_FILENAME, parser, dev_sentences)
In [20]:
evaluation.output_preds(consts.D4_4_TEST_FILENAME, parser, dataset.test_data)

Using Cuda

You can use CUDA to train your network, and you should expect significant speedup if you have a decent GPU and the CUDA toolkit installed. If you want to use CUDA in this assignment, change the HAVE_CUDA variable to True in constants.py, and uncomment the .to_cuda() and .to_cpu() lines below.

We are not officially supporting CUDA though. If you have problems installing or running CUDA, please just use the CPU, we cannot help you debug it.

In [ ]:
#training example for fun
for epoch in xrange(NUM_EPOCHS):
    print "Epoch {}".format(epoch+1)
    
    #parser.to_cuda() # uncomment to train on your GPU
    parser.train() # turn on dropout layers if they are there
    
    # train on full training data
    parsing.train(dataset.training_data, parser, optimizer, verbose=True)
    
    print "Dev Evaluation"
    #parser.to_cpu() #TODO fix evaluation so you dont have to ship everything back to the CPU
    parser.eval() # turn them off for evaluation
    parsing.evaluate(dataset.dev_data, parser, verbose=True)
    print "F-Score: {}".format(evaluation.compute_metric(parser, dataset.dev_data, evaluation.fscore))
    print "Attachment Score: {}".format(evaluation.compute_attachment(parser, dataset.dev_data))
    print "\n"