Lessons Learned from Fine-Tuning BERT for Named Entity Recognition

Do’s and don’ts for fine-tuning on multifaceted NLP tasks

Charlene Chambliss
Gab41

--

In Part 1 of this 2-part series, I introduced the task of fine-tuning BERT for named entity recognition, outlined relevant prerequisites and prior knowledge, and gave a step-by-step outline of the fine-tuning process.

Here, I’ll discuss the interesting practical challenges that came up while building out the project, as well as what I plan to do differently next time I work on a similar project. These will be important to keep in mind if you’re interested in fine-tuning BERT yourself, especially on a deadline.

Lessons Learned

First, check your assumptions

I read all of the blog posts and papers I linked above, and took copious notes, but I was completely stumped when I started training the model and the metrics showed an odd pattern: the precision was extremely low (sub 0.5), while the recall was extremely high (in the .9s).

Even more puzzling, the loss seemed fine; it decreased at every epoch, as expected. My inputs and training loop code were functionally equivalent to the tutorial author’s code, and the author had gotten great results. What was going on?

Let’s think it through. First, NER is token-level classification, meaning that the model makes predictions on a word-by-word (or in BERT’s case, subword-by-subword) basis. This is in contrast to sequence classification, where the model makes one prediction for the entire sequence. Part of preprocessing data for token classification is that you need to add [PAD] tokens to “pad out” any sequence that is less than the maximum sequence length.

When you’re doing token-level classification, you run a tensor through the model, and it will predict on every single token in the sequence. Because a large amount of our input tokens are simply padding tokens, this is not actually what we want.

Those padding tokens are “masked” from the model’s learning process during training using an attention mask, which tells the model to disregard those tokens and their labels when the loss is calculated at each step.

I incorrectly assumed that information from the attention mask would also tell the model not to make predictions on those tokens, because the author of the tutorial had not done anything to accommodate unwanted predictions either. (It turned out that the author was using an earlier version of pytorch-transformers that did not have this problem.)

BERT was indeed predicting on these padding tokens, and this was disastrous for the metrics; since the model didn’t learn anything about them, its predictions for [PAD] tokens’ labels were essentially random! I didn’t realize these padding tokens were being used for prediction until I inserted a confusion matrix into the training loop and saw that they were being included.

I then had to rewrite parts of the training and validation loop, creating a separate “prediction mask” to subset out the unwanted predictions prior to calculating the metrics.

The moral here is that you should always check your assumptions about the model’s behavior. (The second moral is to always use a confusion matrix for classification tasks!) Once I looked at the confusion matrix, it became obvious that the model does not use the attention mask to structure the prediction output (nor should it).

Confusion matrix once the padding token predictions were removed

Token classification is more challenging than sequence classification

More things can go wrong in token classification than in sequence classification, both in data preprocessing and in the training and inference phases. If you’re new to deep learning and looking for a project, I’d recommend starting with sequences first and then moving onto token classification problems later.

Aside from the [PAD] tokens issue already described, here are a few more problems that arise simply because we’re dealing with tokens and not sequences:

  1. During data preprocessing, one of the datasets I downloaded included annotation files with character-level indices showing where each entity started and stopped within the text. For example, an annotation might indicate that the span of text from indices 20 to 35 represents a Person entity, allowing you to easily align the words/spans in the text to the correct tags. However, from the very first file, these indices were drastically misaligned with the actual character positions. This rendered the data unusable. Since there were other datasets available for this task, I chose to exclude this one.
  2. Named entity recognition has a specific tagging scheme that needs to be used, since named entities are usually phrases rather than individual words. Instead of tokens belonging only to classes, such as “person,” “organization,” and so on, they also need to be prefixed with information about where they are in the phrase, such as “B” or “I”, which stand for “beginning” and “inside.” Example: “Charlene Chambliss” would be tagged as (B-PER, I-PER). This is called IOB (sometimes BILUO) format. Not all NER datasets will come with this done for you. Thankfully, instead of writing my own IOB tagger, I was able to use spaCy’s biluo_tags_from_offsets convenience function for the data that wasn’t already IOB-tagged.
  3. BERT uses WordPiece tokenization rather than whole-word tokenization (although there are whole words in its vocabulary). So if you have a word like “personally,” it may be broken up into “person” “##al” “##ly.” You now have a list of WordPieces which is much longer than, and misaligned with, your list of labels. You as the practitioner need to figure out how to handle cases where you have one tag per word, but some words get broken up into several chunks. I wrote a function called tokenize_and_preserve_labels to propagate a word’s original label to all of its pieces during tokenization. Other folks who implemented this have also used the following heuristic: leave the original label on the first token of the word, then use the label “X” for subwords of that word. Either method seems to result in equally good model performance.

