Another noob Question: So a 50% size reduction in BERT? let's see if I am getting these numbers right. At inference time you need a fraction of the neurons in the FF layer to do the inference based on the input data and the previous dot product. Here some quick math for BERT-Base which has 110M params according to the original paper:<p>----<p><pre><code> L (Number of Layers): 12 transformer blocks.
H (Hidden Size): 768 units in the hidden layers.
A (Number of Attention Heads): 12 attention heads.
</code></pre>
Embedding Layers:<p><pre><code> WordPiece Embeddings: 768 (hidden size) * 30,522 (vocab size) = 23,440,896 parameters.
Positional Embeddings: 768 * 512 (max sequence length) = 393,216 parameters.
Segment Embeddings: 768 * 2 (number of segments) = 1,536 parameters.
Total Embedding Parameters: 23,440,896 + 393,216 + 1,536 = 23,835,648 parameters.
</code></pre>
Transformer Blocks:<p><pre><code> Each transformer block has the following components:
Self-Attention Layer: Each attention head has 768 / 12 = 64 units.
Query (Q), Key (K), Value (V) matrices: 3 * (64 * 768) = 147,456 parameters per head.
Across 12 heads: 147,456 * 12 = 1,769,472 parameters.
Output layer of the attention mechanism: 768 * 768 = 589,824 parameters.
Feed-Forward Network (FFN):
First layer: 768 (input) * 3,072 (intermediate size) = 2,359,296 parameters.
Second layer: 3,072 * 768 = 2,359,296 parameters.
Total FFN parameters per block: 2,359,296 + 2,359,296 = 4,718,592 parameters. -----------------> *This is the number to keep in mind.*
Total Parameters per Block: 1,769,472 (self-attention) + 589,824 (output) + 4,718,592 (FFN) = 7,077,888 parameters.
Total for 12 Blocks: 7,077,888 * 12 = 84,934,656 parameters.
Layer Norm and Other Parameters:
Each transformer block also includes layer normalization and other small components, which add a relatively small number of parameters.
</code></pre>
Total Parameters:<p><pre><code> Embeddings: 23,835,648
Transformer Blocks: 84,934,656
Layer Norm and Others: A small number, completing the total to around 110 million.</code></pre>
--------------------------------------<p>4.718M FF Params per block * 12 ~ 56.6 Million/110M Params which is a staggering ~50% reduction in size at inference time if you use 0.3% of the FF neurons for FFF??