DeepSeek V4 in vLLM: Efficient Long-context Attention

17 min read
vLLM Team

We are excited to announce that vLLM now supports the DeepSeek V4 family of models (deepseek-ai/DeepSeek-V4-Pro and deepseek-ai/DeepSeek-V4-Flash).

These models feature an efficient long-context attention mechanism, purpose-built for tasks involving up to one million tokens. While the new attention design may appear intricate on first reading, its underlying principles are straightforward once examined systematically.

This blog post is organized into three sections:

  • Quickstart guide for serving DeepSeek V4 on vLLM
  • First-principles explanation of DeepSeek V4's new architectural design
  • Overview of our implementation approach and optimization challenges for this model on vLLM: hybrid KV cache, kernel fusion, and disaggregated serving.

This represents our initial release of model support, and further optimizations are actively underway. We hope the technical explanation that follows can help the open-source community understand both the attention mechanism itself and the rationale behind our current implementation decisions.

Running DeepSeek V4 on vLLM

DeepSeek V4 comes with 2 models, a big 1.6T parameter DeepSeek-V4-Pro, and a small 285B parameter DeepSeek-V4-Flash. Both models support up to 1 million tokens of context, and vLLM's implementation of the new attention mechanism is designed to scale to that context length.

DeepSeek-V4-Pro

Here we highlight a single node deployment optimized for easy testing and prototyping, with several optional optimizations like FP4 indexer and MTP. The following command is runnable on 8xB200 or 8xB300.

docker run --gpus all \
  --ipc=host -p 8000:8000 \
  -v ~/.cache/huggingface:/root/.cache/huggingface \
  vllm/vllm-openai:deepseekv4-cu130 deepseek-ai/DeepSeek-V4-Pro \
  --trust-remote-code \
  --kv-cache-dtype fp8 \
  --block-size 256 \
  --enable-expert-parallel \
  --data-parallel-size 8 \
  --compilation-config '{"cudagraph_mode":"FULL_AND_PIECEWISE", "custom_ops":["all"]}' \
  --attention_config.use_fp4_indexer_cache=True \
  --tokenizer-mode deepseek_v4 \
  --tool-call-parser deepseek_v4 \
  --enable-auto-tool-choice \
  --reasoning-parser deepseek_v4

For more deployment strategies, including disaggregated serving/more GPU architectures, please refer to the recipes.

DeepSeek-V4-Flash

Here we highlight a single node deployment optimized for easy testing and prototyping, with several optional optimizations like FP4 indexer and MTP. The following command is runnable on 4xB200 or 4xB300.

docker run --gpus all \
  --ipc=host -p 8000:8000 \
  -v ~/.cache/huggingface:/root/.cache/huggingface \
  vllm/vllm-openai:deepseekv4-cu130 deepseek-ai/DeepSeek-V4-Flash \
  --trust-remote-code \
  --kv-cache-dtype fp8 \
  --block-size 256 \
  --enable-expert-parallel \
  --data-parallel-size 4 \
  --compilation-config '{"cudagraph_mode":"FULL_AND_PIECEWISE", "custom_ops":["all"]}' \
  --attention_config.use_fp4_indexer_cache=True \
  --tokenizer-mode deepseek_v4 \
  --tool-call-parser deepseek_v4 \
  --enable-auto-tool-choice \
  --reasoning-parser deepseek_v4

For more deployment strategies, including disaggregated serving/more GPU architectures, please refer to the recipes.

DeepSeek V4's Attention Mechanism Explained

Long-context inference faces two main challenges:

  • KV cache memory growth: The KV cache scales linearly with context length. While DeepSeek-style models use Multi-head Latent Attention (MLA), which is substantially more memory-efficient than standard Multi-head Attention (MHA) or Multi-Query Attention (MQA), scaling to one million tokens remains difficult given the limited capacity of GPU memory.
  • Attention computation cost: Computing attention over long contexts is expensive. Even with prior techniques such as DeepSeek Sparse Attention (DSA), the computation remains a significant bottleneck.

