%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Training
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

load sms.mat;  % loads the following variables

% X_train  contains information about the words within the training
%          messages. the ith row represents the ith training message. 
%          for a particular text, the entry in the jth column tells
%          you how many times the jth dictionary word appears in 
%          that message
%
% X_test   similar but for test set
%
% y_train  ith entry indicates whether message i is spam
%
% y_test   similar
%
% dict     Cell array whose jth entry is the jth dictionary word

[m, n] = size(X_train);

% Initialize theta
theta_init = zeros(n,1);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Train
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% YOUR CODE HERE: 
%  - learn theta by gradient descent 
%  - plot the cost history
%  - tune step size and # iterations if necessary


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Testing
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% YOUR CODE HERE
%  - use theta to make predictions for test set
%  - print the accuracy on the test set---i.e., the precent of messages misclassified

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Inspect the model to see which words are indicative of spam and non-spam
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

word_weights = theta(2:end); 
[~,I] = sort(word_weights, 'descend');
[~,J] = sort(word_weights, 'ascend');

k = 10;

fprintf('Top %d spam words\n', k);
fprintf('  %s\n', dict{I(1:k)});
fprintf('\n');
fprintf('Top %d nonspam words\n', k);
fprintf('  %s\n', dict{J(1:k)});

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Predict for your own message
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

text = 'write a text here...';
x = sms_extract_feature(text, dict);


% YOUR CODE HERE
%  - try a few texts of your own
%  - predict whether they are spam or non-spam




