πŸ’­
Published on

Introduction to Transformers - Part 01 - Implementation of a translation model

Authors
  • avatar
    Name
    Jan Hardtke
    Twitter

Introduction

In the previous post, we explored the inner workings of the Transformer architecture, dissecting its components and understanding how it forms the backbone of state-of-the-art models like BERT, GPT, and T5. Now, it's time to turn theory into practice.

In this post, we will implement a full Transformer model from scratch in PyTorch and apply it to a translation task. Specifically, we will use the WMT14 English-German (en-de) dataset, a benchmark dataset widely used for machine translation tasks.

Dataset Preparation and Preprocessing

To train our Transformer model, we need a dataset and a way to preprocess it into a format suitable for the model. In this example, we use the WMT14 English-German (en-de) dataset from the Hugging Face datasets library. This dataset is a common benchmark for machine translation tasks.


Loading the Dataset

from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from datasets import load_dataset

dataset = load_dataset("wmt14", "de-en")

dataset['train'] = dataset['train'].select(range(10000))
dataset['test'] = dataset['test'].select(range(1000))
dataset['validation'] = dataset['validation'].select(range(1000))

dataset['train'] = dataset['train'].select(range(10000))
dataset['test'] = dataset['test'].select(range(1000))
dataset['validation'] = dataset['validation'].select(range(1000))

The dataset contains parallel sentences in German (de) and English (en), ideal for a translation model.

Initializing the Tokenizer

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
src_seq_len = 128
trg_seq_len = 128
batch_size = 256

def preprocess_function(examples):
    inputs = [item["de"] for item in examples["translation"]]
    targets = [item["en"] for item in examples["translation"]]
    labels = [item["en"] for item in examples["translation"]]

    model_inputs = tokenizer(
        inputs,
        padding="max_length",
        truncation=True,
        max_length=src_seq_len,
        add_special_tokens=False
    )
    model_inputs['src_mask'] = model_inputs["attention_mask"]

    with tokenizer.as_target_tokenizer():
        targets = tokenizer(
            targets,
            padding="max_length",
            truncation=True,
            max_length=trg_seq_len,
            add_special_tokens=True
        )
        labels = tokenizer(
            labels,
            padding="max_length",
            truncation=True,
            max_length=trg_seq_len,
            add_special_tokens=True
        )

    model_inputs["target"] = targets["input_ids"]
    model_inputs["target_mask"] = targets["attention_mask"]
    model_inputs["labels"] = labels["input_ids"]

    return model_inputs

Tokenization

The AutoTokenizer from Hugging Face is initialized with the bert-base-cased model. This tokenizer is used to tokenize both the source (German) and target (English) sentences in the dataset. The tokenization process involves splitting the text into smaller units (tokens) and mapping them to numerical IDs that the model can process.


Handling Input and Output Sequences

The preprocessing function performs the following key steps:

  1. Extracting Text Data:

    • German sentences are extracted as inputs (source sequences).
    • English sentences are extracted as targets (decoder inputs) and labels (ground truth for loss computation).
  2. Tokenizing Source Sequences:

    • German sentences are tokenized and padded or truncated to a fixed length defined by src_seq_len (128 tokens in this case).
    • Special tokens like <START> and <END> are not added for source sequences, as they are only required for the decoder.
  3. Creating a Source Mask:

    • The tokenizer generates an attention mask (src_mask) for the source sequences. This mask indicates valid tokens (1) and padded tokens (0), which the model uses to ignore padding during computations. This (src_mask) will be used in the attention layers of the encoder, such that real tokens won't attend to padding tokens.
  4. Tokenizing Target Sequences:

    • English sentences are tokenized twice:
      • Targets: Used as input to the decoder during training. Special tokens like <START> and <END> are added.
      • Labels: Used for loss computation. These represent the expected output tokens for the decoder.
  5. Creating a Target Mask:

    • The tokenizer generates an attention mask (target_mask) for the target sequences, similar to the source mask, to handle padding effectively. Here the (target_mask) serves the same purpose as the (src_mask) in the decoder, it will be used in the attention layers of the decoder, to mask out padding tokens.

