💭
Published on

Introduction to Transformers - Part 02 - Llama 3 - Building a LLM from scratch

Authors
  • avatar
    Name
    Jan Hardtke
    Twitter

Meta has recently announced the release of LLaMA 3.3, the latest iteration in their state-of-the-art open-source foundational large language models (LLMs).

Description of the image

This new version, particularly the LLaMA 3.3 70B model, demonstrates remarkable performance that either matches or surpasses the capabilities of OpenAI's closed-source GPT-4o across a wide range of benchmarks. LLaMA 3 represents a significant leap in large language model development, building upon Meta's ongoing commitment to open-source AI innovation. In this blog post, we will look at the architectural advancements and optimizations that set LLaMA 3 apart from earlier models and the original transformer framework.

Description of the image

Image source

Prerequisites

Before diving into the model's code, we will first explore the most significant architectural innovation introduced in Llama 3.

Rotary positional Encodings (RoPE)

Rotary Position Embeddings (RoPE) are an advanced method for encoding positional information in transformer models, designed to enhance their ability to process sequential data. Unlike fixed or learned positional embeddings, RoPE introduces a rotational transformation in the query and key vectors of the self-attention mechanism. This approach embeds positional relationships directly into the angular components of these vectors, allowing the model to generalize better across different sequence lengths and maintain relative positional information.

Given the vectors qmq_m and knk_n, which represent the embedded vectors of a word at positions mm and nn in a sequence, the self-attention mechanism between these tokens at different positions is defined by qmknq_m^\top k_n or qmknq_m k_n^\top, depending on whether the configuration is row-wise or column-wise.

To incorporate relative positional information, we aim to express the inner product of the query vector qmq_m and the key vector knk_n as a function gg. This function should depend solely on the word embeddings xmx_m and xnx_n, as well as their relative position mnm - n. In other words, the inner product should encode positional information solely in a relative manner:

fq(xm,m),fk(xn,n)=g(xm,xn,mn).\langle f_q(x_m, m), f_k(x_n, n) \rangle = g(x_m, x_n, m - n).

The primary objective is to devise an encoding mechanism such that the functions fq(xm,m)f_q(x_m, m) and fk(xn,n)f_k(x_n, n) align with and satisfy this relationship.

This is illustrated in the exemplary figure below. The later the word "dog" appears in the sentence, the more it is rotated. However, the most important observation is that the rotation angle between two word vectors (e.g., between "pig" and "dog") remains consistent if their relative distance stays the same, regardless of their absolute position or the length of the sentence. This is what the existence of such a function gg implies.

Description of the image

Image source

To derive RoPE, we will focus on the case where xqx_q and xkx_k are 2-dimensional. We define their positionally encoded counterparts as:

qm=fq(xq,m),kn=fk(xk,n),q_m = f_q(x_q, m), \\ k_n = f_k(x_k, n),

where mm and nn denote their positions, such that:

qmkn=fq(xm,m),fk(xn,n)=g(xm,xn,mn).q_m^\top k_n = \langle f_q(x_m, m), f_k(x_n, n) \rangle = g(x_m, x_n, m - n).

Additionally, we define the initial conditions for m=n=0m = n = 0 as:

q=fq(xq,0),k=fk(xk,0).q = f_q(x_q, 0), \\ k = f_k(x_k, 0).

To find solutions for fqf_q and fkf_k, we note that our vectors xqx_q and xkx_k are 2-dimensional. This allows us to represent them as a complex number z(q,k)z_{(q,k)} using Euler's formula:

eiϕ=cos(ϕ)+isin(ϕ).e^{i\phi} = \cos(\phi) + i\sin(\phi).

Using this representation, we can define fqf_q, fkf_k, and gg as follows:

fq(xq,m)=Rq(xq,m)eiΘ(xq,m)=qm,fk(xk,n)=Rk(xk,n)eiΘ(xk,n)=kn,g(xq,xk,mn)=Rg(xq,xk,mn)eiΘ(xq,xk,mn).f_q(x_q, m) = R_q(x_q, m)e^{i\Theta(x_q, m)} = q_m, \\ f_k(x_k, n) = R_k(x_k, n)e^{i\Theta(x_k, n)} = k_n, \\ g(x_q, x_k, m-n) = R_g(x_q, x_k, m-n)e^{i\Theta(x_q, x_k, m-n)}.

Here, RR represents the magnitude (scaling factor) and Θ\Theta represents the phase (angle of rotation) for the respective functions. This complex representation elegantly encodes both magnitude and phase, enabling us to incorporate relative positional information. As we want to satisfy fq(xm,m),fk(xn,n)=g(xm,xn,mn).\langle f_q(x_m, m), f_k(x_n, n) \rangle = g(x_m, x_n, m - n)., we can derive the following eqautions for the radial and phase functions RR and Θ\Theta:

Rq(xq,m)Rk(xk,n)=Rg(xq,xk,nm)Θk(xk,n)Θq(xq,m)=Θg(xq,xk,nm)R_q(x_q, m)R_k(x_k, n) = R_g(x_q, x_k, n - m) \\ \Theta_k(x_k, n) - \Theta_q(x_q, m) = \Theta_g(x_q, x_k, n - m)

The derivation of the first one is indeed trival. The Θ\Theta equality can be derived as follows:

fq(xm,m),fk(xn,n)=(Rq(xq,m)eiΘ(xq,m))Rk(xk,n)eiΘ(xk,n)=Rq(xq,m)eiΘ(xq,m)Rk(xk,n)eiΘ(xk,n)=Rq(xq,m)Rk(xk,n)ei(Θ(xk,n)Θ(xq,m))=Rg(xq,xk,mn)eiΘ(xq,xk,mn)=g(xq,xk,mn).\begin{aligned} \langle f_q(x_m, m), f_k(x_n, n) \rangle &= (R_q(x_q, m)e^{i\Theta(x_q, m)})^{*}R_k(x_k, n)e^{i\Theta(x_k, n)} \\ &= R_q(x_q, m)e^{-i\Theta(x_q, m)}R_k(x_k, n)e^{i\Theta(x_k, n)} \\ &= R_q(x_q, m)R_k(x_k, n)e^{i {{(\Theta(x_k, n)-\Theta(x_q, m))}}} \\ &= R_g(x_q, x_k, m-n)e^{i{{\Theta(x_q, x_k, m-n)}}} \\ &= g(x_q, x_k, m-n). \end{aligned}

In this representation, our initial conditions can be defined as:

q=qeiθq=Rq(xq,0)eiΘq(xq,0),k=keiθk=Rk(xk,0)eiΘk(xk,0).q = \|q\|e^{i\theta_q} = R_q(x_q, 0)e^{i\Theta_q(x_q, 0)}, \\ k = \|k\|e^{i\theta_k} = R_k(x_k, 0)e^{i\Theta_k(x_k, 0)}.

With the following relationships:

Θq(xq,0)=θq,Θk(xk,0)=θk,\Theta_q(x_q, 0) = \theta_q, \\ \Theta_k(x_k, 0) = \theta_k,

and

Rq(xq,0)=q,Rk(xk,0)=k.R_q(x_q, 0) = \|q\|, \\ R_k(x_k, 0) = \|k\|.

Now, setting m=nm = n and leveraging our knowledge about the initial conditions, we obtain:

