今回の論文
今回取り上げるのは、Tri Dao、Daniel Y. Fu、Stefano Ermon、Atri Rudra、Christopher Ré による 2022 年の論文「FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness」です。公開元は arXiv で、その後 NeurIPS 2022 に採択されています。研究分野は Transformer の高速化、長文処理、GPU メモリ最適化です。URL は https://arxiv.org/abs/2205.14135 です。
この論文を選んだ理由は、いまの LLM 実装でほぼ前提になっている FlashAttention が、単なる CUDA の高速化テクニックではなく、「Attention は FLOPs よりメモリアクセスが支配的である」という見立てから設計されているからです。RAG、エージェント、長文要約、コード生成など、文脈長が伸びるほど効いてくる考え方なので、開発への応用余地も大きいです。
どんな技術か
FlashAttention は、Transformer の self-attention を近似せずに、そのまま正確に計算しながら、高速化と省メモリ化を両立する技術です。ポイントは、Attention 行列を丸ごと GPU の大きいメモリに書き出さず、細かいブロック単位で計算して、その場で必要な集約まで終わらせることにあります。
従来の実装では、QK^T でできる大きなスコア行列や softmax 後の中間行列をいったん GPU の HBM に置くことが多く、ここが速度とメモリ使用量の両方のボトルネックになっていました。FlashAttention は、この中間行列をできるだけ「作らない、書かない、読み直さない」方向で作り直しています。
言い換えると、FlashAttention は新しいモデル構造ではありません。Transformer 自体を置き換えるのではなく、Attention の計算方法をハードウェア寄りに組み替える技術です。そのため、モデル品質を落とさずに長文対応やスループット改善を狙いやすいのが大きな利点です。
課題
FlashAttention が解決しようとしている課題は、Transformer の Attention が長い系列で急激に重くなることです。理論的には Attention は系列長に対して二乗で計算量とメモリ量が増えますが、実際の GPU 実装では、計算そのものよりも大きな中間テンソルの読み書きが重くなりやすいです。
何が難しいのかというと、Attention には QK^T、マスク、softmax、dropout、PV のように複数の処理が連続しており、そのたびに巨大な行列を HBM に置くと転送コストが膨らむことです。GPU は演算器自体は速いのですが、オンチップの SRAM に比べると HBM は遅く、しかも Attention ではこの読み書き回数が非常に多くなります。
既存の方法にも限界がありました。低ランク近似や疎行列化などの approximate attention は、理論上の FLOPs は減っても、実際の壁時計時間があまり縮まらないことがあります。理由は、演算量を減らしてもメモリアクセスやカーネル起動のオーバーヘッドが残るからです。つまり、数式上は軽いのに、GPU では思ったほど速くならないというギャップがあります。
なぜこの課題を解く必要があるかというと、実際の AI システムでは系列長を伸ばしたい場面が多いからです。RAG では大量の検索結果をまとめて読みたいですし、コーディング支援では長いリポジトリ文脈を持ち込みたくなります。音声、動画、マルチモーダルでも、時系列やフレーム列が長くなるほど同じ問題が出ます。Attention を近似なしで軽くできるなら、既存モデルを大きく変えずに実システムを伸ばしやすくなります。
用語解説
- Self-Attention
- 入力トークン同士の関係を計算し、それぞれのトークンがどこを参照すべきかを決める Transformer の中心処理です。FlashAttention はこの仕組み自体を変えるのではなく、同じ結果をもっとハードウェア効率よく計算する技術です。
- HBM と SRAM
- HBM は GPU の大容量だが相対的に遅いメモリ、SRAM はチップ上にある小容量だが非常に速いメモリです。FlashAttention を理解する鍵は、Attention の遅さを「演算不足」ではなく「HBM との往復の多さ」と見る点にあります。
- Softmax の正規化
- Attention スコアを確率的な重みに変換する処理です。通常は行全体を見て最大値や総和を計算しますが、FlashAttention ではこれをブロックごとに分割しても最終結果が一致するよう工夫しています。
- Tiling
- 大きな行列を小さなブロックに分けて、速いメモリに載る範囲で順に処理する手法です。FlashAttention の基本アイデアは、Q、K、V と Attention 計算をこのブロック処理に載せ替えることです。
- Recomputation(再計算)
- 中間結果を大量に保存する代わりに、必要になったときに計算し直す考え方です。FlashAttention では backward 時に巨大な Attention 行列を保存せず、正規化に必要な統計量だけ残して再計算することで、メモリ削減と速度の両立を狙っています。
技術の仕組み
FlashAttention の基本発想は、Attention を数式として近似するのではなく、メモリ階層に合わせて実装を組み替えることです。論文ではこれを IO-aware と呼んでいます。重要なのは、どれだけ少ない HBM 読み書きで exact attention を終えられるかです。
基本アイデア
通常の Attention では、まず S = QK^T を計算し、その後に softmax をかけ、最後に P V を計算します。このとき S や P は系列長 N に対して N × N の大きな行列になるため、長文ではここが非常に重くなります。
FlashAttention は、この巨大な S と P を HBM に materialize しません。代わりに、Q のブロックと K,V のブロックを順に SRAM に読み込み、その場で部分的なスコア計算、softmax の更新、出力の加算まで進めます。最終的な出力だけを HBM に書き戻すので、中間行列の保存コストを大きく減らせます。
モデル構造というより計算カーネルの変更
この技術で変わるのは Transformer ブロックの外見ではなく、中の Attention カーネルです。モデルのパラメータ数や理論上の注意機構はそのままで、Q, K, V の投影を作った後の実行方法だけを差し替えます。
そのため、実務では「FlashAttention 対応実装を使う」といった形で導入されます。アーキテクチャ刷新より導入障壁が低い一方で、GPU 世代、head dimension、mask 形状、フレームワーク統合など、カーネル側の実装制約は意識する必要があります。
Softmax をブロック単位で正しく計算する
Attention の厄介な点は、softmax が行全体を見て正規化することです。単純にブロックごとに softmax して足し合わせると、元の結果と一致しません。
FlashAttention はここで、各行について「これまで見たブロックの最大値」と「その最大値基準での指数和」を持ち回る形を使います。新しい K ブロックを読むたびに局所的なスコアを計算し、最大値と正規化係数を更新しながら、出力ベクトルも再スケールして加算していきます。これにより、全ブロックを見終わった時点で、通常の softmax Attention と同じ出力が得られます。
この工夫が重要なのは、FlashAttention が exact attention である理由がここにあるからです。近似ではなく、計算順序と集約方法を変えているだけなので、モデル品質を落とさずに高速化できます。
Tiling によるメモリアクセス削減
論文の中心は、計算量より HBM アクセス量を削ることです。Q のあるブロックを固定し、K と V のブロックを走査しながら必要な計算を終えることで、同じデータを何度も HBM から読み直す回数を抑えます。
GPU では SRAM が非常に速い一方で容量が小さいため、一度に全部は載せられません。FlashAttention は、SRAM に収まる大きさのタイルを選び、その範囲でスコア計算、softmax、値の重み付き和をまとめて処理します。論文では、この設計が標準的な Attention より少ない HBM access で済み、ある SRAM サイズ範囲では IO 的に最適であることまで示しています。
学習時の backward では再計算を使う
学習時は forward だけでなく backward も重いです。従来実装では backward 用に Attention 行列を保存しておくことが多く、これがさらにメモリを圧迫します。
FlashAttention では、forward 時に必要最小限の統計量だけ保持し、backward 時にブロックごとに Attention を再計算します。一見すると再計算は無駄に見えますが、HBM から巨大行列を読み戻すより速い場合が多く、結果としてメモリも速度も改善します。これは「計算を少し増やしても、IO を大きく減らせば全体は速くなる」という、この論文らしい発想です。
Block-sparse への拡張
論文では exact な FlashAttention だけでなく、block-sparse FlashAttention も示しています。これは、もともと見なくてよいブロックを計算対象から外しつつ、残るブロック群については FlashAttention の IO-aware 実装を使う形です。
つまり FlashAttention は、それ自体が一つの高速化手法であると同時に、他の sparse attention 系手法を実用速度に落とし込む基盤にもなります。近似 attention の理論をそのまま使っても速くならないことがありますが、FlashAttention 系の実装に載せると初めて実アプリで効きやすくなる、という見方ができます。
実験と結果
この論文の実験は、単体カーネルのベンチマークだけでなく、実際の学習時間、長文対応、精度改善まで見ているのが特徴です。単に「GPU 上で何ミリ秒速いか」ではなく、モデル学習や長文ベンチマークで何が起きるかを確認しています。
何を検証したのか
主な検証は 3 つあります。1 つ目は、FlashAttention が標準 Attention より実時間で速いかです。2 つ目は、長い系列を扱えることでモデル品質が上がるかです。3 つ目は、既存の approximate attention より現実的に使いやすいかです。
使ったデータセットや評価設定
学習速度の検証では、BERT-large と GPT-2 を使っています。長文タスクでは Long Range Arena を用い、ListOps、Text、Retrieval、Image、Pathfinder など複数の長系列ベンチマークで比較しています。さらに Path-X と Path-256 では、16K や 64K といった極端に長い系列長で Transformer がどこまで動けるかも検証しています。
評価指標は、学習時間、スループット、accuracy、perplexity などです。つまり FlashAttention は「速いけれど精度は不明」ではなく、速度・メモリ・品質を合わせて見た論文です。
学習速度の結果
論文では、BERT-large の学習で MLPerf 1.1 の Nvidia 実装に対して 15% の end-to-end wall-clock speedup を報告しています。具体的には、目標 accuracy 到達までの時間が 20.0 分から 17.4 分になっています。
GPT-2 では系列長 1K の学習で 3 倍の速度向上が報告されています。さらに、Attention 計算単体では PyTorch 実装比で最大 7.6 倍の高速化が示されています。ここで重要なのは、FlashAttention が理論値だけでなく end-to-end の学習時間にも効いていることです。
長文での精度と適用範囲
長い系列を扱えることで、単なる速度改善以上の効果も出ています。論文では GPT-2 の perplexity が 0.7 改善し、長文分類では 6.4 ポイントの改善が報告されています。これは「同じモデルを速くする」だけでなく、「今までメモリ制約で使えなかった長い文脈を読めるようになる」ことで品質も上がることを示しています。
Long Range Arena では、標準 Transformer の平均 accuracy 59.3 に対して、FlashAttention は 59.8 で、しかも 2.4 倍の speedup を出しています。Block-sparse FlashAttention は平均 59.6 で 2.8 倍です。つまり、高速化のために性能を大きく犠牲にしていません。
極端な長文ベンチマークでの結果
Path-X は系列長 16K、Path-256 は 64K まで伸びる難しいベンチマークです。従来の Transformer 系モデルはメモリ不足かランダム同然の性能に留まっていましたが、FlashAttention は Path-X で 61.4% accuracy を達成しました。さらに block-sparse FlashAttention は Path-256 で 63.1% accuracy を出しています。
この結果から言えるのは、FlashAttention の価値は「いまあるワークロードを少し速くする」だけではないことです。長さの制約で試せなかった問題設定自体を Transformer で扱えるようにする点が大きいです。
何に使える?
FlashAttention は、長いコンテキストを扱うか、学習・推論コストが重い Transformer 系システムで広く使えます。特に、モデル本体を変えずに性能改善したいケースと相性がよいです。
長文RAGとドキュメントQA
RAG では、検索結果の chunk 数を増やすと最終的な LLM 入力が長くなります。FlashAttention があると、同じ GPU でもより長い文脈を扱いやすくなり、文書全体の整合性を見た回答や、複数資料横断の要約を実装しやすくなります。
特に、検索結果を厳しく絞らずに少し多めに渡したい場面では有効です。retriever の精度だけに頼らず、generator 側に広めの文脈を読ませる設計が取りやすくなります。
コーディング支援やリポジトリ理解
コード補完、バグ調査、設計レビューのような用途では、複数ファイルや長い履歴を一度に入れたくなります。FlashAttention はそのまま「長いコード文脈を扱える LLM サービング基盤」に効くので、リポジトリ規模が大きい開発支援ほど恩恵を受けやすいです。
また、同じモデルサイズでも batch size や context を取りやすくなるため、社内向けの推論基盤でスループット改善にもつながります。
音声・動画・マルチモーダル系列処理
FlashAttention の発想はテキスト専用ではありません。フレーム列、音声トークン列、視覚トークン列など、系列が長くなりやすいモダリティでも効きます。特に動画理解や長時間会議の音声要約では、系列長の制約がボトルネックになりやすいため、exact attention を維持したまま伸ばせる価値があります。
学習基盤のコスト削減
BERT や GPT 系の事前学習、継続学習、ドメイン適応では、Attention のコストが積み上がります。FlashAttention を取り込むだけで学習時間が短くなれば、GPU コストの削減、試行回数の増加、より長いコンテキスト設定の検証がしやすくなります。
モデル構造を変える研究より、まず既存パイプラインを速くしたい現場にはかなり使いやすい技術です。
開発や事業へのヒント
FlashAttention から得られる示唆は、「AI の高速化はアルゴリズム近似だけではなく、メモリ移動の設計で大きく変わる」ということです。これは LLM アプリを作る側にもそのまま効く考え方です。
コンテキスト長を機能として売れる
プロダクトではモデル名だけで差別化しにくくなっていますが、「どれだけ長い文脈を安定して扱えるか」は依然として体験差になります。FlashAttention のような基盤最適化を入れると、長文PDF要約、議事録横断検索、コードベース理解などを、より安い構成で提供できる可能性があります。
近似より先に実装最適化を疑う
長文処理で遅いとき、すぐに sparse attention や別アーキテクチャへ飛ぶのではなく、まず標準 Attention の実装がメモリ効率よく動いているかを見るべきだ、という教訓があります。小規模プロダクトでも、モデルを作り替える前にランタイムやカーネル最適化で大きく改善することがあります。
RAG やエージェントの設計自由度が広がる
FlashAttention が効くと、入力文脈を少し厚めに渡す設計がしやすくなります。これは、RAG で多数の候補文書を比較させる、エージェントに長い行動履歴を持たせる、レビュー用モデルに差分全体を読ませる、といった設計の自由度につながります。
特に業務自動化や社内ツールでは、アルゴリズムの新規性より「既存モデルでどこまで実務文脈を持てるか」が重要なので、こうした基盤改善は価値に直結しやすいです。
今後注目すべき方向性
この論文以降、FlashAttention-2、FlashAttention-3、各種デコード最適化など、Attention をハードウェア寄りに最適化する流れが強くなっています。今後も、モデルの理論改善と同じくらい、実装カーネルやメモリ階層を意識した最適化が重要になるはずです。
特に、長文 LLM、マルチモーダル、推論サーバー多並列化では、FLOPs より IO が支配的な場面が増えるため、この論文の考え方は今後も長く使えます。
限界
FlashAttention にも注意点があります。まず、Attention の理論的な二乗スケーリング自体を消すわけではありません。メモリ使用量は線形に近づけられても、計算対象の組み合わせそのものが増える問題までは解決しません。極端に長い系列では、依然として近似や疎化が必要になる場合があります。
また、導入効果はハードウェアと実装条件に依存します。GPU 世代、ヘッド次元、mask 種類、フレームワーク、バッチサイズによって最適条件が変わるため、論文の倍率がそのまま再現されるとは限りません。
実装難度もあります。FlashAttention は考え方としては明快ですが、実際の高速化は低レベルカーネル実装に支えられています。そのため、自前カーネルを書かずに使う場合は、ライブラリや推論基盤の対応状況に制約されます。カスタムマスクや特殊 attention にすぐ適用できないこともあります。
さらに、この論文の主要評価は学習と long-sequence ベンチマーク中心です。実運用のオンライン推論、特に小バッチ・低レイテンシ要求のデコード段階では、後続研究の FlashAttention-2 やデコード専用最適化の知見も合わせて見る必要があります。したがって、実システムへの適用では「FlashAttention を入れれば常に最適」とは言い切れません。
よくある質問
Q. FlashAttention は新しいモデルですか?
A. いいえ、モデル構造そのものではなく、Attention の計算実装です。Transformer を別物に置き換えるのではなく、同じ Attention をもっと少ないメモリ移動で実行する技術です。
Q. 近似 attention と何が違うのですか?
A. 近似 attention は計算対象を減らしたり、低ランク近似を使ったりして結果自体を近似します。FlashAttention は結果を近似せず、計算順序とメモリ配置を変えることで exact attention のまま速くします。
Q. RAG やエージェント開発でも本当に意味がありますか?
A. あります。特に、複数の検索結果や長い会話履歴、行動ログをまとめてモデルに読ませたい場合に効きます。文脈長の上限や GPU メモリ制約が緩むと、アプリ設計の自由度が上がります。
Q. すでに長文対応モデルを使っているなら不要ですか?
A. 不要とは限りません。長文対応モデルでも内部では Attention 計算が重く、実装次第でスループットや最大入力長が大きく変わります。モデル能力とは別に、実行基盤の最適化として価値があります。
Q. 最初に実務で試すならどこですか?
A. 長文入力を扱う推論基盤か、Transformer 学習パイプラインです。たとえば RAG の generator、コード支援の長コンテキスト推論、継続事前学習や長文 fine-tuning などは効果を測りやすいです。
今日の学び
この論文は、Transformer の Attention が長文で遅く重くなる課題を扱いました。これに対して、Attention を近似するのではなく、HBM と SRAM のメモリアクセスを意識した IO-aware なブロック計算で exact attention を高速化しました。
ここから得られるヒントは、AI システムの性能はモデル構造だけで決まらないということです。長文 RAG、エージェント、コード支援のように文脈長が価値になるプロダクトでは、アルゴリズムそのものだけでなく、メモリ移動を減らす実装設計が競争力になりえます。