Need an interactive demo or app? Keep it simple, use Streamlit

When you’re prototyping your interactive demo, I’d highly recommend using streamlit rather than going straight to Flask. It is extremely easy to get started with (simply sprinkle a few extra lines of code into your existing model-prediction script), and updates in real time as you add functionality. It also offers the ability to add caching, meaning you don’t need to wait for your model to load every time the app code changes.

Thoughts on Performance

Performance for NER can be calculated in many different ways, but the standard is to use some form of F1 score, generally strict match, partial match, or relaxed match. In my research for this post, I noticed that the terms “partial match” and “relaxed match” have been used slightly differently from place to place. For our purposes, we define “relaxed match” to mean calculating performance based on the proportion of entity tokens that were identified as the correct entity type, regardless of whether the “boundaries” around the entity were correct.

For example: the model classifies “United States of America” as [B-LOC] [I-LOC] [O] [B-LOC], where the correct labels are [B-LOC] [I-LOC] [I-LOC] [I-LOC]. This would receive 75% credit rather than 50% credit. The last two tags are both “wrong,” in a strict classification label sense, but the model at least classified the fourth token as the correct entity type.

Since models’ decisions are used to highlight or not highlight words in the user interface, rather than to extract entity spans, I calculated scores based on relaxed match. For stricter use cases such as extraction, other evaluation methods should be used (likely including other tweaks, such as weighting recall higher than precision in the calculation of the overall F1 score).

The English model was trained on a combination of CoNLL-2003, the classic NER dataset for researchers, and Emerging Entities (a novel, challenging, and noisy user-generated dataset). The performance on a combined validation set drawn from both CoNLL and EE is as follows:

The Russian model was trained on a combination of data from factRuEval-2016 and the BSNLP 2019 Shared Task. SoTA for F1 on these datasets is around 0.8 and 0.93, respectively.

Here’s the performance of the Russian model on a combined validation set from both factRuEval and BSNLP:

ORG and MISC entities present a perpetual challenge in NER, which is reflected in the relatively lower performance for these classes. However, the class we were primarily interested in was PER (person), and the results there are excellent for both the English and Russian models.

My project satisfied Lab41’s goals for integrating multilingual NER into the UI. If I were to revisit this project in order to improve performance, I would spend time tuning hyperparameters, conducting error analysis on the models to better understand the contexts in which they’re failing, and collecting or labeling more data that align with those contexts. There was also more data available from BSNLP in similar languages to Russian, and I’d be curious to know whether mixing in some of this data during training would help or hurt performance on purely Russian text.

Summary

To recap, the goal of this project was to train BERT-based named-entity recognition models in Russian and English, with the end goal of integrating the NER capabilities into an error analysis front-end for machine translation. The BERT models performed extremely well on relaxed-match P/R/F1, with the English model achieving .96 P, .94 R, and .95 F1 for Person tags, and the Russian model achieving .94 P, .92 R, and .93 F1 for Person tags.

During this project, I learned valuable lessons on the importance of using a confusion matrix while training classification models, the unique challenges of working on token-level tasks as opposed to sequence-level tasks, and the value of exploring related repos and other work prior to diving in.

If you’d like to take a look at the full repo, the code is here.

Acknowledgements

I’d like to thank my advisor Nina Lopatina for her invaluable perspective at various parts of this process, particularly her advice on prioritization, timeline, and setting performance goals. It is incredibly valuable to have someone you can turn to when deciding what is a must and what is a “nice-to-have.” I really enjoyed working on this project knowing that it would be directly useful to Nina’s team, which works on quality estimation for Russian-English machine translation models.

I’d also like to thank SharpestMinds, which offers a fantastic project-based data science mentorship via income-sharing agreements. SM makes it possible for experienced practitioners like Nina to offer financially sound mentorship agreements to ambitious folks who are getting their start in data science and machine learning.

What practical challenges have you experienced while working on deep learning projects? Let us know in the comments!

Lab41 is a Silicon Valley challenge lab where experts from the U.S. Intelligence Community (IC), academia, industry, and In-Q-Tel come together to gain a better understanding of how to work with — and ultimately use — big data.

Learn more at lab41.org and follow us on Twitter: @_lab41

--

--

Writer for

Machine Learning Engineer at Primer AI. I’m on Twitter @blissfulchar, and here’s my LinkedIn: https://www.linkedin.com/in/charlenechambliss/