Hey ya'll author here!<p>Thank you for all the nice and constructive comments!<p>For clarity, this is ONLY the forward pass of the model. There's no training code, batching, kv cache for efficiency, GPU support, etc ...<p>The goal here was to provide a simple yet complete technical introduction to the GPT as an educational tool. Tried to make the first two sections something any programmer can understand, but yeah, beyond that you're gonna need to know some deep learning.<p>Btw, I tried to make the implementation as hackable as possible. For example, if you change the import from `import numpy as np` to `import jax.numpy as np`, the code becomes end-to-end differentiable:<p><pre><code> def lm_loss(params, inputs, n_head) -> float:
x, y = inputs[:-1], inputs[1:]
output = gpt(x, **params, n_head=n_head)
loss = np.mean(-np.log(output[y]))
return loss
grads = jax.grad(lm_loss)(params, inputs, n_head)
</code></pre>
You can even support batching with `jax.vmap` (<a href="https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html" rel="nofollow">https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.h...</a>):<p><pre><code> gpt2_batched = jax.vmap(gpt2, in_axes=0)
gpt2_batched(batched_inputs) # [batch, seq_len] -> [batch, seq_len, vocab]
</code></pre>
Of course, with JAX comes in-built GPU and even TPU support!<p>As far as training code and KV Cache for inference efficiency, I leave that as an exercise for the reader lol