Some (too-)simple tricks for running a generative classifier (on CPU)

Jialin Lu, 2026-06-08

Code: github.com/LuxxxLucy/qwen3guard-test

TL;DR: So this blog happens as one of my colleagues was trying to assess the qwen3guard modelQwen3Guard is an (autoregressive) language model that outputs text, which is finetuned on safety corpus and reframed as a classifier. for whether it is of practical merits for scanning LLM prompts (our usecase). The results on in-house evals seem relatively good, but then he complained that this is just too slow to be useful in our usecaseHe adapted the original inference snippets in the README (see here). Also we do need to mention that we do not have GPUs in our usecase. In my experience, less than 0.8 1B model should actually run faster on CPU anyway so it should be fine. (caveat: not always, but mostly).

I then start digging into it and found that there are some really easy opportunities, like maybe if you read some basics of LLM inference you shall really not miss it.I only started reading LLM inference recently (like 2 months ago, when I realized that even in serving the same open source models, somehow some providers can make it much faster than others). But anyway I will lay down these really simple and apparent tricks here, and the results are pretty good: we can get about 8.5x speedup on CPU (across different runtimes and similarly in GPU as well) with the same full precision fp32 model, and quantization can further halve the latency.

A classifier of text

The task we consider is simple, we read a string of text and we want to label it as one of the 3 categories: safe, unsafe, or controversial.

In an oversimplified textbook setting, this classifier would be defined as a function that takes an input and turns it into a vector of (unnormalized) scores with the dimension being (here ). A softmax turns the scores into a proxy of probability for each label, and the largest is the prediction:

The case of a generative classifier is different though. As the name suggests, it uses a generative model of text. It instead re-uses an autoregressive language model, that takes an input string and extends it via next-token prediction so we have more text generated and concatenated in the end. Once the generation is done, we search the generated text, if the word “safe” existed in the generated text, then we label it as “safe”. Of course, this means the model would have good prompts and finetuning so that this instruction spec is followed.

We need to note that several models use this recipe, such as Llama GuardMeta Llama Team. 2024. “Llama Guard 3-8b Model Card.” , ShieldGemmaZeng, Wenjun, and others. 2024. “Shieldgemma: Generative AI Content Moderation Based on Gemma.” , and Qwen3GuardZhao, Haiyang, and others. 2025. “Qwen3guard Technical Report.” . I think the main reason for this design is that we can assume that the base model is already pretrained on a lot of language and world knowledge, so it has the capacity to be used as a good and robust (i.e. generalizable) classifier with just a little tuning.

One additional advantage, that makes it particularly interesting, is the user experience: Now we can write human language for what is considered unsafe, a.k.a in-context learning, as part of the prompt. This makes especially the network administrator happy as now finally in all these years they can write policies in an easy way, and that this policy can also change on the fly without retraining the model, which is a huge plus.

Analysis, breakdowns and tricks

So we have an autoregressive language model, reframed as a classifier. Now let us take a closer look at the default way of how this model is used: From the hugging face readme, the default way to use the model is to call generate() with a prompt that includes the system prompt and the chat template, and then parse the generated text to get the verdict.

input
→ render chat template
→ prefill 1 forward pass over the whole prompt
→ decode × 9 steps 1 forward pass per token: "Safety: unsafe\nCategories: …"
→ regex-parse the text
→ verdict (+ categories)

If we look closely enough we will understand that something is off, there is redundant computation here.

L1: forced prefix

Our first instinct is that the model writes Safety: before the verdict anyway, so generating it is wasted work. In fact, this can be seen as an extreme version of constrained decoding, like we already know (and actually finetuned the model to do this), this makes the additional decoding work for generating Safety: really unnecessary.

We can fix it simply by enforcing this part of the text instead, which we call prefix enforcing. We can simply treat it as if it is part of the input: append Safety: to the prompt, run one forward pass, and read the next-token distribution at the last position. The entirety of the decoding step is removed and now it becomes part of the prefill.

If we only want the verdict we can stop here; the Categories: line never needs to be decoded. In fact we know in real traffic, the benign samples must outnumber the malicious samples by several magnitudes, so we can even just check for Safety: unsafe and skip the Categories: line entirely (make the latter part only conditional compute). In this way, even more of the decode loop is eliminated: we need now only one forward pass instead of about ten.ShieldGemma’s card publishes the identical recipe; Llama Guard’s is the first-token-logit variant. The model would write Safety: unsafe, then a Categories: line, about nine tokens, so generate() runs about ten forward passes.

Figure 1: generate() runs a prefill pass then nine decode steps, one forward pass each; the forced-prefix path runs one prefill and reads the verdict, since the label is fixed by the end of the first step.

L2: LM-head trimming

We still have more redundant computation.

Here we refer to lm_head as the final MLP layer that projects the final embedding into the token logit space. If we have a text of length , then this means we need to project every one of the positions onto the full 150,000-token vocabulary, but really that is not needed.

