Porting nanochat to Transformers: an AI modeling history lesson

Banner

There is a lot to learn about ML from nanochat, and even more to learn about the history of the transformer architecture.

Affiliation

Hugging Face

Published

Nov. 27, 2025

PDF

Table of Contents

Recently we were working on helping students of the nanochat project to share their models and discuss their learning on Hugging Face. In the process, we thought it would be useful if the model was integrated into the transformers library. This would allow others to use their nanochat models for inference in loads of downstream libraries like vLLM for inference or TRL for post-training.

You can now use nanochat models in transformers and tap into all those educational gains across the ecosystem. But along the way, we uncovered a further treasure trove of education about how canonical models relate to each other, and the components they share. We received the lesson from the simple teacher of class inheritance and transformers modular philosophy.

Now, let’s tuck into this deep dive on how NanoChat relates the lineage of transformer architectures.

What is nanochat?

On October 13th 2025, Andrej Karpathy unceremoniously dropped the nanochat repo into the unsuspecting AI world. To hype seekers, this was just a small and pretty average LLM. To ML devotees, this was nirvana. A raw unadulterated chance to tinker, fiddle, and play with a transformer model defined in pure pytorch. Nothing was hidden away in fancy torch methods or inherited from complex class structures. It was all there in a simple file.

image1

Karpathy had painstakingly implemented an end-to-end build of an LLM system without the use of most major libraries. Even though in real world situations most rely on transformers, tokenizers, datasets, trl, etc. This back to basics approach gives us the chance to genuinely learn and understand something from the ground up.

Personally, I found the process to be one of the most educational I can remember.

What is transformers?

Most of us know the transformers library as the backbone of modern machine learning, but if we dig a little deeper, it’s a powerful piece of education.

If you don’t know… transformers is the de facto implementation of modern AI models that bear the same name; ‘transformers’ like models in GPT, DeepSeek, Claude, series. transformers is a special project because it contains the implementation of all major open model architecture and those model architectures are modularized to reuse functionality from each other.

In general, scientists at AI research labs design, implement, and train their models in their framework of choice, be that torch, JAX, etc. When they come to share their open model with the community, they will open a PR on transformers and refactor their code to use relevant modules.

Because transformers contain most major model implementations, researchers have to inherent model architecture attributes from other canonical models. This is in every sense a ‘single source of truth’.

This practical feature of the library has an amazingly educational quality to it. We can read a model implementation as a series of references to other usages of those architectural features. For example, when one model uses a certain type of RMSNorm, we can plainly see that it is the same implementation as another model because it inherits that class entirely. For example, check out nanochat’s RMSNorm:

class NanoChatRMSNorm(Llama4TextL2Norm):
    pass

The transformers library then converts the modular_* implementation into a modeling_* implementation, which contains the complete torch native implementation:

class NanoChatRMSNorm(torch.nn.Module):
    def __init__(self, eps: float = 1e-6):
        super().__init__()
        self.eps = eps

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        return self._norm(x.float()).type_as(x)

    def extra_repr(self):
        return f"eps={self.eps}"

If we review a model in transformers, we can review both sides and learn from the math and literature of the model’s implementation. Due to the educational nature of nanochat, I thought that it was a perfect opportunity to explore this aspect of transformers and share what I learnt with students.

Why do we need nanochat in transformers?

It might seem counterintuitive to support an educational model like nanochat in a production grade library like transformers. After all, we can see from nanochat’s benchmark scores that it does not rival state of the art models like Qwen3, SmolLM3, Gemma3, or Olmo3. In fact, that’s the reason we think nanochat should be in transformers. Here’s what the community gains from its inclusion:

Firstly, as mentioned above transformers teaches us about the modeling conventions that Karpathy uses from other canonical implementations.

Secondly, because transformers is a standard within the ecosystem, it unlocks more downstream learning in post training libraries, quantisation tools, inference libraries, and device integrations. In practical terms, here are some examples nanochat students could learn on top of transformers:

