Medusaとは?LLM推論を高速化する複数デコードヘッドの仕組みと使い道

Medusaは、LLMに複数の軽量デコードヘッドを追加して複数トークンを並列予測し、推論を高速化する手法です。別のドラフトモデルを使わずに速くする考え方、仕組み、評価結果、実装へのヒントを技術的に解説します。

参考文献

Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads

Tianle Cai, Yuhong Li, Zhengyang Geng

論文を見る

今回の論文

今回取り上げるのは、Princeton University などの Tianle Cai、Yuhong Li、Zhengyang Geng、Hongwu Peng、Jason D. Lee、Deming Chen、Tri Dao による 2024 年の論文「Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads」です。公開元は arXiv で、研究分野は LLM 推論高速化です。URL は https://arxiv.org/abs/2401.10774 です。

この論文を選んだ理由は、今の LLM 開発でかなり現実的な課題である「品質をできるだけ落とさず、推論を速くしたい」に対して、実装しやすい構成を出しているからです。特に、別の小型ドラフトモデルを用意せず、既存モデルの上に軽量ヘッドを足して高速化する発想は、推論基盤、社内ツール、SaaS の応答改善などに直結しやすいです。

どんな技術か

Medusa は、LLM の最後の隠れ状態の上に複数の追加ヘッドを載せて、「次の1トークンだけでなく、その先の数トークンも同時に予測する」技術です。通常の自己回帰デコードでは、1 トークン生成するたびに巨大なモデルを 1 回まるごと通す必要がありますが、Medusa は 1 回の前向き計算で複数の候補続きをまとめて出し、それを一括で検証します。

ポイントは、よく知られた speculative decoding に近い狙いを持ちながら、ドラフト用の別モデルを持たないことです。つまり、「小さいモデルで下書きして大きいモデルで確認する」のではなく、「大きいモデル自身に軽い予測ヘッドを増設して、下書き候補を内部で作る」アプローチです。既存の LLM サービングに組み込みやすいのが大きな特徴です。

課題

この技術が解決しようとしている課題は、LLM の推論が本質的に逐次処理になりやすく、応答速度が頭打ちになりやすいことです。次のトークンが確定しないと、その次のトークンを正しく計算できないため、大きなモデルほど待ち時間が積み上がります。

難しいのは、LLM 推論の遅さが単に FLOPs の多さだけではなく、重いモデル重みを何度もメモリから読み出す構造にもある点です。論文では、この自己回帰デコードがアクセラレータのメモリ帯域に縛られやすいと説明しています。巨大モデルを毎ステップ呼び出して 1 トークンずつ出す方式では、ハードウェアの計算資源を十分に使い切れません。

既存の speculative decoding は有力な解決策ですが、そこには別の難しさがあります。小型ドラフトモデルを別途用意し、元モデルとの整合を取り、サービング環境にも 2 つのモデルを載せる必要があるからです。研究段階では成立しても、実運用ではモデル管理、デプロイ、量子化、分散推論、更新運用が一気に複雑になります。

なぜこの課題を解く必要があるかは明快です。チャット、コード補完、エージェント、RAG、社内検索支援など、現在の AI システムはほぼすべて推論レイテンシの影響を受けます。品質が高くても遅ければ UX が崩れ、GPU コストも上がります。Medusa は、この「品質は維持したいが、推論はもっと速くしたい」という現場の要請にかなり正面から答えた手法です。

用語解説

自己回帰デコード
1トークンずつ順番に生成する方式です。LLMの標準的な推論方法ですが、前の出力に依存するため並列化しにくく、Medusaが解こうとしているボトルネックそのものです。
Speculative Decoding
小さなドラフトモデルが先のトークン列を提案し、大きな元モデルがまとめて検証する高速化手法です。Medusaはこの考え方を踏まえつつ、ドラフトモデルを別に持たずに近い効果を狙います。
デコードヘッド
隠れ状態から次トークンの分布を出す出力層です。Medusaでは通常の次トークン用ヘッドに加えて、2個先、3個先といった位置向けの追加ヘッドを用意し、複数位置を同時に予測します。
Tree Attention
複数の候補続きを木構造としてまとめ、一度の検証で処理する仕組みです。単に候補をたくさん出すだけでは遅くなるため、共通プレフィックスを共有しながらまとめて評価する工夫が重要になります。
Typical Acceptance
元モデルの分布から見て極端に不自然でない候補を受け入れるための受理規則です。厳密な分布一致を守る rejection sampling より実務寄りで、速度と品質のバランスを取りやすいのが論文上の狙いです。