To address these challenges, the DeepSeek team designed a new attention mechanism aimed at both compressing the KV cache and reducing attention computation time.

  1. Share key and value vectors (2x memory savings). For correctness, we apply an inverse RoPE operation to the attention output.
  2. Compress the KV cache across multiple tokens (4x to 128x memory savings). In DeepSeek V4, there are two ways to do this:
    • c4a: compress the KV cache by roughly 1/4. One compressed token is a weighted sum of 8 uncompressed tokens, with a stride of 4.
    • c128a: compress the KV cache by roughly 1/128. One compressed token is a weighted sum of 128 uncompressed tokens, with a stride of 128.
  3. DeepSeek Sparse Attention (bounded attention computation cost). Even after compressing the KV cache with c4a attention, a one-million-token sequence still has 250k compressed tokens. To accelerate the attention computation, we can use DeepSeek Sparse Attention (DSA) to attend to only top- compressed tokens.
  4. Preserving locality: Short sliding window. DeepSeek V4 uses a sliding window of size 128 for local information, operating on the uncompressed tokens, so that a query token can attend to local information before it reaches the compression boundary.

To better illustrate this new attention mechanism, here's an animation of the c4a attention processing 13 tokens. With the details above in mind, the c128a case should be straightforward to follow as well. Launch the interactive version to hover over tokens and inspect the connections.

Animation of c4a attention
Animation of c4a attention

The efficient attention design leads to substantial KV cache savings. With bf16 KV cache, DeepSeek V4 only has 9.62 GiB KV cache per sequence at 1M context. That is about 8.7x smaller than the 83.9 GiB estimate for a 61-layer DeepSeek V3.2-style stack. In practice, we use fp4 for the indexer cache and fp8 for the attention cache, which further reduces the KV cache size by roughly 2x compared to the bf16 estimate!

Per-layer KV state in DeepSeek V3.2 versus DeepSeek V4.
Per-layer KV state in DeepSeek V3.2 versus DeepSeek V4.

For more detail on the arithmetic and the mathematical interpretation, please refer to the appendix.

vLLM's Implementation of DeepSeek V4

Despite the structural savings, the attention mechanism still carries intrinsic complexity, and realizing those savings efficiently in vLLM is a systems problem with several implementation challenges:

  • Similar to the DeepSeek V3.2 model, the attention kernel uses bfloat16 KV cache for prefill and partially token-wise fp8 for decode.
  • The model uses a mix of c4a and c128a attention, and some attention layers use purely a sliding window for local information without compression. The heterogeneous attention types make KV cache management much more complex.
  • When batching multiple sequences, they might have different states with respect to the KV cache compression boundary.
  • The model ships with native fp4 MoE weights, which require special handling in vLLM.

Aside from the attention mechanism itself, there are several other updates, including architecture changes like Manifold-Constrained Hyper-Connections, and some changes to the MoE module. They are not covered in this post, as they are simpler model changes that are easier to adapt.

vLLM addresses these challenges with optimizations on two fronts: memory management and kernel efficiency.

Keeping the KV Cache Memory Packed

vLLM's KV cache memory allocator has to pack several kinds of KV state tightly in GPU memory while still working with prefix caching, prefill/decode disaggregation, CUDA graphs, and the rest of vLLM's serving path. Three design choices keep this manageable.

(1) A single logical block size

Different layers compress at different rates (1/4 for c4a, 1/128 for c128a, 1/1 for SWA). An obvious design is to size each layer's block around a round number of compressed entries. But then every layer gets its own page layout, and the allocator has to reason about all of them separately.

Instead, we fix the logical block at 256 native token positions for every compressed layer. A c4a block then physically holds 256 / 4 = 64 compressed entries, and a c128a block holds 256 / 128 = 2. Allocating a block always means reserving the next 256 native positions of a request's context, regardless of which layer owns it. Slot mapping, scheduler accounting, and prefix-hit detection can all use that same unit instead of branching on compress_ratio.

(2) Compressor state as a sliding window

Each compressor layer also maintains a small rolling residual per request: an 8-token (overlapped) partial state for C4, and a 128-token partial state for C128. A natural first design is to keep that residual in a per-request side buffer. That works in isolation, but it becomes awkward once it has to interact with the rest of the serving stack.

With a side buffer, prefix caching would need to snapshot the rolling state at every cacheable boundary, key it alongside the prefix hash, and restore it on a hit. Disaggregated prefill would need a second transfer path that ships residuals from prefill workers to decode workers alongside the KV blocks. Each requirement is manageable on its own, but together they create another state-management path to maintain across features.

vLLM avoids this by treating the compressor state like sliding-window KV. The runtime invariant is the same: fixed size per request, advanced as decoding proceeds, with state outside the window either discarded or handled through caching. So we register the compressor state under the sliding-window KV cache spec, with sliding_window = coff * compress_ratio (8 for C4 and 128 for C128), and place it into SWA-style blocks under the same hybrid KV cache manager.

