TE
TechEcho
Home24h TopNewestBestAskShowJobs
GitHubTwitter
Home

TechEcho

A tech news platform built with Next.js, providing global tech news and discussions.

GitHubTwitter

Home

HomeNewestBestAskShowJobs

Resources

HackerNews APIOriginal HackerNewsNext.js

© 2025 TechEcho. All rights reserved.

JAX: Numpy with Gradients, GPUs and TPUs

132 pointsby one-more-minuteover 6 years ago

8 comments

chrisaycockover 6 years ago
For anyone else outside of machine learning who was wondering what all of this is, here is my best explanation:<p>The inferencing phase of a neural network attempts to minimize <i>error</i> or <i>loss</i> as defined by the user. This is done by iteratively applying gradient descent to the error function. Thus, the error function must have a known derivative, which can be difficult if the function has loops and conditionals.<p>Autograd is a software package that produces derivatives of a function automatically. It creates a computation graph of the user-defined code from which it can determine the derivative:<p><a href="https:&#x2F;&#x2F;en.wikipedia.org&#x2F;wiki&#x2F;Automatic_differentiation" rel="nofollow">https:&#x2F;&#x2F;en.wikipedia.org&#x2F;wiki&#x2F;Automatic_differentiation</a><p>XLA is a JIT from TensorFlow that compiles common array functions. The JAX project from this GitHub page brings JIT optimizations to Autograd&#x27;s automatic differentiation. That will speed-up the error function when inferring a neural network. Neat!<p>(I would be grateful for any corrections to my above explanation as I am not an expert in ML.)
评论 #18638097 未加载
评论 #18637900 未加载
评论 #18638909 未加载
duckworthdover 6 years ago
My $0.02.<p>I&#x27;ve been using JAX for a while now. A paper I&#x27;m an author on (<a href="https:&#x2F;&#x2F;arxiv.org&#x2F;abs&#x2F;1806.09597" rel="nofollow">https:&#x2F;&#x2F;arxiv.org&#x2F;abs&#x2F;1806.09597</a>, a follow-up to <a href="https:&#x2F;&#x2F;news.ycombinator.com&#x2F;item?id=18633215" rel="nofollow">https:&#x2F;&#x2F;news.ycombinator.com&#x2F;item?id=18633215</a>) resulted in an algorithm that required taking second-derivatives on a per-example basis. This is extremely difficult in TensorFlow, but with JAX it was a 2-liner. Even better, it&#x27;s _super_ fast, thanks to XLA&#x27;s compile-to-GPU and JAX&#x27;s auto-batching mechanics.<p>I highly recommend JAX to power users. It&#x27;s nowhere near as feature-complete from a neural network sense as, say, PyTorch, but it is very good at what it does, and its core developers are second to none in responsiveness.
评论 #18640896 未加载
评论 #18640810 未加载
buildbotover 6 years ago
I wonder how performance compares to cupy : <a href="https:&#x2F;&#x2F;cupy.chainer.org" rel="nofollow">https:&#x2F;&#x2F;cupy.chainer.org</a><p>Seems a little limited in terms of supported operations for autograd compared to Chainer:<a href="https:&#x2F;&#x2F;chainer.org" rel="nofollow">https:&#x2F;&#x2F;chainer.org</a> or Flux:<a href="http:&#x2F;&#x2F;fluxml.ai" rel="nofollow">http:&#x2F;&#x2F;fluxml.ai</a><p>Really cool to see though! XLA&#x2F;TPU support is awesome.
评论 #18637630 未加载
snackematicianover 6 years ago
Very cool! I love autograd, it had tape-based autodiff way before pytorch, and the way it wraps numpy is much more convenient than tensorflow&#x2F;pytorch. Been wanting GPU support in autograd for years now, so am very happy to see this.<p>I have some academic software (<a href="https:&#x2F;&#x2F;github.com&#x2F;popgenmethods&#x2F;momi2" rel="nofollow">https:&#x2F;&#x2F;github.com&#x2F;popgenmethods&#x2F;momi2</a>) that uses autograd, was planning to port it to pytorch since it&#x27;s better supported&#x2F;maintained, but now I&#x27;ll have to consider jax. Though I&#x27;m a little worried about the maturity of the project, seems like the numpy&#x2F;scipy coverage is not all the way there yet. Then again, it would be fun to contribute back to JAX, I did contribute a couple PRs to autograd back in the day so I think I could jump right into it...
whoisnnamdiover 6 years ago
This is very cool, and I can see all sorts of use cases where a tool like this could be valuable.<p>Definitely will try to keep in mind that a tool for fast differentiation of arbitrary functions exists out in the world when starting my next project<p>Thanks for posting!
hyperpalliumover 6 years ago
&gt; JAX can automatically differentiate native Python ... functions<p>sympy can differentiate functions, but they have to be set up properly. How can JAX differentiate native functions?<p>(Or do they mean numerical differentiation, like a finite difference estimation?)
评论 #18639157 未加载
评论 #18638074 未加载
lostmsuover 6 years ago
What does it do, if the function is not differentiable? Most of the commonly seen functions in real code aren&#x27;t.
p1eskover 6 years ago
How does it compare to CuPy (in terms of Numpy compatibility)?