技術の仕組み

Medusa の基本アイデアは単純です。通常の LLM は「次の 1 トークン」を予測しますが、Medusa はその上に複数の補助ヘッドを載せて「次の 1 個先」「2 個先」「3 個先」も予測させます。これにより、1 回の forward で先読み候補をまとめて作れるようにします。

追加ヘッドで複数位置を予測する

通常の自己回帰生成では、トークンを 1 つ出すたびに次の隠れ状態を計算し直します。Medusa では、最後の隠れ状態から複数の軽量ヘッドを通し、未来の各位置に対する候補トークンを並列に出します。ここで重要なのは、元の巨大な Transformer 本体を何回も回す代わりに、安い補助計算で候補を増やしている点です。

ただし、各ヘッドが出すのはあくまで候補です。2 個先や 3 個先の予測は不確実性が高いので、そのまま確定はできません。そこで次の段階として、出した候補列を元モデルでまとめて検証します。

候補列を木構造にして一括検証する

もし各ヘッドの上位候補を素直に全組み合わせすると、候補列はすぐ爆発します。Medusa はここで tree attention を使い、候補列を木構造としてまとめます。共通するプレフィックスを共有したまま、複数候補を 1 回の検証パスで扱えるようにする設計です。

考え方としては、1 個先の候補 A から派生する 2 個先候補群、さらにその先の候補群を木にし、その木全体を 1 度に評価します。これにより、単に「先読み候補を増やしただけ」で終わらず、検証コストまで抑えられます。論文の本質は、追加ヘッドそのものより、この候補生成と検証を 1 ステップに束ねる流れにあります。

受理された最長プレフィックスを次ステップに使う

検証後は、候補列のうちどこまで受理するかを決めます。厳密には speculative decoding と同様の rejection sampling も使えますが、論文では typical acceptance というより実務向きの方式も提案しています。これは、元モデルの確率分布から見て十分もっともらしいトークン列なら採用し、最長の受理プレフィックスを次の生成状態に反映するものです。

ここで効くのは、1 ステップで 1 トークンではなく、複数トークンをまとめて確定できる可能性があることです。受理長が伸びるほど、重い本体モデルを呼び出す回数が減り、最終的な tokens/sec が上がります。

Medusa-1 と Medusa-2 の違い

論文では学習方法を 2 段階に分けています。Medusa-1 は、元の LLM 本体を凍結したまま、追加ヘッドだけを学習する方式です。元モデルの振る舞いを壊しにくく、既存モデルへ後付けしやすいのが利点です。

一方の Medusa-2 は、追加ヘッドだけでなく本体側も一緒に学習します。こちらはヘッドの予測精度が上がりやすく、より高い高速化が見込めますが、元モデルの能力を壊さないための学習レシピが必要になります。論文では二段階学習やウォームアップを使って、この破壊を避ける工夫を入れています。

学習データがない場合の self-distillation

現実には、元モデルの学習データをそのまま使えないケースも多いです。特に RLHF 後のモデルや公開されていない独自学習データを持つモデルでは、この問題が出ます。Medusa はその対策として、元モデル自身に推論させて教師信号を作る self-distillation も提案しています。

これは実務上かなり重要です。つまり、既存の運用モデルに対しても、完全な再学習より軽い形で Medusa 化できる余地があるということです。

実験と結果

論文では、Medusa が本当に速くなるのか、品質を壊さないのか、学習条件が違っても機能するのかを中心に検証しています。対象には Vicuna 7B、Vicuna 13B、Vicuna 33B、Zephyr 7B などが含まれており、単一モデルだけの特殊な結果ではないことを示そうとしています。

何を検証したのか

主な比較軸は 3 つです。1 つ目は通常の自己回帰デコードに対してどれだけ速度が上がるか、2 つ目は Medusa-1 と Medusa-2 のどちらがどの条件で有利か、3 つ目は self-distillation や typical acceptance のような拡張がどれだけ実用的かです。

評価は、単なる理論的なトークン削減ではなく、実際の生成速度と生成品質の両面で行われています。論文中では MT-Bench のカテゴリ別評価や、モデル品質比較、ツリー構造のアブレーションなども実施されています。

どんな結果が出たのか