Rq(xq,m)Rk(xk,m)=Rg(xq,xk,0)=Rq(xq,0)Rk(xk,0)=qk,Θk(xk,m)Θq(xq,m)=Θg(xq,xk,0)=Θk(xk,0)Θq(xq,0)=θkθq.R_q(x_q, m)R_k(x_k, m) = R_g(x_q, x_k, 0) = R_q(x_q, 0)R_k(x_k, 0) = \|q\|\|k\|, \\ \Theta_k(x_k, m) - \Theta_q(x_q, m) = \Theta_g(x_q, x_k, 0) = \Theta_k(x_k, 0) - \Theta_q(x_q, 0) = \theta_k - \theta_q.

From this, we can observe that a possible straight forward solution for RqR_q, RkR_k, and RgR_g can be derived as follows:

Rq(xq,m)=Rq(xq,0)=q,Rk(xk,n)=Rk(xk,0)=k,Rg(xq,xk,nm)=Rg(xq,xk,0)=qk.R_q(x_q, m) = R_q(x_q, 0) = \|q\|, \\ R_k(x_k, n) = R_k(x_k, 0) = \|k\|, \\ R_g(x_q, x_k, n - m) = R_g(x_q, x_k, 0) = \|q\|\|k\|.

From this, we can see that these solutions for RqR_q, RkR_k, and RgR_g are, in fact, independent of any positional information.

Additionally, we observe that Θ\Theta in Θq(xq,m)Θk(xk,m)=θqθk\Theta_q(x_q, m) - \Theta_k(x_k, m) = \theta_q - \theta_k is independent of xqx_q or xkx_k. Since Θq(xq,m)Θk(xk,m)=θqθk\Theta_q(x_q, m) - \Theta_k(x_k, m) = \theta_q - \theta_k, there must exist some function ψ(m)\psi(m) such that:

Θq(xq,m)=ψ(m)+θq,Θk(xk,m)=ψ(m)+θk.\Theta_q(x_q, m) = \psi(m) + \theta_q, \\ \Theta_k(x_k, m) = \psi(m) + \theta_k.

Now let’s set n=m+1n=m+1, which means we get

Θk(xk,n)Θq(xq,m)=ψ(m+1)+θkψ(m)θq=ψ(m+1)ψ(m)+θkθq=Θg(xq,xk,1)    ψ(m+1)ψ(m)=Θg(xq,xk,1)+θqθk.\begin{aligned} \Theta_k(x_k, n) - \Theta_q(x_q, m) &= \psi(m+1) + \theta_k - \psi(m) - \theta_q \\ &= \psi(m+1) - \psi(m) + \theta_k - \theta_q \\ &= \Theta_g(x_q, x_k, 1) \\ &\iff \psi(m+1) - \psi(m) = \Theta_g(x_q, x_k, 1) + \theta_q - \theta_k. \end{aligned}

As Θg(xq,xk,1)+θqθk\Theta_g(x_q, x_k, 1) + \theta_q - \theta_k is constant, we can write ψ\psi as

ψ(m)=mθ+γ\psi(m) = m\theta + \gamma

with θ=Θg(xq,xk,1)+θqθk0\theta = \Theta_g(x_q, x_k, 1) + \theta_q - \theta_k \neq 0 and γR\gamma \in \mathcal{R}.

This means when now have a way to write our encoding functions fqf_qand fkf_k:

fq(xq,m)=qei(θq+mθ+γ)=qeiθqei(mθ+γ)=qei(mθ+γ),fk(xk,n)=kei(θk+nθ+γ)=keiθkei(mθ+γ)=kei(nθ+γ).\begin{aligned} f_q(x_q, m) &= \|q\|e^{i(\theta_q + m\theta + \gamma)} =\|q\|e^{i\theta_q }e^{i(m\theta + \gamma)} = q e^{i(m\theta + \gamma)}, \\ f_k(x_k, n) &= \|k\|e^{i(\theta_k + n\theta + \gamma)} =\|k\|e^{i\theta_k }e^{i(m\theta + \gamma)} = k e^{i(n\theta + \gamma)}. \end{aligned}

As we dont make assumtions about fq(xq,0)f_q(x_q,0) and fk(xk,0)f_k(x_k,0) we choose

q=fq(xq,0)=Wqxqk=fk(xk,0)=Wkxkq = f_q(x_q,0) = W_qx_q \\ k = f_k(x_k,0) = W_kx_k

Therefore we can derive the following solution:

fq(xq,m)=(Wqxq)eimθfk(xk,n)=(Wkxk)einθf_q(x_q,m) = (W_qx_q) e^{im\theta} \\ f_k(x_k,n) = (W_kx_k) e^{in\theta}

Instead of using Euler's formula with complex numbers, we can express rotations in the attention mechanism using rotation matrices. Let's define a rotation matrix RΘ,mdR^d_{\Theta,m} which rotates 2D vectors by mθm\theta:

RΘ,md=(cos(mθ)sin(mθ)sin(mθ)cos(mθ))R^d_{\Theta,m} = \begin{pmatrix} cos(m\theta) & -sin(m\theta)\\ sin(m\theta) & cos(m\theta) \end{pmatrix}

We can now rewrite fqf_q and fkf_k as follows:

fq(xq,m)=RΘ,md(Wqxq)fk(xk,n)=RΘ,nd(Wkxk)f_q(x_q,m) = R^d_{\Theta,m} (W_qx_q) \\ f_k(x_k,n) = R^d_{\Theta,n} (W_kx_k)

Using this matrix representation, the standard attention mechanism can be expressed as:

qmTkn=(RΘ,mdWqxm)T(RΘ,ndWkxn)=xmTWq(RΘ,mdRΘ,nd)Wkxn=xmTWqRΘ,mndWkxnq_m^Tk_n = (R^d_{\Theta,m} W_qx_m)^T(R^d_{\Theta,n} W_kx_n) = x_m^TW_q (R^d_{\Theta,m}R^d_{\Theta,n})W_kx_n = x_m^TW_q R^d_{\Theta,m-n}W_kx_n

the last equality highlights a crucial point: the applied rotation depends solely on the distance between tokens (mnm - n) and not on the sentence length or absolute token positions

Currently, we have only considered 2D vectors for xqx_q and xkx_k. To generalize this to xq,kRdx_{q,k} \in \mathbb{R}^d, we divide the dd-dimensional vector into d/2d/2 two-dimensional subspaces. Each subspace is rotated based on the position of the token in the sequence.

Now given a xmRdx_{m} \in \mathcal{R}^d, we formulate a Rd_Θ,mR^d\_{\Theta,m}, with:

RΘ,md=(cosmθ1sinmθ10000sinmθ1cosmθ1000000cosmθ2sinmθ20000sinmθ2cosmθ2000000cosmθd/2sinmθd/20000sinmθd/2cosmθd/2).R^d_{\Theta,m} = \begin{pmatrix} \cos m\theta_1 & -\sin m\theta_1 & 0 & 0 & \cdots & 0 & 0 \\ \sin m\theta_1 & \cos m\theta_1 & 0 & 0 & \cdots & 0 & 0 \\ 0 & 0 & \cos m\theta_2 & -\sin m\theta_2 & \cdots & 0 & 0 \\ 0 & 0 & \sin m\theta_2 & \cos m\theta_2 & \cdots & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & \cdots & \cos m\theta_{d/2} & -\sin m\theta_{d/2} \\ 0 & 0 & 0 & 0 & \cdots & \sin m\theta_{d/2} & \cos m\theta_{d/2} \end{pmatrix}.

with pre-defined parameters:

