> our linear transformers are somewhat useless, as the positive impact from the speedup seen in long contexts is undermined by the negative impact of degraded learning.<p>> In a future post, we will explain how to improve the learning of linear transformers<p>So the techniques here are useless without special secret sauce that they're not disclosing. Yet. Mamba is already out there solving similar problems, but the more the merrier. I hope they publish the useful part soon.
This is not a new algorithm. The same algorithm is described in Figure 4 (Theorem 3.1) of <a href="https://arxiv.org/pdf/2310.01655.pdf" rel="nofollow">https://arxiv.org/pdf/2310.01655.pdf</a><p>(Disclaimer: I am an author on the linked paper)
I don't understand something, why do they claim they go from O(N*N) to O(N), but all they claim they are doing is removing one exponentiation operation, which is O(1)? Where is the O(N) they are removing?
To be honest this makes me less excited about linear transformers.<p>If even heavily optimized, they are still (nearly) no better than normal flash attention up to context length 10^4.<p>And then you haven't even started to account for the degradation in learning.<p>Maybe if you're doing 100k attention at inference it starts making sense... But then there are other methods you can start using too.
Great writeup and interesting experiments. I can’t help but wonder what would happen in you instead made a rectified linear attention. Is that even possible?