ecahLang

Building a Lightweight LLM Inference Engine

From naive inference to optimized continuous batching with FlashInfer

PyTorch FlashInfer HuggingFace FastAPI

~1,500 lines of code · Any HuggingFace CausalLM · OpenAI-compatible API

The Problem: LLM Inference is Slow

  • LLMs are autoregressive — one token at a time
  • 384 tokens = 384 sequential forward passes
  • CPU kernel launch overhead dominates small batches
  • At 128 users, everything compounds

Goal: minimize TTFT (Time-to-First-Token) and ITL (Inter-Token Latency)

Naive Baseline Performance
ITL (1 user)26.3 ms/token
E2E (1 user)10.1 seconds
ITL (128 users)120.8 ms/token
E2E (128 users)47.3 seconds
Throughput (128)1,040 tok/s
2.7x faster
after all optimizations

Two Phases of LLM Inference

Phase 1: Prefill Prompt processing (determines TTFT) Input: What is the capital of France ? all 7 tokens at once Layer 1 Layer N KV Cache stored Keys & Values for all 7 tokens (reused during decode) Output: "The" • Compute-bound (matrix multiplications) • Happens once per request • Processes all prompt tokens in parallel Phase 2: Decode Token generation (determines ITL) Step 1: The + KV lookup capital Step 2: capital + KV of Step 3: of + KV (growing) France KV Cache (grows each step) What is the capital of France ? The 1 token in → 1 token out × N steps • Memory-bandwidth bound (reads entire KV) • Happens N times (once per output token) • 1 token per forward pass Different bottlenecks → different optimizations

Architecture Overview

HTTP Client POST /chat/completions FastAPI AppOpenAI API · SSE streaming Prefill Queue Decode Queue asyncio.Queue asyncio.Queue process_queue(prefill=True)BatchPrefillWrapper process_queue(prefill=False)BatchDecodeWrapper HuggingFace Modelecah_attention() hook AutoKVCacheManager (Paged KV Cache) main.py main.py main.py manager.py Separate loops for each phase

Continuous Batching

Static Batching

Req A
500 tokens
Req B
200 tok
idle...
Req C
400 tok
Req D
blocked

All must finish before new requests start

Continuous Batching (ecahLang)

Req A
500 tokens
Req B
200 tok
Req C
400 tokens
Req D
joins when B finishes

Requests join/leave dynamically — no wasted GPU

  • Two independent asyncio.Queue loops — prefill and decode run as concurrent background tasks
  • Microsleep (0.1ms) groups incoming requests · Up to 128 requests per batch

Paged KV Cache — The Problem

KV Cache Memory per Token

Qwen2.5-3B: per-token KV size 36 layers × 4 heads × 128 dim × 2(K+V) × 2 bytes = 72 KB per token 1,000 tokens ≈ 72 MB · 2,048 tokens ≈ 147 MB

Naive Pre-allocation

Used200 tok
Wasted1848 tok
Wasted1848 tok
Used150 tok
Wasted1898 tok
Wasted1898 tok
128 reqs × 2048 max × 72KB = 18.4 GB wasted
Most requests use <10% of allocated memory

Paged KV Cache — The Solution

Physical KV Cache Memory (GPU) Block 016 tokens Block 116 tokens Block 23 tokens Block 316 tokens Block 44 tokens Free Free Free Request A (35 tokens, 3 blocks) Request B (20 tokens, 2 blocks) How it works: • Divide GPU memory into fixed-size blocks (16 tokens each) • Allocate blocks on-demand as tokens are generated · Free-list allocator: O(1) • Auto-sizing via pynvml: max_blocks = (free_gpu_mem × utilization) / (layers × per_block_bytes)
kv_cache = torch.zeros(num_layers, max_blocks, 2, block_size, num_kv_heads, head_dim)  # auto-computed from GPU memory

FlashInfer Integration

