It's a little disingenuous to say that the 4000x speedup is due to Jax. I'm a huge Jax fanboy (one of the biggest) but the speedup here is thanks to running the simulation environment on a GPU. But as much as I love Jax, it's still extraordinarily difficult to implement even simple environments purely on a GPU.<p>My long-term ambition is to replicate OpenAI's Dota 2 reinforcement learning work, since it's one of the most impactful (or at least most entertaining) use of RL. It would be more or less impossible to translate the game logic into Jax, short of transpiling C++ to Jax somehow. Which isn't a bad idea – someone should make that.<p>It should also be noted that there's a long history of RL being done on accelerators. AlphaZero's chess evaluations ran entirely on TPUs. Pytorch CUDA graphs also make it easier to implement this kind of thing nowadays, since (again, as much as I love Jax) some Pytorch constructs are simply easier to use than turning everything into a functional programming paradigm.<p>All that said, you should really try out Jax. The fact that you can calculate gradients w.r.t. any arbitrary function is just amazing, and you have complete control over what's JIT'ed into a GPU graph and what's not. It's a wonderful feeling compared to using Pytorch's accursed .backwards() accumulation scheme.<p>Can't wait for a framework that feels closer to pure arbitrary Python. Maybe AI can figure out how to do it.
Reminds me of this evergreen tweet from ryg: <a href="https://mobile.twitter.com/rygorous/status/1271296834439282690" rel="nofollow">https://mobile.twitter.com/rygorous/status/12712968344392826...</a><p><pre><code> if you made something 2x faster, you might have done something smart
if you made something 100x faster, you definitely just stopped doing something stupid</code></pre>
From what I understand of Jax, it feels somewhat similar in flavor to Julia, but trying to live with the language constraints (and ecosystem benefits) of Python.<p>I wonder how Julia is placed for running reinforcement learning algorithms (efficiently) — particularly in cases when the “environment” is nicely wrapped in Python to fit some standardized interface.
Strange that the author claims Jax's vmap is what's doing the heavy lifting, but doesn't use PyTorch vmap to make the benchmark comparable.
Neural differential equations are also easier with jax. sim2real may be easier with simulator where some of hard computations are replaced with neural approximations
How does this compare with PyTorch / Tensorflow / etc.? Obviously doing heavy data processing on the GPU will have a large speedup compared to a single thread on the CPU.<p>It's almost like the author is claiming credit for creating Nvidia, when in fact he is just calling its APIs.