The Translation Model

Below, we describe the translation model, which builds upon the core principles of the Transformer architecture. This model incorporates the fundamental components of the Transformer while effectively handling padding and the causal mask required for cross-attention. The architecture processes the input sequences by first preparing them with appropriate masks and then passing them through the encoder and decoder components.

class TranslationModel(nn.Module):
  def __init__(self,vocab_size,d_model,src_seq_len,trg_seq_len):
    super(TranslationModel, self).__init__()
    self.src_seq_len = src_seq_len
    self.trg_seq_len = trg_seq_len

    self.encoder = Encoder(d_model=d_model,)
    self.decoder = Decoder(d_model=d_model,)

    self.embedding_src = nn.Embedding(vocab_size,d_model)
    self.embedding_tgt = nn.Embedding(vocab_size,d_model)

    self.positional_encoding_src = nn.Embedding(src_seq_len,d_model)
    self.positional_encoding_tgt = nn.Embedding(trg_seq_len,d_model)

    self.src_pos = torch.arange(0,src_seq_len).unsqueeze(0).to(device)
    self.trg_pos = torch.arange(0,trg_seq_len).unsqueeze(0).to(device)

    self.ffn = nn.Linear(d_model,vocab_size)

  def forward(self,src,tgt,src_padding_mask,tgt_padding_mask):

    src_padding_mask = src_padding_mask.unsqueeze(1).repeat(1, self.src_seq_len, 1)
    tgt_padding_mask = tgt_padding_mask.unsqueeze(1).repeat(1, self.trg_seq_len, 1)
    tgt_causal_mask = torch.tril(torch.ones_like(tgt_padding_mask)).to(device)

    src_embed = self.embedding_src(src) + self.positional_encoding_src(self.src_pos)
    tgt_embed = self.embedding_tgt(tgt) + self.positional_encoding_tgt(self.trg_pos)

    enc_out = self.encoder(src_embed,src_padding_mask)
    dec_out = self.decoder(tgt = tgt_embed,
                           memory=enc_out,
                           tgt_padding_mask = tgt_padding_mask,
                           tgt_causal_mask = tgt_causal_mask,
                           memory_padding_mask = src_padding_mask)

    return self.ffn(dec_out)

As you can see, the first step in the model's design is the initialization of embedding layers, which hold the embeddings for the tokenized sequences. These embeddings represent each token as a dense vector in a high-dimensional space, providing the model with a numerical representation of the input.

In addition to the token embeddings, the model initializes two embedding layers for positional encoding. Unlike the fixed sinusoidal encoding discussed in the previous post, this implementation uses learnable positional encodings. This approach is more flexible, allowing the model to adapt the positional information to the specific dataset and task during training. Empirical evidence shows that learnable positional encodings often lead to better performance compared to fixed encodings.

In the _forward_ method, three masks are created to handle the input sequences. The src_padding_mask ensures the encoder ignores padding tokens in the source sequence. Similarly, the tgt_padding_mask prevents the decoder from processing padding tokens in the target sequence. The tgt_causal_mask enforces causality in the decoder, ensuring tokens only attend to earlier tokens during prediction.

  src_padding_mask = src_padding_mask.unsqueeze(1).repeat(1, self.src_seq_len, 1)
  tgt_padding_mask = tgt_padding_mask.unsqueeze(1).repeat(1, self.trg_seq_len, 1)
  tgt_causal_mask = torch.tril(torch.ones_like(tgt_padding_mask)).to(device)

What we do here is transform the src_padding_mask and tgt_padding_mask from 1D arrays into matrix masks. Since the tokenizer only creates 1D arrays, we need to expand them into a matrix form to use them effectively in the attention layer. This transformation is achieved by stacking each 1D array self.src_seq_len times, ensuring that padding tokens are ignored during attention calculations. The same process is applied to the tgt_padding_mask. For the cross-attention layer, we generate the causal mask. This mask enforces causality by ensuring that tokens can only attend to themselves or tokens that came before them.