This lets several serving features reuse the same abstraction:

  • Prefix caching reuses the normal block semantics. A cache hit lands on a KV cache block boundary (the 256-position unit above), and the compressor state at that boundary is already the correct handoff point.
  • Disaggregated prefill treats the compressor state like SWA state. Only the blocks inside the window are transferred, which preserves the transfer-size savings without introducing a separate residual-specific transfer path.
  • CUDA graphs and MTP follow the same integration pattern as SWA, while keeping metadata and implementation details specific to the compressor state.

(3) Unifying page sizes

The two choices above are still not enough. A C4 indexer block, a c128a KV block, and a c4a compressor-state block still come in different page sizes (different numbers of bytes per block). If each cache kind gets its own block pool, we end up with the same cross-pool fragmentation we were trying to eliminate.

Fortunately, the page size of each cache kind is the product block_size * compress_ratio * per_entry_size, and all three factors are under our control. If we choose them carefully, the different cache kinds collapse into a small number of page-size buckets, and each bucket can be backed by a single shared block pool.

In our implementation, the entire five-way cache stack fits into three page sizes. Each pool is sized once at load time, and allocation becomes a bucket lookup. There is no runtime repartitioning, no per-kind accounting, and no fragmentation between cache kinds.

  • Largest bucket: c4a main KV, SWA KV, c4a compressor state, c128a compressor state.
  • Middle bucket: C4 indexer KV, C4 indexer compressor state.
  • Smallest bucket: c128a main KV.

Keeping the GPU Busy

Memory layout is only half of the runtime story; the other half is keeping the GPU compute saturated.

vLLM integrates FlashMLA and FlashInfer, which provide optimized attention and MoE kernels. But this model requires many small, mostly memory-bound kernels. We need to avoid extra launches and HBM round-trips that would otherwise slow the full decode path.

c4a decode path: operator graph with kernel fusions (colored outlines) and multi-stream partitioning (default stream = blue band, indexer stream = amber band).
c4a decode path: operator graph with kernel fusions (colored outlines) and multi-stream partitioning (default stream = blue band, indexer stream = amber band).

(1) Kernel Fusion

We deploy three fusions to cut memory round-trips. In the figure below, these appear as the colored outlines around groups of operators.

  • Compressor + RMSNorm + RoPE + cache insertion. After compression, the compressed K immediately goes through RMSNorm, RoPE, and insertion into the following attention's KV cache, either for main attention or for the indexer. Because these stages are almost entirely elementwise, we fuse them into one kernel. We keep separate kernels for the indexer K cache and the main-attention K cache so the parallelization strategy can still be tuned to each head dim. Overall we see a ~1.4-3x speedup over the unfused baseline.
  • Inverse RoPE + fp8 quant. After main attention, the output goes through inverse RoPE and then into the fp8 batched matmul for the o_lora projection. Fusing the two avoids a back-to-back HBM round trip and raises arithmetic intensity, for a ~2-3x speedup over the unfused version.
  • Fused Q norm + KV RoPE + K insert. Before main attention, we need KV cache insertion for both the compressed path and the sliding-window path. The compressed path is already covered by the first fusion, so what remains is elementwise work on the queries and the uncompressed SWA keys. We horizontally fuse that work into a single kernel with static warpID dispatch: each warp works independently on either a Q head or a K head, so no cross-warp communication is needed. This delivers a 10-20x speedup over the naive unfused kernels.

We also reuse fusions from our DeepSeek V3.2 work, including Q RoPE + quant + weight multiply, and the horizontal fusion of QK norm right after QK projection at the start of attention.

(2) Multi-stream

The operations before main attention are highly parallelizable. They break into three pieces: indexer computation, main-attention KV compression, and sliding-window token insertion. After the initial projection these branches are almost independent, so we overlap them across CUDA streams. The same figure can be read a second way here: the blue band marks the default stream, while the amber band marks the indexer stream.

  • For c128a layers, which have no indexer, we run main KV compression in parallel with SWA token insertion.
  • For c4a layers, we run the full indexer pipeline on its own stream in parallel with main KV compression and SWA token insertion (the latter two remain serial with respect to each other).

