What you’ll learn
- the limitations of word2vec/doc2vec methods for text documents classification
- a simple but efficient method for sentence embedding
- how to build a batch generator so that your model can scale on large datasets
- how convolutional neural nets can be used for NLP tasks such as sentiment analysis
- check my github repo for the project!!
As we saw in the previous post, it is possible to obtain useful representations of labelled document using a simple doc2vec algorithm. However, we showed a specific example for which the doc2vec low-dimensional embedding had nice properties: the spam and non-spam messages were grouped together, so that it was very easy to draw a line between the two classes. Unfortunately, this nice representation is rather a special case than a generality. Generally speaking, it is not so easy to learn such interesting representations using word2vec/ doc2vec. Word2vec is known to give pretty nice embeddings for words, clustering similar words together. In addition, when words are vectorized, it is also possible to use some kind of vector algebra on words ( for instance: King – Man + Woman = Queen , see this link for more info).
However, doc2vec often struggles to cluster similar documents in the desired way. It is important to remind that these methods are built on one-hidden-layer neural networks, and have a limited capacity to learn complex representations (even though the Universal Approximation Theorem states that any one-hidden-layer neural net can learn any continuous function defined on a compact domain, it is not so easy in practice!).
To learn sequential dependencies in sentences, it is natural to resort to recurrent neural networks, equipped with LSTM, GRU or other types of units. These neural networks have been specifically built to deal with sequential dependencies. Typical applications are times series predictions, text generation and of course, documents classification.
Nonetheless, we will describe here a different approach, which may sound quite counter-intuitive but works very well!
The first thing to do is to embed the sentences into some real number space. To do this, we build a vocabulary which assigns to each word a specific index. You can check the vocab.py file from the project GitHub repo for more details. Each embedded sentence is padded, so that each input sample has the same length (the length of the longest sentence in the document). This facilitates the training procedure. For instance, the sentence “I like parrot because they are really smart” may be encoded in the following way: [1,23, 34, 89, 16 ,…54, 0, 0, …,0]. 1 corresponds here to “I”, 23 to “like” and so on.
First, a tensorflow variable of size voc_size*embedding_size is created Then, during training,. the relevant rows of this matrix are retrieved with the tf.nn.embedding_lookup method. Those embedding weights are trainable parameters. The embedding size is set to 128 by default in our implementation. Hence, if we take the same example as below “I like parrot because they are really smart”, we will only consider the 1st, 23th, 34th … rows of the embedding matrix of size voc_size*embedding_size.
Hence, as explained above, the scenario is the following: we have a large voc_size*embedding_size embedding matrix, and for each sentence sample, we retrieve the relevant rows. Note that each input sentence is padded so that the number of rows extracted from the embedding matrix is always the same.
The second step is convolution. We will apply filters of height 3,4,5 on the maximum_sentence_length *embedding_size and add a RELU activation function. We create 64 feature maps for each filter height. The width of the filter is the embedding_size and the stride is 1, so that each feature map has the same length as the input matrix, that is maximum_sentence_length but a width of 1 (due to the RELU activation).
The third step is a max pooling on each column of the feature map. The maximum value of each column are stored in one vector of size embedding_size. Then, a flat array is generated with all the pooling results for all feature maps (size of this vector: 64).
For each of the three filters, the results of the previous operations are concatenated, so that we have one large one-dimensional vector (size: 3*64 = 192)
The next layer is a dropout layer (0.25 dropout probability), which is useful for regularisation.
The last layer is a normal fully-connected layer with softmax activation functions.
The loss function is the regular cross entropy, and we use the Adam optimiser for gradient descent.
Training and results
We trained the network on 100 epochs on the STSA binary (Stanford Sentiment Treebank) dataset, with a learning rate of 0.001. This dataset has a training and validation subsets. Our results on validation data are quite good.
validation loss: 0.1654, validation accuracy: 0.9511
- Besides, the use of generators and a simple parsing systems allows to scale on large data sets with minimal efforts. This implementaton has been tested on this Amazon review dataset (> 3 million reviews) and achieved an accuracy comparable to FastText (+- 91% precision/recall)
- Precision and recall on benchmark datasets such as STSA binary (Stanford Sentiment Treebank) and Sentence Polarity dataset v1 are higher than 95% on validation data (validation loss: 0.1654, validation accuracy: 0.9511). For the STSA data we measured accuracy/precision/recall on the provided validation set and for Sentence Polarity v1 we performed cross-validation with a random split of 0.9 -0.1 each time. This outperforms the benchmark of 87% accuracy for the STSA dataset.
- Even when training and building a vocabulary on the STSA data, precision and recall are in the order of 85% on the Sentence Polarity v1 validation set, which shows that the model generalizes well on new data.
- This implementation also delivers good results for multi-label classification tasks. Here is the confusion matrix obtained on validation data for the MBTI Myers-Briggs Personality Type Dataset available on Kaggle. See my Github repo
- The default implementation is on the MBTI dataset but you can check on any other labelled documents dataset.