論文の中心的な結果はかなり明快です。Medusa-1 は、品質を落とさずに 2.2 倍超の高速化を達成し、Medusa-2 はさらに 2.3 倍から 3.6 倍の高速化を示しました。つまり、後付けしやすさを重視するなら Medusa-1、本気で速度を取りにいくなら Medusa-2 という整理ができます。

論文では、モデルサイズやタスクカテゴリによって伸び方が変わることも示されています。構造化された出力や先読みしやすい生成では、受理できるトークン列が長くなりやすく、高い高速化が出やすいと読めます。逆に、自由度が高く次トークン分布が散らばるタスクでは、先の位置ほど予測が難しくなるため、受理長が伸びにくい場面もあります。

また、typical acceptance は、高温度サンプリング時に rejection sampling で効率が落ちやすい問題への対策として機能します。論文のアブレーションでも、速度と品質のバランスを取る閾値設計が重要だと示されています。これは、Medusa が単なる理論実験ではなく、実際の生成設定を意識していることを意味します。

結果から何が言えるのか

この結果から言えるのは、LLM 推論高速化では「別モデルを増やす」以外の道が十分あるということです。Medusa は、重い本体モデルの呼び出し回数を減らす方向で効いており、しかもそのために新しいサービング構成を大きく変えなくてよいのが強いです。

同時に、速度改善は受理率に強く依存します。つまり Medusa は万能な一定倍率ブーストではなく、「追加ヘッドがどれだけ先を正しく当てられるか」「その候補をどれだけまとめて受理できるか」が性能を決める手法です。ここが、後続研究で Hydra などがさらに改良しようとしているポイントでもあります。

何に使える?

Medusa がまず役立ちそうなのは、応答品質よりも待ち時間や GPU 単価が大きな課題になる生成系アプリです。たとえば社内チャット、問い合わせ支援、コード生成、文章補助、エージェントの中間推論などでは、同じモデル品質のまま速度を底上げできる価値が大きいです。

チャットや社内ツールの体感速度改善

社内アシスタントやカスタマーサポート支援では、返答の最初の数秒が UX を大きく左右します。Medusa は 1 リクエストあたりのデコードステップ数を減らすので、特に 1 ユーザーずつ対話する低バッチ環境で効きやすいです。論文でも batch size 1 に近いシナリオを重視しており、ローカルホストや小規模サービングとの相性がよいと考えられます。

コード生成や定型出力

コード補完、SQL 生成、JSON 生成、フォーム自動入力のように、出力パターンが比較的予測しやすいタスクでは、先読み候補の受理率が上がりやすい可能性があります。論文でもタスクカテゴリごとに速度差が出ており、構造化出力の方が恩恵を受けやすいと読むのが自然です。

エージェントの内部推論コスト削減

エージェントは、最終回答だけでなく、ツール選択、計画、観察の要約など、短い生成を何度も繰り返します。このタイプのワークフローでは、1 回ごとの生成が短くても回数が多いため、デコードステップ削減の累積効果が出やすいです。特に、同じモデルを大量の補助タスクに使っているシステムでは、Medusa のような推論高速化が全体コストに効いてきます。

オンプレミスや小規模 GPU 運用

大規模クラウドだけでなく、社内 GPU サーバーやエッジ寄りの構成でも有効です。別ドラフトモデルを同時に積むより、単一モデルに追加ヘッドを持たせる方が運用設計を単純にしやすいからです。量子化や既存サービングへの組み込みも、理論上は 2 モデル運用より整理しやすいです。

開発や事業へのヒント

この論文から得られる一番大きなヒントは、LLM の差別化はモデル選定だけでなく、推論レイヤーの設計でも作れるということです。もし自分で AI アプリを作るなら、まず精度改善だけでなく、生成速度改善そのものをプロダクト価値として捉えるべきだとわかります。

既存サービスでも後付けしやすい改善余地がある

Medusa-1 の発想は、既存モデルを大きく壊さず、追加学習だけで高速化層を足すというものです。これは、すでに動いている社内 LLM サービスでも採用余地があります。特に「モデルは十分良いが、待ち時間が長い」という状態なら、モデルを入れ替えるより推論方式を変える方が安い可能性があります。

推論最適化は UX と原価の両方に効く

事業面では、同じ GPU でさばけるリクエスト数が増えることに直結します。応答が速くなれば継続利用率や業務定着率にも効きますし、原価が下がれば価格設計も柔軟になります。モデルの賢さばかり追うのではなく、推論パイプラインの工夫を競争力として扱う発想はかなり重要です。

