Hi, one of the authors of this blog post (Horace He), along with Driss Guessous, Yanbo Liang, and Joy Dong.<p>We’re quite happy with this abstraction - happy to answer any questions about it!
It's interesting that optimizing a computation that can be described in a single line of math takes so much work. It took forever even to discover Flash attention. And in the 6 years since transformers were invented, thousands of papers worked on making it faster.<p>Attention(Q,K,V) = Softmax(Q*K^T/sqrt(d_k))*V<p>FlexAttention seems to have found the right abstraction for the task.
For most LLM workloads today (short text chats), hundreds or a couple thousand tokens suffice. attention mechanisms don’t dominate (< 30% compute). But as the modalities inevitably grow, work in attention approximation/compression is going to be paramount.<p>Nice to see Pytorch already elegantly supporting this next step in research
I didn't see any notice of this being CUDA only (like FlashAttention). I tried running on my Mac M3, python 3.11.8, following the quickstart (with the deviation of running it in a new venv). Got the following error:<p>/attention-gym/.venv/lib/python3.11/site-packages/torch/_subclasses/functional_tensor.py:258: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_numpy.cpp:84.)
cpu = _conversion_method_template(device=torch.device("cpu"))
Traceback (most recent call last):
File "/attention-gym/attn_gym/masks/document_mask.py", line 7, in <module>
from torch.nn.attention.flex_attention import _mask_mod_signature
ModuleNotFoundError: No module named 'torch.nn.attention.flex_attention'
> FlexAttention achieves 90% of FlashAttention2’s performance in the forward pass and 85% in the backward pass.<p>It's very good. But note FlashAttention-3 is 1.5x - 2x faster than FlashAttention-2.
Always had the curiosity to put something together with pytorch but it always seemed either a steep learning curve or there wasn't a big motivator (project, problem to solve, something in my daily routine to optimize).<p>Does anybody have a good starting point to learn with hands-on projects and also that could accommodate for flexattention?