Finally, training AI models is expensive. Running the nanochat speedrun.sh costs between $200 and $2k depending on the model size we use. Which is little compared to the millions of dollars invested by frontier labs. But that is still a significant sum for students, who always learn best by taking a few chances to fail and build experience.

In short, let’s unlock more opportunities for education!

The nanochat architecture

As described by Karpathy, nanochat uses an archetypal architecture that is common across the field, which makes it an excellent choice for an educational resource because folk get to learn from what works. The core model implementation demonstrates modern transformer architecture, with every design decision documented and justified.

The configuration uses a single complexity slider: depth. Set --depth=20 and everything else automatically adjusts. Model dimension equals depth × 64 (20 layers → 1,280 dimensions). Number of attention heads equals depth ÷ 2 (10 heads). Head dimension is fixed at 128. This “aspect ratio philosophy” simplifies scaling. So if you want a more capable model or have a bigger budget. Just increase depth to 26 ($300 budget) or 30 ($1,000 budget).

The architecture incorporates five key improvements over vanilla transformers. Let’s work through the components of this architecture and compare them across implementation:

Forward pass based on the Llama Architecture

The forward pass in nanochat handles both training and generation. We can simply read that the input x is embedded and then updated by each layer then the head. During training, a loss is calculated and returned instead of the logits themselves.

def forward(self, x, targets=None, loss_reduction='mean'):
    x = self.token_emb(x)
    for layer in self.layers:
        x = layer(x)
    x = self.ln_f(x)
    logits = self.lm_head(x)
    
    if targets is not None:
        loss = F.cross_entropy(
            logits.view(-1, self.vocab_size),
            targets.view(-1),
            ignore_index=-1,
            reduction=loss_reduction
        )
        return loss
    return logits

By returning loss directly when targets are provided, the training loop becomes trivial. No separate loss computation, no manual masking logic—just loss = model(inputs, targets) followed by loss.backward().

transformers has to make things a bit more complex to facilitate the downstream ecosystem that uses logits in a broad spectrum of ways. Therefore, loss calculation is dealt with in training-specific code, and the forward function returns BaseModelOutputWithPast.

class NanoChatModel(LlamaModel):
    def __init__(self, config: NanoChatConfig):
        super().__init__(config)

        self.initial_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
        self.norm = NanoChatRMSNorm(eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        cache_position: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> BaseModelOutputWithPast:
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)

        if use_cache and past_key_values is None:
            past_key_values = DynamicCache(config=self.config)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position: torch.Tensor = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = create_causal_mask(
            config=self.config,
            input_embeds=inputs_embeds,
            attention_mask=attention_mask,
            cache_position=cache_position,
            past_key_values=past_key_values,
            position_ids=position_ids,
        )

        hidden_states = inputs_embeds
        position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)

        hidden_states = self.initial_norm(hidden_states)  # Additional norm before the layers
        for decoder_layer in self.layers[: self.config.num_hidden_layers]:
            hidden_states = decoder_layer(
                hidden_states,
                attention_mask=causal_mask,
                position_embeddings=position_embeddings,
                position_ids=position_ids,
                past_key_values=past_key_values,
                cache_position=cache_position,
                **kwargs,
            )

        hidden_states = self.norm(hidden_states)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values,
        )

Rotary Position Embeddings (RoPE)

Rotary Position Embeddings (RoPE) replace learned positional encodings by rotating query and key vectors using precomputed sin/cos frequencies:

def apply_rope(x, cos, sin):
    x1, x2 = x[..., ::2], x[..., 1::2]
    y1 = x1 * cos - x2 * sin
    y2 = x1 * sin + x2 * cos
    return torch.stack([y1, y2], dim=-1).flatten(-2)

In transformers, the rotary embeddings are implemented like so:

from ..llama.modeling_llama import (
    LlamaDecoderLayer,
    LlamaModel,
    LlamaPreTrainedModel,
    LlamaRotaryEmbedding,
    apply_rotary_pos_emb,
    eager_attention_forward,
)


class NanoChatRotaryEmbedding(LlamaRotaryEmbedding):
    pass


def rotate_half(x):
    """Rotates half the hidden dims of the input with flipped signs for NanoChat."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((x2, -x1), dim=-1)

NanoChatRotaryEmbedding almost entirely inherits from the original Llama series, except for a sign inversion in rotate_half.

QK Normalization

NanoChat applies RMSNorm to queries and keys before computing attention to stabilize training.

In the original gpt.py, this is achieved via a functional norm helper applied directly inside the attention forward pass:

def norm(x):
    # Purely functional rmsnorm with no learnable params
    return F.rms_norm(x, (x.size(-1),))

class CausalSelfAttention(nn.Module):
    ...
    def forward(self, x, cos_sin, kv_cache):
        B, T, C = x.size()

        # Project the input to get queries, keys, and values
        q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
        k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
        v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)

        # Apply Rotary Embeddings to queries and keys to get relative positional encoding
        cos, sin = cos_sin
        q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
        q, k = norm(q), norm(k) # QK norm
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
	  ...

In the modular transformers implementation, we see a fascinating mix of lineages. The NanoChatRMSNorm inherits directly from Llama4TextL2Norm, while the attention mechanism inherits from Qwen3Attention. We simply inject the QK normalization into the Qwen3 logic:


class NanoChatRMSNorm(Llama4TextL2Norm):
    pass

class NanoChatAttention(Qwen3Attention):
    def __init__(self, config: NanoChatConfig, layer_idx: int):
        super().__init__(config, layer_idx)
        del self.sliding_window
        del self.layer_type

        self.q_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
        self.k_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        # RoPE -> Norm (instead of usual Norm -> RoPE)
        query_states = self.q_norm(query_states)
        key_states = self.k_norm(key_states)

        if past_key_values is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights

Untied Weights

Karpathy’s implementation deliberately unties the weights between the token embedding and the language model head to provide the model with more flexibility. In gpt.py, these are initialized as two completely separate modules:

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict({
            "wte": nn.Embedding(config.vocab_size, config.n_embd),
            "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
        })
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        # ... (rest of init)

In the modular implementation, we inherit from Gemma2ForCausalLM. Gemma 2 also used untied weights and advanced output structures. By simply inheriting the class, we pull in all the necessary machinery for causal generation, while the configuration object (defined elsewhere) ensures the weights remain untied. Though Gemma 2 ties weights by default, we inherit primarily for code structure alignment and softcapping support—the tie_word_embeddings config flag controls the behavior, with _tied_weights_keys defining the mapping if applied:

class NanoChatForCausalLM(Gemma2ForCausalLM):
    def forward(self, **super_kwargs) -> CausalLMOutputWithPast:
        super().forward(**super_kwargs)

ReLU² Activation

The original implementation replaces the standard GELU activation with ReLU², which is simply ReLU squared. This provides a faster alternative without performance loss. In gpt.py, this is hardcoded into the MLP block:

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
    def forward(self, x):
        x = self.c_fc(x)
        x = F.relu(x).square()
        x = self.c_proj(x)
        return x

In the modular file, we see another surprising inheritance: CLIPMLP. The CLIP architecture uses a structure that fits our needs perfectly, so we inherit the structural definition from CLIP and let the configuration drive the specific activation function (ReLU2):

class NanoChatMLP(CLIPMLP):
    def __init__(self, config):
        super().__init__(config)
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)

Multi-Query Attention (MQA)

NanoChat uses Multi-Query Attention (MQA) to reduce the memory footprint of the KV cache, using 6 query heads but only 1 key/value head (in the default config). This is a common configuration for smaller models like nanochat.

In gpt.py, this logic is handled by passing distinct head counts and relying on PyTorch’s functional attention to handle the broadcasting (or explicitly handling it during inference):

class CausalSelfAttention(nn.Module):
    # ...
    def forward(self, x, cos_sin, kv_cache):
        # ...
        # Attention: queries attend to keys/values autoregressively. A few cases to handle:
        enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
        if kv_cache is None or Tq == Tk:
            # During training (no KV cache), attend as usual with causal attention
            # And even if there is KV cache, we can still use this simple version when Tq == Tk
            y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
        elif Tq == 1:
            # During inference but with a single query in this forward pass:
            # The query has to attend to all the keys/values in the cache
            y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
        else:
            # During inference AND we have a chunk of queries in this forward pass:
            # First, each query attends to all the cached keys/values (i.e. full prefix)
            attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
            prefix_len = Tk - Tq
            if prefix_len > 0: # can't be negative but could be zero
                attn_mask[:, :prefix_len] = True
            # Then, causal attention within this chunk
            attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
            y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
        # ...

In modular_nanochat.py, we don’t need to write this logic at all. As seen in the QK Normalization section above, NanoChatAttention inherits from Qwen3Attention. The Qwen3 implementation is robust and fully supports GQA/MQA out of the box. By using this parent class, we get production-grade attention implementation “for free,” allowing us to focus solely on the unique normalizations required by NanoChat.

Conclusion

It’s very clear that Andrej Karpathy’s implementation offers 10 times more to learn from than the transformer version which inherits almost entirely from existing models or features. That said, we can still take more away from the inherited modular modeling implementation. Models like Llama, Llama4, Gemma2, Qwen3, and CLIP are all reused to create a genuinely canonical implementation of a transformer.

Ok. Let’s cut the philosophy and see what we can do with nanochat in transformers.

Example 1: Inference on nanochat in Transformers

First bonus tutorial will help you to do basic inference in transformers:

import torch
from transformers import AutoTokenizer, NanoChatForCausalLM

tokenizer = AutoTokenizer.from_pretrained("nanochat-students/nanochat-d20")
model = NanoChatForCausalLM.from_pretrained("nanochat-students/nanochat-d20")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

prompt = "Hello, how are you?"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
inputs.pop("token_type_ids", None)
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Inference in transformers with vLLM

Next, let’s use transformers as a backend for vLLM to serve the model for optimized inference.

We’ll need to install vLLM from main:

pip install git+https://github.com/huggingface/transformers.git@main

Then we can start a vLLM server like so:

vllm serve nanochat-students/nanochat-d20 --enforce-eager --revision refs/pr/1

Finally, we can call the server like so:

curl -X POST "http://localhost:8000/v1/completions" \
	-H "Content-Type: application/json" \
	--data '{
		"model": "nanochat-students/nanochat-d20",
		"prompt": "Once upon a time,",
		"max_tokens": 512,
		"temperature": 0.5
	}'

Inference on your trained nanochat weights

Let’s say you’ve followed the nanochat repo and used it to train a model. Then you can add transformer compatibility to your model and use it in other libraries.

  1. download any nanochat checkpoint from the hub. Here we use Karpathy’s but this could be yours:
hf download karpathy/nanochat-d34 --local-dir nanochat-d34
  1. convert the checkpoint to transformers format using the conversion scripts:
uv run \
--with "transformers @ git+https://github.com/huggingface/transformers.git@main" \
--with "tiktoken>=0.12.0" \
https://raw.githubusercontent.com/huggingface/transformers/main/src/transformers/models/nanochat/convert_nanochat_checkpoints.py \
--input_dir ./nanochat-d34 \
--output_dir ./nanochat-d3-hf
  1. (optional) Upload the checkpoint to the Hugging Face Hub
hf upload <username>/nanochat-d34 nanochat-d34
  1. As above, you can generate with your model in transformers.
import torch
from transformers import AutoTokenizer, NanoChatForCausalLM

tokenizer = AutoTokenizer.from_pretrained("./nanochat-d3-hf")
model = NanoChatForCausalLM.from_pretrained("./nanochat-d3-hf")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

prompt = "Hello, how are you?"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
inputs.pop("token_type_ids", None)
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Example 2: Supervised Fine-tuning in torch

Supervised Fine-Tuning (SFT) is the process of adapting a pre-trained language model to follow instructions by training it on curated input-output pairs. Unlike pre-training which learns general language patterns from massive text corpora, SFT teaches the model how to respond—following a specific format, tone, or task structure.

In this tutorial, we’ll fine-tune the NanoChat model using pure PyTorch, giving you complete visibility into every step of the training process.

Want a production-ready solution? TRL is Hugging Face’s reinforcement learning library with battle-tested SFT implementations. Check out the SFT notebook to use it with your nanochat checkpoint.

Import model and tokenizer

We start by loading the pre-trained NanoChat model and its tokenizer. The revision parameter points to a specific model version—useful when models are updated frequently or you want reproducible results.

import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup


model_id = "karpathy/nanochat-d32"
revision = "refs/pr/1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    revision=revision,
    torch_dtype=torch.bfloat16 if device.type == "cuda" else torch.float32,
).to(device)

We use bfloat16 precision on GPU to reduce memory usage while maintaining training stability. On CPU, we fall back to float32 for compatibility.

Setup LoRA

Training all 1.8B parameters would require significant GPU memory and risk catastrophic forgetting. Instead, we use LoRA (Low-Rank Adaptation) which freezes the original weights and injects small trainable matrices into specific layers.

The key parameters:

from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=1,
    lora_alpha=2,
    lora_dropout=0.00,
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "fc1", "fc2"]
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
trainable params: 1,179,648 || all params: 1,880,227,840 || trainable%: 0.0627

With LoRA, we’re only training 0.06% of the model’s parameters—just over 1 million weights instead of 1.8 billion. This makes fine-tuning feasible on consumer hardware.

Demo the model

Before training, let’s verify the model works correctly. We’ll test two modes: raw text completion and chat-formatted generation.

Plain autoregressive completion continues text naturally:

print("=" * 80)
print("TEST 1: Plain Autoregressive Prompt")
print("=" * 80)
prompt = "The Eiffel Tower stands in Paris and"
test_inputs = tokenizer(prompt, return_tensors="pt").to(device)


with torch.no_grad():
    test_outputs = model.generate(
        **test_inputs,
        max_new_tokens=64,
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id,
    )

generated_tokens = test_outputs[0, test_inputs["input_ids"].shape[1] :]
print(f"Prompt: {prompt}")
print(f"\nGenerated: {tokenizer.decode(generated_tokens, skip_special_tokens=True)}")
print("=" * 80)
================================================================================
TEST 1: Plain Autoregressive Prompt
================================================================================
Prompt: The Eiffel Tower stands in Paris and

Generated:  is one of the most famous landmarks in the world. It is located on the Champ de Mars in the heart of the city. The tower was built for the 1889 World's Fair. It was designed by the French engineer Gustave Eiffel and took 2 years to build. The Eiffel Tower stands 324 meters
================================================================================

The chat template wraps the input in special tokens that the model learned during instruction tuning:

print("=" * 80)
print("TEST 2: Chat Template")
print("="*80)
conversation = [
    {"role": "user", "content": "What is the capital of France?"},
]

inputs = tokenizer.apply_chat_template(
    conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
).to(device)

print(f"Formatted prompt: {tokenizer.decode(inputs['input_ids'][0])}")
print(f"Input IDs: {inputs['input_ids'][0].tolist()}")

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=64,
        do_sample=False
    )

generated_tokens = outputs[0, inputs["input_ids"].shape[1] :]
print(f"\nGenerated: {tokenizer.decode(generated_tokens)}")
print("=" * 80)
================================================================================
TEST 2: Chat Template
================================================================================
Formatted prompt: <|bos|><|user_start|>What is the capital of France?<|user_end|><|assistant_start|>
Input IDs: [65527, 65528, 1442, 309, 261, 3429, 281, 4215, 63, 65529, 65530]

Generated: The capital of France is Paris.<|assistant_end|>
================================================================================

Notice the special tokens: <|bos|>, <|user_start|>, <|assistant_start|>, etc. These delimiters help the model understand conversation structure.

Dataset

For SFT, we need high-quality instruction-response pairs. We’ll use OpenThoughts, a dataset designed for training models to reason step-by-step before answering.

raw_dataset = load_dataset("HuggingFaceTB/smoltalk2", "SFT", split="OpenThoughts3_1.2M_think")
splits = raw_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = splits["train"]
eval_dataset = splits["test"]

Raw examples contain message lists that need to be converted into token sequences. The apply_chat_template method handles this conversion, inserting the appropriate special tokens.

We limit examples to 2048 tokens and cap the dataset size to make training tractable on limited hardware:

max_length = 2048
max_train_examples = 20000
max_eval_examples = 1000

def format_example(example):
    formatted = tokenizer.apply_chat_template(
        example["messages"],
        add_generation_prompt=False,
        truncation=True,
        max_length=max_length,
        padding=False,
        return_dict=True,
        return_tensors="pt",
    )
    return {
        "input_ids": formatted["input_ids"][0].tolist(),
        "attention_mask": formatted["attention_mask"][0].tolist(),
    }


train_dataset = train_dataset.select(range(min(len(train_dataset), max_train_examples)))
train_dataset = train_dataset.map(format_example, remove_columns=train_dataset.column_names)

eval_dataset = eval_dataset.select(range(min(len(eval_dataset), max_eval_examples)))
eval_dataset = eval_dataset.map(format_example, remove_columns=eval_dataset.column_names)

The collate pads variable-length sequences to the same length within each batch and creates the labels tensor for loss computation:

def collate_fn(batch):
    batch_dict = {
        "input_ids": [record["input_ids"] for record in batch],
        "attention_mask": [record["attention_mask"] for record in batch],
    }
    padded = tokenizer.pad(batch_dict, padding=True, return_tensors="pt")
    labels = padded["input_ids"].clone()
    labels[padded["attention_mask"] == 0] = -100
    padded["labels"] = labels
    return padded


train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn)
eval_loader = DataLoader(eval_dataset, batch_size=eval_batch_size, shuffle=False, collate_fn=collate_fn)

Setting padding tokens to -100 in labels tells PyTorch’s cross-entropy loss to ignore them—we don’t want to penalize the model for not predicting padding.

Training

These hyperparameters control the training dynamics. We use conservative values that work well across different hardware:

train_batch_size = 2
eval_batch_size = 2
num_epochs = 1
gradient_accumulation_steps = 4
learning_rate = 1e-5
weight_decay = 0.0
warmup_ratio = 0.03
logging_frequency = 10

Key configuration choices include using a low learning rate (1e-5), as LoRA generally requires smaller learning rates given that the base model weights are kept frozen. Additionally, gradient accumulation is employed to enable larger effective batch sizes, which helps when training on GPUs with limited memory.

Optimizer

AdamW is the standard optimizer for transformer fine-tuning. It combines Adam’s adaptive learning rates with proper weight decay:

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=learning_rate,
    weight_decay=weight_decay,
)

Learning Rate Scheduler

A linear schedule with warmup gradually increases the learning rate at the start of training (warmup), then linearly decreases it to zero. This helps stabilize early training and improves final performance:

num_update_steps_per_epoch = max(len(TrainLoader) // gradient_accumulation_steps, 1)
max_train_steps = num_epochs * num_update_steps_per_epoch
warmup_steps = max(1, int(max_train_steps * warmup_ratio))
scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, max_train_steps)

Now we bring everything together. The training loop follows the standard PyTorch pattern with gradient accumulation:

  1. Forward pass: Compute loss on a mini-batch
  2. Backward pass: Accumulate gradients
  3. Optimizer step: Update weights (every gradient_accumulation_steps batches)
  4. Logging: Track loss and learning rate
  5. Evaluation: Measure validation loss after each epoch
model.train()
global_step = 0
running_loss = 0.0
running_steps = 0

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    optimizer.zero_grad(set_to_none=True)
    for step, batch in enumerate(TrainLoader, start=1):
        batch = {key: value.to(device) for key, value in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss / gradient_accumulation_steps
        loss.backward()

        running_loss += outputs.loss.float().item()
        running_steps += 1

        if step % gradient_accumulation_steps == 0 or step == len(TrainLoader):
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad(set_to_none=True)
            global_step += 1

            if global_step % logging_frequency == 0:
                current_lr = scheduler.get_last_lr()[0]
                mean_loss = running_loss / running_steps
                print(f"step={global_step:05d} | loss={mean_loss:.4f} | lr={current_lr:.2e}")
                running_loss = 0.0
                running_steps = 0

    train_loss = running_loss / running_steps if running_steps > 0 else float("nan")
    print(f"Training loss after epoch {epoch + 1}: {train_loss:.4f}")

    model.eval()
    losses = []
    with torch.no_grad():
        for _, batch in enumerate(EvalLoader, start=1):
            batch = {key: value.to(device) for key, value in batch.items()}
            loss = model(**batch).loss
            losses.append(loss.float().item())
    model.train()
    val_loss = sum(losses) / len(losses) if losses else float("nan")

    print(f"Validation loss after epoch {epoch + 1}: {val_loss:.4f}")

print("Training complete.")
Epoch 1/1
step=00010 | loss=1.7586 | lr=1.33e-06
step=00020 | loss=1.8188 | lr=2.67e-06
step=00030 | loss=1.8235 | lr=4.00e-06
step=00040 | loss=1.7935 | lr=5.33e-06
step=00050 | loss=1.8029 | lr=6.67e-06
...

Example 3: Fine-tuning with TRL

Finally, we can implement the training loop above with TRL. Which definitely simplifies the code and abstracts away a lot of the complexity (education). But it’s got all the bells and whistles of a production-ready solution.

We can define the training arguments and create the trainer object like this:

from trl import SFTConfig

training_args = SFTConfig(
    per_device_train_batch_size = 1,      # Batch size per GPU
    gradient_accumulation_steps = 4,      # Gradients are accumulated over multiple steps → effective batch size = 2 * 8 = 16
    warmup_steps = 5,
    # num_train_epochs = 1,               # Number of full dataset passes. For shorter training, use `max_steps` instead (this case)
    max_steps = 30,
    learning_rate = 2e-4,                 # Learning rate for the optimizer
    optim = "paged_adamw_8bit",           # Optimizer

    # Logging / reporting
    logging_steps=1,                      # Log training metrics every N steps
    report_to="trackio",                  # Experiment tracking tool
    trackio_space_id=output_dir,          # HF Space where the experiment tracking will be saved
    output_dir=output_dir,                # Where to save model checkpoints and logs

    max_length=1024,                      # Maximum input sequence length
    use_liger_kernel=True,                # Enable Liger kernel optimizations for faster training
    activation_offloading=True,           # Offload activations to CPU to reduce GPU memory usage
    gradient_checkpointing=True,          # Save memory by re-computing activations during backpropagation

    # Hub integration
    push_to_hub=True,                     # Automatically push the trained model to the Hugging Face Hub
                                          # The model will be saved under your Hub account in the repository named `output_dir`
)

Then we can train the model like this and TRL will deal with data loading, batching, and training.

from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    peft_config=peft_config
)

And then we can train the model like this:

trainer_stats = trainer.train()