SFT Explained Simply | Generated by AI
Question: What is SFT? I just saw a PyTorch call, but how does it work underneath?
Answer:
SFT stands for Supervised Fine-Tuning. It is the simplest and most common way to train or adapt a large language model (LLM).
At its core, SFT is just next-token prediction on a dataset of instruction-response pairs. There is no magic algorithm—it is essentially the same training procedure used during pretraining, except on a much smaller, curated dataset.
For example, suppose your dataset contains:
User: What is 2 + 2?
Assistant: 4
The model actually sees something like:
<user>
What is 2 + 2?
</user>
<assistant>
4
</assistant>
The tokenizer converts this into IDs:
[123, 456, 789, 234, ...]
The model predicts:
Input:
What is 2 +
Predict:
2
Next step:
Input:
What is 2 + 2
Predict:
?
...
Eventually:
Input:
Assistant:
Predict:
4
It learns by minimizing the prediction error for the assistant’s response.
Under the hood
Suppose you have
input_ids = tokenizer(text).input_ids
The model does
logits = model(input_ids)
where
input_ids
↓
Embedding
↓
Transformer blocks
↓
Linear projection
↓
logits
The output shape is
(batch, sequence_length, vocab_size)
For a 50k-token vocabulary:
(8, 512, 50000)
Each position predicts the next token.
Loss computation
PyTorch typically uses
loss = F.cross_entropy(
logits.view(-1, vocab_size),
labels.view(-1)
)
Internally this is
Softmax
↓
Negative log likelihood
↓
Average over tokens
The loss is
[ -\log(p(\text{correct token})) ]
If the correct next token is
"dog"
and the model predicts
cat : 0.20
dog : 0.60
bird: 0.10
then
loss = -log(0.6)
If instead it predicts
dog : 0.99
the loss becomes very small.
Why is it called “Supervised”?
Because humans provide the desired output.
Example:
Question:
How do I sort a list?
Desired answer:
Use Python's sorted() function...
The model is directly told what the correct response should be.
This is unlike reinforcement learning, where the model receives only a reward signal rather than the exact target output.
Masking the prompt
One important detail in modern SFT is that we usually do not train on the user’s prompt. Instead, we compute the loss only on the assistant’s reply.
For example:
User:
Tell me a joke.
Assistant:
Why did the chicken cross the road?
The labels might look like:
User tokens -> -100
Assistant tokens -> actual token IDs
In PyTorch:
labels = input_ids.clone()
labels[user_part] = -100
The -100 value tells CrossEntropyLoss to ignore those positions.
This means the model is conditioned on the prompt, but it is only penalized based on how well it predicts the assistant’s response.
Backpropagation
Once the loss is computed:
loss
↓
backward()
↓
compute gradients
↓
optimizer.step()
↓
update weights
Every weight in the Transformer is adjusted slightly using gradient descent.
The actual PyTorch training loop
A simplified version looks like this:
for batch in dataloader:
outputs = model(
input_ids=batch["input_ids"],
labels=batch["labels"]
)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
Inside
outputs = model(...)
the Hugging Face model typically does something equivalent to:
logits = transformer(input_ids)
shift_logits = logits[:, :-1]
shift_labels = labels[:, 1:]
loss = F.cross_entropy(
shift_logits.reshape(-1, vocab_size),
shift_labels.reshape(-1),
ignore_index=-100
)
The labels are shifted by one position because the model predicts the next token.
How Hugging Face hides the complexity
When you write:
outputs = model(
input_ids=input_ids,
labels=labels
)
the model automatically:
- Runs the forward pass through the Transformer.
- Shifts logits and labels for next-token prediction.
- Computes the cross-entropy loss.
- Applies
ignore_index=-100to masked positions. - Returns both the
lossand thelogits.
This is why an SFT training script can look deceptively simple.
Overall pipeline
Instruction dataset
│
▼
Tokenizer
│
▼
Token IDs
│
▼
Transformer forward pass
│
▼
Logits
│
▼
Shift by one token
│
▼
Cross-entropy loss
│
▼
Backpropagation
│
▼
AdamW optimizer
│
▼
Updated model weights
In summary, SFT is fundamentally the same next-token prediction objective used during pretraining, but applied to high-quality instruction-following examples. The key differences are the data (instruction-response pairs instead of raw text) and the common practice of masking the prompt so that only the assistant’s response contributes to the training loss.