tokenizers
Add a visualization utility to render tokens and annotations in a notebook
#508
Merged

Add a visualization utility to render tokens and annotations in a notebook #508

n1t0 merged 14 commits into huggingface:master from feat/visualizer
talolard
talolard4 years ago🎉 2❤ 2

This follows on the discussion we had here.

What It Does

User can get a visualization of the tokenized output with annotations

from tokenizers import EncodingVisualizer
from tokenizers.viz.viztypes import Annotation
from tokenizers import BertWordPieceTokenizer
tokenizer = BertWordPieceTokenizer("/tmp/bert-base-uncased-vocab.txt", lowercase=True)
visualizer = EncodingVisualizer(tokenizer=tokenizer,default_to_notebook=True)
annotations =[...]
visualizer(text,annotations=annotations)

image

Cool Features

  • [ ✔️] Automatically aligns annotations with tokens
  • [ ✔️] Supports annotations in any format (through a converter parameter)
  • [ ✔️] Renders UNK tokens
  • [ ✔️] Preserves white space
  • [ ✔️] Easy to use
  • [ ✔️] Hover over an annotation to clearly see its borders
  • [ ✔️] Clear distinction of adjacent tokens with alternating shades of grey

Missing Stuff

  • [ 💩 ] Tests - I need some guidance how to write tests for this lib
  • [ 💩] Code Review - Is this easy to understand and maintain ?
  • [ 💩] Right To Left Support - Some edge cases still remain with RTL
  • [ 💩] Special Tokens - Doesn't currently render CLS, SEP etc. Need some help with that.
  • [ 🐛] Bug - When an annotation crosses part of a long UNK token like
    image it renders multiple UNK tags on top.

Notebook

It has a notebook in the examples folder

talolard
talolard4 years ago

@n1t0 can you give me some guidance on what's wrong with the docs build ?

n1t0
n1t04 years ago

I should be able to have a deeper look today! Will let you know

n1t0
n1t0 commented on 2020-11-09
n1t04 years ago

Thank you @talolard, this is really nice and clean, I love it!

I'm not entirely sure about the namespacing (tokenizers.viz vs tokenizers.notebooks or tokenizers.tools) and will need to think about it, but that's a detail.

For the error in the CI about the documentation, I think it is because we need to modify the setup.py to have it include what's necessary when we run python setup.py install|develop. I was having the same problem locally when trying to do import tokenizers in a Python shell.

Another little detail: we actually use Google-style for docstrings. If you're not familiar with this syntax, don't worry I'll take care of it. We can also include everything in the API Reference in the Sphinx docs.

Last thing that we'll need to check is if everything works as expected when using a tokenizer like the one from GPT-2 or Roberta. Since they use a byte-level technique, we can have multiple tokens that have overlapping spans over the input, for example with emojis or other Unicode characters that don't have their own token.

Conversation is marked as resolved
Show resolved
bindings/python/py_src/tokenizers/viz/visualizer.py
14 def __init__(
15 self,
16 tokenizer: Tokenizer,
17
default_to_notebook: bool = False,
n1t04 years ago

What do you think about always trying to display in a notebook if IPython is available, and always returning the content? That way it just displays when possible, without the need to change this value.

talolard4 years ago

I think that's the better default, will implement.

Conversation is marked as resolved
Show resolved
bindings/python/py_src/tokenizers/viz/visualizer.py
128 else:
129 # Like above, but a different color so we can see the tokens alternate
130 css_classes.append("even-token")
131
if encoding.tokens[first.token_ix] == "[UNK]":
n1t04 years ago

This won't work with a lot of tokenizers that do not use [UNK] as their unknown token. Unfortunately there is no easy way at the moment to get the unk token directly from the tokenizer, but maybe we can use something a bit more large. Maybe something like the regex /^(.{1}\b)?unk(\b.{1})?$/i would work for now (https://regex101.com/r/zh0He9/1/)

talolard4 years ago

Done

    unk_token_regex = re.compile('(.{1}\b)?unk(\b.{1})?',flags=re.IGNORECASE)

I took off the ^ and $, instead of the flag regexone applied

Conversation is marked as resolved
Show resolved
bindings/python/py_src/tokenizers/viz/visualizer.py
208 )
209 )
210 res = HTMLBody(spans) # Send the list of spans to the body of our html
211
with open("/tmp/out.html", "w") as f:
n1t04 years ago

