Understanding Protein Language Models Part I: Multiple Sequence Alignment in AlphaFold2
The fundamental concepts behind ESM2, ESM3, and AlphaFold
Note: This post is part of the “Understanding Protein Language Models” Series:
[This article] Understanding Protein Language Models Part I: Multiple Sequence Alignment in AlphaFold2
Over the past couple of months, I’ve been on a journey to understand protein language models - what exactly are they learning and how do they work? I began that journey by trying to understand the inner working of AlphaFold2, given that it represented the first leap forward for AI in biology.
In particular, protein language models sought to replace the multiple sequence alignment (MSA) component of AlphaFold2 in order to achieve similar performance at much lower computational costs. So, in order to better understand what protein language models are doing, it is first important to understand what they replaced. To that end, in these notes I dive into the role of multiple sequence alignments in AlphaFold2 and how it drives the model’s ability to infer contacts between residues in the protein. I hope you find these notes useful!
What is MSA?
Multiple sequence alignment (MSA) refers to the process or the result of sequence alignment of three or more biological sequences, generally protein, DNA, or RNA. In many cases, the input set of query sequences is assumed to have an evolutionary relationship by which they share a linkage and are descended from a common ancestor. From the resulting MSA, sequence homology can be inferred, and phylogenetic analysis can be conducted to assess the sequences' shared evolutionary origins.
Visual depictions of the alignment (as in the image below) illustrate mutation events such as point mutations (single amino acid or nucleotide changes) that appear as differing characters in a single alignment column, and insertion or deletion mutations (indels or gaps) that appear as hyphens in one or more of the sequences in the alignment. Multiple sequence alignment is often used to assess sequence conservation of protein domains, tertiary and secondary structures, and even individual amino acids or nucleotides.
Thus, by including an MSA as input, AlphaFold2 is able to infer information about the target sequence by assessing its shared evolutionary history with a number of other sequences. This is a powerful concept - it gives the model a strong “starting point” to make predictions for a new sequence.
The MSA databases used by AlphaFold2 to identify evolutionarily similar sequences to the target sequence are:
- MGnify
- UniRef90
- Uniclust30
- BFD
AlphaFold2 Architecture
Now that we know what MSA is and why it is used, let’s sketch out the high-level architectural details of AlphaFold2. In the above image, we can see the various inputs & components that comprise the model. AlphaFold2 takes in three key inputs:
The input sequence itself
An MSA using the input sequence as its starting point
Template structure related to the input sequence
These three inputs are then distilled into two by using the template structures and input sequence to initialize a pair representation matrix. The pair representation matrix can be thought of as scores for “similarity” or “interaction” between each pair of amino acids i and j in the input sequence.
By contrast, the MSA representation can be thought of as storing a vector representation of each amino acid for each protein in the alignment. If we imagine the matrix as a 2D grid, each row represents a protein and each column represents a position in the aligned amino acid sequence (e.g., amino acid #5 in the sequence). In each cell of this matrix, we can imagine a vector that represents the specified amino acid. In reality, this is a tensor of shape (number of sequences, number of residues, channels).
These two inputs then flow through the Evoformer block, which generates improved representations of the MSA and pair representation matrices for structure prediction. The journey for the full MSA matrix ends here, as we extract the representation for our input sequence from the first row of the MSA matrix and send it forward to the structure module.
Given that the processing for the MSA matrix takes place in the Evoformer block, we’ll dive in a bit deeper there.
Evoformer Block
The Evoformer block begins with components for processing the MSA representation:
Row-wise gated self-attention
Column-wise gated self-attention
Transition
Following these three blocks, the MSA representation matrix is integrated into the pair representation matrix through the outer product mean block and resulting sum.
We’ll now dive deep into the three core components of the Evoformer block (alongside the outer product mean integration) to better understand how the MSA matrix is being updated.
MSA Row-wise Gated Self-Attention
Row-wise attention builds attention weights for residue pairs within the same sequence and integrates the information from the pair representation as an additional bias term. The updated MSA representation matrix thus ensures that each sequence has a contextual representation for its residues - that is, for sequence k, the embedding of the residue at index i takes into account information from the residues at indices 1, …, i-1, i+1, …, r
MSA Column-wise Gated Self-Attention
Column-wise attention lets the elements that belong to the same target residue exchange information across sequences in the MSA. The updated MSA representation matrix thus ensures that each residue has a cross-sequence representation - that is, for the embedding for residue i in sequence k also takes into account information from residue i in sequences 1, …, k-1, k+1, …, s.
MSA Transition
After row-wise and column-wise attention, the MSA stack contains a 2-layer MLP as the transition layer. The intermediate number of channels expands the original number of channels by a factor of 4.
Integration of MSA with Pair Representation
The “Outer product mean” block transforms the MSA representation into an update for the pair representation. Intuitively, its job is to take everything we’ve learned about residues across the aligned sequences and distill it back down into a single signal about each residue pair (i, j) in the target sequence — a signal we can then inject into the pair representation so the structure module can reason about contacts.
The mechanics work as follows. We start from the MSA representation of shape (s, r, c), where s is the number of sequences, r is the number of residues, and c is the channel dimension. After a layer norm, we apply two separate linear projections to produce a “left” tensor A and a “right” tensor B, each of shape (s, r, c’), where c’ is a smaller hidden dimension. You can think of A as the representation of residue i “as a left-hand partner” and B as the representation of residue j “as a right-hand partner” — two different views of the same residue, each tuned for its role in the pairwise interaction.
Then the core operation: For every sequence k in the MSA and every pair of residue positions (i, j), we form the outer product of A[k, i] and B[k, j]. This outer product is a c’ × c’ matrix that captures every pairwise interaction between the features of residue i and the features of residue j, as seen in sequence k. Where a simple dot product would collapse everything down to a single number, the outer product preserves the full grid of feature-by-feature interactions — much richer information about how the two residues relate in this particular sequence.
We then sum these outer products over all sequences in the MSA and divide by the number of sequences (with proper masking to ignore padded rows). This gives us, for each pair (i, j), the average feature-by-feature interaction between residue i and residue j across the entire alignment. Intuitively: if residues i and j consistently co-vary across the MSA, the hallmark of co-evolution and likely physical contact, that pattern will show up strongly and consistently in these outer products, and survive the averaging. If they vary independently, the signals across sequences will wash each other out.
The resulting tensor has shape (r, r, c’, c’). We flatten the last two dimensions to get (r, r, c’·c’), then apply a final linear projection to map down to c_z, the channel dimension of the pair representation. This produces a tensor of shape (r, r, c_z) that can be added directly into the pair representation.
The end result is a clean handoff: the MSA stack uses row- and column-wise attention to figure out which residues behave similarly across evolutionary history, and the outer product mean then translates those cross-sequence patterns into a per-pair signal. This tells the pair representation, for every (i, j), how strongly the evolutionary record suggests these two residues are coupled. The structure module downstream uses this to guess which residues are in contact, which is what ultimately drives accurate structure prediction.
Conclusion - What is the MSA Representation Doing in AlphaFold?
So, putting this all together - the MSA steps compute a representation that optimally captures similarity of residues, both:
1. Within sequences by using row-wise attention to attend across amino acids inside a given sequence
2. Across sequences by using column-wise attention to attend across sequences for a given amino acid index
This representation is then used to generate a measure of similarity between all possible residue pairs in the MSA representation. We then update the pair representation of the target sequence by adding these values. In essence, we use the MSA to "find out" which residues are similar to which other residues, and then add this information to the pair representation so that the structure module can guess at which residues are in contact with one another (based on the fact that they co-evolve and are therefore similar in the MSA representation). This allows for highly accurate structure prediction, incorporating information from the evolutionary tree to infer the optimal folded structure of a given input protein.