After preparing the masks, we proceed to retrieve the embeddings for both the source and target sequences. These embeddings are combined with positional encodings to inject positional information into the token representations. This step ensures the model understands the order of tokens within the sequences. Once the source embeddings are ready, they are passed into the encoder, where the input sequence is processed and contextualized. The output of the encoder is then used as input to the decoder, which generates the output sequence. With this general understanding, we will now take a closer look at the architecture of the encoder.

Encoder

Below is the architecture of the Encoder, which is composed of multiple EncoderBlocks. The Encoder class itself is straightforwardβ€”it initializes nn EncoderBlocks and feeds the output of one block into the next in sequence.

class EncoderBlock(nn.Module):
  def __init__(self,d_model=256,heads=4,hidden=1024,dropout=0.1):
    super(EncoderBlock, self).__init__()
    self.multi_head_attention = MultiHeadAttention(heads=heads,d_model=d_model)
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.mlp = nn.Sequential(nn.Linear(d_model,hidden),
                             nn.GELU(),
                             nn.Linear(hidden,d_model,
                             nn.Dropout(dropout)
                             ))

  def forward(self,src,mask):
    att_output = self.multi_head_attention(src,src,src,mask)
    skip_con = att_output + src
    norm_output = self.norm1(skip_con)
    mlp_out = self.mlp(norm_output)
    return self.norm2(mlp_out + norm_output)


class Encoder(nn.Module):
  def __init__(self,num_layers=6,d_model=256,heads=4,hidden=1024):
      super(Encoder, self).__init__()
      self.layers = nn.ModuleList([EncoderBlock(d_model=d_model,heads=heads,hidden=hidden) for i in range(num_layers)])

  def forward(self,src,mask):
    for layer in self.layers:
      src = layer(src,mask)
    return src

The interesting part of the encoder lies in the EncoderBlock. Here, the source sequence is passed as Queries, Keys, and Values into the MultiHeadAttention block. Additionally, we pass the src_padding_mask to ensure that padding tokens are ignored during attention calculations. After computing the output of the MultiHeadAttention block, a skip connection is applied by adding the input back to the output. This result is then normalized using a LayerNorm to improve training stability. Next, the output is passed through a simple Feed-Forward Network (FFN), which includes another skip connection and a Dropout layer for regularization.

Multi-Head Attention

class MultiHeadAttention(nn.Module):
  def __init__(self, heads=4,d_model=256):
    super(MultiHeadAttention, self).__init__()
    self.heads = heads
    self.d_model = d_model
    self.W_q = nn.Linear(d_model,d_model)
    self.W_k = nn.Linear(d_model,d_model)
    self.W_v = nn.Linear(d_model,d_model)

    self.W_o = nn.Linear(d_model,d_model)

  def forward(self, q,k,v,mask):
    d_k = self.d_model // self.heads
    q,k,v = self.W_q(q),self.W_k(k),self.W_v(v)

    q = q.view(q.shape[0],q.shape[1],self.heads,-1).transpose(1,2) # (batch,h,seq_len,d_k)
    k = k.view(k.shape[0],k.shape[1],self.heads,-1).transpose(1,2) # (batch,h,seq_len,d_k)
    v = v.view(v.shape[0],v.shape[1],self.heads,-1).transpose(1,2) # (batch,h,seq_len,d_k)

    res = torch.matmul(q,k.transpose(-2,-1))  /  math.sqrt(d_k)

    if mask is not None:
      mask = mask.unsqueeze(1)
      res = torch.masked_fill(res,mask==0,float('-inf'))

    attention = nn.functional.softmax(res,dim=-1)

    # (batch,h,seq_len,d_k) -> (batch,seq_len,h,d_k) -> (batch,h,seq_len,d_model)
    output = torch.matmul(attention,v).transpose(1,2).contiguous().view(q.shape[0],-1,self.d_model)
    output = self.W_o(output)

    return output