Θ={θi=100002(i1)/d,i[1,2,,d/2]}.\Theta = \{\theta_i = 10000^{-2(i-1)/d}, \, i \in [1, 2, \dots, d/2]\}.

This approach takes every pair of (Wq,kxm)(W_{q,k}x_m) and independently rotates it based on the position mm and the index ii of the pair in xmx_m (θi\theta_i).

Description of the image

Image source

The above figure illustrates this process. For each token embedding, we take successive 2D subvectors and rotate them by a predefined angle θi\theta_i, where ii is the index of the subvector within the token embedding, multiplied by mm, the position of the token in the sequence.

Note that the above formulation of RΘ,md\mathbf{R}^d_{\Theta,m} is highly inefficient, as it constructs a mostly empty d×dd \times d matrix. Therefore, the paper introduces very simple elementwise vector operations, which are equivalent to multiplying Wq,kxmW_{q,k}x_mwith the very sparse RΘ,md\mathbf{R}^d_{\Theta,m} . It defines

RΘ,mdWq,kxm=(x1x2x3x4xd1xd)(cos(mθ1)cos(mθ1)cos(mθ2)cos(mθ2)cos(mθd/2)cos(mθd/2))+(x2x1x4x3xdxd1)(sin(mθ1)sin(mθ1)sin(mθ2)sin(mθ2)sin(mθd/2)sin(mθd/2))R^d_{\Theta,m}W_{q,k}x_m = \begin{pmatrix} x_1 \\ x_2 \\ x_3 \\ x_4 \\ \vdots \\ x_{d-1} \\ x_d \end{pmatrix} \otimes \begin{pmatrix} \cos(m\theta_1) \\ \cos(m\theta_1) \\ \cos(m\theta_2) \\ \cos(m\theta_2) \\ \vdots \\ \cos(m\theta_{d/2}) \\ \cos(m\theta_{d/2}) \end{pmatrix} + \begin{pmatrix} -x_2 \\ x_1 \\ -x_4 \\ x_3 \\ \vdots \\ -x_d \\ x_{d-1} \end{pmatrix} \otimes \begin{pmatrix} \sin(m\theta_1) \\ \sin(m\theta_1) \\ \sin(m\theta_2) \\ \sin(m\theta_2) \\ \vdots \\ \sin(m\theta_{d/2}) \\ \sin(m\theta_{d/2}) \end{pmatrix}

Now this is the operation that we will implement later in PyTorch.

Grouped-Query Attention (GQA)

Grouped Query Attention (GQA) is an optimization technique in attention mechanisms, aimed at reducing the computational complexity of standard multi-head attention. Instead of computing a separate query for each head, GQA groups multiple heads together to share a single query representation.

Description of the image

Image source

The above figure illustrates this clearly. Each rectangle in the figure represents one attention head. In standard attention, each head in the queries attends to the corresponding head in the keys. However, in this approach, multiple query heads are grouped together to attend to a single key head. The standart attention mechanism can be expressed as:

Q=XWQ,K=XWK,V=XWV,Q = XW^Q, \quad K = XW^K, \quad V = XW^V,

where:

  • XRn×dmodelX \in \mathbb{R}^{n \times d_{\text{model}}} is the input sequence of token embeddings,
  • WQ,WK,WVRdmodel×dhsW^Q, W^K, W^V \in \mathbb{R}^{d_{\text{model}} \times d_{hs}} are the weight matrices for the query, key, and value, respectively,
  • Q,K,VRn×dhsQ, K, V \in \mathbb{R}^{n \times d_{hs}} are the computed query, key, and value matrices,
  • nn is the sequence length, and dmodeld_\text{model} is the embedding dimension.

Here, dhs=heads×dhd_{hs} = \text{heads} \times d_h, where dhd_h is the head dimension. In practice, dmodel=heads×dhd_\text{model} = \text{heads} \times d_h, as we later reshape the QQ, KK, and VV tensors into the shape:

(b,seq_len,headsdh)    (b,seq_len,heads,dh)    (b,heads,seq_len,dh).(\text{b}, \text{seq\_len},\text{heads}* d_h) \implies (\text{b}, \text{seq\_len},\text{heads}, d_h) \implies (\text{b}, \text{heads}, \text{seq\_len}, d_h).

The fundamental difference in GQA lies in the shape of the weight matrices. In standard attention, the weight matrices are of size:

dmodel×dhs,d_{\text{model}} \times d_{hs},

where dhs=heads×dhd_{hs} = \text{heads} \times d_h.

In GQA, the weight matrices for KK and VV are of size:

dmodel×(headskvdh),d_{\text{model}} \times (\text{heads}_{\text{kv}} * d_h),

where headskv\text{heads}_{\text{kv}} refers to the number of key-value heads. For a group size of group_size=2\text{group\_size} = 2, as shown in the figure, the number of key-value heads is reduced:

headskv=headsgroup_size=4.\text{heads}_{\text{kv}} = \frac{\text{heads}}{\text{group\_size}} = 4.

This will result in the shape (b,seq_len,headskvdh)(\text{b}, \text{seq\_len},\text{heads}_{\text{kv}}* d_h) for our K,VK,V. Such that we can reshape them into:

(b,seq_len,headskvdh)    (b,seq_len,headskv,dh)    (b,headskv,seq_len,dh).(\text{b}, \text{seq\_len},\text{heads}_{\text{kv}}* d_h) \implies (\text{b}, \text{seq\_len},\text{heads}_{\text{kv}}, d_h) \implies (\text{b}, \text{heads}_{\text{kv}}, \text{seq\_len}, d_h).

Now, since headskvheads\text{heads}_{\text{kv}} \leq \text{heads}, we observe that by sharing attention heads between queries, we can drastically reduce the memory footprint.

Preparing the Dataset

We will train out Llama model on the TinyStories datasets. The TinyStories dataset is a dataset of short, narrative-style stories designed to be used for training natural language processing (NLP) models, particularly language models like Llama. These dataset consist of small, self-contained stories that capture a variety of writing styles, including simple plots, dialogues, and events, making them suitable for fine-tuning language models on tasks related to story generation. This is a small example from the TinyStories dataset:

Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun. Beep was a healthy car because he always had good fuel. Good fuel made Beep happy and strong. One day, Beep was driving in the park when he saw a big tree. The tree had many leaves that were falling. Beep liked how the leaves fall and wanted to play with them. Beep drove under the tree and watched the leaves fall on him. He laughed and beeped his horn. Beep played with the falling leaves all day. When it was time to go home, Beep knew he needed more fuel. He went to the fuel place and got more healthy fuel. Now, Beep was ready to go fast and play again the next day. And Beep lived happily ever after.

Now let’s look our code that prepares this dataset for training:

import torch
from datasets import load_dataset

torch.set_printoptions(profile="full")
from datasets import load_dataset
from transformers import GPT2TokenizerFast
from model import Llama3, ModelArgs
from transformers import DataCollatorForLanguageModeling
from torch.utils.data import DataLoader
import logging

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s: %(message)s",
    handlers=[logging.StreamHandler()],  # Logs to stdout
)
logger = logging.getLogger()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

max_seq_len = 256
batch_size = 70
dataset = load_dataset("roneneldan/TinyStories", split="train").select(range(1000))
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})


def tokenize_function(examples):
    return tokenizer(examples["text"], max_length=max_seq_len, truncation=True)


dataset = dataset.filter(
    lambda example: example["text"] is not None and example["text"].strip() != ""
)

