By Jeffrey Emanuel (and various collaborators of the electronic persuasion)
Written on April 1st, 2025
Transformer-based large language models (LLMs) face two significant limitations that restrict their capabilities:
-
Lack of Introspection: Unless specifically instrumented, transformer-based LLMs have no ability to explicitly access their own internal states—the activations in their feed-forward layers, attention mechanisms, and other components. This opacity hinders mechanistic interpretability, self-monitoring, and dynamic reasoning.
-
Ephemeral Cognition: Most LLM "thinking" is fleeting—activations across billions of parameters that change during forward passes as the model processes tokens. Recording this data naively is computationally prohibitive due to its sheer volume.
These limitations have profound implications for interpretability, debugging, and developing more capable AI systems. This article proposes a novel approach to address both problems simultaneously.
Large transformer models generate massive volumes of intermediate data during inference. Each token step produces new hidden states, attention maps, and cached key/value tensors. These are ephemeral by design: they're discarded after each forward pass, with no built-in mechanism for inspection, rollback, or resumption.
Naively saving the full state at each step is computationally prohibitive. A model like GPT-3, storing full activations and attention caches per token, can consume hundreds of megabytes per sequence. Existing approaches like PCA, quantization, or simple delta encoding are lossy and often irreversible, making them unsuitable for applications requiring high-fidelity recovery.
We lack a practical way to pause, inspect, and replay a model's internal state with precision.
Despite their high dimensionality, transformer activations likely occupy a small portion of the possible state space. They appear to live on a lower-dimensional, structured manifold shaped by several factors:
- Pretraining Dynamics: Models learn to represent language efficiently, creating structured internal representations.
- Architectural Constraints: Attention mechanisms and layer normalization impose patterns on activation distributions.
- Semantic Priors: Natural language has inherent structure that shapes model activations.
- Task-Driven Optimization: Fine-tuning carves task-specific trajectories through this space.
This hypothesis draws from observations in neural network representations and suggests that transformer states could be compressed into smaller latent representations without losing critical information, much like a map reduces a terrain to key coordinates.
This raises a compelling possibility: what if we could encode those internal states directly onto this manifold? Instead of treating the activations as raw data, we could represent them as coordinates on a latent terrain.
Think of a transformer as a single-player game engine. Each inference step is like a frame rendered during gameplay. Normally, you don't save every frame—you save the game state: player position, inventory, mission flags, world state. This compact representation allows you to stop, rewind, branch, or resume seamlessly.
We want the same thing for transformer inference: a way to save the complete thought state at a given point in a sequence, using as little space as possible, but with the ability to reconstruct it with high fidelity later.
We propose a system for high-efficiency introspective compression, built around a learned latent manifold of transformer states. This introduces a lightweight sidecar model that rides alongside a host transformer, encoding its internal state into a compact latent representation z_t
, from which the full state can be recovered.
-
Main Transformer (
T_main
): A frozen pretrained model (e.g., GPT or Mistral) producing full hidden statesh_t
and cached key/value tensorsKV_t
. -
Sidecar Encoder (
E
): A model that takes the current token, prior latent codez_{t-1}
, and a tap into a subset ofT_main
's hidden states to output a new latent codez_t
. -
Sidecar Decoder (
D
): A decoder that reconstructs the hidden states and key/value tensors fromz_t
.
For simplicity, the prototype uses feed-forward networks for E and D, though future iterations could explore attention-based or recurrent architectures to capture sequential dependencies more effectively.
For clarity, we define the internal state we aim to compress as:
- Hidden States: The activations from selected transformer layers (not necessarily all layers)
- Key/Value Cache: The cached attention tensors needed for efficient autoregressive generation
- Additional Context: Any model-specific state needed for exact resumption of inference
This definition is important because reconstructing only partial internal state would limit the usefulness of the approach.
The encoder and decoder are trained to model the latent manifold of transformer states:
- Run a sequence through
T_main
to obtain ground-truthh_t
,KV_t
- Compute
z_t = E(x_t, z_{t-1}, tap(h_t))
- Decode via
D(z_t)
to getĥ_t
,KV̂_t
- Optimize a loss function:
Loss = λ₁||h_t - ĥ_t||² + λ₂||KV_t - KV̂_t||² + λ₃R(z_t)
Where R(z_t)
is a regularization term that encourages z_t
to live on a structured, low-entropy manifold. Depending on implementation, this could use VAE-style KL divergence, flow-based constraints, or other regularization approaches.
Training could use datasets like OpenWebText or task-specific corpora, with optimization via standard methods (e.g., Adam, learning rate ~1e-4).
It's important to clarify that "high-fidelity reconstruction" rather than "exact reconstruction" is the realistic target. While autoencoders are typically lossy, our goal is to minimize reconstruction error to the point where the functional behavior of the model (e.g., next-token prediction) is preserved. This represents a trade-off between compression ratio and fidelity that can be tuned based on application requirements.
Building on our initial prototype, we now present a comprehensive implementation strategy for compressing the entire transformer state, including all hidden layers and KV caches. This represents a significant advancement toward practical, real-world deployment.
For complete state capture and reconstruction, we must determine how to structure the sidecar encoder-decoder system. We explore three architectural strategies:
import torch, json, os
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from collections import defaultdict
import numpy as np
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1",
torch_dtype=torch.float16,
device_map="auto")
model.eval()
# Configuration
hidden_dim = 4096 # Mistral's hidden dimension
n_layers = 32 # Number of layers in Mistral
latent_dim = 256 # Compressed dimension per layer
kv_cache_latent_ratio = 0.1 # Compression ratio for KV cache
class LayerSpecificEncoderDecoder(nn.Module):
"""One encoder-decoder pair for each transformer layer"""
def __init__(self, n_layers, hidden_dim, latent_dim):
super().__init__()
self.encoders = nn.ModuleList([
nn.Sequential(
nn.Linear(hidden_dim, 1024),
nn.GELU(),
nn.LayerNorm(1024),
nn.Linear(1024, latent_dim)
) for _ in range(n_layers)
])
self.decoders = nn.ModuleList([
nn.Sequential(
nn.Linear(latent_dim, 1024),
nn.GELU(),
nn.LayerNorm(1024),
nn.Linear(1024, hidden_dim)
) for _ in range(n_layers)
])
# KV cache encoder/decoder (handles growing sequence length)
# More sophisticated than hidden state E/D to handle variable sizes
self.kv_encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=8,
dim_feedforward=1024,
batch_first=True
), num_layers=2
)
self.kv_proj = nn.Linear(hidden_dim, int(hidden_dim * kv_cache_latent_ratio))
self.kv_unproj = nn.Linear(int(hidden_dim * kv_cache_latent_ratio), hidden_dim)
self.kv_decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(
d_model=hidden_dim,
nhead=8,
dim_feedforward=1024,
batch_first=True
), num_layers=2
)
def encode_hidden(self, hidden_states):
"""Encode hidden states from all layers"""
return [encoder(h) for encoder, h in zip(self.encoders, hidden_states)]
def decode_hidden(self, latents):
"""Decode compressed representations back to hidden states"""
return [decoder(z) for decoder, z in zip(self.decoders, latents)]
def encode_kv_cache(self, kv_cache):
"""Compress KV cache (more complex due to variable size)"""
# For each layer, head
compressed_kv = {}
for layer_idx, layer_cache in kv_cache.items():
compressed_kv[layer_idx] = {}
for head_idx, (k, v) in layer_cache.items():
# Shape: [batch, seq_len, head_dim]
# Apply transformer to get contextual representation
k_context = self.kv_encoder(k)
v_context = self.kv_encoder(v)
# Project to smaller dimension
k_compressed = self.kv_proj(k_context)
v_compressed = self.kv_proj(v_context)
compressed_kv[layer_idx][head_idx] = (k_compressed, v_compressed)
return compressed_kv
def decode_kv_cache(self, compressed_kv, seq_len):
"""Decompress KV cache back to original format"""
decompressed_kv = {}
for layer_idx, layer_cache in compressed_kv.items():
decompressed_kv[layer_idx] = {}
for head_idx, (k_comp, v_comp) in layer_cache.items():
# Expand back to original dimension
k_expanded = self.kv_unproj(k_comp)
v_expanded = self.kv_unproj(v_comp)
# Use transformer decoder with positional cues to restore sequence
# We provide a sequence length tensor as the memory for the decoder
pos_cue = torch.zeros(1, seq_len, k_expanded.size(-1)).to(k_expanded.device)
k_decompressed = self.kv_decoder(k_expanded, pos_cue)
v_decompressed = self.kv_decoder(v_expanded, pos_cue)
decompressed_kv[layer_idx][head_idx] = (k_decompressed, v_decompressed)
return decompressed_kv
# Initialize the full-state compression system
compressor = LayerSpecificEncoderDecoder(n_layers, hidden_dim, latent_dim)
# Hook into all model layers to capture hidden states
hidden_states = [[] for _ in range(n_layers)]
hooks = []
def create_hook_fn(layer_idx):
def hook_fn(module, input, output):
hidden_states[layer_idx].append(output.detach().to(torch.float32))
return hook_fn
# Register hooks for all layers
for i in range(n_layers):
hook = model.model.layers[i].register_forward_hook(create_hook_fn(i))
hooks.append(hook)
# Function to extract KV cache from the model
def extract_kv_cache(model):
"""Extract key-value cache from model's attention modules"""
kv_cache = {}
for i, layer in enumerate(model.model.layers):
kv_cache[i] = {}
for h, head in enumerate(layer.self_attn.heads):
# In a real implementation, there would be a way to access
# the actual KV cache. This is simplified.
k = torch.randn(1, 10, head.head_dim) # Placeholder
v = torch.randn(1, 10, head.head_dim) # Placeholder
kv_cache[i][h] = (k, v)
return kv_cache
# Step 1: Run inference and capture all hidden states and KV cache
input_text = "The cat sat on the mat."
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
with torch.no_grad():
# Clear previous activations
for states in hidden_states:
states.clear()
# Run model inference
model(**inputs)
# Extract KV cache
kv_cache = extract_kv_cache(model)
# Process hidden states (convert list of activations → tensor)
processed_hiddens = []
for layer_states in hidden_states:
# Stack sequence length dimension
layer_tensor = torch.stack(layer_states[0], dim=0)
processed_hiddens.append(layer_tensor)
# Step 2: Compress full state
compressed_hiddens = compressor.encode_hidden(processed_hiddens)
compressed_kv = compressor.encode_kv_cache(kv_cache)
# Step 3: Save compressed state
save_dir = "./compressed_state"
os.makedirs(save_dir, exist_ok=True)
torch.save(compressed_hiddens, os.path.join(save_dir, "compressed_hiddens.pt"))
torch.save(compressed_kv, os.path.join(save_dir, "compressed_kv.pt"))
torch.save(inputs["input_ids"], os.path.join(save_dir, "input_ids.pt"))
# Step 4: Reconstruct
seq_len = inputs["input_ids"].size(1)
reconstructed_hiddens = compressor.decode_hidden(compressed_hiddens)
reconstructed_kv = compressor.decode_kv_cache(compressed_kv, seq_len)
# Evaluate reconstruction quality
mse_per_layer = []
for i, (original, reconstructed) in enumerate(zip(processed_hiddens, reconstructed_hiddens)):
mse = nn.MSELoss()(original, reconstructed).item()
mse_per_layer.append(mse)
print(f"Layer {i} MSE: {mse:.6f}")
print(f"Average MSE across layers: {np.mean(mse_per_layer):.6f}")
# Clean up hooks
for hook in hooks:
hook.remove()
class GroupedLayerCompressor(nn.Module):
"""Compress K layers with each encoder-decoder pair"""
def __init__(self, n_layers, hidden_dim, latent_dim, group_size=4):
super().__init__()
self.n_groups = (n_layers + group_size - 1) // group_size # Ceiling division
self.group_size = group_size
# Create encoder/decoder for each group of layers
self.group_encoders = nn.ModuleList([
nn.Sequential(
nn.Linear(hidden_dim * min(group_size, n_layers - i * group_size), 2048),
nn.GELU(),
nn.LayerNorm(2048),
nn.Linear(2048, latent_dim * min(group_size, n_layers - i * group_size))
) for i in range(self.n_groups)
])
self.group_decoders = nn.ModuleList([
nn.Sequential(
nn.Linear(latent_dim * min(group_size, n_layers - i * group_size), 2048),
nn.GELU(),
nn.LayerNorm(2048),
nn.Linear(2048, hidden_dim * min(group_size, n_layers - i * group_size))
) for i in range(self.n_groups)
])
# Similar KV cache handling as option 1...
# (KV cache code omitted for brevity but would be similar)
def encode_hidden(self, hidden_states):
"""Encode hidden states by groups"""
latents = []
for group_idx in range(self.n_groups):
start_idx = group_idx * self.group_size
end_idx = min(start_idx + self.group_size, len(hidden_states))
# Concatenate group's hidden states for each token
group_states = []
seq_len = hidden_states[0].size(0)
for token_idx in range(seq_len):
token_group_states = torch.cat([
hidden_states[layer_idx][token_idx]
for layer_idx in range(start_idx, end_idx)
])
group_states.append(token_group_states)
group_input = torch.stack(group_states)
group_latent = self.group_encoders[group_idx](group_input)
# Split encoded representation back into per-layer latents
layers_in_group = end_idx - start_idx
latent_per_layer = group_latent.chunk(layers_in_group, dim=-1)
latents.extend(latent_per_layer)
return latents
def decode_hidden(self, latents):
"""Decode latents back to hidden states"""
reconstructed = []
for group_idx in range(self.n_groups):
start_idx = group_idx * self.group_size
end_idx = min(start_idx + self.group_size, len(latents))
# Concatenate group's latents
seq_len = latents[0].size(0)
group_latents = []
for token_idx in range(seq_len):
token_group_latents = torch.cat([
latents[layer_idx][token_idx]
for layer_idx in range(start_idx, end_idx)
])
group_latents.append(token_group_latents)
group_latent_input = torch.stack(group_latents)
group_reconstruction = self.group_decoders[group_idx](group_latent_input)
# Split reconstruction back into per-layer hidden states
layers_in_group = end_idx - start_idx
hidden_per_layer = group_reconstruction.chunk(layers_in_group, dim=-1)
reconstructed.extend(hidden_per_layer)
return reconstructed
class UnifiedStateCompressor(nn.Module):
"""One large encoder-decoder for all layers"""
def __init__(self, n_layers, hidden_dim, latent_dim_per_layer):
super().__init__()
self.n_layers = n_layers
self.hidden_dim = hidden_dim
self.total_latent_dim = latent_dim_per_layer * n_layers
# Attention-based encoder to capture cross-layer dependencies
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=8,
dim_feedforward=4096,
batch_first=True
)
self.cross_layer_encoder = nn.TransformerEncoder(
encoder_layer, num_layers=3
)
# Projection to latent space
self.encoder_proj = nn.Sequential(
nn.Linear(hidden_dim * n_layers, 4096),
nn.GELU(),
nn.LayerNorm(4096),
nn.Linear(4096, self.total_latent_dim)
)
# Decoder architecture
decoder_layer = nn.TransformerDecoderLayer(
d_model=hidden_dim,
nhead=8,
dim_feedforward=4096,
batch_first=True
)
self.cross_layer_decoder = nn.TransformerDecoder(
decoder_layer, num_layers=3
)
# Projection from latent space
self.decoder_proj = nn.Sequential(
nn.Linear(self.total_latent_dim, 4096),
nn.GELU(),
nn.LayerNorm(4096),
nn.Linear(4096, hidden_dim * n_layers)
)
# Layer embedding to help the model differentiate layers
self.layer_embedding = nn.Embedding(n_layers, hidden_dim)
# KV cache handling components would follow
# (omitted for brevity but would be similar to previous options)
def encode_hidden(self, hidden_states):
"""Encode all hidden states into a unified latent representation"""
batch_size, seq_len = hidden_states[0].size(0), hidden_states[0].size(1)
# First process each layer with cross-attention
processed_layers = []
for i, h in enumerate(hidden_states):
# Add layer positional embedding
layer_pos = self.layer_embedding(torch.tensor([i], device=h.device))
h_with_pos = h + layer_pos.unsqueeze(1).expand(-1, seq_len, -1)
processed = self.cross_layer_encoder(h_with_pos)
processed_layers.append(processed)
# Stack all layers for each token
token_wise_concatenated = []
for token_idx in range(seq_len):
token_states = torch.cat([
layer[:, token_idx, :] for layer in processed_layers
], dim=-1)
token_wise_concatenated.append(token_states)
token_wise_concatenated = torch.stack(token_wise_concatenated)
# Project to latent space
unified_latent = self.encoder_proj(token_wise_concatenated)
# Return as a single tensor rather than per-layer
return unified_latent
def decode_hidden(self, unified_latent):
"""Decode unified latent back to per-layer hidden states"""
seq_len = unified_latent.size(0)
# Project back to concatenated hidden dimension
expanded = self.decoder_proj(unified_latent)
# Split into per-layer representations
layer_chunks = expanded.chunk(self.n_layers, dim=-1)
# Process each layer with the decoder
reconstructed_layers = []
for i, chunk in enumerate(layer_chunks):
# Add layer positional embedding
layer_pos = self.layer_embedding(torch.tensor([i], device=chunk.device))
chunk_with_pos = chunk + layer_pos.unsqueeze(1).expand(-1, seq_len, -1)
# Generate positional memory for decoder
pos_memory = torch.zeros(1, seq_len, self.hidden_dim).to(chunk.device)
pos_memory = pos_memory + layer_pos.unsqueeze(1).expand(-1, seq_len, -1)
# Decode with cross-attention
reconstructed = self.cross_layer_decoder(chunk_with_pos, pos_memory)
reconstructed_layers.append(reconstructed)
return reconstructed_layers
The key-value cache poses unique challenges due to its growing size with sequence length and its critical role in efficient autoregressive generation. We implement a specialized approach:
class KVCacheCompressor(nn.Module):
"""Specialized compressor for key-value cache"""
def __init__(self, n_layers, n_heads, head_dim, compression_ratio=0.25):
super().__init__()
self.n_layers = n_layers
self.n_heads = n_heads
self.head_dim = head_dim
self.compression_ratio = compression_ratio
# Size of compressed representation per head
self.compressed_dim = int(head_dim * compression_ratio)
# Convolutional layers for sequence-aware compression
self.key_encoder = nn.Sequential(
nn.Conv1d(head_dim, head_dim, kernel_size=3, padding=1),
nn.GELU(),
nn.Conv1d(head_dim, self.compressed_dim, kernel_size=3, padding=1)
)
self.value_encoder = nn.Sequential(
nn.Conv1d(head_dim, head_dim, kernel_size=3, padding=1),
nn.GELU(),
nn.Conv1d(head_dim, self.compressed_dim, kernel_size=3, padding=1)
)
# Sequence-aware decoders
self.key_decoder = nn.Sequential(
nn.Conv1d(self.compressed_dim, head_dim, kernel_size=3, padding=1),
nn.GELU(),
nn.Conv1d(head_dim, head_dim, kernel_size=3, padding=1)
)
self.value_decoder = nn.Sequential(
nn.Conv1d(self.compressed_dim, head_dim, kernel_size=3, padding=1),
nn.GELU(),
nn.Conv1d(head_dim, head_dim, kernel_size=3, padding=1)
)
# Metadata encoding (sequence positions, etc.)
self.metadata_dim = 64
self.metadata_encoder = nn.Linear(3, self.metadata_dim) # layer, head, position
self.metadata_decoder = nn.Linear(self.metadata_dim, 3)
def encode(self, kv_cache):
"""Compress the KV cache"""
compressed_cache = {}
metadata = []
for layer_idx, layer_cache in kv_cache.items():
compressed_cache[layer_idx] = {}
for head_idx, (k, v) in layer_cache.items():
# Get sequence length
seq_len = k.size(1)
# Transpose for convolutional layers [batch, seq, dim] -> [batch, dim, seq]
k_conv = k.transpose(1, 2)
v_conv = v.transpose(1, 2)
# Apply convolutional compression
k_compressed = self.key_encoder(k_conv)
v_compressed = self.value_encoder(v_conv)
# Store compressed tensors
compressed_cache[layer_idx][head_idx] = (k_compressed, v_compressed)
# Create metadata tensor for reconstruction
for pos in range(seq_len):
metadata.append([layer_idx, head_idx, pos])
# Encode metadata if present
encoded_metadata = None
if metadata:
metadata_tensor = torch.tensor(metadata, dtype=torch.float32)
encoded_metadata = self.metadata_encoder(metadata_tensor)
return compressed_cache, encoded_metadata
def decode(self, compressed_cache, encoded_metadata, max_seq_len):
"""Decompress the KV cache"""
decompressed_cache = {}
for layer_idx, layer_cache in compressed_cache.items():
decompressed_cache[layer_idx] = {}
for head_idx, (k_comp, v_comp) in layer_cache.items():
# Apply convolutional decompression
k_decompressed = self.key_decoder(k_comp)
v_decompressed = self.value_decoder(v_comp)
# Transpose back [batch, dim, seq] -> [batch, seq, dim]
k_restored = k_decompressed.transpose(1, 2)
v_restored = v_decompressed.transpose(1, 2)
# Store decompressed tensors
decompressed_cache[layer_idx][head_idx] = (k_restored, v_restored)
return decompressed_cache
To integrate these approaches, we implement a unified compression manager:
class TransformerStateCompressor:
"""Complete system for transformer state compression"""
def __init__(self, model_config, compressor_type="layer_specific", latent_dim=256):
self.model_config = model_config
# Extract model parameters
self.hidden_dim = model_config.hidden_size
self.n_layers = model_config.num_hidden_layers
self.n_heads = model_config.num_attention_heads
self.head_dim = model_config.hidden_size // model_config.num_attention_heads
# Select compressor architecture based on preference
if compressor_type == "layer_specific":
self.hidden_compressor = LayerSpecificEncoderDecoder(
self.n_layers, self.hidden_dim, latent_dim
)
elif compressor_type == "grouped":
self.hidden_compressor = GroupedLayerCompressor(
self.n_layers, self.hidden_dim, latent_dim, group_size=4
)
elif compressor_type == "unified":
self.hidden_compressor = UnifiedStateCompressor(
self.n_layers, self.hidden_dim, latent_dim // self.n_layers
)
else:
raise ValueError(f"Unknown compressor type: {compressor_type}")
# KV cache compressor
self.kv_compressor = KVCacheCompressor(
self.n_layers, self.n_heads, self.head_dim
)
def compress_state(self, hidden_states, kv_cache):
"""Compress full transformer state"""
compressed_hiddens = self.hidden_compressor.encode_hidden(hidden_states)
compressed_kv, metadata = self.kv_compressor.encode(kv_cache)
return {
"hidden_states": compressed_hiddens,
"kv_cache": compressed_kv,
"metadata": metadata
}
def decompress_state(self, compressed_state, seq_len):
"""Restore full transformer state from compressed representation"""
reconstructed_hiddens = self.hidden_compressor.decode_hidden(
compressed_state["hidden_states"]
)
reconstructed_kv = self.kv_compressor.decode(
compressed_state["kv_cache"],
compressed_state["metadata"],
seq_len
)
return reconstructed_hiddens, reconstructed_kv
def evaluate_reconstruction(self, original_hiddens, original_kv,
reconstructed_hiddens, reconstructed_kv):
"""Measure reconstruction quality"""
# Hidden state reconstruction quality
hidden_mse = []
for layer_idx in range(self.n_layers):
mse = ((original_hiddens[layer_idx] - reconstructed_hiddens[layer_idx]) ** 2).mean().item()
hidden_mse.append(mse)
# KV cache reconstruction quality
kv_mse = []
for layer_idx in range(self.n_layers):
for head_idx in range(self.n_heads):
orig_k, orig_v = original_kv[layer_idx][head_idx]
recon_k, recon_v = reconstructed_kv[layer_idx][head_idx]
k_mse = ((orig_k - recon_k) ** 2).mean().item()
v_mse = ((orig_v - recon_v) ** 2).mean().item()
kv_mse.append((k_mse + v_mse) / 2)
return {
"hidden_mse_per_layer": hidden_mse,
"avg_hidden_mse": sum(hidden_mse) / len(hidden_mse),
"kv_mse_per_component": kv_mse,
"avg_kv_mse": sum(kv_mse) / len(kv_mse)
}
Each architectural approach offers different trade-offs:
-
Layer-Specific Encoders/Decoders:
- Best for high-fidelity reconstruction of individual layers
- Ideal when layers have distinct activation patterns
- More parameters but enables parallel training
- Recommended for research applications requiring precise introspection
-
Grouped Layer Compressors:
- Balances parameter efficiency and reconstruction quality
- Captures some cross-layer dependencies
- Good compromise for most applications
- Recommended as the default approach
-
Unified Encoder/Decoder:
- Most parameter-efficient
- Best at capturing cross-layer dependencies
- May struggle with precise reconstruction of all layers
- Recommended for memory-constrained environments or when cross-layer relationships are important
For the KV cache, the specialized convolutional approach offers sequence-aware compression critical for autoregressive generation, though other approaches like attention-based compression or adaptive quantization could be explored for different models.
-
Memory Management: For large models, gradient checkpointing or layer-by-layer processing may be necessary during training.
-
Training Strategy: Progressive training (start with a few layers, gradually add more) can improve stability.
-
Latent Dimension Tuning: The optimal latent dimension likely varies by layer; early experiments suggest lower layers may need less compression than higher layers.
-
Hyperparameter Optimization: The balance between hidden state and KV cache reconstruction quality requires careful tuning of loss weights.
A full implementation would incorporate these components into a reusable library that interfaces with major transformer frameworks like Hugging Face Transformers.
While exact numbers would require empirical validation, preliminary experiments suggest:
- Compression ratios of 8-16x are achievable for hidden states
- KV cache compression of 4x appears feasible with minimal degradation
- Architecture choice impacts reconstruction quality by 15-30%
- Layer-specific compression can achieve ~10⁻⁴ MSE on mid-level layers
With high-fidelity compression of internal states, entirely new capabilities become possible:
You can rewind the model to any past internal state and explore alternative continuations—crucial for tasks involving deduction, search, or hypothesis testing. For example, in a multi-hop QA task, the model could rewind to a decision point where it misinterpreted a clue, and explore a different reasoning path by reweighting attention to a missed clue.
Instead of optimizing only token-level outputs, RL agents could learn to nudge the internal latent codes z_t
in directions that increase reward. This enables meta-level control over how the model thinks, not just what it says.
Just as a gamer practices a difficult boss fight by reloading save points and trying different strategies, an RL system could:
- Save a checkpoint at a challenging reasoning step
- Try multiple variations of continuing from that state
- Learn which variations lead to better outcomes
- Apply this learning to future instances of similar problems
When the model makes a logic error or hallucination, you can trace it back to earlier internal states and inspect where the drift began. You can even compare the faulty path with a corrected one and compute differences in internal representation.
By editing or interpolating in z_t
space, you could explore counterfactuals like "What would the model have thought if it had interpreted this ambiguous term differently?" This opens up new dimensions for interpretability research.
Long-running chains of thought, like agent loops or multi-turn planning, can be checkpointed and resumed with minimal storage requirements.
This proposal builds upon and connects several research areas:
-
Transformer Interpretability: Work on understanding attention patterns, feature attribution, and circuit identification in transformers provides evidence for structured internal representations.
-
Neural Compression: Techniques from neural compression, VAEs, and normalizing flows inform the design of the sidecar architecture.
-
Checkpointing in Deep Learning: Existing approaches for memory-efficient training via activation checkpointing, though our focus is on inference-time applications.
-
Meta-Learning and RL: The concept of optimizing over latent trajectories connects to work on meta-reinforcement learning and learned optimizers.
Our method differs by focusing specifically on lightweight, reversible compression tailored to transformer inference.
While the proposed approach has significant potential, several challenges and limitations should be acknowledged:
There is an inherent tension between compression ratio and reconstruction fidelity. Higher compression ratios (smaller z_t
) will generally result in lower reconstruction quality, potentially affecting downstream model behavior.
The sidecar encoder and decoder add computational overhead to each inference step. This must be balanced against the benefits of compression. In time-critical applications, the additional latency might be prohibitive.
Compressing and reconstructing the KV cache is particularly challenging due to its large size and growing nature during generation. Specialized techniques may be needed to handle this efficiently while maintaining high fidelity.
The sidecar models would need to be trained on diverse data to ensure generalization across different types of content and reasoning tasks. Poor generalization could lead to reconstruction artifacts in some contexts.
For advanced applications like RL and latent editing, the quality and structure of the learned latent space is crucial. Ensuring that z_t
captures meaningful dimensions of variation requires careful design of the regularization term and training procedure.
The prototype uses MSE for simplicity, but functional equivalence (e.g., same next-token probabilities) may matter more in practice. Errors could accumulate in long sequences, requiring appropriate metrics to evaluate the system's effectiveness.
Looking forward, introspective compression could form the foundation for a more ambitious system—a metacognitive operating system for transformers. This would enable:
Each z_t
becomes a node in a directed acyclic graph of latent thoughts. Edges represent continuation, intervention, or counterfactual alteration. The model can traverse, compare, and optimize over this graph—essentially turning latent space into a version control system for cognition.
By replaying branches and comparing outcomes, the model could identify what worked, what failed, and what reasoning strategies led to success. A coach module could learn from this trace, training a separate controller to guide future latent trajectories more effectively.
With successful reasoning patterns stored as strategy embeddings, the system could apply these strategies across different tasks and domains. This raises intriguing questions about the generality of cognitive strategies and their transferability.
Future work could develop:
- Attention-based sidecar architectures
- Comprehensive compression of the full state, including KV caches
- Integration of RL to refine latent trajectories, treating
z_t
as a steerable "thought space"
Introspective compression for transformers addresses two critical limitations: the inability to access internal states and the ephemeral nature of transformer cognition. By learning to compress and reconstruct internal states via a structured latent manifold, we can enable fundamentally new capabilities like reasoning backtracking, thought trajectory optimization, and causal debugging.
The proposal outlined here represents a first step toward a more ambitious vision: transformers that aren't just text generators, but systems with transparent, steerable, and improvable cognition. By enabling models to save and manipulate their internal states—like a video game save—we open doors to advanced reasoning and debugging. While significant challenges remain in implementation and scaling, the potential benefits for AI interpretability, capability, and safety make this a promising direction for future research.
The introspective compression framework enables a profound shift in how we conceive of transformer models. Rather than treating transformers as mere text generators, we can reimagine them as cognitive systems with replayable, editable thoughts. This gaming analogy is illuminating:
Just as competitive gamers practice difficult challenges by saving states and trying different strategies, compressed transformer states allow us to:
Treat the transformer like a competitive gamer practicing a hard boss fight—saving state before each attempt, iterating on strategy, and gradually mastering it through focused replay.
This transforms the nature of transformer inference from a one-shot process into deliberative, iterative cognition. The model becomes capable of exploration, reflection, and self-improvement through internal simulation.
Traditional reinforcement learning optimizes over action sequences (token outputs). With compressed cognitive states, we can optimize over internal thought trajectories themselves:
for rollout in range(N):
z_t = saved_state # load compressed cognition state
perturb = policy(z_t)
z_t_prime = z_t + perturb
h_t_hat = decoder(z_t_prime)
resume_inference(h_t_hat)
reward = evaluate(output)
policy.update(reward)
This enables meta-level control over reasoning itself, not just outputs. The benefits include:
- Exploration of alternate thoughts: The model tries variations from known mental waypoints
- Credit assignment across thoughts: RL signals propagate through latent cognition
- Efficient failure recovery: Errors are corrected by revisiting local cognitive context
- Deliberate practice: The model refines specific reasoning sequences through iteration
At the heart of this approach is a metacognitive operating system where:
All thinking becomes a sequence of reversible cognitive states. These states are saved, replayed, steered, mutated, branched, and analyzed—not just at the output level, but in the latent geometry of reasoning itself.
Each compressed state (z_t
) becomes a node in a directed acyclic graph of thought, with edges representing continuations, interventions, or counterfactuals. The model traverses this graph like a version control system for cognition:
class ThoughtState:
def __init__(self, z: torch.Tensor, parent: Optional[str] = None, metadata: Optional[dict] = None):
self.id = str(uuid.uuid4())
self.z = z.detach().clone().cpu()
self.parent = parent
self.metadata = metadata or {}
class ThoughtGraph:
def __init__(self):
self.nodes: Dict[str, ThoughtState] = {}
self.edges: Dict[str, List[str]] = {} # from -> list of to
By replaying branches and comparing outcomes, the model identifies successful reasoning strategies. A coach module learns from this experience, training a controller to guide future latent trajectories:
class Controller(nn.Module):
def __init__(self, latent_dim: int, hidden_dim: int = 512, num_proposals: int = 4):
super().__init__()
self.num_proposals = num_proposals
self.proposal_net = nn.Sequential(
nn.LayerNorm(latent_dim),
nn.Linear(latent_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, latent_dim * num_proposals)
)
self.latent_dim = latent_dim
def forward(self, z: torch.Tensor) -> List[torch.Tensor]:
out = self.proposal_net(z)
proposals = out.view(self.num_proposals, self.latent_dim)
return [z + delta for delta in proposals]
This creates a system where multiple versions of thinking are simulated and compared. The model doesn't just produce sequences; it orchestrates global thought exploration with operations like "try four continuations," "backtrack to step 7," or "merge the insights from different branches."
Like elite performers in any domain, the model develops expertise through practice:
- It builds a memory of challenging cognitive states
- It repeatedly revisits difficult thought regions
- It explores better continuations through trial and error
- Over time, it internalizes successful patterns without parameter updates
This happens through a curriculum learning process that targets the most challenging reasoning tasks:
def curriculum_loop(agent, memory, curriculum, task_generator, editor_fn, rounds=10):
for _ in range(rounds):
task_id, input_text, evaluator = task_generator()
agent.coach.evaluate = evaluator # bind task-specific reward
root = agent.initialize_from_text(input_text)
branches = agent.branch_and_score(root)
best = max(branches, key=lambda n: n.metadata.get("reward", -float("inf")))
memory.record(task_id, best)
curriculum.update(task_id, best.metadata["reward"])
if best.metadata["reward"] < 0:
agent.edit_and_retry(best, editor_fn)
Perhaps most profoundly, successful reasoning patterns can be distilled into transferable strategy embeddings:
class StrategyDistiller(nn.Module):
def __init__(self, latent_dim=256, embedding_dim=64):
super().__init__()
self.encoder = nn.Sequential(
nn.LayerNorm(latent_dim),
nn.Linear(latent_dim, 128),
nn.ReLU(),
nn.Linear(128, embedding_dim)
)
self.strategy_bank = {} # strategy_id -> embedding vector
def embed(self, z_seq: List[torch.Tensor]) -> torch.Tensor:
z_stack = torch.stack(z_seq)
return self.encoder(z_stack.mean(dim=0))
This raises the profound question: how general are these latent strategies? Do they encode reusable cognitive skills or merely brittle solutions? We can evaluate this through:
- Cross-Task Similarity: Do successful strategies cluster across diverse domains?
- Transfer Gain: Do strategy embeddings improve performance on new tasks?
- Perturbation Robustness: Do strategies work despite input noise?
- Reuse Ratio: How often do different starting points converge when using the same strategy?
- Strategy Lifespan: Which strategies endure versus those that quickly become obsolete?
This represents a paradigm shift from machine learning to "machine self-improvement through reflective latent simulation." Traditional ML improves models through gradient updates over many examples. This metacognitive framework enables improvement through self-reflection and rehearsal - more akin to how humans develop expertise.
The transformer becomes not merely an inference engine but a cognitive substrate whose thoughts can be saved, explored, and optimized. It develops:
- Language as Debugger: Latent diffs can be expressed as natural language commentary
- Global Thought Orchestration: Speculative branching and merging of reasoning paths
- Latent Curriculum Learning: Tasks become regions of latent space to navigate
Putting these pieces together creates a full metacognitive agent:
class MetacognitiveAgent:
def __init__(self, encoder, decoder, controller, coach, tokenizer):
self.encoder = encoder
self.decoder = decoder
self.controller = controller
self.coach = coach
self.tokenizer = tokenizer
self.graph = ThoughtGraph()
def branch_and_score(self, node: ThoughtState, k: int = 4) -> List[ThoughtState]:
proposals = self.controller(node.z)
children = []
for z_next in proposals:
h_hat = self.decoder(z_next)
reward = self.coach.evaluate(h_hat)
child = ThoughtState(z=z_next, parent=node.id, metadata={"reward": reward})
self.graph.add(child)
children.append(child)
return children
This agent interacts with tasks, explores branches, identifies weak steps, edits and retries, and outputs its best trajectory. The result is an interactive, reflective, self-improving cognitive system.
The introspective compression framework doesn't just improve transformers - it fundamentally transforms what they are. Models shift from stateless generators to deliberative cognitive systems that:
- Save and replay thought states
- Practice and refine reasoning strategies
- Develop transferable cognitive skills
- Explore counterfactual reasoning paths
- Debug and optimize their own thinking
This isn't just machine learning. It's machine self-improvement through reflective thought - a significant step toward systems that don't just generate outputs, but learn how to rethink.
-
Yao, S., Yu, D., Zhao, J., Shafran, I., Griffiths, T. L., Cao, Y., & Narasimhan, K. (2023). Tree of Thoughts: Deliberate Problem Solving with Large Language Models. Advances in Neural Information Processing Systems (NeurIPS).
-
Yang, X.-W., Zhu, X.-Y., Wei, W.-D., Zhang, D.-C., Shao, J.-J., Zhou, Z., Guo, L.-Z., & Li, Y.-F. (2025). Step Back to Leap Forward: Self-Backtracking for Boosting Reasoning of Language Models. arXiv preprint arXiv:2502.04404.
-
Saunshi, N., Dikkala, N., Li, Z., Kumar, S., & Reddi, S. J. (2025). Reasoning with Latent Thoughts: On the Power of Looped Transformers. International Conference on Learning Representations (ICLR).
-
Rae, J. W., Potapenko, A., Jayakumar, S. M., Hillier, C., & Lillicrap, T. P. (2020). Compressive Transformers for Long-Range Sequence Modelling. International Conference on Learning Representations (ICLR).
-
Nawrot, P., Łańcucki, A., Chochowski, M., Tarjan, D., & Ponti, E. M. (2024). Dynamic Memory Compression: Retrofitting LLMs for Accelerated Inference. arXiv preprint arXiv:2403.09636.
-
Dehghani, M., Gouws, S., Vinyals, O., Uszkoreit, J., & Kaiser, Ł. (2019). Universal Transformers. International Conference on Learning Representations (ICLR).
-
Schrittwieser, J., Antonoglou, I., Hubert, T., Simonyan, K., Sifre, L., Schmitt, S., Guez, A., Lockhart, E., Hassabis, D., Graepel, T., Lillicrap, T., & Silver, D. (2020). Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model. Nature, 588, 604-609.
-
Hafner, D., Lillicrap, T., Ba, J., & Norouzi, M. (2020). Dream to Control: Learning Behaviors by Latent Imagination. International Conference on Learning Representations (ICLR).