This implementation of multi-head attention closely follows the formulation outlined in the previous post. last post.

    d_k = self.d_model // self.heads
    q,k,v = self.W_q(q),self.W_k(k),self.W_v(v)

First, the embedding dimension for each head is calculated by dividing the model's embedding dimension by the number of heads. This is done to split the input into multiple smaller attention heads, allowing the model to focus on different aspects of the input simultaneously. Next, learnable weight matrices are applied to the input sequences to project them into the Query (Q), Key (K), and Value (V) spaces.

    q = q.view(q.shape[0],q.shape[1],self.heads,-1).transpose(1,2) # (batch,h,seq_len,d_k)
    k = k.view(k.shape[0],k.shape[1],self.heads,-1).transpose(1,2) # (batch,h,seq_len,d_k)
    v = v.view(v.shape[0],v.shape[1],self.heads,-1).transpose(1,2) # (batch,h,seq_len,d_k)

Here we bring the tensor into the required shape, as they are of shape (batch,seq_len,dmodel)(batch,seq\_len,d_{\text{model}}) and should be of shape (batch,h,seq_len,dk)(batch,h,seq\_len,d_k), for multi head attention.

    res = torch.matmul(q,k.transpose(-2,-1))  /  math.sqrt(d_k)

    if mask is not None:
      mask = mask.unsqueeze(1)
      res = torch.masked_fill(res,mask==0,float('-inf'))

To compute the attention scores, we first calculate QKTdk\frac{QK^T}{\sqrt{d_k}}. We then apply the mask to the attention scores. Padding tokens are handled by setting their corresponding positions in the attention mask to βˆ’βˆž-\infty. This ensures that, when we apply the softmax function, those positions are effectively ignored. The formula for softmax is:

softmax(xi)=exiβˆ‘jexj\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j} e^{x_j}}

By setting the masked values to βˆ’βˆž-\infty, their exponential becomes:

eβˆ’βˆž=0e^{-\infty} = 0

This means the contributions of these positions to the numerator and denominator of the softmax equation are completely removed.

  attention = nn.functional.softmax(res,dim=-1)

  # (batch,h,seq_len,d_k) -> (batch,seq_len,h,d_k) -> (batch,h,seq_len,d_model)
  output = torch.matmul(attention,v).transpose(1,2).contiguous().view(q.shape[0],-1,self.d_model)
  output = self.W_o(output)

After applying the mask, the next step is to apply the softmax function to each row of the attention matrix. Once the attention scores are multiplied with VV, we handle the multiple heads by concatenating them. The tensor is reshaped from (batch,h,seq_len,dk)(\text{batch}, h, \text{seq\_len}, d_k) back to (batch,seq_len,dmodel)(\text{batch}, \text{seq\_len}, d_{\text{model}}), restoring the original embedding dimension.

Finally, the concatenated output is passed through the last weight matrix. This produces the final output of the attention layer, concluding its computations.

Decoder

Now that we have an understanding of the encoder, we turn our attention to the decoder. The decoder shares many similarities with the encoder but introduces two key distinctions. The first is the application of the causal mask in addition to the padding mask. The second is the inclusion of the cross-attention layer. This layer allows the decoder to attend to the encoder's outputs,


class DecoderBlock(nn.Module):
  def __init__(self,d_model=256,heads=4,hidden=1024,dropout=0.1):
    super(DecoderBlock, self).__init__()
    self.masked_multi_head_attention = MultiHeadAttention(heads=heads,d_model=d_model)
    self.cross_multi_head_attention = MultiHeadAttention(heads=heads,d_model=d_model)

    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.norm3 = nn.LayerNorm(d_model)

    self.mlp = nn.Sequential(nn.Linear(d_model,hidden),
                             nn.GELU(),
                             nn.Linear(hidden,d_model),
                             nn.Dropout(dropout))

  def forward(self,tgt,memory,tgt_padding_mask,tgt_causal_mask,memory_padding_mask):
    tgt_padding_mask = (tgt_causal_mask & tgt_padding_mask).int()

    att_output = self.masked_multi_head_attention(tgt,tgt,tgt,tgt_padding_mask)
    norm_output = self.norm1(att_output + tgt )
    cross_attn_output = self.cross_multi_head_attention(norm_output,memory,memory,memory_padding_mask)
    norm_output = self.norm2(cross_attn_output + norm_output)
    mlp_out = self.mlp(norm_output)
    return self.norm3(mlp_out + norm_output)