小規模プロダクトでも発想を流用できる

論文そのままの実装でなくても、「複数ステップを 1 ステップに束ねられないか」という考え方は応用できます。たとえば、ドラフトモデル型 speculative decoding、キャッシュ最適化、構造化出力での先読み、テンプレート区間のまとめ生成など、周辺の推論最適化にもつながります。Medusa は、その設計空間の基準点として理解しておく価値があります。

今後注目すべき方向性

今後は、より高い受理率を出すヘッド設計、サービングエンジンとの統合、量子化との相性、バッチ推論時の効果、マルチモーダル生成への拡張が重要になります。実際、Medusa のあとには Hydra のように「ヘッド間の依存関係」まで改善する研究が出ており、この分野がまだ伸びることを示しています。

限界

Medusa の限界としてまずあるのは、追加ヘッドがあっても未来トークンの予測は本質的に難しいことです。先の位置ほど予測誤差が増えるため、受理率が伸びなければ速度改善も頭打ちになります。つまり、追加ヘッドを増やせば無限に速くなるわけではありません。

計算資源の面では、巨大な本体モデルより軽いとはいえ、追加ヘッド分のパラメータとメモリは必要です。推論そのものは速くなっても、モデルサイズや GPU メモリ使用量が少し増える点には注意が必要です。

実装面では、tree attention の扱い、候補木の最適化、受理規則の閾値調整など、素朴なデコード実装より複雑になります。既存のサービングエンジンに入れるにはエンジニアリングが必要で、論文でもバッチ推論への統合には追加実装が必要だと示唆されています。

また、品質を完全に壊さないためには学習レシピが重要です。特に Medusa-2 は本体も一緒に学習するので、雑に実装すると元モデルの品質を落とす可能性があります。self-distillation も便利ですが、元モデルの癖をそのまま蒸留するため、特定タスクでの偏りが残る懸念はあります。

最後に、実運用での効果はタスク依存です。自由記述が多く、分布の広い生成では、論文ほどきれいに受理長が伸びない可能性があります。したがって導入時は、平均 tokens/sec だけでなく、実際のアプリの入力分布と出力形式で評価する必要があります。

よくある質問

Q. Medusa は speculative decoding と何が違うのですか?

A. 狙いはかなり近いですが、Medusa は別のドラフトモデルを持たず、元モデルの上に追加ヘッドを載せて候補を作る点が違います。そのため、2 モデル運用より実装と運用を単純化しやすいのが利点です。

Q. Medusa-1 と Medusa-2 はどちらを選ぶべきですか?

A. 既存モデルに後付けしたいなら Medusa-1 が現実的です。本体を凍結するので品質を壊しにくいからです。より大きな高速化を狙うなら Medusa-2 が有力ですが、学習コストと実装難度は上がります。

Q. どんなアプリで特に効きやすいですか?

A. コード生成、構造化出力、短い対話を大量に回すエージェント、低バッチのチャットサーバーなどで効きやすいと考えられます。理由は、先読み候補が比較的当たりやすく、デコードステップ削減の効果がそのまま UX に出やすいからです。

Q. RAG やエージェントにも関係ありますか?

A. 直接は推論高速化の論文ですが、かなり関係あります。RAG やエージェントは 1 回の長文生成だけでなく、小さな生成を何度も繰り返すため、1 ステップごとの短縮が全体の待ち時間や GPU コストに効きやすいです。

Q. すぐに本番導入できるほど簡単ですか?

A. 発想はシンプルですが、実装はそこまで簡単ではありません。追加学習、候補木の制御、サービング統合、量子化との整合などを詰める必要があります。ただ、別モデルを増やす speculative decoding よりは運用しやすい場面が多いです。

今日の学び

この論文は、LLM の自己回帰デコードが遅く、特に推論レイテンシとメモリ帯域がボトルネックになる課題を扱いました。これに対して Medusa は、元モデルの上に複数のデコードヘッドを追加し、候補続きを木構造でまとめて検証することで、複数トークンを一度に受理できるようにしました。

ここから得られるヒントは、LLM の改善はモデル本体の性能向上だけではないということです。推論をどう束ねるか、どこで並列性を取り戻すかを設計するだけでも、実用上は大きな差になります。速さそのものがプロダクト価値になる時代には、こうした推論最適化の発想を早めに押さえておく意味があります。

関連記事