ホームLLMdeepseek-ai/FlashMLA
deepseek-ai

FlashMLA

AI#DeepSeek#Attention#CUDA#PyTorch#LLM
GitHub で見る →
12,617

// 概要

FlashMLA は DeepSeek-V3 および DeepSeek-V3.2 モデルを駆動するために特別に設計された高性能な attention kernel ライブラリです。Prefill および decoding ステージにおける sparse attention と dense attention の両方に対して最適化された実装を提供します。本ライブラリは FP8 KV cache のような高度な機能をサポートしており、SM90 や SM100 を含む様々な GPU アーキテクチャと互換性があります。

// 技術解説

FlashMLA は、 DeepSeek の Multi-Head Latent Attention (MLA) モデルを駆動するために設計された、高性能な attention カーネルの専門ライブラリです。 dense および sparse な attention の両方に対して高度に最適化された実装を提供することで、このプロジェクトは大規模な transformer 推論、特に prefill および decoding 段階で発生する計算上のボトルネックに対処します。その設計は NVIDIA アーキテクチャにおけるハードウェアレベルの効率を優先しており、 FP8 KV cache 量子化のような技術を活用して、モデルの精度を維持しつつスループットを最大化します。

// 主要ハイライト

01
prefill および decoding の両方で DeepSeek Sparse Attention (DSA) を可能にする、専門的な sparse attention カーネルを提供します。
02
FP8 KV cache 量子化をサポートし、 bfloat16 の計算精度を維持しながら decoding 中のメモリオーバーヘッドを大幅に削減します。
03
NVIDIA H800 GPU 上で高いパフォーマンスを実現し、計算負荷の高いワークロードにおいて最大 660 TFlops に到達します。
04
標準的な Multi-Head Attention (MHA) および Multi-Query Attention (MQA) モード向けの最適化された dense attention カーネルを含みます。
05
カスタム sparse インデックスを統合するための柔軟な API を提供し、開発者がトークンレベルで attention 計算を制御できるようにします。
06
SM90 や SM100 を含む最新の GPU アーキテクチャとの互換性を維持し、様々なハードウェア固有の最適化をサポートします。

// ユースケース

01
Prefill および decoding ステージ向けの token-level sparse attention
02
高性能な prefill および decoding を実現する dense attention kernel
03
メモリと計算効率を最適化する FP8 KV cache のサポート

// クイックスタート

FlashMLA の利用を開始するには、リポジトリを clone し、サブモジュールを初期化してから 'pip install -v .' を使用してパッケージをインストールします。インストールが完了したら、 'get_mla_metadata' を使用してタイルスケジューラのメタデータを準備し、続いて decoding ループ内で 'flash_mla_with_kvcache' を呼び出すことで、カーネルを推論パイプラインに統合できます。具体的な実装例については、 'tests/' ディレクトリ内の提供されているテストスクリプトを参照してください。