Together AI Inference Optimization | Generated by AI

Home 2026.06

Question: What does Tri Dao work on at Together AI to optimize inference?

Answer:

Tri Dao is co-founder and chief scientist of Together AI, and also an Assistant Professor at Princeton. His research sits at the intersection of ML and systems, focusing on hardware-aware algorithms and sequence models.

Here’s what he’s actually doing at Together AI on inference:


1. FlashAttention Series (Core Work)

The whole FlashAttention line is his main contribution — and it’s now the backbone of nearly every serious inference stack.

The core insight: Standard attention reads Q, K, V from HBM (slow), computes attention scores, writes back to HBM, reads again for softmax — massive memory bandwidth waste. FA tiles the computation into SRAM-resident blocks, fusing the entire QK^T → softmax → V matmul into a single kernel pass, cutting HBM round-trips from O(N²) to O(N).

FlashAttention-2: Rewritten from scratch using NVIDIA’s CUTLASS 3.x and CuTe primitives. ~2x faster than FA1, reaching up to 230 TFLOPs/s on A100 (FP16/BF16), with 72% model FLOPs utilization for end-to-end GPT training.

FlashAttention-3 (H100-specific): FA2 achieved only 35% utilization of theoretical max FLOPs on H100. FA3 exploits Hopper-specific async instructions — WGMMA (warp group matrix multiply-accumulate) and TMA (Tensor Memory Accelerator) — to overlap operations. Goes from ~350 TFLOPS (FA2 FP16 forward) to ~540-570 TFLOPS. The key technique is warp specialization: separate producer warps issue TMA loads while consumer warps run WGMMA, hiding memory latency behind compute.

FlashAttention-4 (Blackwell): On Blackwell, MMA accumulators live in TMEM (not registers like Hopper), making it practical to keep multiple MMAs in flight while CUDA cores handle element-wise work (softmax, dS). The key backward pass overlap: while computing softmax for tile j, dK and dQ MMAs for tile j-1 are already issued.


2. Together Kernel Collection

At the heart of Together AI’s infrastructure is the Together Kernel Collection — described as “a breakthrough in AI system optimization developed under the leadership of Chief Scientist Tri Dao.” It delivers 2-3x faster inference than hyperscaler solutions. This collection likely packages FA + custom decode kernels + MoE routing kernels under one roof.


3. Three Inference Regimes (His Framework)

Tri Dao identifies three distinct inference optimization targets:

Each regime needs different kernel/scheduling strategies — not one-size-fits-all.


4. Hybrid Mamba-Transformer + Marconi Prefix Caching

Hybrid Mamba-Transformer models now match or exceed leading open models while delivering faster inference and million-token context. The Mamba SSM component reduces the quadratic KV cache cost for long contexts during decode.

From his publications page: Marconi: Prefix Caching for the Era of Hybrid LLMs and Opportunistic Expert Activation: Batch-Aware Expert Routing for Faster Decode Without Retraining are recent papers from his group — both squarely targeting inference throughput. The expert activation paper is particularly relevant given DeepSeek-style MoE deployment.


TL;DR Stack

Layer What he does
Attention kernel FlashAttention 1/2/3/4 — SRAM tiling, IO-awareness, Hopper/Blackwell async
Architecture Mamba/SSM hybrids to kill quadratic KV cache growth
MoE Opportunistic expert activation for faster decode
KV cache Marconi prefix caching for hybrid LLMs
System Together Kernel Collection, targeting 3 inference regimes

The throughput numbers matter for you concretely: if you’re doing RL rollouts or synthetic data gen on your MI300X, FA3/FA4 equivalents (ROCm port of flash-attn exists at Dao-AILab/flash-attention) are worth checking — they’re the difference between 35% and 75%+ HBM utilization on the attention layer.

References:


Back Donate