tokenized_dataset = dataset.map(
    tokenize_function, batched=True, remove_columns=["text"]
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

dataloader = DataLoader(
    tokenized_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=lambda feature: Llama3.gen_labels(feature, tokenizer, data_collator),
)

This loads the TinyStories dataset and prepares it for training a language model. It tokenizes the text using GPT2TokenizerFast, ensuring that the sequences are of a fixed length (max_seq_len). The dataset is filtered to remove empty or invalid texts. A DataLoader is created using a custom collate function (Llama3.gen_labels) to generate labels for training, while efficiently handling batch processing with DataCollatorForLanguageModeling, which pads sequences and generates attention masks.

class Llama3
  (...)
  @staticmethod
  def __build_masks(seq_len, attention_mask, device, position=0, training=False):
      causal = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool)).unsqueeze(0)
      (...)
      if attention_mask == None:
          return causal
      attention_mask = attention_mask.unsqueeze(1).repeat(1, seq_len, 1).int()

      return (causal & attention_mask).int()

  @staticmethod
  def gen_labels(labels, tokenizer, data_collator, ignore_index=-100):
      batch = data_collator(labels)
      labels = batch["labels"]
      attention_mask = batch["attention_mask"]
      for i in range(labels.shape[0]):
          l = torch.roll(labels[i], -1)
          target_indices = (l == ignore_index).nonzero(as_tuple=True)[0]
          if len(target_indices) == 0:
              l[-1] = tokenizer.eos_token_id
          else:
              seq_len = len(l) - len(target_indices)
              l[-1] = ignore_index
              l[seq_len - 1] = tokenizer.eos_token_id

          labels[i] = l

      batch["labels"] = labels
      batch["attention_mask"] = Llama3.__build_masks(
          labels.shape[1], attention_mask, device, training=True
      )
  (...)

The gen_labels function is designed to prepare the labels for training by shifting them one token to the left and appending an EOS (End of Sequence) token at the end. At the start of the function, the labels are just a copy of the input. The challenge is to place the EOS token at the correct position, which requires determining the length of the unpadded sequence. If padding is present, the function places the EOS token at the end of the actual data, not on top of padding tokens. If no padding exists, the EOS token is placed at the very end of the sequence.

Once the labels are shifted and the EOS token is placed correctly, the function must handle the attention masks. Since DataCollatorForLanguageModeling provides an array of attention values for a single sequence, the function needs to convert this into a matrix that can be used for causal language modeling. This involves combining the attention mask with a causal mask, which is a lower triangular matrix, ensuring that the model only attends to previous tokens in the sequence and not to any future tokens.

Llama3 - Model Architecture

Llama3 Architecture Overview

Now that we have everything in place, let's look at how the Llama3 architecture is structured. Below is an overview of its design:

Description of the image

Image source

As we can see, Llama3 is conceptually very similar to the vanilla transformer architecture. Most importantly, the Llama3 and other GPT-style models only consist of a decoder-only transformer. This means that we won't use the encoder and the cross-attention, which we covered in the previous blog post. Another notable change that stands out is the use of RMSNorm instead of LayerNorm for normalization.

RMSNorm(x)=x1ni=1nxi2+ϵγ\text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{n} \sum_{i=1}^{n} x_i^2 + \epsilon}} \cdot \gamma LayerNorm(x)=xμσ+ϵγ+β\text{LayerNorm}(x) = \frac{x - \mu}{\sigma + \epsilon} \cdot \gamma + \beta

While LayerNorm normalizes the input by using the mean and standard deviation of the tensor, RMSNorm uses the root mean square of the elements, making it simpler and computationally more efficient. Both methods have learnable scaling γ\gamma and bias β\beta parameters, but RMSNorm avoids the need for computing the mean and standard deviation, which can be more efficient in practice.

Additionally, we observe that the positional encoding (RoPE) is applied directly to the queries and keys in the attention layer, rather than at the embedding level. The values, however, are not position encoded.

After that the outputs of the attention layer are again feed into a FFN. This time however we use the so called SwiGLU activation function. The SwiGLU (Swish-Gated Linear Unit) activation function is a variation of the standard activation functions like ReLU and GELU. It is designed to combine the benefits of the Swish activation function and the Gated Linear Unit (GLU). SwiGLU introduces a gating mechanism using a linear function combined with the Swish activation, which helps improve model performance, particularly in transformer-based architectures.

The SwiGLU activation function is mathematically defined as:

SwiGLU(x)=Swish(x)h(X)\text{SwiGLU}(x) = \text{Swish}(x) \cdot h(X)

Where:

  • Swish(x)Swish(x) is the Swish activation function defined as:

    Swish(x)=xσ(x)\text{Swish}(x) = x \cdot \sigma(x)

    with σ(x)\sigma(x) being the sigmoid function:

    σ(x)=11+ex\sigma(x) = \frac{1}{1 + e^{-x}}
  • h(X)h(X) is the Gated Linear Unit with: h(X)=(XW+b)σ(XV+c)b,cRh(X)=(XW+b) \otimes \sigma(XV+c) \quad b,c \in \mathbb{R} with \otimes being the element-wise product

Now that we have an overview of the Llama3 architecture, let's quickly discuss the implementation of the rotary positional encodings in our model. As mentioned earlier, the direct formulation of a large, mostly empty RΘ,mdR^d_{\Theta,m} matrix for positional encodings is inefficient and wasteful in terms of memory and computation. Therefore, we pursue a different approach to implement positional encodings.

Below we have the RoPE-class, which implents the positional encoding.

class RoPE:
    @staticmethod
    def compute_freq(head_dim: int, seq_len: int, base: int = 10000):
        exp = -2 * torch.arange(0, head_dim, 2).float() / head_dim
        thetas = torch.pow(base, exp)
        m = torch.arange(0, seq_len).float()
        freq = torch.outer(m, thetas).float()
        freq_comp = torch.polar(torch.ones_like(freq), freq)
        return freq_comp

    @staticmethod
    def apply_rotary_embedding(x: torch.Tensor, freq_cis):
        # batch,seq_len,heads,d_k
        x_comp = torch.view_as_complex(x.float().reshape(*(x.shape[:-1]), -1, 2))
        freq_com = freq_cis.unsqueeze(0).unsqueeze(2)

        x_out = torch.view_as_real(x_comp * freq_com).reshape(*x.shape)
        return x_out.float()

The compute_freq function precomputes the mθim\theta_i for mmax_lenm \leq \text{max\_len} and id_modeli \leq \text{d\_model} .

While the first two lines implement the following equation:

Θ={θi=100002(i1)/d,i[1,2,,d/2]},\Theta = \{\theta_i = 10000^{-2(i-1)/d}, \, i \in [1, 2, \dots, d/2]\},

the third line implements the outer product between the array of positions and the θi\theta_i .

Lastly, we use torch.polar, which converts the number into a complex number using the matrix entries as angular components. Mathematically, this whole process can be expressed as:

ψ=MΘT=[m1θ1m1θ2m1θd/2m2θ1m2θ2m2θd/2mmaxθ1mmaxθ2mmaxθd/2], with M=[0max]Θ=[θ0θd/2]\psi = M\Theta^T = \begin{bmatrix} m_1 \theta_1 & m_1 \theta_2 & \dots & m_1 \theta_{d/2} \\ m_2 \theta_1 & m_2 \theta_2 & \dots & m_2 \theta_{d/2} \\ \vdots & \vdots & \ddots & \vdots \\ m_{\text{max}} \theta_1 & m_{\text{max}} \theta_2 & \dots & m_{\text{max}} \theta_{d/2} \end{bmatrix} \text{, with } M = \begin{bmatrix} 0 \\ \vdots \\ \text{max} \end{bmatrix} \Theta = \begin{bmatrix} \theta_0 \\ \vdots \\ \theta_{d/2} \\ \end{bmatrix}