With these overlaps, we observe a 5-6% end-to-end latency reduction at low batch sizes, a useful sign that the decode path spends less time underutilizing the GPU.

On top of that, we use CUDA graphs to cut launch overhead on the decode path, as we do for every other model.

For the full implementation, see the PR.

Planned Work

We are actively working on the following optimizations to further improve the performance of DeepSeek V4 on vLLM:

  • DeepGEMM MegaMoE kernel
  • Paged prefill kernel

The current implementation mainly targets NVIDIA GPUs, including both the Hopper and Blackwell architectures. The deployment recipes for these accelerators can be found at our recipe website. With vLLM's extensible plugin system, hardware vendors can add support for models directly. For example, vllm-ascend and vllm-mlu both support DeepSeek V4 independently.

Acknowledgments

We want to thank the DeepSeek team for open-sourcing DeepSeek V4, as well as DeepSeek leadership for their trust and support in vLLM! The model support is made possible by the contributions from Inferact Inc., a company aiming to grow vLLM as the world's AI inference engine and accelerate AI progress by making inference cheaper and faster.

Appendix: The Math behind DeepSeek V4's Attention Mechanism

Why inverse RoPE is needed when key and value are shared

Given a query token at position , the query representation after applying RoPE is , where is the rotation matrix with the rotation angles parameterized by the position . Some basic properties of the rotation matrix are:

  • is an orthogonal matrix, i.e.,

Given a set of key tokens at positions , the key representations after applying RoPE are , , ..., , ..., .

For value vectors at positions , usually we don't apply RoPE to them. The value representations are simply , , ..., , ..., .

The attention output is then (omitting some details, such as the scaling factor, for simplicity):

One nice property of the attention output is that it is translation invariant. Any factor that depends on position, namely and , depends only on the relative position between the query and the key. This means the attention output is the same if we shift the query and the key by the same amount.

If we share the key and value vectors, the attention output will be:

Now the output carries absolute position information through the rotation matrix directly. This is not what we want. The way to fix it is simple: we apply an inverse RoPE operation to the attention output:

This way, the output only carries relative position information through the rotation matrix , and it is translation invariant again.

Implementation details: exact position ranges and causality conditions

Care must be taken when processing the compressed KV cache. For each compressed index , we first combine a fixed local group of original tokens, then apply RoPE once using the compressed token's anchor position, and then store that compressed token in the KV cache.

For c4a, the -th compressed token is a weighted sum of tokens in position range , where starts from 0 and negative indices are treated as tokens with value 0. The position of the compressed token, when we apply RoPE to it, is .

For c128a, the -th compressed token is a weighted sum of tokens in position range , where starts from 0. The position of the compressed token, when we apply RoPE to it, is .

For causality, we need to ensure that a query token at position only attends to the information produced by tokens in position range . This means that for a query at position and the -th compressed token in the KV cache, we need to ensure that (for c4a) or (for c128a).

Implementation details: The exact value of k in c4a and c128a

For c4a attention in DeepSeek V4, the default value of is 512, and for c128a attention, the default value of is 8192. (For comparison, in DeepSeek V3.2, the default value of is 2048).

The c128a attention has a larger compression ratio. With a 1 million-token context, it possesses at most 8k compressed tokens. 8k tokens are not a big deal for attention computation, so we can simply use full attention over the c128a compressed tokens. Implementation-wise, we can still frame the c128a attention as a sparse-attention problem whose top- value is 8192.

Implementation details: why the short sliding window is needed

With c128a, a query token at position cannot attend to any compressed token in the KV cache, since the first compressed token contains information from position to , but the query token cannot attend to information after position due to causality. With the short sliding window, the query token can attend to uncompressed tokens in position range , so it can still access local information.

Arithmetic behind the estimates for the 8.7x savings

For a sequence with 1M context:

DeepSeek V3.2 with bf16 KV cache:

  • MLA cache per token per layer: bytes.
  • Indexer cache per token per layer: bytes.
  • Total cached state per token per layer: bytes.
  • At 1,048,576 tokens: GiB per layer.
  • Over 61 layers: about GiB.

DeepSeek V4 at 61 layers with bf16 KV cache:

  • Each shared-KV cached entry stores bytes.
  • Each c4a indexer cached entry stores bytes.
  • c4a layer: shared-KV cache bytes plus indexer cache bytes, for a total of about MiB.
  • c128a layer: MiB.
  • Total across 30 c4a layers and 31 c128a layers: about GiB.