class Decoder(nn.Module):
  def __init__(self,num_layers=6,d_model=256,heads=4,hidden=1024):
      super(Decoder, self).__init__()
      self.layers = nn.ModuleList([DecoderBlock(d_model=d_model,heads=heads,hidden=hidden) for i in range(num_layers)])

  def forward(self,tgt,memory,tgt_padding_mask,tgt_causal_mask,memory_padding_mask):
    for layer in self.layers:
      tgt = layer(tgt,memory,tgt_padding_mask,tgt_causal_mask,memory_padding_mask)
    return tgt

The Decoder class assembles multiple DecoderBlocks and passes them the respective masks required for the attention mechanisms. Along with the target padding mask and causal mask, the decoder requires the memory (encoder output) and memory_padding_mask (source padding mask). In the DecoderBlock, the target padding mask is combined with the causal mask to ensure that tokens in the target sequence cannot attend to padding tokens or tokens that come after them in the sequence. This combined mask is applied to the target tokens, which are then fed into the first attention block. he output of this block is processed using a skip connection, followed by a LayerNorm to stabilize training. Next, the cross-attention is computed. In this step, the target tokens (queries) attend to the encoder outputs (keys) and aggregate information from the encoder outputs (values) based on the attention mechanism. To ensure proper operation, the memory_padding_mask is applied so that padding tokens in the encoder outputs are ignored. After that the architecture of the decoder is same as the encoders.

Training loop


warmup = 10
lr = 2.5e-4
min_lr = 1e-6

