Introduction
The Byte Pair Encoding (BPE) tokenizer is a popular method used in natural language processing (NLP) for subword tokenization. It is particularly effective for handling out-of-vocabulary words and reducing the vocabulary size, which is crucial for training large language models. In this article, we will explore the mechanics of BPE, its advantages, and its applications in NLP.
The BPE algorithm was popularized for LLMs by the GPT-2 paper and the associated GPT-2 code release from OpenAI. Sennrich et al. 2015 is cited as the original reference for the use of BPE in NLP applications. Today, all modern LLMs (e.g. GPT, Llama, Mistral) use this algorithm to train their tokenizers.
How BPE Works
The BPE algorithm operates by iteratively replacing the most frequent pair of bytes (or characters) in a dataset with a new token. This process continues until a predefined vocabulary size is reached or no more pairs can be merged. This tokenization method allows you to more densely represent the incoming text data, allowing more information to be captured in fewer tokens (and thus fit into a fixed finite-sized context window), at the expense of larger vocabulary sizes.
Here’s a step-by-step breakdown of how BPE works:
- Initialization: Start with a vocabulary that consists of all individual characters in the training corpus. Select a desired target vocabulary size.
Suppose the training data to be encoded is
aaabdaaabac
The initial vocabulary is {a, b, c, d}
, with indices {0, 1, 2, 3}
.
- Count Pairs: Count the frequency of all adjacent byte pairs in the data. For the example above, the pairs and their counts are:
aa: 4
ab: 2
ba: 1
ad: 1
da: 1
ac: 1
The byte pair “aa” occurs most often, so it will be replaced by a byte that is not used in the data, such as “Z”, and will be assigned the next available token index. Now there is the following data and replacement table:
ZabdZabac
-----------------
Z=aa
with the updated vocabulary {a, b, c, d, Z}
and indices {0, 1, 2, 3, 4}
.
Then the process is repeated with byte pair “ab”, replacing it with “Y”:
ZYdZYac
-----------------
Y=ab
Z=aa
The only literal byte pair left occurs only once, and the encoding might stop here. Alternatively, the process could continue with recursive byte pair encoding, replacing “ZY” with “X”:
XdXac
-----------------
X=ZY | idx: 6
Y=ab | idx: 5
Z=aa | idx: 4
This data cannot be compressed further by byte pair encoding because there are no pairs of bytes that occur more than once. The final vocabulary is {a, b, c, d, Z, Y, X}
with indices {0, 1, 2, 3, 4, 5, 6}
.
To decompress the data, simply perform the replacements in the reverse order.
Minimal BPE implementation in Python
In this section, we will implement a minimal version of the BPE algorithm in Python. The implementation will consist of two main functions: one for counting consecutive pairs and another for merging them.
Counting Consecutive Pairs
The first function, get_consecutive_pair_count
, takes a list of integers (representing tokens) and returns a dictionary of counts of consecutive pairs.
def get_consecutive_pair_count(
ids: list[int],
counts: dict[tuple[int, int], int] | None = None,
):
"""Given a list of integers, return a dictionary of counts of consecutive pairs.
Args:
ids: A list of integers.
counts: An optional dictionary to update with the counts.
Returns:
A dictionary with pairs of integers as keys and their counts as values.
Example:
[1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
Optionally allows to update an existing dictionary of counts
"""
counts = {} if counts is None else counts
for pair in zip(ids, ids[1:]):
counts[pair] = counts.get(pair, 0) + 1
return counts
Merging Consecutive Pairs
The second function, merge
, takes a list of integers, a pair of integers to merge, and a new integer token to replace the pair. It returns a new list with the merged values.
def merge(ids: list[int], pair: tuple[int, int], idx: int) -> list[int]:
"""In the list of ids, replace all consecutive occurrences of pair with the new integer token idx.
Args:
ids: A list of integers.
pair: A tuple of two integers to be merged.
idx: The integer to replace the pair.
Example:
ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
"""
merged_ids = []
i = 0
while i < len(ids):
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i + 1] == pair[1]:
merged_ids.append(idx)
i += 2
else:
merged_ids.append(ids[i])
i += 1
return merged_ids
Putting It All Together
Now we can create a simple class to encapsulate the BPE functionality, including the initialization, training, and encoding processes.
class BasicTokenizer:
def __init__(self):
# default: vocab size of 256 (all bytes), no merges, no patterns
self.merges = {} # (int, int) -> int
self.pattern = "" # str
self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
self.vocab = self._build_vocab() # int -> bytes
def _build_vocab(self) -> dict[int, bytes]:
"""Build the vocabulary."""
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in self.merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
for special_token, idx in self.special_tokens.items():
vocab[idx] = special_token.encode("utf-8")
return vocab
def train(self, text: str, vocab_size: int, verbose: bool = False):
"""Train the tokenizer on the provided text."""
# Tokenization logic goes here
assert vocab_size >= 256, "Vocabulary size must be at least 256."
num_merges = vocab_size - 256
# input text preprocessing
text_bytes = text.encode("utf-8")
ids = list(text_bytes) # Convert bytes to list of integers in range [0, 255]
# Iteratively merge the most common pairs to create new tokens
merges: dict[tuple[int, int], int] = {}
vocab = {idx: bytes([idx]) for idx in range(256)}
for i in range(num_merges):
# Get counts of consecutive pairs
pair_counts = get_consecutive_pair_count(ids)
# Find the most common pair
most_common_pair = max(pair_counts, key=pair_counts.get) # type: ignore
idx = 256 + i # New token index
ids = merge(ids, most_common_pair, idx)
# save the merge
merges[most_common_pair] = idx
vocab[idx] = vocab[most_common_pair[0]] + vocab[most_common_pair[1]]
if verbose:
print(
f"[{i+1} / {num_merges}] Merged {most_common_pair} -> {idx} with count {pair_counts[most_common_pair]}"
)
# save instance variables
self.merges = merges # used in encode()
self.vocab = vocab # used in decode()
def decode(self, ids: list[int]) -> str:
"""Decode a list of integers back into a string."""
# Convert the list of integers back to bytes
text_bytes = b"".join(self.vocab[idx] for idx in ids)
return text_bytes.decode("utf-8", errors="replace")
def encode(self, text):
# given a string text, return the token ids
text_bytes = text.encode("utf-8") # raw bytes
ids = list(text_bytes) # list of integers in range 0..255
while len(ids) >= 2:
# find the pair with the lowest merge index
stats = get_consecutive_pair_count(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
# subtle: if there are no more merges available, the key will
# result in an inf for every single pair, and the min will be
# just the first pair in the list, arbitrarily
# we can detect this terminating case by a membership check
if pair not in self.merges:
break # nothing else can be merged anymore
# otherwise let's merge the best pair (lowest merge index)
idx = self.merges[pair]
ids = merge(ids, pair, idx)
return ids