HubLensLLMdeepseek-ai/FlashMLA
deepseek-ai

FlashMLA

AI#DeepSeek#Attention#CUDA#PyTorch#LLM
View on GitHub
12,617

// summary

FlashMLA is a library of high-performance attention kernels specifically designed to power DeepSeek-V3 and DeepSeek-V3.2 models. It provides optimized implementations for both sparse and dense attention mechanisms during prefill and decoding stages. The library supports advanced features like FP8 KV cache and is compatible with various GPU architectures including SM90 and SM100.

// technical analysis

FlashMLA is a specialized library of high-performance attention kernels designed to power DeepSeek's Multi-Head Latent Attention (MLA) models. By providing highly optimized implementations for both dense and sparse attention, the project addresses the computational bottlenecks inherent in large-scale transformer inference, particularly during prefill and decoding stages. Its design prioritizes hardware-level efficiency on NVIDIA architectures, utilizing techniques like FP8 KV cache quantization to maximize throughput while maintaining model accuracy.

// key highlights

01
Provides specialized sparse attention kernels that enable DeepSeek Sparse Attention (DSA) for both prefill and decoding.
02
Supports FP8 KV cache quantization to significantly reduce memory overhead during decoding while maintaining bfloat16 computation precision.
03
Achieves high performance on NVIDIA H800 GPUs, reaching up to 660 TFlops in compute-bound workloads.
04
Includes optimized dense attention kernels for standard Multi-Head Attention (MHA) and Multi-Query Attention (MQA) modes.
05
Offers a flexible API for integrating custom sparse indices, allowing developers to control attention computation at the token level.
06
Maintains compatibility with modern GPU architectures including SM90 and SM100, with support for various hardware-specific optimizations.

// use cases

01
Token-level sparse attention for prefill and decoding stages
02
Dense attention kernels for high-performance prefill and decoding
03
FP8 KV cache support for optimized memory and compute efficiency

// getting started

To begin using FlashMLA, clone the repository, initialize the submodules, and install the package using 'pip install -v .'. Once installed, you can integrate the kernels into your inference pipeline by using 'get_mla_metadata' to prepare tile scheduler metadata, followed by calling 'flash_mla_with_kvcache' during your decoding loop. Refer to the provided test scripts in the 'tests/' directory for concrete implementation examples.