After the torch.polar operation, which converts the matrix entries into their polar representation, it looks like this:

polar(ψ)=[cos(m1θ1)+isin(m1θ1)cos(m1θ2)+isin(m1θ2)cos(m1θd/2)+isin(m1θd/2)cos(m2θ1)+isin(m2θ1)cos(m2θ2)+isin(m2θ2)cos(m2θd/2)+isin(m2θd/2)cos(mmaxθ1)+isin(mmaxθ1)cos(mmaxθ2)+isin(mmaxθ2)cos(mmaxθd/2)+isin(mmaxθd/2)]\begin{aligned} polar(\psi) &= \begin{bmatrix} cos(m_1 \theta_1) + i sin(m_1 \theta_1) & cos(m_1 \theta_2) + i sin(m_1 \theta_2) & \dots & cos(m_1 \theta_{d/2}) + i sin(m_1 \theta_{d/2})\\ cos(m_2 \theta_1) + i sin(m_2 \theta_1) & cos(m_2 \theta_2) + i sin(m_2 \theta_2) & \dots & cos(m_2 \theta_{d/2}) + i sin(m_2 \theta_{d/2})\\ \vdots & \vdots & \ddots & \vdots \\ cos(m_{\text{max}} \theta_1) + i sin(m_{\text{max}} \theta_1) & cos(m_{\text{max}} \theta_2) + i sin(m_{\text{max}} \theta_2) & \dots & cos(m_{\text{max}} \theta_{d/2}) + i sin(m_{\text{max}} \theta_{d/2}) \end{bmatrix} \end{aligned}

As we can see the radial components are 11.

We will now shift our focus on the apply_rotary_embedding function, which given queries QQ and keys KK, will encode them based on their position in the sentence.

let’s demonstrate this problems bei focusing on polar(ψ)jpolar(\psi)_j, which contains the angular components for position jj. Let x{q,k}R1×d_modelx \in \{q,k\} \sub \mathbb{R}^{1 \times \text{d\_model}} be the queries (QQ) and keys(KK). We will transfrom them in the following way:

x=[x1x2x3x4xd1xd]reshape[[x1x2][x3x4][xd1xd]]convert into complex[x1+ix2x3+ix4xd1+ixd]=x^x = \begin{bmatrix} x_1 \\ x_2 \\ x_3 \\ x_4 \\ \vdots \\ x_{d-1} \\ x_d \end{bmatrix} \xrightarrow{\text{reshape}} \begin{bmatrix} \begin{bmatrix} x_1 & x_2 \end{bmatrix} \\ \begin{bmatrix} x_3 & x_4 \end{bmatrix} \\ \vdots \\ \begin{bmatrix} x_{d-1} & x_d \end{bmatrix} \\ \end{bmatrix} \xrightarrow{\text{convert into complex}} \begin{bmatrix} x_1 + ix_2 \\ x_3 + ix_4 \\ \vdots \\ x_{d-1} + ix_d \\ \end{bmatrix} = \hat{x}

Now we take this vector of complex entries and multiply it element-wise with polar(ψ)jpolar(\psi)_j:

x^polar(ψ)j=[x1+ix2x3+ix4xd1+ixd][cos(jθ1)+isin(jθ1)cos(jθ2)+isin(jθ2)cos(jθd/2)+isin(jθd/2)]=[(x1+ix2)(cos(jθ1)+isin(jθ1))(x3+ix4)(cos(jθ2)+isin(jθ2))(xd1+ixd)(cos(jθd/2)+isin(jθd/2))]=[x1cos(jθ1)x2sin(jθ1)+i(x1sin(jθ1)+x2cos(jθ1))x3cos(jθ1)x4sin(jθ1)+i(x3sin(jθ1)+x4cos(jθ1))xd1cos(jθ1)xdsin(jθ1)+i(xd1sin(jθ1)+xdcos(jθ1))]\begin{aligned} \hat{x} \otimes polar(\psi)_j &= \begin{bmatrix} x_1 + ix_2 \\ x_3 + ix_4 \\ \vdots \\ x_{d-1} + ix_d \\ \end{bmatrix} \otimes \begin{bmatrix} cos(j \theta_1) + i sin(j \theta_1) \\ cos(j \theta_2) + i sin(j \theta_2) \\ \vdots \\ cos(j \theta_{d/2}) + i sin(j \theta_{d/2})\\ \end{bmatrix}\\ &= \begin{bmatrix} (x_1 + ix_2)(cos(j \theta_1) + i sin(j \theta_1)) \\ ( x_3 + ix_4 )(cos(j \theta_2) + i sin(j \theta_2)) \\ \vdots \\ (x_{d-1} + ix_d)(cos(j \theta_{d/2}) + i sin(j \theta_{d/2}))\\ \end{bmatrix} \\ &= \begin{bmatrix} x_1 \cos(j \theta_1) - x_2 \sin(j \theta_1) + i\left(x_1 \sin(j \theta_1) + x_2 \cos(j \theta_1)\right) \\ x_3 \cos(j \theta_1) - x_4 \sin(j \theta_1) + i\left(x_3 \sin(j \theta_1) + x_4 \cos(j \theta_1)\right) \\ \vdots \\ x_{d-1} \cos(j \theta_1) - x_d \sin(j \theta_1) + i\left(x_{d-1} \sin(j \theta_1) + x_d \cos(j \theta_1)\right) \\ \end{bmatrix}\\ \end{aligned}

We can now view this result again as a real matrix and reshape it into its original dimension:

into real[[x1cos(jθ1)x2sin(jθ1)x1sin(jθ1)+x2cos(jθ1)][x3cos(jθ1)x4sin(jθ1)x3sin(jθ1)+x4cos(jθ1)][xd1cos(jθ1)xdsin(jθ1)xd1sin(jθ1)+xdcos(jθ1)]]reshape[x1cos(jθ1)x2sin(jθ1)x1sin(jθ1)+x2cos(jθ1)x3cos(jθ1)x4sin(jθ1)x3sin(jθ1)+x4cos(jθ1)xd1cos(jθ1)xdsin(jθ1)xd1sin(jθ1)+xdcos(jθ1)]=(x1x2x3x4xd1xd)(cos(jθ1)cos(jθ1)cos(jθ2)cos(jθ2)cos(jθd/2)cos(jθd/2))+(x2x1x4x3xdxd1)(sin(jθ1)sin(jθ1)sin(jθ2)sin(jθ2)sin(jθd/2)sin(jθd/2))=RΘ,jdx\begin{aligned} &\xrightarrow{\text{into real}} \begin{bmatrix} \begin{bmatrix}x_1 \cos(j \theta_1) - x_2 \sin(j \theta_1) \quad x_1 \sin(j \theta_1) + x_2 \cos(j \theta_1)\end{bmatrix} \\ \begin{bmatrix}x_3 \cos(j \theta_1) - x_4 \sin(j \theta_1) \quad x_3 \sin(j \theta_1) + x_4 \cos(j \theta_1)\end{bmatrix} \\ \vdots \\ \begin{bmatrix}x_{d-1} \cos(j \theta_1) - x_d \sin(j \theta_1) \quad x_{d-1} \sin(j \theta_1) + x_d \cos(j \theta_1) \end{bmatrix} \\ \end{bmatrix}\\ &\xrightarrow{\text{reshape}} \begin{bmatrix} x_1 \cos(j \theta_1) - x_2 \sin(j \theta_1) \\ x_1 \sin(j \theta_1) + x_2 \cos(j \theta_1) \\ x_3 \cos(j \theta_1) - x_4 \sin(j \theta_1) \\ x_3 \sin(j \theta_1) + x_4 \cos(j \theta_1) \\ \vdots \\ x_{d-1} \cos(j \theta_1) - x_d \sin(j \theta_1) \\ x_{d-1} \sin(j \theta_1) + x_d \cos(j \theta_1) \\ \end{bmatrix}\\ &= \begin{pmatrix} x_1 \\ x_2 \\ x_3 \\ x_4 \\ \vdots \\ x_{d-1} \\ x_d \end{pmatrix} \otimes \begin{pmatrix} \cos(j\theta_1) \\ \cos(j\theta_1) \\ \cos(j\theta_2) \\ \cos(j\theta_2) \\ \vdots \\ \cos(j\theta_{d/2}) \\ \cos(j\theta_{d/2}) \end{pmatrix} + \begin{pmatrix} -x_2 \\ x_1 \\ -x_4 \\ x_3 \\ \vdots \\ -x_d \\ x_{d-1} \end{pmatrix} \otimes \begin{pmatrix} \sin(j\theta_1) \\ \sin(j\theta_1) \\ \sin(j\theta_2) \\ \sin(j\theta_2) \\ \vdots \\ \sin(j\theta_{d/2}) \\ \sin(j\theta_{d/2}) \end{pmatrix}\\ &= R^d_{\Theta,j}x \end{aligned}

As we can see, we arrived at our end result of RΘ,jdxR^d_{\Theta,j}x, with the operations performed in our PyTorch code in the function apply_rotation.

Now that we have covered the most important details that distinguish the Llama3 architecture from a vanilla decoder-only transformer, we can actually take a look at the code of the Llama class.


