For anyone else outside of machine learning who was wondering what all of this is, here is my best explanation:<p>The inferencing phase of a neural network attempts to minimize <i>error</i> or <i>loss</i> as defined by the user. This is done by iteratively applying gradient descent to the error function. Thus, the error function must have a known derivative, which can be difficult if the function has loops and conditionals.<p>Autograd is a software package that produces derivatives of a function automatically. It creates a computation graph of the user-defined code from which it can determine the derivative:<p><a href="https://en.wikipedia.org/wiki/Automatic_differentiation" rel="nofollow">https://en.wikipedia.org/wiki/Automatic_differentiation</a><p>XLA is a JIT from TensorFlow that compiles common array functions. The JAX project from this GitHub page brings JIT optimizations to Autograd's automatic differentiation. That will speed-up the error function when inferring a neural network. Neat!<p>(I would be grateful for any corrections to my above explanation as I am not an expert in ML.)
My $0.02.<p>I've been using JAX for a while now. A paper I'm an author on (<a href="https://arxiv.org/abs/1806.09597" rel="nofollow">https://arxiv.org/abs/1806.09597</a>, a follow-up to <a href="https://news.ycombinator.com/item?id=18633215" rel="nofollow">https://news.ycombinator.com/item?id=18633215</a>) resulted in an algorithm that required taking second-derivatives on a per-example basis. This is extremely difficult in TensorFlow, but with JAX it was a 2-liner. Even better, it's _super_ fast, thanks to XLA's compile-to-GPU and JAX's auto-batching mechanics.<p>I highly recommend JAX to power users. It's nowhere near as feature-complete from a neural network sense as, say, PyTorch, but it is very good at what it does, and its core developers are second to none in responsiveness.
I wonder how performance compares to cupy : <a href="https://cupy.chainer.org" rel="nofollow">https://cupy.chainer.org</a><p>Seems a little limited in terms of supported operations for autograd compared to Chainer:<a href="https://chainer.org" rel="nofollow">https://chainer.org</a> or Flux:<a href="http://fluxml.ai" rel="nofollow">http://fluxml.ai</a><p>Really cool to see though! XLA/TPU support is awesome.
Very cool! I love autograd, it had tape-based autodiff way before pytorch, and the way it wraps numpy is much more convenient than tensorflow/pytorch. Been wanting GPU support in autograd for years now, so am very happy to see this.<p>I have some academic software (<a href="https://github.com/popgenmethods/momi2" rel="nofollow">https://github.com/popgenmethods/momi2</a>) that uses autograd, was planning to port it to pytorch since it's better supported/maintained, but now I'll have to consider jax. Though I'm a little worried about the maturity of the project, seems like the numpy/scipy coverage is not all the way there yet. Then again, it would be fun to contribute back to JAX, I did contribute a couple PRs to autograd back in the day so I think I could jump right into it...
This is very cool, and I can see all sorts of use cases where a tool like this could be valuable.<p>Definitely will try to keep in mind that a tool for fast differentiation of arbitrary functions exists out in the world when starting my next project<p>Thanks for posting!
> JAX can automatically differentiate native Python ... functions<p>sympy can differentiate functions, but they have to be set up properly. How can JAX differentiate native functions?<p>(Or do they mean numerical differentiation, like a finite difference estimation?)