TE
TechEcho
Home24h TopNewestBestAskShowJobs
GitHubTwitter
Home

TechEcho

A tech news platform built with Next.js, providing global tech news and discussions.

GitHubTwitter

Home

HomeNewestBestAskShowJobs

Resources

HackerNews APIOriginal HackerNewsNext.js

© 2025 TechEcho. All rights reserved.

4000x Speedup in Reinforcement Learning with Jax

131 pointsby _harkabout 2 years ago

8 comments

sillysaurusxabout 2 years ago
It&#x27;s a little disingenuous to say that the 4000x speedup is due to Jax. I&#x27;m a huge Jax fanboy (one of the biggest) but the speedup here is thanks to running the simulation environment on a GPU. But as much as I love Jax, it&#x27;s still extraordinarily difficult to implement even simple environments purely on a GPU.<p>My long-term ambition is to replicate OpenAI&#x27;s Dota 2 reinforcement learning work, since it&#x27;s one of the most impactful (or at least most entertaining) use of RL. It would be more or less impossible to translate the game logic into Jax, short of transpiling C++ to Jax somehow. Which isn&#x27;t a bad idea – someone should make that.<p>It should also be noted that there&#x27;s a long history of RL being done on accelerators. AlphaZero&#x27;s chess evaluations ran entirely on TPUs. Pytorch CUDA graphs also make it easier to implement this kind of thing nowadays, since (again, as much as I love Jax) some Pytorch constructs are simply easier to use than turning everything into a functional programming paradigm.<p>All that said, you should really try out Jax. The fact that you can calculate gradients w.r.t. any arbitrary function is just amazing, and you have complete control over what&#x27;s JIT&#x27;ed into a GPU graph and what&#x27;s not. It&#x27;s a wonderful feeling compared to using Pytorch&#x27;s accursed .backwards() accumulation scheme.<p>Can&#x27;t wait for a framework that feels closer to pure arbitrary Python. Maybe AI can figure out how to do it.
评论 #35494267 未加载
评论 #35479831 未加载
percentcerabout 2 years ago
Reminds me of this evergreen tweet from ryg: <a href="https:&#x2F;&#x2F;mobile.twitter.com&#x2F;rygorous&#x2F;status&#x2F;1271296834439282690" rel="nofollow">https:&#x2F;&#x2F;mobile.twitter.com&#x2F;rygorous&#x2F;status&#x2F;12712968344392826...</a><p><pre><code> if you made something 2x faster, you might have done something smart if you made something 100x faster, you definitely just stopped doing something stupid</code></pre>
评论 #35477989 未加载
评论 #35475956 未加载
评论 #35478071 未加载
评论 #35478495 未加载
ssivarkabout 2 years ago
From what I understand of Jax, it feels somewhat similar in flavor to Julia, but trying to live with the language constraints (and ecosystem benefits) of Python.<p>I wonder how Julia is placed for running reinforcement learning algorithms (efficiently) — particularly in cases when the “environment” is nicely wrapped in Python to fit some standardized interface.
评论 #35477485 未加载
ipsum2about 2 years ago
Strange that the author claims Jax&#x27;s vmap is what&#x27;s doing the heavy lifting, but doesn&#x27;t use PyTorch vmap to make the benchmark comparable.
schizo89about 2 years ago
Neural differential equations are also easier with jax. sim2real may be easier with simulator where some of hard computations are replaced with neural approximations
xyzzy4747about 2 years ago
How does this compare with PyTorch &#x2F; Tensorflow &#x2F; etc.? Obviously doing heavy data processing on the GPU will have a large speedup compared to a single thread on the CPU.<p>It&#x27;s almost like the author is claiming credit for creating Nvidia, when in fact he is just calling its APIs.
评论 #35494376 未加载
评论 #35477698 未加载
nothrowawaysabout 2 years ago
It is misleading, the speedup is not just because it is Jax. The devil is in the GPU
评论 #35494283 未加载
_harkabout 2 years ago
jax.vmap() is all you need?
评论 #35475377 未加载
评论 #35476782 未加载