I am actually surprised to find that this is the default behavior, but then I understand that PyTorch is ultimately a framework for training and this is actually needed and makes sense. But in inference this is not really needed in two perspectives:

  1. first, we only care about the last position, so the projections at the other positions are wasted work. Only of the is neededIn PyTorch this is logits_to_keep=1; in ONNX a slice node on the graph; llama.cpp and most runtimes already return only the last position..
  2. Even for that one last position, we are actually only interested in the three label tokens, so projecting onto the whole vocabulary is also wasted work. Only of the are needed.

I mean this seems really obvious and really small, but when we talk about this 0.6B model, this is really wasted compute that cannot and should not be ignored.

Figure 2: The read itself: at the last position the lm_head gives a distribution over the whole vocabulary, and we keep the three label tokens and renormalize to get .

This means a much smaller multiplication and an updated version of Equation 1, the classification head really shall be three rows of the lm_head all along if we only care about the classification.

Figure 3: The optimized read (L1 + L2): the forced Safety: prefix, one forward pass, and the lm_head run at the last position only. The projections at every other position, dashed, are the work L2 skips; the decode loop, gone, is the work L1 skips.

L3: KV cache

Now this is the usual game, we can cache the KVs for much of the system prompt. This should need no explanation. It does not even need one more data copy but just makes one persistently in memory and that would work. The real layout is a system-prompt head, then the user text, then a system-prompt tail and the forced Safety: . Only the head is a constant prefix, so only its keys and values are cacheable; the tail sits after the variable user text and will be recomputed. The diagrams simplify this to one fixed prefix.

Figure 4: The optimized read with L3: the fixed prefix is cached, so the transformer runs over the suffix only. L1 and L2 still hold, one forward pass and a single last-position lm_head.

Results

First, correctness: the optimization should not introduce any errors, and it should be: on every sample the optimized path returns the same verdict as the model-card path, this is actually exact.

The tricks change how much computation is run. That means they hold on any backend and in any precision, and the savings are largest where each forward pass is expensive, which is often determined by the memory bandwidth for moving things in and out between the CPU cache and memory. So as long as we are using the same machine, different runtime backends should have similar speedup (the overhead of each runtime should be similar).

Here we test and present the results with Qwen3Guard-Gen-0.6B, batch one inference, on sixteen cores of a Kunpeng 920 server CPUThe work per call is small, so going past 12 cores gives diminishing returns, but I just settled with 16 cores as it seems a good and reasonable number., across three backends: PyTorch, ONNX Runtime,Microsoft. 2018. “ONNX Runtime.” and llama.cpp.Gerganov, Georgi, and the llama.cpp contributors. 2023. “Llama.cpp.” The input is a few hundred tokens, and each number is the median of 100 timed calls after basic warmups.

Figure 5: The ladder on one backend (PyTorch fp32). The decode loop, removed by L1, is most of the cost; L2 and L3 take the rest.

Table 1: Per-call latency, p50 ms, walking the ladder on three backends. L0 is the model-card generate() path, the same code the Qwen3Guard-Gen card shows. llama.cpp returns only the last position by default, so L2 is already in its baseline; that row is also 8-bit quantized, which is why it starts lower.

backend L0 baseline +L1 +L2 +L3
PyTorch fp32 2148 688 555 408
ONNX fp32 1671 598 485 253
llama.cpp q8_0 643 261 129

On every backend L1 removes most of the time, then L2 and L3 take more off: the three tricks bring PyTorch from 2148 to 408 ms, about five times faster. If we keep the tricks and switch to a faster fp32 runtime, ONNX reaches 253 ms, 8.5 times faster than the model-card reference, still at fp32 and still returning identical verdicts.

Quantization is a separate dimension orthogonal to the three tricks, and it is the obvious next thing to try. Storing the weights in 8 bits instead of 32 shrinks the model and speeds up every matrix multiply, on top of the tricks already in place.

Table 2: Best fp32 path against two 8-bit paths, all with the tricks applied, p50 ms. The reference is the 2148 ms model-card path.

config (with the tricks) p50 ms vs. reference
ONNX fp32 (no quant) 253 8.5×
ONNX int8 164 13×
llama.cpp q8_0 129 17×

The fp32 paths reproduce the reference verdict on every sample. The 8-bit paths do not: int8 agrees with fp32 on about 98 of 100 inputs, and the two it misses are borderline, near the safe/controversial line.Weight-only 8-bit quantization, fp32 accumulation. The drift is a fraction of a logit, enough to flip a verdict only where the top two labels were already almost tied. For most uses that is a fine trade; where no borderline label can move, fp32 stays exact and keeps the 8.5×.

The full results across every backend we measured are listed below. Here we test with two system prompt templates: original (the official 296-token system prompt from Qwen3Guard) and test-200 (a compressed and simplified policy prompt, about 130 tokens).

