PRISM

Projection-based Relevance-Informed Steering Method

Training-free attention steering via differential cross-covariance decomposition.
Compatible with FlashAttention. Negligible memory overhead.

PRISM-Delta method overview

Method

Differential Cross-Covariance

SVD of $\Omega_\Delta = \Omega^+ - \Omega^-$ extracts directions that distinguish relevant from irrelevant conditions, automatically eliminating shared variance.

Softplus Head Weighting

Each attention head receives a continuous importance weight $w_{\ell,h} = \text{softplus}(D_{\ell,h} - \delta_{\min})$, replacing binary hard thresholds with smooth gradation.

K V

Dual-Channel Steering

Steers both Key (routing) and Value (content) channels simultaneously, enabling fine-grained control over information flow at inference time.

Spectral Learning of Relevance-Aligned Projections

Using token-level key embeddings from synthetic contrastive prompts — denoted $\mathbf{h}$ (neutral), $\mathbf{h}^+$ (positive), and $\mathbf{h}^-$ (negative) — we compute cross-covariance matrices for each transformer layer $\ell$ and head $h$:

$$ \Omega^+_{\ell,h} = \frac{\mathbf{h}^\top \mathbf{h}^+}{n}, \quad \Omega^-_{\ell,h} = \frac{\mathbf{h}^\top \mathbf{h}^-}{n} $$

SVD is then applied: $\Omega^+_{\ell,h} = U^+ S^+ V^{+\top}$, $\Omega^-_{\ell,h} = U^- S^- V^{-\top}$. The projection matrices are constructed from the top singular vectors:

$$ P^+_{\ell,h} = U^+_{\ell,h,:,:k^+} \left(U^+_{\ell,h,:,:k^+}\right)^\top, \quad P^-_{\ell,h} = U^-_{\ell,h,:,k^-:} \left(U^-_{\ell,h,:,k^-:}\right)^\top $$

where $k^+$ and $k^-$ are chosen such that they capture at least a proportion $\gamma$ of the total singular value sum:

$$ \frac{\sum_{i=1}^{k^+} S^+_{\ell,h,i}}{\sum_{i=1}^{d_k} S^+_{\ell,h,i}} \geq \gamma $$

Inference-Time Spectral Editing

During inference, learned projections are injected into key embeddings before attention scores are computed. For each token key $\mathbf{k}_j \in \mathbb{R}^{d_k}$ at layer $\ell$ and head $h$:

$$ \mathbf{k}_j' = \mathbf{k}_j + \frac{g^+ \cdot P^+_{\ell,h}\, \mathbf{k}_j + g^- \cdot P^-_{\ell,h}\, \mathbf{k}_j}{2} $$

where $g^+, g^-$ are independently adjustable scalars controlling positive and negative steering gains. This is algebraically equivalent to augmenting the attention score matrix $A$ with a low-rank relevance bias $B$:

$$ \text{Logits}_{ij} = \underbrace{\frac{\mathbf{q}_i^\top \mathbf{k}_j}{\sqrt{d_k}}}_{A_{ij}} + \underbrace{\frac{\mathbf{q}_i^\top \left(\frac{g^+ \cdot P^+_{\ell,h}\, \mathbf{k}_j + g^- \cdot P^-_{\ell,h}\, \mathbf{k}_j}{2}\right)}{\sqrt{d_k}}}_{B_{ij}} $$

Because the method operates entirely on key representations prior to attention computation, it requires no access to the attention matrix, making it inherently compatible with FlashAttention.

Query-Adaptive Expert Routing

The adaptive variant maintains multiple domain-specific expert projections $\{U^+_{m,\ell,h}\}_{m=1}^M$ and routes them based on query alignment. At inference time, the query vector $\mathbf{q}_{\ell,h}$ of the last prompt token is used to compute dynamic coefficients:

$$ \alpha_{m,\ell,h}(\mathbf{q}_{\ell,h}) = \frac{\sum_{k=1}^{K} (\mathbf{q}_{\ell,h}^{\top}\, \mathbf{u}^{+(k)}_{m,\ell,h}) \cdot \sigma^{+(k)}_{m,\ell,h}}{\max_m \left| \sum_{k=1}^{K} (\mathbf{q}_{\ell,h}^{\top}\, \mathbf{u}^{+(k)}_{m,\ell,h}) \cdot \sigma^{+(k)}_{m,\ell,h} \right|} $$

The final projection is a weighted combination of expert projections:

$$ P_{\text{dyn},\ell,h}(\mathbf{q}_{\ell,h}) = \sum_{m=1}^{M} \alpha_{m,\ell,h}(\mathbf{q}_{\ell,h}) \cdot U^+_{m,\ell,h,:,:K}\, (U^+_{m,\ell,h,:,:K})^\top $$

The key transformation during inference becomes: $\mathbf{k}_j' = \mathbf{k}_j + g \cdot P_{\text{dyn},\ell,h}(\mathbf{q}_{\ell,h})\, \mathbf{k}_j$.

Relevance-Sensitive Head Selection

Steering is most effective when applied selectively to KV heads that are naturally sensitive to prompt relevance. For each layer $\ell$ and head $h$, the average per-token $\ell_2$ distance is:

$$ D_{\ell,h} = \frac{1}{N} \sum_{i=1}^{N} \left\| \mathbf{h}^+_{\ell,h,i} - \mathbf{h}^-_{\ell,h,i} \right\|_2 $$

Projection is applied only if $D_{\ell,h} \geq \delta_{\min}$, where $\delta_{\min}$ is a tunable threshold. Mid-to-late layers consistently show larger separation, aligning with recent findings on retrieval head localisation.

At a Glance

5
Models Supported
Qwen3 (4B / 8B / 14B) · Gemma3 (4B / 12B)
3
Benchmarks
BiasBios · CounterFact · PronChange
0
Training Required
Pure inference-time · FlashAttention compatible