@JaydevTonde: https://x.com/JaydevTonde/status/2068361821002846418

X AI KOLs Timeline 工具

摘要

有关在LLM推理服务器Tokn中实现CUDA Graphs的详细教程,涵盖FastAPI服务器设置、引擎初始化以及用于优化解码阶段的CUDA Graph捕获。

https://t.co/3uZh75Aojr
查看原文
查看缓存全文

缓存时间: 2026/06/21 04:33

在 LLM 推理服务器中实现 CUDA Graph

我最近在我的 LLM 推理服务器 Tokn 中实现了 CUDA Graph

实现非常精简,但如果你想理解像 vLLMSGLang 这类大型推理服务器使用 CUDA Graph 背后的核心思想,这篇文章值得一看。

1. FastAPI 服务器

这部分包含 FastAPI 服务器代码。它接收传入的提示词,并将这些请求添加到 generate_async 函数内部的处理队列中。start_engine_loop 在引擎内部运行。其职责是持续监控调度器的等待队列,并决定每个请求应该走预填充路径还是解码路径。

@app.on_event("startup")
async def start_engine_loop():
    if engine.background_task is None:
        engine.background_task = asyncio.create_task(engine.run_loop())

@app.post("/completions")
async def completions(req: CompletionRequest):
    print(f"\n\nRequest received for with the prompt : {req.prompts}")

    tasks = [
        engine.generate_async(prompt)
        for prompt in req.prompts
    ]

    outputs = await asyncio.gather(*tasks)

    return {
        idx: output
        for idx, output in enumerate(outputs)
    }

2. 引擎初始化

Engine 类的初始化内部,我调用了 capture_cudagraph

if self.device.startswith("cuda") and not self.enforce_eager:
    self.capture_cudagraph()

3. 捕获 CUDA Graph

这是 CUDA Graph 实现的核心。

在这一步中,我首先读取所需的配置,例如 max_num_seqs,它定义了可以调度用于预填充或解码的最大序列数。

然后,我创建用于运行虚拟预填充和解码传递的虚拟缓冲区/上下文。这些虚拟输入是必需的,因为 CUDA Graph 需要固定的执行模式才能捕获。

之后,我创建一个批大小列表,这些批大小对应我要捕获 CUDA Graph 的批次。对于解码,我主要需要批大小 1,因为每个解码步骤每个序列只处理一个 token。

最后,我遍历每个批大小,为该形状捕获一个 CUDA Graph,并将其存储在字典中。在推理过程中,服务器可以根据批大小选择正确的图并重放,而不是逐个启动内核。

@torch.inference_mode()
def capture_cudagraph(self):
    config         = self.hf_config
    hidden_size    = config.hidden_size
    max_bs         = self.scheduler.max_num_seqs
    max_num_blocks = (self.max_model_len + self.block_size - 1) // self.block_size
    device         = torch.device(self.device)

    # 静态缓冲区 —— 地址在图的整个生命周期内固定。
    input_ids      = torch.zeros(max_bs, dtype=torch.long,  device=device)
    positions      = torch.zeros(max_bs, dtype=torch.long,  device=device)
    slot_mapping   = torch.zeros(max_bs, dtype=torch.int32, device=device)
    context_lens   = torch.zeros(max_bs, dtype=torch.int32, device=device)
    block_tables   = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32, device=device)
    outputs        = torch.zeros(max_bs, hidden_size, dtype=self.dtype, device=device)

    self.graph_bs  = [bs for bs in (1,2,3,4) if bs <= max_bs]

    if max_bs not in self.graph_bs:
        self.graph_bs.append(max_bs)

    for bs in reversed(self.graph_bs):
        graph = torch.cuda.CUDAGraph()
        set_context(
            is_prefill=False,
            slot_mapping=slot_mapping[:bs],
            context_lens=context_lens[:bs],
            block_tables=block_tables[:bs]
        )

        # 预热运行(惰性分配/自动调优在图形外部进行)。
        outputs[:bs] = self.custom_model.model(input_ids[:bs], positions[:bs])
        with torch.cuda.graph(graph, self.graph_pool):
            outputs[:bs] = self.custom_model.model(input_ids[:bs], positions[:bs])

        if self.graph_pool is None:
            self.graph_pool = graph.pool()

        self.graphs[bs] = graph
        torch.cuda.synchronize()
        reset_context()

    self.graph_vars = dict(
        input_ids=input_ids,
        positions=positions,
        slot_mapping=slot_mapping,
        context_lens=context_lens,
        block_tables=block_tables,
        outputs=outputs
    )

4. run_model

这是主要的模型前向传递函数。

在此函数中,我对解码阶段使用捕获的 CUDA Graph,因为解码通常遵循固定的执行模式,非常适合图重放。

对于预填充,我使用正常的模型前向传递,因为预填充的输入形状可能因提示长度和块大小而更具动态性。

所以流程很简单:

  • 预填充: 正常前向传递
  • 解码: CUDA Graph 重放

这样既保持了实现的简洁性,又展示了 CUDA Graph 如何减少解码期间的内核启动开销。

def run_model(self, input_ids, positions, is_prefill):

    if is_prefill or self.enforce_eager or not self.graphs or input_ids.size(0) > self.graph_bs[-1]:
        return self.custom_model(input_ids, positions)

    bs      = input_ids.size(0)
    context = get_context()
    graph   = self.graphs[next(x for x in self.graph_bs if x >= bs)]
    gv      = self.graph_vars

    gv["input_ids"][:bs] = input_ids
    gv["positions"][:bs] = positions
    gv["slot_mapping"].fill_(-1)          # 填充行不写入任何地方
    gv["slot_mapping"][:bs] = context.slot_mapping
    gv["context_lens"].zero_()            # 填充行不关注任何位置
    gv["context_lens"][:bs] = context.context_lens
    gv["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables

    graph.replay()

    return self.custom_model.compute_logits(gv["outputs"][:bs])

我工作的 LLM 推理服务器 Tokn 的 GitHub 仓库: https://github.com/jaytonde/Tokn

我很快会通过在同一数据集上比较 Tokn 与官方 vLLM 来添加基准测试结果。这将更清晰地展示当前实现的状态以及下一步需要优化的方向。

目前就到这里。我的下一步计划是在 Tokn 中实现更多 LLM 推理技术,主要是分布式推理(TP)、一种推测解码方法和一种量化技术。我计划在接下来的 15 天内完成这些。

之后,我想开始更认真地学习 C++TritonCUDA 编程。目标是再次审视同样的推理技术,但这次从内核级角度出发。

这将帮助我将高级推理服务器设计与底层 GPU 执行细节联系起来。

相似文章

标准GPU上的实时LLM推理:每请求3k tokens/秒

Hacker News Top

Kog AI 发布了 Kog Inference Engine 的技术预览版,通过协同设计模型架构、运行时和底层 GPU 代码,在标准数据中心 GPU 上实现了每请求 3,000 tokens/s 的性能,面向延迟敏感的 AI 代理工作流。