Testing FlashAttention 2 Verification | Generated by AI

Home 2026.06

Question: How does this code test whether FlashAttention 2 is working correctly?

Answer:

Your script performs three levels of verification:

1. Verify PyTorch sees the GPU

import torch
print('torch:', torch.__version__)
print('GPU:', torch.cuda.get_device_name(0))

This confirms:

Example output:

torch: 2.8.0+rocm
GPU: AMD Instinct MI300X

If this fails, FlashAttention cannot work because PyTorch itself cannot access the GPU.


2. Verify FlashAttention 2 imports successfully

import flash_attn
print('flash_attn version:', flash_attn.__version__)

from flash_attn import flash_attn_func
print('flash_attn_func imported OK')

This checks:

Typical success:

flash_attn version: 2.8.3
flash_attn_func imported OK

Typical failure:

ImportError: cannot import name flash_attn_func

or

undefined symbol ...

which usually indicates a build problem.


3. Execute a real FlashAttention kernel

This is the most important part.

q = torch.randn(2, 1024, 12, 128,
                device='cuda',
                dtype=torch.bfloat16)

k = torch.randn(2, 1024, 12, 128,
                device='cuda',
                dtype=torch.bfloat16)

v = torch.randn(2, 1024, 12, 128,
                device='cuda',
                dtype=torch.bfloat16)

out = flash_attn_func(q, k, v, causal=True)

Shape meaning:

Dimension Value Meaning
Batch 2 Two sequences
Seq Len 1024 Context length
Heads 12 Attention heads
Head Dim 128 Head size

FlashAttention 2 computes:

[ \text{softmax}(QK^T)V ]

using the fused kernel.

Success output:

FA2 forward pass: OK, shape torch.Size([2, 1024, 12, 128])

This proves:


4. Test sliding-window attention

out_sw = flash_attn_func(
    q, k, v,
    causal=True,
    window_size=(256, 0)
)

This verifies an additional FlashAttention feature.

Instead of attending to all 1024 tokens, each token only sees:

256 tokens to the left
0 tokens to the right

Success:

FA2 sliding window: OK, shape torch.Size([2, 1024, 12, 128])

This confirms:


5. Verify NanoChat detects FlashAttention

from nanochat.flash_attention import (
    flash_attn as fa,
    USE_FA2,
    HAS_FA2,
    USE_FA3
)

These flags determine what NanoChat will actually use.

Expected on MI300X:

HAS_FA2: True
USE_FA2: True
USE_FA3: False

Meaning:

Flag Meaning
HAS_FA2 FlashAttention 2 installed
USE_FA2 NanoChat will use FA2
USE_FA3 FlashAttention 3 active

For MI300X today, seeing:

HAS_FA2: True
USE_FA2: True
USE_FA3: False

is the ideal result.


Your test only proves correctness, not speed.

To verify that the fused kernel is actually running efficiently:

import time
import torch

torch.cuda.synchronize()

start = time.time()

for _ in range(100):
    out = flash_attn_func(q, k, v, causal=True)

torch.cuda.synchronize()

elapsed = time.time() - start

print("100 iterations:", elapsed)
print("ms/iter:", elapsed * 1000 / 100)

Compare this with PyTorch’s standard attention:

torch.nn.functional.scaled_dot_product_attention

On an MI300X, FlashAttention 2 should generally be noticeably faster and use less memory for long sequence lengths.


If all of the following succeed:

flash_attn_func imported OK
FA2 forward pass: OK
FA2 sliding window: OK
HAS_FA2: True
USE_FA2: True

then FlashAttention 2 is not merely installed—it is actually executing attention kernels successfully on your MI300X and NanoChat can use it during training.


Back Donate