ecah_attention() — per layer query, key, value [1, H, L, D] Reshape → [L, H, D] Append K,V to Paged KV Cachemanager.append_paged_kv_cache() FlashInfer Paged Attentionwrapper.run(query, kv_cache[layer]) Reshape → [1, L, H, D] output
Key design: Register custom attention via HuggingFace AttentionInterface — works with any CausalLM without modifying model code.
# Register custom attention (once)
AttentionInterface.register(
    "ecah_attention", ecah_attention
)

# Load model — any CausalLM works
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    attn_implementation="ecah_attention",
)

Two FlashInfer Wrappers

WrapperPhase
BatchPrefillWithPagedKVCachePrefill (causal)
BatchDecodeWithPagedKVCacheDecode (single-token)

Metadata (kv_indices, kv_indptr, kv_last_page_len) pre-computed once per batch, reused across all layers.

Chunked Prefill — The Problem

At high concurrency, many prefill requests arrive simultaneously:

ONE GIANT PREFILL 128 requests × ~200 tokens = 25,600 tokens Decode queue blocked: waiting... waiting... waiting... . . . TTFT = 0.96s at 128 concurrency
Result: Large prefill batches monopolize the GPU. All decode requests for existing users are blocked, ITL spikes for everyone, and new users wait for the entire batch to complete before getting their first token.

Chunked Prefill — How It Works

Set a token budget per prefill round (max_prefill_tokens = 2048). Iterate through queued requests, adding each one until the budget would be exceeded.

Budget: 2048 tokens limit + req1 (200 tok) 200 total: 200 ✓ fits + req2 (150 tok) 200 150 total: 350 ✓ fits + req3 (300 tok) 200 150 300 total: 650 ✓ fits ... reqs keep filling... total: 1900 ✓ fits + req10 (250 tok) reqs 1–9 (1900 tokens) +250 = 2150 ✗ over budget! PROCESS NOW reqs 1–9 (1900 tokens) next_batch (no sleep!) reqs 10–128 → process immediately next loop reqs 1–9 start decoding while rest still prefill

Chunked Prefill — Impact

Early chunks finish and enter the decode queue while later chunks are still prefilling.

Without: ONE GIANT PREFILL: 25,600 tokens (0.96s TTFT) With: C1 C2 C3 C4 ... C13 decode decode decode Chunk 1 requests enter decode queue immediately — no waiting for C2–C13!
ConcurrencyBaseline TTFTChunked TTFTChange
10.046s0.041s-9%
160.146s0.122s-16%
640.493s0.396s-20%
1280.964s0.743s-23%
-23% TTFT
at 128 concurrency (0.96s → 0.74s)
Bonus: Also improves ITL indirectly. Smaller prefill chunks release the GPU sooner, so the decode loop gets serviced more frequently.
ITL at 128: baseline 120.75ms → chunked 43.04ms

CUDA Graphs — The Problem

During decode, each token = tiny GPU workload. But CPU launch overhead dominates:

CPU GPU prepare launch wait prepare launch wait prepare launch wait idle run idle! idle run idle! idle run idle!
Per decode step:
  • Prepare tensors: ~0.2ms
  • Kernel launch: ~0.5ms
  • Actual GPU compute: ~2ms
  • Synchronization: ~0.3ms

CPU overhead ≈ 33% of total time

GPU idle >60%
at batch_size=1, GPU waits for CPU

CUDA Graphs — The Solution

Record entire forward pass once at startup, replay with a single CPU call:

CPU GPU copy replay copy replay copy replay copy replay forward pass forward pass forward pass forward pass no idle gaps!
Capture → Replay CAPTURE at startup (once per bucket) REPLAY every decode step 1 CPU call = entire forward pass No CPU-GPU round-trips, no kernel launch overhead
class CUDAGraphDecodeWrapper:
    def warmup(self, bs, capture_stream, **inputs):
        self.static_inputs[bs] = {
            k: v.clone().cuda() for k,v in inputs.items()
        }
        for _ in range(2): self.fn(**self.static_inputs[bs])
        torch.cuda.synchronize()
        g = torch.cuda.CUDAGraph()
        with torch.cuda.graph(g, stream=capture_stream):
            out = self.fn(**self.static_inputs[bs])
        self.graphs[bs] = g

    def run(self, bs, **new_inputs):
        for k, v in new_inputs.items():
            self.static_inputs[bs][k].copy_(v)
        self.graphs[bs].replay()  # single call!
        return self.static_outputs[bs]

CUDA Graphs — Bucket Padding

CUDA Graphs require fixed tensor shapes. Batch sizes vary at runtime → round up to power-of-2 buckets and pad with dummy KV pages.
Req 1real
Req 2real
Req 3real
Req 4real
Req 5real
PADdummy
PADdummy
PADdummy

Batch of 5 → bucket 8 (next power of 2) · Padding results discarded

Bucket sizes

1
2
4
8
16
32
64
128

Per bucket: alloc dummy KV → plan wrapper → warmup 2x → capture graph → free dummies

At inference time

Batch arrives (N=5) bucket = 8 copy + replay take first 5 results

CUDA Graphs — Impact

ConcurrencyBaseline ITLCUDA Graph ITLSpeedup
126.30 ms8.06 ms3.3x
428.76 ms8.84 ms3.3x
857.36 ms19.02 ms3.0x
3271.55 ms22.03 ms3.3x
6498.22 ms27.69 ms3.5x
128120.75 ms44.19 ms2.7x
2.7–3.5x ITL
across all concurrency levels
10.1s → 3.1s E2E
at 1 user (3.2x faster)

torch.compile

Alternative to CUDA Graphs: JIT-compile the decode function. Same bucketing infrastructure, less rigid.
AspectCUDA Graphstorch.compile
MechanismRecord & replay opsJIT → optimized CUDA
FlexibilityFixed shapes onlyDynamic shapes possible
OverheadNear-zero replaySome per-call overhead
SpeedFasterGood middle ground
@torch.compiler.disable  # FlashInfer ops can't compile
def ecah_attention(...):
    ...

decode = torch.compile(decode, mode="default", dynamic=False)
Baselinetorch.compileCUDA Graphs
ITL @ 126.3 ms19.8 ms (1.3x)8.1 ms (3.3x)
ITL @ 128120.8 ms75.9 ms (1.6x)44.2 ms (2.7x)
ITL at 128 concurrency (lower is better) Baseline: 120.8 ms torch.compile: 75.9 ms CUDA Graph: 44.2 ms

Mutually exclusive (enforced). Same bucketing infra.

CUDA Stream Overlap

Operations on different CUDA streams execute concurrently. ecahLang uses two streams.

Without (sequential):

Single stream
plan
forward 15ms
plan
forward 15ms

With stream overlap:

Default
plan
CPU free
plan
CPU free
Forward
forward + sampling
forward + sampling
fwd_stream.wait_stream(torch.cuda.current_stream())  # ensure plan() finished
with torch.cuda.stream(fwd_stream):
    output = model.forward(...)    # GPU on forward stream
    logits = output.logits[...]    # sampling also on forward stream
# CPU free to collect next batch, detokenize, etc.

Two-Phase Batch Collection

While GPU processes current batch, CPU pre-collects the next batch from the queue.

Without overlap:

GPU
Forward A
idle
Forward B
idle
CPU
plan
idle
collect B
plan
idle
collect

With two-phase collection:

GPU
Forward Batch A
Forward Batch B
CPU
plan
collect B (while GPU runs)
plan
collect C (while GPU runs)
with torch.cuda.stream(fwd_stream):    # Phase 1: GPU working
    output = model.forward(...)
await asyncio.sleep(0)                  # Phase 2: yield to event loop
while not queue.empty():                # collect next batch while GPU runs
    next_batch.append(await queue.get())
