Defeating nondeterminism in LLM inference: the case for batch-invariant kernels

ChineLLMNewsPerformance

Deterministic outputs with batch-invariant kernels

LLM endpoints feel nondeterministic to users primarily because kernel numerics change with batch size, not because of GPU “randomness.” By making RMSNorm, matmul, and attention kernels batch-invariant—especially via fixed split-size attention and consistent tiling—we can get bitwise-identical outputs at temperature 0. The payoff isn’t just neat unit tests: it unlocks true on-policy RL and keeps performance within striking distance of production baselines.

Points clés

  • Author and date: Horace He in collaboration with Thinking Machines (Sep 10, 2025), challenging the “concurrency + floating-point” explanation for LLM inference nondeterminism
  • Core diagnosis: forward-pass kernels are run-to-run deterministic, but non–batch-invariant numerics make outputs depend on server load (i.e., batch size), causing endpoint-level nondeterminism
  • RMSNorm strategy: use data-parallel reductions per batch element; avoid split reductions at small batch sizes or accept modest slowdowns to preserve batch invariance
  • Matmul strategy: fix kernel configuration (no Split-K, consistent tiling/tensor-core choices) to achieve batch invariance with roughly a 20% throughput hit versus cuBLAS
  • Attention strategy: adopt a fixed split-size Split-KV (FlashDecode) and unify KV cache layout before the kernel so reduction order is identical across prefill/decoding and batch sizes
  • Implementation: batch-invariant kernels integrated via torch.Library and vLLM’s FlexAttention; reference repo thinking-machines-lab/batch-invariant-ops with a “deterministic” vLLM example
  • Determinism test (Qwen/Qwen3-235B-A22B-Instruct-2507): 1,000 temperature-0 runs of “Tell me about Richard Feynman” yielded 80 unique completions; divergence began at token 103; 992 said “Queens, New York,” 8 said “New York City”; batch-invariant kernels produced 1,000/1,000 identical completions
  • Performance (Qwen-3-8B, single-GPU API, 1,000 sequences of 90–110 tokens): vLLM default 26 s; unoptimized deterministic 55 s; with improved attention 42 s—usable, with major overhead from unoptimized FlexAttention
  • True on-policy RL (RLVR on Bigmath, policy from Qwen 2.5-VL instruct 8B, max rollout 4096): without importance weighting, reward collapses and KL spikes; with importance weighting, KL ≈ 0.001 with spikes; with deterministic sampler/trainer, KL stays at 0 and training remains stable

À retenir

Want stable, repeatable LLM outputs without making an offering to the GPU gods? Use batch-invariant kernels across RMSNorm, matmul, and attention, lock your kernel configs, and pick fixed split sizes for decoding. Accept a modest speed tax now to avoid the far pricier “why did my model say New York City this time?” debugging later. And if you’re doing RL, enjoy the luxury of true on-policy training—because zero KL beats crossed fingers every day.

Sources