class FFN(nn.Module):
    def __init__(self, d_model=256, multiple_of=256):
        super(FFN, self).__init__()
        hidden = 4 * d_model
        hidden = int(2 * hidden / 3)

        hidden = multiple_of * ((hidden + multiple_of - 1) // multiple_of)
        self.w1 = nn.Linear(d_model, hidden, bias=False)
        self.v = nn.Linear(d_model, hidden, bias=False)
        self.w2 = nn.Linear(hidden, d_model, bias=False)

    def forward(self, x):
        return self.w2(nn.functional.silu(self.w1(x)) * self.v(x))

class MultiHeadGQAttention(nn.Module):
    flash = False

    def __init__(self, heads=4, d_model=256, group_size=2, max_seq_len=256):
        super(MultiHeadGQAttention, self).__init__()
        self.heads = heads
        self.d_model = d_model
        self.group_size = group_size

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model // group_size, bias=False)
        self.W_v = nn.Linear(d_model, d_model // group_size, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(
        self,
        q,
        k,
        v,
        mask,
        freq_cis,
        position=-1,
    ):
        d_k = self.d_model // self.heads
        bs, seq_len = q.shape[:2]
        q, k, v = self.W_q(q), self.W_k(k), self.W_v(v)

        q = q.view(
            q.shape[0], q.shape[1], self.heads, -1
        )  # (batch,seq_len,heads,d_k) or (batch,1,heads,d_k)
        k = k.view(
            k.shape[0], k.shape[1], self.heads // self.group_size, -1
        )  # (batch,seq_len,heads,d_k) or (batch,1,heads,d_k)
        v = v.view(
            v.shape[0], v.shape[1], self.heads // self.group_size, -1
        )  # (batch,seq_len,heads,d_k) or (batch,1,heads,d_k)
        q = RoPE.apply_rotary_embedding(q, freq_cis)
        k = RoPE.apply_rotary_embedding(k, freq_cis)
        if not self.training:
            k, v = self.kv_cache.update(k, v, position)


        q, k, v = (x.transpose(1, 2) for x in (q, k, v))

        if MultiHeadGQAttention.flash:
            q, k, v = (x.contiguous() for x in (q, k, v))
            output = (
                F.scaled_dot_product_attention(
                    q, k, v, attn_mask=(mask == 1).unsqueeze(1)
                )
                .transpose(1, 2)
                .reshape(bs, seq_len, -1)
            )
        else:
            scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
            if mask is not None:
                mask = mask.unsqueeze(1)
                scores = torch.masked_fill(scores, mask == 0, float("-inf"))
                attention = nn.functional.softmax(scores, dim=-1)
                output = (
                    torch.matmul(attention, v)
                    .transpose(1, 2)
                    .contiguous()
                    .view(bs, seq_len, -1)
                )

        output = self.W_o(output)
        return output

class DecoderLayer(nn.Module):
    def __init__(self, d_model=256, heads=4, group_size=2, max_seq_len=256):
        super(DecoderLayer, self).__init__()
        self.norm1 = nn.RMSNorm(d_model, eps=1e-6)
        self.norm2 = nn.RMSNorm(d_model, eps=1e-6)
        self.ffn = FFN(d_model=d_model)
        self.attention = MultiHeadGQAttention(
            heads=heads, d_model=d_model, group_size=group_size, max_seq_len=max_seq_len
        )

    def forward(self, x, tgt_causal_mask, pos, freqs_cis):
        x_norm = self.norm1(x)
        x = x + self.attention(
            x_norm, x_norm, x_norm, tgt_causal_mask, freqs_cis, position=pos
        )
        return x + self.ffn(self.norm2(x))

class Llama3(nn.Module):
    def __init__(self, params: ModelArgs):
        super(Llama3, self).__init__()
        self.tokenizer = params.tokenizer
        self.max_seq_len = params.max_seq_len
        self.ignore_index = params.ignore_index
        self.num_layers = params.num_layers
        self.layers = nn.ModuleList(
            [
                DecoderLayer(
                    d_model=params.d_model,
                    heads=params.heads,
                    group_size=params.group_size,
                    max_seq_len=params.max_seq_len,
                )
                for _ in range(params.num_layers)
            ]
        )
        self.embedding = nn.Embedding(params.vocab_size, params.d_model)
        self.norm = nn.RMSNorm(params.d_model, eps=1e-6)
        self.ffn = nn.Linear(params.d_model, params.vocab_size, bias=False)

        self.d_k = params.d_model // params.heads
        self.freqs_cis = RoPE.compute_freq(
            head_dim=params.d_model // params.heads, seq_len=params.max_seq_len
        )
        self.d_model = params.d_model

        MultiHeadGQAttention.flash = params.use_flash
    (...)
    def calc_loss(self, logits, labels):
        loss = nn.functional.cross_entropy(
            logits.view(-1, logits.shape[-1]),
            labels.view(-1),
            ignore_index=self.ignore_index,
        )
        return loss

    def __run_model(self, tgt, attention_mask, position=0):

        tgt_embed = self.embedding(tgt)
        freqs_cis = self.freqs_cis[position : position + tgt_embed.shape[1]].to(
            tgt.device
        )

        for i in range(self.num_layers):
            tgt_embed = self.layers[i](tgt_embed, attention_mask, position, freqs_cis)
        return self.ffn(self.norm(tgt_embed))

    def forward(self, tgt, attention_mask, labels):
        logits = self.__run_model(tgt, attention_mask)
        return self.calc_loss(logits, labels)

As this is rather much to comprehend, we will start with the forward method of the Llama class. It is easy to see that this just calls __run_model, which precomputes the frequency components for the position encoding, based on the position of the current token. This is a result of the KV-Cache, which we will cover in a bit. The determined mjθim_j \theta_i are then passed to the decoder layers, together with the token embeddings and masks.

The DecoderLayer may be of more interest:

class DecoderLayer(nn.Module):
    def __init__(self, d_model=256, heads=4, group_size=2, max_seq_len=256):
        super(DecoderLayer, self).__init__()
        self.norm1 = nn.RMSNorm(d_model, eps=1e-6)
        self.norm2 = nn.RMSNorm(d_model, eps=1e-6)
        self.ffn = FFN(d_model=d_model)
        self.attention = MultiHeadGQAttention(
            heads=heads, d_model=d_model, group_size=group_size, max_seq_len=max_seq_len
        )

    def forward(self, x, tgt_causal_mask, pos, freqs_cis):
        x_norm = self.norm1(x)
        x = x + self.attention(
            x_norm, x_norm, x_norm, tgt_causal_mask, freqs_cis, position=pos
        )
        return x + self.ffn(self.norm2(x))

As we can see from our architecture overview from before, we first do an RMSNorm before we feed the output into the MultiHeadGQAttention layer. Applying the skip connection and the FFN is pretty self explanotroy. So the most interesting Layer is the MultiHeadGQAttention. So let’s look at that:

class MultiHeadGQAttention(nn.Module):
    flash = False

    def __init__(self, heads=4, d_model=256, group_size=2, max_seq_len=256):
        super(MultiHeadGQAttention, self).__init__()
        self.heads = heads
        self.d_model = d_model
        self.group_size = group_size

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model // group_size, bias=False)
        self.W_v = nn.Linear(d_model, d_model // group_size, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(
        self,
        q,
        k,
        v,
        mask,
        freq_cis,
        position=-1,
    ):
        d_k = self.d_model // self.heads
        bs, seq_len = q.shape[:2]
        q, k, v = self.W_q(q), self.W_k(k), self.W_v(v)

        q = q.view(
            q.shape[0], q.shape[1], self.heads, -1
        )  # (batch,seq_len,heads,d_k) or (batch,1,heads,d_k)
        k = k.view(
            k.shape[0], k.shape[1], self.heads // self.group_size, -1
        )  # (batch,seq_len,heads,d_k) or (batch,1,heads,d_k)
        v = v.view(
            v.shape[0], v.shape[1], self.heads // self.group_size, -1
        )  # (batch,seq_len,heads,d_k) or (batch,1,heads,d_k)
        q = RoPE.apply_rotary_embedding(q, freq_cis)
        k = RoPE.apply_rotary_embedding(k, freq_cis)
        if not self.training:
            k, v = self.kv_cache.update(k, v, position)


        q, k, v = (x.transpose(1, 2) for x in (q, k, v))

        if MultiHeadGQAttention.flash:
            q, k, v = (x.contiguous() for x in (q, k, v))
            output = (
                F.scaled_dot_product_attention(
                    q, k, v, attn_mask=(mask == 1).unsqueeze(1)
                )
                .transpose(1, 2)
                .reshape(bs, seq_len, -1)
            )
        else:
            scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
            if mask is not None:
                mask = mask.unsqueeze(1)
                scores = torch.masked_fill(scores, mask == 0, float("-inf"))
                attention = nn.functional.softmax(scores, dim=-1)
                output = (
                    torch.matmul(attention, v)
                    .transpose(1, 2)
                    .contiguous()
                    .view(bs, seq_len, -1)
                )

        output = self.W_o(output)
        return output

As we can see, this is standard functionality, with most parts being identical to the standard attention layer used in a decoder-only Transformer. However, the first difference arises in the weight matrices for QQ, KK, and VV. Specifically, we choose KK and VV to have dimensions d_modelgroup_size\frac{\text{d\_model}}{\text{group\_size}} , where d_modelgroup_size=headskv\frac{\text{d\_model}}{\text{group\_size}} = \text{heads}_{kv} .

In the forward\texttt{forward} method, we use a standard reshaping mechanism to ensure that the resulting tensors for QQ, KK, and VV have the shape seq_len×d_head\text{seq\_len} \times d\_{\text{head}} .

Next, we apply positional encoding to the queries and keys using the previously defined apply_rotary_embedding\texttt{apply\_rotary\_embedding} function. As mentioned earlier, the freq_cis\texttt{freq\_cis} contain the terms mjθi m_j\theta_i, where mjm_j represents the positions of the tokens fed into the model. This step is crucial, especially during inference, as the entire sequence of tokens is not fed into the model repeatedly. We will delve into this aspect in more detail later.

Following this, we perform the now well-known attention operation between the encoded queries\textbf{queries}, keys\textbf{keys}, and the plain values\textbf{values}. Additionally, there is an optional function called scaled_dot_product_attention\texttt{scaled\_dot\_product\_attention}. This function provides a hardware- and memory-efficient implementation of the attention operation, commonly referred to as flash attention.

We will not go into much detail about flash attention in this post, but we may cover it in another post. For now, it suffices to say that this function implements attention more efficiently in terms of both hardware and memory usage.

Now this already completes the model code of Llama! We will use the following training code, along with the data preparation code from above, to make our tiny model tell us some stories!


warmup = 100
lr = 1e-4
min_lr = 7e-5

args = ModelArgs(
    vocab_size=len(tokenizer),
    tokenizer=tokenizer,
    d_model=256,
    heads=16,
    group_size=2,
    num_layers=32,
    max_seq_len=max_seq_len,
    use_flash=True,
)


epochs = 2

model = Llama3(args).to(device)
# model.load_state_dict(torch.load('tiny_stories_2.pth'))
model = torch.compile(model)
model.train()
optim = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=0)
# lr_scheduler = InverseSquareRootLR(optim, warmup, lr, min_lr=min_lr)
logger.info(
    f"Param. count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
)
scaler = torch.GradScaler()

for epoch in range(epochs):
    for i, batch in enumerate(dataloader):

        inputs = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        optim.zero_grad()

        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            loss = model(inputs, attention_mask, labels)

        scaler.scale(loss).backward()

        scaler.unscale_(optim)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        scaler.step(optim)
        scaler.update()
        # lr_scheduler.step()

        logger.info(
            f"Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}"
        )
        if (i + 1) % 1000 == 0:
            torch.save(model._orig_mod.state_dict(), "tiny_stories_50M.pth")

As you can see, there is nothing special about this code. We set up our optimizer, compile our Llama3 model, and feed the inputs and labels, such as the attention mask, into the model. Additionally, we use the torch.amp\texttt{torch.amp} package, which provides methods for implementing mixed-precision training.

Now that we have set up our training code, trained, and subsequently saved our model to tiny_stories_50M.pth\texttt{tiny\_stories\_50M.pth}, we can actually use this model to sample some stories from it. So let's do that!

Optimizing Inference - The KV-Cache

One of the most important optimizations that Llama has introduced in their architecture is the so-called key-value cache. The purpose of this becomes clear if we look at the figure below:

Description of the image

Image source

