@JaydevTonde: https://x.com/JaydevTonde/status/2068361821002846418
Summary
A detailed tutorial on implementing CUDA Graphs in an LLM inference server Tokn, covering FastAPI server setup, engine initialization, and CUDA Graph capture for optimized decode phases.
View Cached Full Text
Cached at: 06/21/26, 04:33 AM
CUDA Graph implementation in LLM Inference server
I recently implemented CUDA Graphs in my LLM inference server, Tokn.
The implementation is minimal, but it is worth looking at if you want to understand the core idea behind how CUDA Graphs are used in larger inference servers like **vLLM **or SGLang.
1. FastAPI Server
This part contains the FastAPI server code. It accepts incoming prompts and adds those requests to the processing queue inside the generate_async function. The start_engine_loop runs inside the engine. Its responsibility is to continuously monitor the scheduler’s waiting queue and decide whether each request should go through the prefill or decode path.
python
@app.on_event(“startup”) async def start_engine_loop(): if engine.background_task is None: engine.background_task = asyncio.create_task(engine.run_loop())
@app.post(“/completions”) async def completions(req: CompletionRequest): print(f“\n\nRequest received for with the prompt : {req.prompts}“)
tasks = [
engine.generate_async(prompt)
for prompt in req.prompts
]
outputs = await asyncio.gather(*tasks)
return {
idx: output
for idx, output in enumerate(outputs)
}
- Engine initialization Inside the Engine class initialization, I call capture_cudagraph.
pythonif self.device.startswith(“cuda”) and not self.enforce_eager: self.capture_cudagraph()
3. Capture CUDA Graphs
This is the heart of the CUDA Graph implementation.
In this step, I first read the required configs, like max_num_seqs, which defines the maximum number of sequences that can be scheduled for prefill or decode.
Then, I create dummy buffers/context for running dummy prefill and decode passes. These dummy inputs are required because CUDA Graphs need a fixed execution pattern to capture.
After that, I create a list of batch sizes for which I want to capture CUDA Graphs. For decode, I mainly need batch size 1, because each decode step processes one token per sequence.
Finally, I iterate over each batch size, capture a CUDA Graph for that shape, and store it in a dictionary. Later, during inference, the server can pick the right graph based on the batch size and replay it instead of launching kernels one by one.
[email protected]_mode() def capture_cudagraph(self): config = self.hf_config hidden_size = config.hidden_size max_bs = self.scheduler.max_num_seqs max_num_blocks = (self.max_model_len + self.block_size - 1) // self.block_size device = torch.device(self.device)
# Static buffers — addresses are frozen for the lifetime of the graphs.
input_ids = torch.zeros(max_bs, dtype=torch.long, device=device)
positions = torch.zeros(max_bs, dtype=torch.long, device=device)
slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device=device)
context_lens = torch.zeros(max_bs, dtype=torch.int32, device=device)
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32, device=device)
outputs = torch.zeros(max_bs, hidden_size, dtype=self.dtype, device=device)
self.graph_bs = [bs for bs in (1,2,3,4) if bs <= max_bs]
if max_bs not in self.graph_bs:
self.graph_bs.append(max_bs)
for bs in reversed(self.graph_bs):
graph = torch.cuda.CUDAGraph()
set_context(
is_prefill=False,
slot_mapping=slot_mapping[:bs],
context_lens=context_lens[:bs],
block_tables=block_tables[:bs]
)
# Warmup run (lazy allocs / autotune happen OUTSIDE the graph).
outputs[:bs] = self.custom_model.model(input_ids[:bs], positions[:bs])
with torch.cuda.graph(graph, self.graph_pool):
outputs[:bs] = self.custom_model.model(input_ids[:bs], positions[:bs])
if self.graph_pool is None:
self.graph_pool = graph.pool()
self.graphs[bs] = graph
torch.cuda.synchronize()
reset_context()
self.graph_vars = dict(
input_ids=input_ids,
positions=positions,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
outputs=outputs
)
4. run_model
This is the main model forward pass.
In this function, I use the captured CUDA Graph for the decode phase, because decode usually follows a fixed execution pattern and is a good fit for graph replay.
For prefill, I use the normal model forward pass, since prefill can have more dynamic shapes depending on the prompt length and chunk size.
So the flow is simple:
-
Prefill: normal forward pass
-
Decode: CUDA Graph replay
This keeps the implementation minimal while still showing the core idea of how CUDA Graphs can reduce kernel launch overhead during decoding.
python def run_model(self, input_ids, positions, is_prefill):
if is_prefill or self.enforce_eager or not self.graphs or input_ids.size(0) > self.graph_bs[-1]:
return self.custom_model(input_ids, positions)
bs = input_ids.size(0)
context = get_context()
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
gv = self.graph_vars
gv["input_ids"][:bs] = input_ids
gv["positions"][:bs] = positions
gv["slot_mapping"].fill_(-1) # padding rows write nowhere
gv["slot_mapping"][:bs] = context.slot_mapping
gv["context_lens"].zero_() # padding rows attend to nothing
gv["context_lens"][:bs] = context.context_lens
gv["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
graph.replay()
return self.custom_model.compute_logits(gv["outputs"][:bs])
GitHub repo for my working LLM inference server, Tokn: https://github.com/jaytonde/Tokn
I will soon add benchmark results by comparing Tokn with official vLLM on the same dataset. This should give a clearer picture of where the current implementation stands and what needs to be optimized next.
That’s it for now. My next steps are to implement more LLM inference techniques in Tokn, mainly Distributed inference(TP), one speculative decoding and one quantization technique. I am planning to cover these over the next 15 days.
After that, I want to start learning C++, Triton, and CUDA programming more seriously. The goal is to revisit the same inference techniques again, but this time from the kernel-level perspective.
This should help me connect the high-level inference server design with the low-level GPU execution details.
Similar Articles
Every AI researcher should grasp inference acceleration—CUDA Graph is the heart of vLLM's GPU efficiency
A tweet urging AI researchers to learn inference-acceleration basics and spotlighting CUDA Graph as the key to vLLM’s GPU utilization.
@kazukifujii: Sakura Internet's Michishita-san's article comprehensively summarizes LLM Inference and comes highly recommended. It fe…
This article summarizes a presentation by Junda Chen on disaggregated inference for LLMs, explaining why goodput (throughput meeting latency SLOs) matters more than raw throughput, and how separating prefill and decode phases improves performance. It also highlights the influence on NVIDIA Dynamo.
An open handbook on LLM inference at scale (GPU internals, KV cache, batching, vLLM/SGLang/TensorRT-LLM) [P]
An open, in-progress handbook explaining LLM inference internals including GPU memory hierarchy, KV cache, batching, and popular inference engines like vLLM and TensorRT-LLM.
Real-time LLM Inference on Standard GPUs: 3k tokens/s per request
Kog AI launches a tech preview of the Kog Inference Engine, achieving 3,000 tokens/s per request on standard datacenter GPUs by co-designing model architecture, runtime, and low-level GPU code, targeting latency-critical AI agent workflows.
@TheAhmadOsman: Why do I focus on Inference Engines/Software Stacks for your hardware? - 2x RTX 3090s: ~14.5 tok/s → ~64 tok/s moving t…
Comparison of inference engine performance on different hardware: moving from baseline to vLLM with TP=2 on 2x RTX 3090s improves from ~14.5 tok/s to ~64 tok/s, and on RTX PRO 6000 moving to Sglang improves from ~32 tok/s to ~110 tok/s. Recommends vLLM/Sglang for CUDA/multi-GPU and llama.cpp for edge devices.