This is a really great overview of JAX - the best I've seen outside of primary Google/DeepMind sources. Glad to see people besides Googlers getting familiar with its capabilities.
> Why Should I Care About JAX?<p>> In short - speed.<p>For me personally, the magic of JAX is that it able to have this performance, while being as close as possible to having <i>first class differentiation</i> in Python. The latter is a far more important reason to use JAX. It can really change how you think about programming and ML. Rather than implementing a specific model, you can write up the parameterized solution to a problem then <i>solve it</i>.<p>However first class differentiation ultimately isn't really useful unless you happen to also solve the speed problem. That is what makes JAX incredible. From the programming perspective JAX is to differentiable programming what Prolog is to logic programming, however Prolog has always been limited ultimately by performance problems where JAX is not.
With the year so prominently in the title, i thought this was going to be about a technology that is obviously not trendy anymore, ie, JAX The Terrible Java XML Parser. I'm disappointed that it was just blog title spam about some non-controversial modern technology.
I wish JAX worked with windows natively (without using wsl). I teach a very high level intro to numpy and would _love_ to have my students try jax. These students are relatively new to programming and the idea of using a linux shell or having to compile anything themselves just wouldn't work.
These benchmarks are pretty crazy, especially as I presumed NumPy to do far better. Isn’t tensorflow etc already doing GPU acceleration? Is this actually a fair comparison?