Vision Transformer Overview | Generated by AI
Introduction to Vision Transformer (ViT)
The Vision Transformer (ViT) is a neural network architecture that adapts the transformer model—originally designed for natural language processing (NLP)—to computer vision tasks, particularly image classification. Introduced in 2020 by Alexey Dosovitskiy and colleagues at Google in the paper An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale, ViT treats images as sequences of patches rather than processing them with traditional convolutional neural networks (CNNs). This allows it to leverage the transformer’s strengths in handling long-range dependencies and parallel computation.
ViT has shown competitive or superior performance to CNNs on large-scale datasets like ImageNet, especially when pre-trained on massive amounts of data (e.g., JFT-300M). Variants like DeiT (Data-efficient Image Transformers) make it more efficient for smaller datasets. Today, ViT-inspired models power many vision tasks in models like DALL-E, Stable Diffusion, and modern classifiers.
How ViT Works: Overall Architecture and Workflow
ViT’s core idea is to “tokenize” an image into a sequence of fixed-size patches, similar to how text is broken into words or tokens. This sequence is then processed by a standard transformer encoder (no decoder, unlike generative text models). Here’s a step-by-step breakdown of how it works:
- Image Preprocessing and Patch Extraction:
- Start with an input image of size \(H \times W \times C\) (height × width × channels, e.g., 224 × 224 × 3 for RGB).
- Divide the image into non-overlapping patches of fixed size \(P \times P\) (e.g., 16 × 16 pixels). This yields \(N = \frac{HW}{P^2}\) patches (e.g., 196 patches for a 224×224 image with 16×16 patches).
- Each patch is flattened into a 1D vector of length \(P^2 \cdot C\) (e.g., 768 dimensions for 16×16×3).
- Why patches? Raw pixels would create an impractically long sequence (e.g., millions for a high-res image), so patches act as “visual words” to reduce dimensionality.
- Patch Embedding:
- Apply a learnable linear projection (a simple fully connected layer) to each flattened patch vector, mapping it to a fixed embedding dimension \(D\) (e.g., 768, matching BERT-like transformers).
- This produces \(N\) embedding vectors, each of size \(D\).
- Optionally, add a special [CLS] token embedding (a learnable vector of size \(D\)) prepended to the sequence, similar to BERT for classification tasks.
- Positional Embeddings:
- Add learnable 1D positional embeddings to the patch embeddings to encode spatial information (transformers are permutation-invariant without this).
- The full input sequence is now: \([ \text{[CLS]}, \text{patch}_1, \text{patch}_2, \dots, \text{patch}_N ] + \text{positions}\), a matrix of shape \((N+1) \times D\).
- Transformer Encoder Blocks:
- Feed the sequence into \(L\) stacked transformer encoder layers (e.g., 12 layers).
- Each layer consists of:
- Multi-Head Self-Attention (MSA): Computes attention scores between all pairs of patches (including [CLS]). This allows the model to capture global relationships, like “this cat’s ear relates to the whisker 100 patches away,” unlike CNNs’ local receptive fields.
- Formula: Attention(Q, K, V) = \(\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V\), where Q, K, V are projections of the input.
- Multi-Layer Perceptron (MLP): A feed-forward network (two linear layers with GELU activation) applied position-wise.
- Layer normalization and residual connections: Input + MSA → Norm → MLP → Norm + Input.
- Multi-Head Self-Attention (MSA): Computes attention scores between all pairs of patches (including [CLS]). This allows the model to capture global relationships, like “this cat’s ear relates to the whisker 100 patches away,” unlike CNNs’ local receptive fields.
- Output: A sequence of refined embeddings, still \((N+1) \times D\).
- Classification Head:
- For image classification, extract the [CLS] token’s output (or take the mean of all patch embeddings).
- Pass it through a simple MLP head (e.g., one or two linear layers) to output class logits.
- During training, use cross-entropy loss on labeled data. Pre-training often involves masked patch prediction or other self-supervised tasks.
Key Hyperparameters (from the original ViT-Base model):
- Patch size \(P\): 16
- Embedding dim \(D\): 768
- Layers \(L\): 12
- Heads: 12
- Parameters: ~86M
ViT scales well: Larger models (e.g., ViT-Large with \(D=1024\), \(L=24\)) perform better but need more data/compute.
Training and Inference:
- Training: End-to-end on labeled data; benefits hugely from pre-training on billions of images.
- Inference: Forward pass through the encoder (~O(N²) time due to attention, but efficient with optimizations like FlashAttention).
- Unlike CNNs, ViT has no inductive biases like translation invariance—everything is learned.
Comparison to Text Transformers: Similarities and Differences
ViT is fundamentally the same architecture as the encoder part of text transformers (e.g., BERT), but adapted for 2D visual data. Here’s a side-by-side comparison:
| Aspect | Text Transformer (e.g., BERT) | Vision Transformer (ViT) |
|---|---|---|
| Input Representation | Sequence of tokens (words/subwords) embedded into vectors. | Sequence of image patches embedded into vectors. Patches are like “visual tokens.” |
| Sequence Length | Variable (e.g., 512 tokens for a sentence). | Fixed based on image size/patch size (e.g., 197 with [CLS]). |
| Positional Encoding | 1D (absolute or relative) for word order. | 1D (learnable) for patch order (e.g., row-major flattening). No built-in 2D structure. |
| Core Mechanism | Self-attention over tokens to model dependencies. | Self-attention over patches—same math, but attends to spatial “relationships” instead of syntactic ones. |
| Output/Tasks | Encoder for classification/Masked LM; decoder for generation. | Encoder-only for classification; can be extended for detection/segmentation. |
| Strengths | Handles long-range text dependencies. | Global context in images (e.g., whole-scene understanding). |
| Weaknesses | Needs huge text corpora. | Data-hungry; struggles on small datasets without CNN pre-training. |
| Prediction Style | Next-token prediction in decoders (autoregressive). | No “next” prediction inherently—classifies the whole image holistically. |
In essence, ViT is a “plug-and-play” swap: Replace token embeddings with patch embeddings, and you get a vision model. Both rely on attention to weigh relationships in a sequence, but text is inherently sequential/linear, while images are spatial (ViT learns this via attention).
Addressing “Next Token” vs. “Next Pixel” in ViT
No, ViT does not predict the “next pixel” like a text transformer predicts the “next token” in autoregressive generation (e.g., GPT). Here’s why:
-
Text Transformers (Autoregressive): In models like GPT, the decoder generates sequentially—one token at a time, conditioning on all previous ones. It’s pixel-by-pixel for images in some generative models (e.g., PixelRNN), but inefficient.
-
ViT’s Approach: ViT is non-autoregressive and holistic. It processes the entire image (all patches) in parallel via the encoder. There’s no sequential “next” prediction during inference—attention computes relationships across all patches simultaneously. For generation tasks (e.g., inpainting), ViT variants might use a decoder to predict masked patches autoregressively, but the base ViT is for classification, not generation.
-
Why Not Pixels? Treating individual pixels as tokens would explode the sequence length (e.g., 224×224×3 = ~150K tokens), making quadratic attention computationally infeasible (O(N²) = trillions of operations). Patches group ~256 pixels (16×16) into one “token,” keeping N manageable (~200).
If you want generative vision (like next-pixel or next-patch), look at models like VQ-VAE + Transformer or Parti, which build on ViT ideas.
For more details, check the original paper: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. If you’d like code examples (e.g., in PyTorch) or comparisons to specific variants, let me know!