I think this is a leftover and not meant to be kept 😄

Conversation is marked as resolved
Show resolved
bindings/python/py_src/tokenizers/viz/visualizer.py
241 """
242 word_map: PartialIntList = [None] * len(text)
243 token_map: PartialIntList = [None] * len(text)
244
for token_ix, word_ix in enumerate(encoding.words):
n1t04 years ago

I think you might be able to directly use char_to_token and char_to_word to populate word_map and token_map here. Something like this (not tested):

word_map = [encoding.char_to_word(c) for c in range(len(text))]
token_map = [encoding.char_to_token(c) for c in range(len(text))]
Conversation is marked as resolved
Show resolved
bindings/python/examples/using_the_visualizer.ipynb
n1t04 years ago

We might want to add a wget command that retrieves this file so that users can easily try the notebook without doing it manually.

Conversation is marked as resolved
Show resolved
.idea/.gitignore
1
# Default ignored files
n1t04 years ago

Can you remove all the .idea/* files?

Conversation is marked as resolved
Show resolved
bindings/python/py_src/tokenizers/__init__.pyi
55from .processors import *
66from .trainers import *
7
7
from .viz.visualizer import EncodingVisualizer
n1t04 years ago

I don't think we need this in the .pyi

talolard
talolard4 years ago

Last thing that we'll need to check is if everything works as expected when using a tokenizer like the one from GPT-2 or Roberta. Since they use a byte-level technique, we can have multiple tokens that have overlapping spans over the input, for example with emojis or other Unicode characters that don't have their own token.

I don't know about BPE enough to think of a test case. My poo emoji is a surrogate pair, but I guess the test needs to be on a surrogate pair that isn't in the vocab ? Any ideas ?

I also tried 'Z͑ͫ̓ͪ̂ͫ̽͏̴̙̤̞͉͚̯̞̠͍A̴̵̜̰͔ͫ͗͢L̠ͨͧͩ͘G̴̻͈͍͔̹̑͗̎̅͛́Ǫ̵̹̻̝̳͂̌̌͘!͖̬̰̙̗̿̋ͥͥ̂ͣ̐́́͜͞' and got back
image
which I can't tell if it's good or bad given that len(len("Z͑ͫ̓ͪ̂ͫ̽͏̴̙̤̞͉͚̯̞̠͍A̴̵̜̰͔ͫ͗͢L̠ͨͧͩ͘G̴̻͈͍͔̹̑͗̎̅͛́Ǫ̵̹̻̝̳͂̌̌͘!͖̬̰̙̗̿̋ͥͥ̂ͣ̐́́͜͞'") = 76

Can you think of an example to test ?
image

talolard talolard requested a review from n1t0 n1t0 4 years ago
talolard
talolard4 years ago

I'm not entirely sure about the namespacing (tokenizers.viz vs tokenizers.notebooks or tokenizers.tools) and will need to think about it, but that's a detail.

My inclination would be towards tools, or viz but I have no conviction.

For the error in the CI about the documentation, I think it is because we need to modify the setup.py to have it include what's necessary when we run python setup.py install|develop. I was having the same problem locally when trying to do import tokenizers in a Python shell.

Could you handle that, I'm not sure what to do. When I run setup.py develop it works, presumably because of something I don't understand.

Another little detail: we actually use Google-style for docstrings. If you're not familiar with this syntax, don't worry I'll take care of it. We can also include everything in the API Reference in the Sphinx docs.

I made an attempt to use Google style docstrings. There were some places where I wasn't sure about how to write out the typings. If you comment on things to fix I'll learn and fix.

Last thing that we'll need to check is if everything works as expected when using a tokenizer like the one from GPT-2 or Roberta.

I added something to the notebook, but per my comment above, not sure exactly what to test for.

n1t0
n1t04 years ago

The current version of the input text is great I think to check if it works as expected. I just ran some tests with a last cell with the following code:

encoding = roberta_tokenizer.encode(text)
[(token, offset, text[offset[0]:offset[1]]) for (token, offset) in zip(encoding.tokens, encoding.offsets)]

which gives this kind of output:

[
...,
 ('Ġadd', (212, 216), ' add'),
 ('Ġa', (216, 218), ' a'),
 ('Ġunit', (218, 223), ' unit'),
 ('Ġtest', (223, 228), ' test'),
 ('Ġthat', (228, 233), ' that'),
 ('Ġcontains', (233, 242), ' contains'),
 ('Ġa', (242, 244), ' a'),
 ('Ġpile', (244, 249), ' pile'),
 ('Ġof', (249, 252), ' of'),
 ('Ġpo', (252, 255), ' po'),
 ('o', (255, 256), 'o'),
 ('Ġ(', (256, 258), ' ('),
 ('ðŁ', (258, 259), '💩'),
 ('Ĵ', (258, 259), '💩'),
 ('©', (258, 259), '💩'),
 (')', (259, 260), ')'),
 ('Ġin', (260, 263), ' in'),
 ('Ġa', (263, 265), ' a'),
 ('Ġstring', (265, 272), ' string'),
...,
 ('©', (280, 281), '💩'),
 ('ðŁ', (281, 282), '💩'),
 ('Ĵ', (281, 282), '💩'),
 ('©', (281, 282), '💩'),
 ('ðŁ', (282, 283), '💩'),
 ('Ĵ', (282, 283), '💩'),
 ('©', (282, 283), '💩'),
 ('ðŁ', (283, 284), '💩'),
 ('Ĵ', (283, 284), '💩'),
 ('©', (283, 284), '💩'),
 ('ðŁ', (284, 285), '💩'),
 ('Ĵ', (284, 285), '💩'),
 ('©', (284, 285), '💩'),
 ('ðŁ', (285, 286), '💩'),
 ('Ĵ', (285, 286), '💩'),
 ('©', (285, 286), '💩'),
 ('Ġand', (286, 290), ' and'),
 ('Ġsee', (290, 294), ' see'),
 ('Ġif', (294, 297), ' if'),
 ('Ġanything', (297, 306), ' anything'),
 ('Ġbreaks', (306, 313), ' breaks'),
...,
]

As you can see, there are actually a lot of tokens that we can't see because they are representing sub-parts of an actual Unicode code point.

I'd be curious to see what it looks like with a ByteLevelBPETokenizer trained on a language like Hebrew and see if the visualization actually makes sense in this case. Is this something you'd like to try?

Could you handle that, I'm not sure what to do. When I run setup.py develop it works, presumably because of something I don't understand.

Sure don't worry, I'll take care of anything left related to the integration!

talolard
talolard4 years ago

I'd be curious to see what it looks like with a ByteLevelBPETokenizer trained on a language like Hebrew and see if the visualization actually makes sense in this case. Is this something you'd like to try?

I actually did that and took it out because it was too much text. What do you think of a "gallery" notebook with a mix of langauges and tokenizers?

n1t0
n1t04 years ago

Sure! That'd be a great way to check that everything works as expected.

talolard
talolard4 years ago

Sure! That'd be a great way to check that everything works as expected.

I added a notebook with some examples in different languages.

@n1t0 I actually noticed something strange. When I use the BPE tokenizer, whitespaces are included in the following token. I'm not sure if that's how it's supposed to be, or a bug in my code or in the tokenizers. Could you take a look at these pics and give guidance ?

image

n1t0
n1t04 years ago

Thank you @talolard!

This is actually expected yes. The byte-level BPE also encodes the whitespace because it then allows it to decode back to the original sentence. From your pictures and the various examples in the notebooks, I think everything looks as expected for English (and probably other latin languages).

My current concern is about the other languages. I just tried checking the generated tokens with byte-level BPE using the example in Hebrew, and they don't seem to match the visualization at all.

>>> encoding = roberta_tokenizer.encode(texts["Hebrew"])
>>> [(token, offset, texts["Hebrew"][offset[0]:offset[1]]) for (token, offset) in zip(encoding.tokens, encoding.offsets)]
[('×ij', (0, 1), 'ב'),
 ('×', (1, 2), 'נ'),
 ('ł', (1, 2), 'נ'),
 ('×Ļ', (2, 3), 'י'),
 ('Ġ×', (3, 5), ' א'),
 ('IJ', (4, 5), 'א'),
 ('×', (5, 6), 'ד'),
 ('ĵ', (5, 6), 'ד'),
 ('×', (6, 7), 'ם'),
 ('Ŀ', (6, 7), 'ם'),
 ('Ġ×', (7, 9), ' ז'),
 ('ĸ', (8, 9), 'ז'),
 ('×ķ', (9, 10), 'ו'),
 ('ר', (10, 11), 'ר'),
 ('×Ļ×', (11, 13), 'ים'),
 ('Ŀ', (12, 13), 'ם'),
 ('Ġ×', (13, 15), ' ב'),
 ('ij', (14, 15), 'ב'),
 ('×', (15, 16), 'כ'),
 ('Ľ', (15, 16), 'כ'),
 ('׾', (16, 17), 'ל'),
 ('Ġ×', (17, 19), ' י'),
 ('Ļ', (18, 19), 'י'),
 ('×ķ', (19, 20), 'ו'),
 ('×', (20, 21), 'ם'),
 ('Ŀ', (20, 21), 'ם'),
 ('Ġ×', (21, 23), ' ל'),
 ('ľ', (22, 23), 'ל'),
 ('ר', (23, 24), 'ר'),
 ('×ķ', (24, 25), 'ו'),
 ('×', (25, 26), 'ח'),
 ('Ĺ', (25, 26), 'ח'),
 (',', (26, 27), ','),
...
]

As you can see, most tokens represent one character at most, and often there are two tokens for one character. Yet, in the visualization, it appears to be long tokens, which seems wrong.

I think the roberta byte-level BPE has been trained only on English and is not capable of tokenizing Hebrew correctly, that's why I wanted to see the result with a tokenizer trained specifically for Hebrew. I'd like to check if, as I expect it, in this case the result would look accurate. Does it make sense?

talolard
talolard4 years ago (edited 4 years ago)

Nice catch.
I think you warned me about this before I started and I didn't understand. Also this is tricky...

What seems to be happening is that some characters are assigned to two tokens. I didn't anticipate that so the code just takes the last token of the pair. The coloring of tokens is done by alternating even and odd tokens so we end up skipping an even.
e.g. the two chars בנ
image
are tokenized as three tokens
image

And so we skip an even token in the css
image

.

I can color "multi token single chars" uniqugly and set a hover state that shows all the tokens assigned to a char which should clear that up visually, (thought it's confusing AF and would probably surprise users)

But, there's something another wierd thing I need your input on:

image

These two overlap, e.g. the char at position 4 א is assigned to two different tokens. Is that expected ?

n1t0
n1t04 years ago

I can color "multi token single chars" uniqugly and set a hover state that shows all the tokens assigned to a char which should clear that up visually, (thought it's confusing AF and would probably surprise users)

I honestly don't know how we should handle this. I expect the BPE algorithm to learn the most common tokens without having any overlap in most cases, and small overlaps with rarely seen tokens that end up being decomposed, but I'm not sure at all. Maybe just making sure it alternates between the two shades of grey, while excluding any "multi token single chars" from the next tokens could be enough.

But, there's something another wierd thing I need your input on:
image
These two overlap, e.g. the char at position 4 א is assigned to two different tokens. Is that expected ?

Yes, in this case, the Ġ at the beginning of the first token represents a whitespace. So the first of these tokens is composed of the whitespace and a part of the character א (two characters in the span 3, 5), while the second one is just another part of the character א (with 4, 5). This is hard to visualize here because the spans are expressed in terms of Unicode code-point while each token has a unique span of non-overlapping bytes.
For example, we often see × as the first part of some characters, and I guess this byte actually represents the offset of the Unicode block reserved for Hebrew, while the second part is the byte that represents the actual character in this block. I expect these to get merged when a tokenizer is trained with Hebrew in the first place.

Maybe this post can help understand how the byte-level works: #203 (comment)

talolard
talolard4 years ago❤ 1

OK this requires actual thinking. I'll tinker with it on the weekend and come back with something

talolard
talolard4 years ago

Sooo....
I set it up to track how many tokens "participate" in each charachter and visualize "multi token chars" differently + add a tooltip on hover.
image

I think it solves for the case in Hebrew you pointed to, e.g. the second and third chars reflect that they are in different tokens
image

My confidence with BPE is < 100% so I'm not sure this covers it all. What do you think?

cceyda
cceyda4 years ago

How would this compare to spacy's displacy? a while back I have done something to visualize (token classification) outputs that way. Something similar can be done just for tokens + annotations (just gotta write the huggingface->spacy align|formatter)
BTW nice work on LightTag

talolard
talolard4 years ago

How would this compare to spacy's displacy? a while back I have done something to visualize (token classification) outputs that way. Something similar can be done just for tokens + annotations (just gotta write the huggingface->spacy align|formatter)
BTW nice work on LightTag

Thanks!
I think displacy and this solve a similar pain, and are optimized for different tokenizations/libraries. I think the upcoming spacy3 has a strong focus on transformers, so integrating would probably make sense. I'll try to pr that once this goes in

n1t0
n1t04 years ago

Thank you @talolard that looks great!
I'll try to have everything ready to merge this today!

talolard Draft functionality of visualization
108eee44
talolard Added comments to make code more intelligble
b8718e20
talolard polish the styles
87b97ad1
talolard Ensure colors are stable and comment the css
1bd70442
talolard Code clean up
f67df485
talolard Made visualizer importable and added some docs
5e86f182
talolard Fix styling
3724313c
talolard implement comments from PR
6e343380
talolard Fixed the regex for UNK tokens and examples in notebook
788dcffc
talolard Converted docs to google format
1471bc1d
talolard Added a notebook showing multiple languages and tokenizers
c1d4359e
talolard Added visual indication of chars that are tokenized with >1 token
ecc3c48f
n1t0
n1t04 years ago

@talolard I think I don't have the authorization to push on the branch used for this PR. Maybe you disabled the option while opening the PR?

talolard
talolard4 years ago

@talolard I think I don't have the authorization to push on the branch used for this PR. Maybe you disabled the option while opening the PR?

Fixed

n1t0 Reorganize things a bit and fix import
6d12147d
n1t0 Update docs
7692a670
n1t0
n1t0 approved these changes on 2020-12-04
n1t04 years ago

This is now ready to be merged! Sorry it took me so long to finalize it, I was a bit overwhelmed with things left to do last week, and was off this week.

Here is a summary of the little things I changed:

  • Everything now leaves in a single file, under tools. So in order to import the visualizer and the annotations we can do:
from tokenizer.tools import EncodingVisualizer, Annotation
  • Updated the setup.py file to help it package the lib with the newly added files.
  • Updated a bit the docstrings, and included everything in the API Reference part of the docs.
  • I finally removed the language gallery. This notebook has been a great help in debugging what was happening with the various languages, but I fear that it might be misleading for the end-user. BERT and Roberta are both trained on English and so it does not represent the end result that a tokenizer trained on each specific language would produce.

Thanks again @talolard, this is a really great addition to the library and will be very helpful in understanding the tokenization. It will be included in the next release!

n1t0 n1t0 merged 8916b6bb into master 4 years ago
talolard
talolard4 years ago

Yay!!

talolard talolard deleted the feat/visualizer branch 4 years ago

Login to write a write a comment.

Login via GitHub

Reviewers
Assignees
No one assigned
Labels
Milestone