The quantization approach is basically identical to the 1.58bit LLM paper:<p><a href="https://arxiv.org/abs/2402.17764" rel="nofollow">https://arxiv.org/abs/2402.17764</a><p>The main addition of the new paper seems to be the implementation of optimized and fused kernels using triton, as seen here:<p><a href="https://github.com/ridgerchu/matmulfreellm/blob/master/mmfreelm/ops/fusedbitnet.py">https://github.com/ridgerchu/matmulfreellm/blob/master/mmfre...</a><p>This is quite useful, as this should make training this type of LLMs much more efficient.<p>So this is a ternary weight LLM using quantization aware training (QAT). The activations are quantized to 8 bits. The matmal is still there, but it is multiplying the 8 bit activations by one bit values.<p>Quantization aware training with low bit weights seems to lead to reduced overfitting by an intrensic tendency to regularize. However, also the model capacity should be reduced compared to a model with the same number of weights and a higher number of bits per weights. It's quite possible that this only becomes apparent after the models have been trained with a significant number of tokens, as LLMs seem to be quite sparse.<p>Edit: In addition to the QAT they also changed the model architecture to use a linear transformer to reduce reliance on multiplications in the attention mechanism. Thanks to logicchains for pointing this out.
Wow - This seems at first read to be really impressive work. They got scaling laws up to a reasonable size, 2.7B, and also run a few downstream tasks. Would be interesting to see how a comparable model trained by someone else does, to check their scores against those.<p>They get real (61%!?) memory savings during training, and inference too.<p>On top of all that, they then go build an FPGA core which is programmed with a custom assembler. And their code is posted and works seamlessly with huggingface transformers?! Absolutely going to test this out.
There was another matmul-free language model paper released a year ago FYI:<p><a href="https://arxiv.org/abs/2305.17190" rel="nofollow">https://arxiv.org/abs/2305.17190</a>
I feel like all of these transformer reductions to binary or ternary bits are basically constructing an implicit decision tree, where any stage of the process is basically answering a question with yes/no/I don't know answers, where "I don't know" basically invokes a continuation for further processing with more context.
Not sure if it's fair to call binary multiplication "multiplication free", you can express any multiplication as a sequence of additions/subtractions.
the github link in the paper: <a href="https://github.com/ridgerchu/matmulfreellm">https://github.com/ridgerchu/matmulfreellm</a><p>it is super easy to try it out, the 2.7B, 1.3B, 0.37B models are on huggingface, and the generate.py example just works if you have triton 2.2 installed
One thing I didn’t figure out from just the paper: how does one train these parameters that are not even approximately real numbers? Specifically, most of the parameters are ternary (i.e. -1, 0, or 1). The approximate gradient discussed in the paper will (I think) give some <i>real</i> gradient on each parameter, and that can be further processed by the learning rate schedule, but the result is still a real number g_i for each parameter a_i. Normally one would update a_i to a_i + g_i, but with these ternary parameters, a_i + g_i isn’t ternary!<p>So what’s the extra trick to make the model stay quantized? Does one evaluate the gradients on a whole bunch of training inputs, add them up, apply some randomness, and then re-quantize the model? Or is it something else?
Reminds me of ghotz's interview: <a href="https://youtu.be/wE1ZoMGIZHM" rel="nofollow">https://youtu.be/wE1ZoMGIZHM</a>