I strongly enjoy the simplicity of their "minGRU" architecture. It's basically just:<p><pre><code> class MinGRU(nn.Module):
def __init__(self, token_size, hidden_state_size):
self.token_to_proposal = nn.Linear(token_size, hidden_size)
self.token_to_mix_factors = nn.Linear(token_size, hidden_size)
def forward(self, previous_hidden_state, current_token):
proposed_hidden_state = self.token_to_proposal(current_token)
mix_factors = torch.sigmoid(self.token_to_mix_factors(current_token))
return torch.lerp(proposed_hidden_state, previous_hidden_state, mix_factors)
</code></pre>
And since the proposed hidden states and mix factors for each layer are both only dependent on the current token, you can compute all of them in parallel if you know the whole sequence ahead of time (like during training), and then combine them in linear time using parallel scan.<p>The fact that this is competitive with transformers and state-space models in their small-scale experiments is gratifying to the "best PRs are the ones that delete code" side of me. That said, we won't know for sure if this is a capital-B Breakthrough until someone tries scaling it up to parameter and data counts comparable to SOTA models.<p>One detail I found really interesting is that they seem to do all their calculations in log-space, according to the Appendix. They say it's for numerical stability, which is curious to me—I'm not sure I have a good intuition for why running everything in log-space makes the model more stable. Is it because they removed the tanh from the output, making it possible for values to explode if calculations are done in linear space?<p>EDIT: Another thought—it's kind of fascinating that this sort of sequence modeling works at all. It's like if I gave you all the pages of a book individually torn out and in a random order, and asked you to try to make a vector representation for each page as well as instructions for how to mix that vector with the vector representing all previous pages — except you have zero knowledge of those previous pages. Then, I take all your page vectors, sequentially mix them together in-order, and grade you based on how good of a whole-book summary the final vector represents. Wild stuff.<p>FURTHER EDIT: Yet <i>another</i> thought—right now, they're just using two dense linear layers to transform the token into the proposed hidden state and the lerp mix factors. I'm curious what would happen if you made those transforms MLPs instead of singular linear layers.