fwd_stream.synchronize()                # Phase 3: GPU done, next batch ready
# Also: tokenizer.batch_decode() in thread pool via run_in_executor()

Pinned Sampling Buffers

Regular Transfer

CPU RAMallocate Stage Bufferextra copy! GPUblocks CPU Allocate each step + blocking copy

Pinned Memory (ecahLang)

Pinned CPU RAMnumpy view DMA GPUnon-blocking! Pre-allocated + direct transfer
# Startup: allocate ONCE
temp_cpu = torch.ones(max_batch, 1,
    dtype=torch.float32).pin_memory()

# Numpy views (zero-overhead scalar writes)
temp_np = temp_cpu.numpy()

# Persistent GPU buffers
temp_gpu = torch.ones(max_batch, 1,
    dtype=torch.float32, device="cuda")


# Each decode step:
# 1. Write via numpy (zero Python overhead)
temp_np[:n, 0] = temperatures

# 2. Async DMA copy (non-blocking)
temp_gpu[:n].copy_(
    temp_cpu[:n], non_blocking=True
)

# 3. GPU reads temp_gpu during sampling
# No allocation, no staging, no blocking!

Multi-Step Decode

Run N decode steps in one batch round-trip instead of returning to the event loop after each token.

multi_step=1 (default):

Pipeline
loop
Q
fwd
loop
Q
fwd
loop
Q
fwd

multi_step=4:

Pipeline
loop
Q
fwd 1
fwd 2
fwd 3
fwd 4
detok
loop

Active Sequence Tracking

StepSeq ASeq BSeq C
0"The""Hello""Yes"
1"end"<EOS>","
2"."skip"it"
3<EOS>skip"is"

Benefits

  • 1 queue round-trip for N tokens
  • Batch detokenization at the end
  • EOS-aware: inactive seqs skip forward passes
  • Not compatible with CUDA Graphs (enforced)

FlashInfer-Native Sampling

Standard: 5 Kernel Launches

logits Kernel 1: ÷ temperature Kernel 2: top-k filter Kernel 3: softmax Kernel 4: top-p filter Kernel 5: sample 5x CPU-GPU launch overhead

FlashInfer: 1 Fused Kernel

logits Single Fused CUDA Kernel top-k + top-p + sampling all ops fused on GPU token_id 1x launch, zero overhead
idx = flashinfer.sampling \
  .top_k_top_p_sampling_from_logits(
    logits, top_k=top_k_t,
    top_p=top_p_t, deterministic=True)

Benchmark Setup

Configuration

  • Model: Qwen2.5-3B-Instruct
  • Dtype: float16
  • Output: 384 tokens (ignore_eos)
  • Prompt: ~200 tokens

Concurrency

1, 2, 4, 8, 16, 32, 64, 128

All concurrent · SSE streaming

Compared Against

  • Baseline (no opts)
  • CUDA Graphs / torch.compile
  • Chunked Prefill
  • SGLang / vLLM
MetricWhat it measuresPhase
TTFTTime to First TokenPrefill speed
ITLInter-Token LatencyDecode speed
E2EEnd-to-End LatencyTotal request time
ThroughputTokens/secondSystem capacity

Benchmark — TTFT (Time to First Token)

Lower is better. Measures prefill speed.

ConcurrencyBaselineCUDA GraphChunked Prefilltorch.compileSGLangvLLM
10.046s0.037s0.041s0.051s0.104s0.061s
80.094s0.094s0.088s0.164s0.047s0.047s
320.253s0.264s0.226s0.271s0.117s0.171s
640.493s0.579s0.396s0.520s0.399s0.245s
1280.964s1.033s0.743s0.942s0.336s0.412s
Takeaway: Chunked Prefill is the clear TTFT winner within ecahLang — 23% improvement at 128 concurrency. CUDA Graphs are decode-only and slightly hurt TTFT. SGLang/vLLM excel at high concurrency (advanced prefill-decode interleaving).