class InverseSquareRootLR(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_steps, init_lr, min_lr=1e-9, last_epoch=-1):
        self.warmup_steps = warmup_steps
        self.init_lr = init_lr
        self.min_lr = min_lr
        super(InverseSquareRootLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        step = self.last_epoch + 1
        if step < self.warmup_steps:
            lr = self.init_lr * (step / self.warmup_steps)
        else:
            lr = self.init_lr * math.sqrt(self.warmup_steps / step)

        lr = max(lr, self.min_lr)

        return [lr for _ in self.base_lrs]

epochs = 100
model = TranslationModel(vocab_size=tokenizer.vocab_size,d_model=512,src_seq_len=src_seq_len,trg_seq_len=trg_seq_len)
model.to(device)
optim = torch.optim.AdamW(model.parameters(),lr=lr,betas=(0.9, 0.98))
lr_scheduler = InverseSquareRootLR(optim,warmup,lr,min_lr=min_lr)
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
for epoch in range(epochs):
    for batch in train_dataloader:
        src = batch["input_ids"].to(device)
        src_attention_mask = batch["src_mask"].to(device)
        tgt = batch["target"].to(device)
        tgt_attention_mask = batch["target_mask"].to(device)

        labels = batch["labels"].to(device)[:,1:]

        pad_tokens = torch.full((labels.shape[0], 1), tokenizer.pad_token_id, dtype=tgt.dtype).to(device)
        labels = torch.cat((labels, pad_tokens), dim=-1)
        output = model(src,tgt,src_attention_mask,tgt_attention_mask)

        output = output.view(-1,output.shape[-1])
        labels = labels.view(-1)

        loss = loss_fn(output,labels)
        optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optim.step()
        lr_scheduler.step()

    print(loss)

Now lets first look at the model initialization code

epochs = 100
model = TranslationModel(vocab_size=tokenizer.vocab_size,d_model=512,src_seq_len=src_seq_len,trg_seq_len=trg_seq_len)
model.to(device)
optim = torch.optim.AdamW(model.parameters(),lr=lr,betas=(0.9, 0.98))
lr_scheduler = InverseSquareRootLR(optim,warmup,lr,min_lr=min_lr)
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

In this step, we initialize the model, set up the optimizer, and define the loss function. The loss function used is CrossEntropyLoss, which is commonly employed for sequence-to-sequence tasks. CrossEntropyLoss is mathematically defined as:

CrossEntropyLoss(y,y^)=βˆ’1Nβˆ‘i=1Nβˆ‘j=1Cyijlog⁑(y^ij)\text{CrossEntropyLoss}(y, \hat{y}) = - \frac{1}{N} \sum_{i=1}^N \sum_{j=1}^C y_{ij} \log(\hat{y}_{ij})

Here:

  • yy represents the true distribution (one-hot encoded ground truth).
  • y^\hat{y} represents the predicted probability distribution from the model.
  • NN is the number of samples in the batch.
  • CC is the number of classes (vocabulary size in this case).

This loss function calculates the negative log likelihood of the correct class, weighted by the true distribution. It is averaged over all tokens in the batch, penalizing predictions that deviate from the true class probabilities. The ignore_index parameter tells the CrossEntropyLoss to set probabilities y^ij\hat{y}_{ij} to 0, where j=pad_token_idj= \text{pad\_token\_id}. This will exclude the padding values in the loss computation.

Next, we adjust the shape of the label tokens to ensure they align correctly with the target tokens. Both the target tokens and the label tokens currently contain <START> and <END> tokens. However, for training, the labels need to be shifted one token to the left.

This means the <START> token is removed from the labels, and a <PAD> token is added at the end to maintain the correct sequence length. This shift ensures that each label token corresponds to the prediction of the decoder for the previous target token.

src = batch["input_ids"].to(device)
src_attention_mask = batch["src_mask"].to(device)
tgt = batch["target"].to(device)
tgt_attention_mask = batch["target_mask"].to(device)
labels = batch["labels"].to(device)[:,1:]

pad_tokens = torch.full((labels.shape[0], 1), tokenizer.pad_token_id, dtype=tgt.dtype).to(device)
labels = torch.cat((labels, pad_tokens), dim=-1)

We then feed in these inputs into the model.

output = model(src,tgt,src_attention_mask,tgt_attention_mask)

output = output.view(-1,output.shape[-1])
labels = labels.view(-1)

loss = loss_fn(output,labels)

After processing, the model produces an output tensor of shape (batch,seq_len,vocab_size)(\text{batch}, \text{seq\_len}, \text{vocab\_size}). This tensor contains the prediction vectors for the entire vocabulary at each position in the sequence. These vectors do not represent probabilities yet, as the CrossEntropyLoss function internally applies a softmax operation to convert them into probabilities. For the CrossEntropyLoss to work with a batched case, it expects a 2D tensor as input. To satisfy this requirement, we reshape the output tensor into the shape (batchΓ—seq_len,vocab_size)(\text{batch} \times \text{seq\_len}, \text{vocab\_size}), effectively flattening the batch and sequence dimensions while keeping the vocabulary dimension intact.

Similarly, the label tensor is reshaped to have the shape (batchΓ—seq_len)(\text{batch} \times \text{seq\_len}). This ensures that each label aligns correctly with the corresponding prediction vector for the loss calculation.

Now lettings this run will train our model to translate german into english sentences. With this current configuration the model contains roughly 75M parameters. In the next post, we will explore how to adjust this architecture to recreate the well-known Llama 3 foundational language model from Meta. While the current implementation of our Transformer adheres closely to its vanilla formulation, Llama 3 represents a state-of-the-art large language model (LLM) that builds on and extends the original Transformer design. We will delve into the architectural changes and optimizations that have been introduced since the original Transformer, examining how these innovations contribute to the exceptional performance of a model like Llama 3.