As you know, during inference we sample the next token and append it to the sequence before we feed this new sequence into the transformer to predict the next token. But as you can see in the figure, to predict token 5, we only need query token 4 to multiply with the keys. This means to predict token 5, we only need the last row of the attention matrix. Thus, instead of feeding in the entire sequence of tokens of length nn to predict n+1n+1, we just feed in the nn-th token. For the attention and subsequent multiplication with VV, however, we still need the previous tokens. This is exactly where the key-value cache (kv-cache) comes in.

Description of the kv-cache mechanism

Image source

For every token we see, we save it in the kv-cache for later usage in the multiplications. By doing this, we can save a significant amount of attention multiplications, as we only need to compute the last row of the attention matrix! Nice!

Now lets look at the code for the KV-Cache:

class KV_Cache(nn.Module):
    def __init__(self, batch_size, seq_length, n_kv_heads, head_dim):
        super(KV_Cache, self).__init__()
        device = "cuda"
        cache_shape = (batch_size, seq_length, n_kv_heads, head_dim)
        self.cache_k = torch.zeros(cache_shape, device=device)
        self.cache_v = torch.zeros(cache_shape, device=device)

    def update(self, xk, xv, pos):
        bx, seq_len = xk.shape[:2]
        self.cache_k[:bx, pos : pos + seq_len] = xk
        self.cache_v[:bx, pos : pos + seq_len] = xv
        return self.cache_k[:bx, : pos + seq_len], self.cache_v[:bx, : pos + seq_len]

Here we reserve as much space as we could possibly need by constructing a zero tensor of shape (batch_size, max_seq_length, n_kv_heads, head_dim). Then, during the update, we pass in the current token position and save the keys and values into the cache before returning all the cached values up to the current point in time.

Here is how we will use it in the MultiHeadGQAttention\texttt{MultiHeadGQAttention} class:

class MultiHeadGQAttention(nn.Module):
    flash = False
    def __init__(self, heads=4, d_model=256, group_size=2, max_seq_len=256):
        super(MultiHeadGQAttention, self).__init__()
        (...)
        self.kv_cache = KV_Cache(
            batch_size=4,
            seq_length=max_seq_len,
            n_kv_heads=self.heads // self.group_size,
            head_dim=d_model // heads,
        )

    def __repeat_kv(self, x):
        bs, slen, n_kv_heads, head_dim = x.shape
        return (
            x[:, :, :, None, :]
            .expand(bs, slen, n_kv_heads, self.group_size, head_dim)
            .reshape(bs, slen, n_kv_heads * self.group_size, head_dim)
        )
    def forward(
        self,
        q,
        k,
        v,
        mask,
        freq_cis,
        position=-1,
    ):
        (...)
        q = RoPE.apply_rotary_embedding(q, freq_cis)
        k = RoPE.apply_rotary_embedding(k, freq_cis)
        if not self.training:
            k, v = self.kv_cache.update(k, v, position)

        k = self.__repeat_kv(k)
        v = self.__repeat_kv(v)
        (...)

As you can see, we will store the keys and values in the cache after we positional encoded them. Now lets look at the generate_kv\texttt{generate\_kv} function, that actually handles sampling from the model, utilizing the kv-cache:

 def generate_kv(self, prompt, tokenizer, temp=1.0, top_p=None):
        device = "cuda"
        tokenized = tokenizer(prompt, max_length=self.max_seq_len, truncation=True)
        tokens = torch.tensor(tokenized["input_ids"]).unsqueeze(0).to(device)
        sampled_token = None
        sampled_tokens = tokens.squeeze(0).tolist()

        token_len = tokens.shape[1]

        mask = (
            torch.tril(torch.ones(token_len, token_len, dtype=torch.bool))
            .unsqueeze(0)
            .to(device)
        )
        i = 0
        for block in self.layers:
            block.attention.cache = KV_Cache(
                batch_size=1,
                seq_length=self.max_seq_len,
                n_kv_heads=self.heads // self.group_size,
                head_dim=self.d_model // self.heads,
            )
        while i < self.max_seq_len:

            logits = self.__run_model(tokens, mask, position=i)[:, -1, :] / temp
            probabilities = F.softmax(logits.float(), dim=-1).squeeze()
            if top_p is not None:
                sorted_probs, sorted_indices = torch.sort(
                    probabilities, descending=True
                )
                cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
                cutoff_index = torch.searchsorted(cumulative_probs, top_p, right=True)

                sorted_probs[cutoff_index + 1 :] = 0
                sorted_probs = sorted_probs / sorted_probs.sum()
                probabilities = torch.zeros_like(probabilities).scatter(
                    0, sorted_indices, sorted_probs
                )

                sampled_token = torch.multinomial(probabilities, num_samples=1)
            else:
                sampled_token = torch.multinomial(probabilities, num_samples=1)

            tokens = sampled_token.unsqueeze(0)
            if sampled_token.item() != tokenizer.eos_token_id:
                sampled_tokens.append(sampled_token.item())
            else:
                break
            i = len(sampled_tokens)
            mask = None

        return tokenizer.decode(sampled_tokens)

Now this function takes in a prompt like: Once a little green car. We then tokenize this sequence and create a causal mask. Next, we initialize the KV-Cache for each attention layer. The generation process then begins and ends either if we exceed the max_length or if the special EOS token appears.

For sampling, we have two distinct strategies. The first one is temperature sampling, and the second is top-p sampling. Using temperature sampling, we divide the logits by the temperature variable. If the temperature is below 1, this sharpens the probability distribution after applying the softmax. A temperature greater than 1 has the opposite effect, leading to a more spread-out distribution and making previously less likely tokens more likely. Top-p sampling will only consider the most likely tokens whose cumulative probability sum is less than or equal to pp. Typical values for pp are 0.80.8 or 0.90.9.

Now using this code we can sample from the model like this:

import torch
from model import Llama3, ModelArgs
from transformers import GPT2TokenizerFast
from colorama import Fore, Back, Style, init
init()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
max_seq_len = 256

args = ModelArgs(
    vocab_size=len(tokenizer),
    tokenizer=tokenizer,
    d_model=256,
    heads=4,
    group_size=2,
    num_layers=32,
    max_seq_len=max_seq_len,
    use_flash=True,
)
model = Llama3.from_pretrained("tiny_stories_50M.pth", args).to(device)
model.eval()

res = model.generate_kv("There once a little green car called Beep", tokenizer=tokenizer, top_p=0.65)

print(Fore.GREEN + res)

Running this will print:

There once a little green car called Beep. Beep loved to drive around the town, but one day it got lost. Beep drove around looking for its way back, but it couldn't find its way back. Beep was so sad and started to cry. Suddenly, Beep saw a big green truck. The truck said, "Don't worry Beep. I will help you find your way home." Beep was so happy and the truck said, "Let's go! I know the way!" So Beep and the truck drove together, until they reached the other side of the town. When they arrived, Beep said, "Thank you for helping me. You're a good friend." The truck said, "You're welcome. It was nice to help." Wep drove back home, feeling happy and relieved. It knew that the green truck was right. Beep and the truck were friends forever.

As you can see, this is already a pretty good story! Our model currently has about 50 million parameters and was trained on the TinyStories dataset for just one epoch. Due to the limitations of my GPU, which has only 8GB of VRAM, I had to use a smaller batch size. While this increases training times, I am planning to upgrade my hardware in the near future!

The complete code will be updated and made available on my GitHub page. This concludes my post about LLaMA3, and I hope you found it an enjoyable and informative read. If you spot any mistakes or have suggestions, please feel free to reach out to me!