As LLMs become larger, more capable, and more ubiquitous, the field of mechanistic interpretability -- that is, understanding the inner workings of these models -- becomes increasingly interesting and important. Similar to how software engineers benefit from having good mental models of file systems and networking, AI researchers and engineers should strive to have some theoretical basis for understanding the "intelligence" that emerges from LLMs. A strong mental model would improve our ability to harness the technology. In this post, I want to cover two fundamental and related concepts in the field (each with their own paper) that I find fascinating from a mathematical perspective: the linear representation hypothesis (Park et al.) and superposition (Anthropic).
The linear representation hypothesis (LRH) has existed for quite some time, ever since people noticed that the word embeddings produced by Word2Vec satisfied some interesting properties. If we let $E(x)$ be the embedding vector of a word, then you observe the approximate equivalence
$E(``\text{king"}) - E(``\text{man"}) + E(``\text{woman"}) \approx E(``\text{queen"})$.
Observations of this form suggest that concepts (i.e. gender in the example) are represented linearly in the geometry of the embedding space, which is a simple but non-obvious claim.
Fast forward to modern LLMs, and the LRH remains a popular way to interpret what is going on inside these models. The Park et al. paper presents a mathematical framing of the hypothesis to try and formalize the idea. It uses a simplified model of an LLM where most of the inner workings (multilayer perceptron, attention, etc) are treated as a black box, and the interpretation of the LRH happens in two separate representation spaces with the same dimensionality as the model:
- The "embedding space" where the final hidden states of the network live ($E(x)$ for an input context $x$). This is similar to the word embedding formulation and is where you would perform interventions that affect the model's behavior.
- The "unembedding space" where the rows of the unembedding matrix live ($U(y)$ for each output token $y$). The concept direction measured by a linear probe over the hidden state (to evaluate the presence of the concept) corresponds to a vector in this space.
There are analogous formulations of the LRH in the two respective spaces. Suppose $C$ represents the directional concept of gender, i.e. male => female. Then any pairs of input contexts that differ only in that concept should satisfy, e.g.
$E(``\text{Long live the queen"}) - E(``\text{Long live the king"}) = \alpha \cdot E_C$
where $\alpha \ge 0$ and $E_C$ is a constant vector in the embedding space referred to as the embedding representation. Similarly, any pairs of output tokens that differ only in that concept should satisfy, e.g.
$U(``\text{queen"}) - U(``\text{king"}) = \beta \cdot U_C$
where $\beta \ge 0$ and $U_C$ is a constant vector in the unembedding space referred to as the unembedding representation. Basically, applying the concept has a linear effect in both spaces.
The paper goes into much more detail that I'll skip over here, but they show that these representations are isomorphic, which unifies the intervention and linear probe ideas. They then empirically verify on Llama 2 that they can find embedding and unembedding representations for a variety of concepts (e.g. present => past tense, noun => plural, English => French) that approximately fit into their theoretical framework -- cool!
Okay, so let's assume concepts do in fact have linear representations. Then it would stand to reason that unrelated concepts have orthogonal directions. Otherwise, applying the male => female concept could influence the presence of the English => French concept, which doesn't make sense. One of the key results from Park et al. is that this orthogonality doesn't occur under the standard Euclidean inner product but instead under a "causal inner product" that is derived from the unembedding matrix. Only by looking at concept representations through that lens do we get the orthogonality we expect.
But in these models, the representation space is relatively small (most ranging from 2K to 16K dimensions). So how do these spaces "fit" such a large number of language features that far exceeds their dimensionality? It's impossible for all such features to be orthogonal, no matter the geometry.
This is where superposition comes into play. In low-dimensional spaces, the intuition is that, when you have $N$ vectors in a $d$-dimensional space with $N > d$, they start to interfere substantially (inner product has a large magnitude). This is one of those examples where low-dimensional intuition does not extend to higher dimensions, however, as evidenced by the Johnson-Lindenstrauss lemma. An implication of the lemma is that you can choose exponentially (in the number of dimensions) many vectors that are almost-orthogonal -- that is, the inner products between any pair are bounded by a small constant. You can think of this as the flip side of the curse of dimensionality.
The Anthropic paper demonstrates the superposition phenomenon in toy models on small, synthetic datasets. One particularly interesting observation is that superposition does not occur with no activation function (purely linear computation), but it does occur with a nonlinear one (ReLU in their case). The idea is that the nonlinearity allows the model to manage the interference in a productive way. But this still only works well because of the natural sparsity of these features in the data -- models learn to superimpose features that are unlikely to be simultaneously present.
In experimental setups where the features in the synthetic data are of equal importance and sparsity, they observe that the embedding vectors learned by the model form regular structures in the embedding space, e.g. a tetrahedron, pentagon, or square antiprism. Coincidentally, these are the same types of structures that I worked with in some old research I did on spherical codes. These structures emerged from using gradient descent-like algorithms to minimize the energy (analogous to that described by the Thomson problem) of arrangements of points on unit hyperspheres. Fun to see the overlap of multiple fields!
To conclude, features as linear representations, even if not the complete story, is a valuable framework to help us interpret and intervene in LLMs. It has a solid theoretical basis that is backed up empirically. Sparsity, superimposition, and the non-intuitive nature of higher-dimensional spaces give us a window into understanding how the complexity of language (and intelligence?) gets captured by these models.