backend variant original (p50 / p99) test-200 (p50 / p99)
pytorch fp32 L0 2148.1 / 2830.1 875.8 / 1230.0
+L1 forced prefix 687.6 / 790.8 423.7 / 433.0
+L2 LM-head trimming 554.9 / 711.0 352.8 / 360.1
+L3 KV cache 407.5 / 428.9 310.6 / 330.1
onnx fp32 L0 1670.6 / 1709.6 1277.5 / 1308.7
+L1 forced prefix 598.4 / 620.2 315.0 / 327.9
+L2 LM-head trimming 485.1 / 502.8 239.5 / 254.8
+L3 KV cache 253.2 / 265.5 147.6 / 160.7
onnx int8 L0 2136.3 / 2155.9 1842.4 / 1865.4
+L1 forced prefix 382.0 / 389.2 209.1 / 215.8
+L2 LM-head trimming 280.1 / 286.7 155.6 / 161.3
+L3 KV cache 163.7 / 167.9 113.9 / 118.5
llamacpp f32 (L2 baked) L0 1589.7 / 1625.2 1239.7 / 1282.7
+L1 forced prefix 719.6 / 750.6 426.7 / 467.5
+L3 KV cache 434.0 / 458.9 317.1 / 394.1
llamacpp f32 +kernel-opt (L2 baked) L0 1278.0 / 1292.5 966.6 / 974.1
+L1 forced prefix 511.9 / 523.5 237.5 / 268.3
+L3 KV cache 242.2 / 249.0 147.8 / 156.7
llamacpp f16 (L2 baked) L0 1496.6 / 1527.9 1156.0 / 1189.0
+L1 forced prefix 928.4 / 960.5 619.9 / 652.4
+L3 KV cache 653.4 / 691.0 542.6 / 571.8
llamacpp q8_0 (L2 baked) L0 643.1 / 650.9 437.4 / 445.2
+L1 forced prefix 261.2 / 273.5 111.5 / 115.5
+L3 KV cache 128.7 / 133.1 69.8 / 74.1
rust-candle fp32 L0 6149.1 / 6227.9 5205.9 / 5252.7
+L1 forced prefix 1270.6 / 1346.8 536.3 / 550.3
+L3 KV cache 726.5 / 769.8 374.1 / 388.4
ctranslate2 fp32 (L2 baked) L0
+L1 forced prefix 1718.3 / 1780.0 973.2 / 991.1
mnn-llm fp16 (L2 baked) L0 1336.8 / 1431.6 1037.8 / 1127.1
+L1 forced prefix 571.1 / 586.7 287.9 / 301.4
The full CPU sweep on Kunpeng 920 aarch64, p50 / p99 ms, batch one, 16 threads. Rows are cumulative within a backend; “(L2 baked)” means the backend already returns only the last position, so it has no separate +L2 row.

Wrap up

Here we listed three tricks for optimizing Qwen3Guard, a generative classifier, that are quite simple and apparent once you look at it. The three tricks (forced prefix, LM-head trimming, KV cache) take a two-second CPU call down to about 250 ms, that is 8.5 times faster without quantization; quantization halves it further.

In fact, these tricks are so apparent that if we look at vLLM, perhaps all these tricks are already implemented anyway, so really, these are apparent tricksAnd indeed, vLLM implemented them.

Bibliography

  • Gerganov, Georgi, and the llama.cpp contributors. 2023. “Llama.cpp.”.
  • Meta Llama Team. 2024. “Llama Guard 3-8b Model Card.”.
  • Microsoft. 2018. “ONNX Runtime.”.
  • Zeng, Wenjun, and others. 2024. “Shieldgemma: Generative AI Content Moderation Based on Gemma.”.
  • Zhao, Haiyang, and others. 2025. “Qwen3guard Technical Report.”.

Appendix: GPU results

Additional results on GPU are presented here. Our main hardware is an RTX 3090.

On an RTX 3090, Qwen3Guard-Gen-0.6B with L1 runs about 29 ms p50 against 237 ms for the model-card path at a comparable input length. L2 and L3 do not help on the GPU, where the vocabulary projection and prompt re-reading are cheap next to the per-call overhead. On CPU the opposite holds, which is why the full ladder matters there.

The stage-by-stage cost at a representative input, about 369 tokens, on the 3090 makes this plain: prefill is about 21 ms, then nine decode steps at roughly 17 ms each. That decode loop is most of the default path’s 237 ms, and the forced-prefix path drops it entirely, leaving prefill and a sub-millisecond read, about 29 ms.

Figure 6: Qwen3Guard-Gen-0.6B on an RTX 3090: the model-card path against the forced-prefix path (L1), p99 latency across input lengths. The forced-prefix path stays under a 200 ms budget out to about 2048 input tokens.

The same construction works on the larger sizes: with the forced-prefix path, 0.6B stays under a 200 ms p99 budget up to about 2048 input tokens, 4B up to about 256, and 8B up to 128.

Figure 7: The forced-prefix path across the three Qwen3Guard-Gen sizes (0.6B / 4B / 8B) on an RTX 3090. All three get the same trick; the 0.6B clears a 200 ms budget at the longest inputs, the larger two hit it sooner.