Benchmark — ITL (Inter-Token Latency)

Lower is better. Measures decode speed.

ConcurrencyBaselineCUDA GraphChunked Prefilltorch.compileSGLangvLLM
126.30 ms8.06 ms7.84 ms19.79 ms3.85 ms3.86 ms
857.36 ms19.02 ms18.07 ms38.56 ms4.09 ms4.35 ms
3271.55 ms22.03 ms22.69 ms45.61 ms4.50 ms5.25 ms
6498.22 ms27.69 ms28.68 ms57.63 ms5.02 ms6.66 ms
128120.75 ms44.19 ms43.04 ms75.93 ms6.67 ms9.85 ms
Takeaway: CUDA Graphs give 2.7–3.5x ITL improvement. Chunked Prefill nearly matches at high concurrency — indirectly helps decode. SGLang/vLLM maintain sub-10ms ITL even at 128 (custom fused kernels + advanced scheduling).

Benchmark — E2E & Throughput

E2E Latency (lower is better)

Conc.BaselineBest ecahLangSGLang
110.12s3.05s1.58s
822.07s7.01s1.61s
3227.67s8.71s1.84s
12847.25s17.23s2.89s

Throughput (higher is better)

Conc.BaselineBest ecahLangSGLang
138126243
81394381,903
324441,4116,664
1281,0402,83516,940

Why the ~6x throughput gap vs SGLang/vLLM?

  • Advanced schedulers: interleave prefill + decode
  • Tensor parallelism (multi-GPU)
  • Prefix caching (KV reuse)
  • Custom fused CUDA kernels
  • Speculative decoding
  • ecahLang: simplicity & hackability (~1500 LOC)

Future Work — LMCache Integration

Goal: KV cache reuse across requests with common token prefixes (system prompts, RAG context). Skip prefill for cached portions → reduce TTFT.
Tokenize prompt LMCache Lookup: prefix cached? Cache HITLoad KV from CPU → GPU Cache MISSNormal full prefill LMCache Store (async)Save KV → CPU for future reuse

Expected Impact

  • 1000-token cached system prompt = skip ~1000 tokens of prefill
  • CPU memory tier with LRU eviction
  • Significant TTFT reduction for repeated prefixes

Compatibility

  • Chunked Prefill: reduces count before budgeting
  • CUDA Graphs: decode-only, LMCache is prefill-only
  • torch.compile: same as CUDA Graphs

All Optimizations

#OptimizationPhaseKey IdeaImpact
1Paged KV CacheBothOS-style virtual memory for KV128+ concurrent requests
2FlashInfer AttentionBothBatched paged attention kernelsFaster than SDPA
3Chunked PrefillPrefillToken-budgeted batch splitting-23% TTFT
4CUDA GraphsDecodeRecord & replay GPU ops2.7-3.5x ITL
5torch.compileDecodeJIT compile decode path1.6x ITL
6Stream OverlapBothConcurrent CUDA streams~5-10% latency
7Pinned BuffersDecodeDMA + numpy viewsZero-alloc per step
8Two-Phase BatchingBothCollect while GPU runsHide collection latency
9Multi-Step DecodeDecodeN steps per round-tripReduce queue overhead
10Fused SamplingDecodeSingle FlashInfer kernelFewer launches

Key Takeaways

Different phases need different optimizations
Prefill is compute-bound → chunked prefill
Decode is memory-bandwidth-bound → CUDA graphs, stream overlap, pinned buffers
Simplicity and hackability over maximum throughput
~1,500 LOC vs 100k+ in SGLang/vLLM · Works with any HuggingFace CausalLM · OpenAI-compatible API
2.8x ITL
@ 128 concurrency
2.7x E2E
@ 128 concurrency
2.7x Throughput
1,040 → 2,835 tok/s

github.com/Scicom-AI-Enterprise-Organization/ecahLang