DeepSeek R1 | Generated by AI

Home PDF

DeepSeek R1 employs a novel approach to reinforcement learning (RL) that significantly emphasizes reasoning capabilities in large language models (LLMs). Their technique differs from traditional Reinforcement Learning from Human Feedback (RLHF) in several key aspects.

Here’s a breakdown of the techniques DeepSeek used in R1’s reinforcement learning:

1. Pure Reinforcement Learning (RL) for Initial Reasoning (DeepSeek-R1-Zero):

2. Multi-Stage Training for Enhanced Readability and General Capabilities (DeepSeek-R1):

To address the limitations of DeepSeek-R1-Zero (like poor readability and language mixing), DeepSeek-R1 employed a more comprehensive multi-stage training pipeline:

Key Differences from Traditional RLHF:

Code to Show Reinforcement Learning (Conceptual and Simplified):

It’s challenging to provide a direct, runnable code example that fully replicates DeepSeek’s RL training process due to its complexity and scale. However, the following conceptual PyTorch-like snippet illustrates the core idea of GRPO and a rule-based reward:

import torch
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer

# Assume you have a pre-trained language model and tokenizer
model_name = "gpt2"  # Replace with a more suitable base model
policy_model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
optimizer = optim.AdamW(policy_model.parameters(), lr=5e-6)
device = "cuda" if torch.cuda.is_available() else "cpu"
policy_model.to(device)

def generate_responses(prompt, num_responses=4, max_length=128):
    input_tokens = tokenizer(prompt, return_tensors="pt").to(device)
    outputs = policy_model.generate(
        input_tokens.input_ids,
        max_length=max_length,
        num_return_sequences=num_responses,
        do_sample=True,
        top_k=50,
        top_p=0.95,
        pad_token_id=tokenizer.eos_token_id
    )
    responses = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
    return responses

def calculate_accuracy_reward(response):
    # Simplified example for a math problem: "What is 2 + 2?"
    if "2 + 2" in response and "4" in response:
        return 1.0
    else:
        return 0.0

def calculate_format_reward(response):
    if "<think>" in response and "</think>" in response:
        return 0.5
    else:
        return 0.0

def calculate_combined_reward(response):
    accuracy_reward = calculate_accuracy_reward(response)
    format_reward = calculate_format_reward(response)
    return accuracy_reward + format_reward

def train_step(prompt, num_samples=4):
    optimizer.zero_grad()
    responses = generate_responses(prompt, num_samples=num_samples)
    rewards = torch.tensor([calculate_combined_reward(resp) for resp in responses]).float().to(device)

    # Simplified GRPO-like update: Encourage higher reward responses
    best_reward_index = torch.argmax(rewards)
    best_response = responses[best_reward_index]
    inputs = tokenizer(prompt + best_response, return_tensors="pt").to(device)
    outputs = policy_model(**inputs, labels=inputs.input_ids)
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    return loss.item(), best_response, rewards.tolist()

# Training loop (very simplified)
num_episodes = 10
training_prompts = ["Solve: 2 + 2 = ?", "Explain the concept of gravity <think>", "Write a short story about a cat."]

for episode in range(num_episodes):
    prompt = training_prompts[episode % len(training_prompts)]
    loss, best_response, rewards = train_step(prompt)
    print(f"Episode {episode+1}, Loss: {loss:.4f}, Best Response: '{best_response}', Rewards: {rewards}")

print("Simplified RL Training Done!")

Important Notes on the Code:

This explanation and the conceptual code provide a foundational understanding of the innovative reinforcement learning techniques employed by DeepSeek in the development of their R1 model, particularly their focus on pure RL and rule-based rewards for fostering strong reasoning capabilities.


Back 2025.04.04 Donate