PyTorch is a generationally important project. I've never seen a tool that is so inline with how researchers learn and internalize a subject. Teaching Machine Learning before and after its adoption has been a completely different experience. Never can be said enough how cool it is that Meta fosters and supports it.<p>Viva PyTorch! (Jax rocks too)
the author got a couple of things wrong, that are worth pointing out:<p>1. PyTorch is going all-in on torch.compile -- Dynamo is the frontend, Inductor is the backend -- with a strong default Inductor codegen powered by OpenAI Triton (which now has CPU, NVIDIA GPU and AMD GPU backends). The author's view that PyTorch is building towards a multi-backend future isn't really where things are going. PyTorch supports extensibility of backends (including XLA), but there's disproportionate effort into the default path. torch.compile is 2 years old, XLA is 7 years old. Compilers take a few years to mature. torch.compile will get there (and we have reasonable measures that the compiler is on track to maturity).<p>2. PyTorch/XLA exists, mainly to drive a TPU backend for PyTorch, as Google gives no other real way to access the TPU. It's not great to try shoe-in XLA as a backend into PyTorch -- as XLA fundamentally doesn't have the flexibility that PyTorch supports by default (especially dynamic shapes). PyTorch on TPUs is unlikely to ever have the experience of JAX on TPUs, almost by definition.<p>3. JAX was developed at Google, not at Deepmind.
PyTorch beat Tensorflow because it was much easier to use for research. Jax is much harder to use for exploratory research than PyTorch, due to requiring a fixed shape computation graph, which makes implementing many custom model architectures very difficult.<p>Jax's advantages shine when it comes to parallelizing a new architecture across multiple GPU/TPUs, which it makes much easier than PyTorch (no need for custom cuda/networking code). Needing to scale up a new architecture across many GPUs is however not a common use-case, and most teams that have the resources for large-scale multi-gpu training also have the resources for specialised engineers to do it in PyTorch.
From an eng/industry perspective, back in 2016/2017 I watched the realtime decline of Tensorflow towards Pytorch.<p>The issue was TF had too many interfaces to accomplish the same thing and each one was rough in its own way. Along with some complexity for using serving and experiment logging via Tensorboard, but this wasn’t as bad at least for me.<p>Keras was integrated in an attempt to help, but ultimately it wasn’t enough and people started using Torch more and more even against the perception that TF was for prod workloads and Torch was for research.<p>TFA mentions the interface complexity as starting to be a problem with Torch, but I don’t think we’re anywhere near the critical point that would cause people to abandon it in favor of JAX.<p>Additionally with JAX you’re just shoving the portability problems mentioned down to XLA which brings its own issues and gotchas even if it hides the immediate reality of said problems from the end user.<p>I think the Torch maintainers should watch not to repeat the mistakes of TF, but I think theres a long way to go before JAX is a serious contender. It’s been years and JAX has stayed in relatively small usage.
PyTorch is developed by multiple companies / stake holders while jax is google only with internal tooling they don’t share with the world. This alone is a major reason not to use jax. Also I think it is more the other way around: with torch.compile the main advantage of jax is disappearing.
Pushback notwithstanding, this article is 100% correct in all PyTorch criticisms. PyTorch was a platform for fast experimentation with eager evaluation, now they shoehorn "compilers" into it. "compilers", because a lot of the work is done by g++ and Triton.<p>It is a messy and quickly expanding codebase with many surprises like segfaults and leaks.<p>Is scientific experimentation really sped up by these frameworks? Everyone uses the Transformer model and uses the same algorithms over and over again.<p>If researchers wrote directly in C or Fortran, perhaps they'd get new ideas. The core inference (see Karparthy's llama.c) is ridiculously small. Core training does not seem much larger either.
Can we get the title changed to the actual title of the post? "The future of Deep Learning frameworks" sounds like a neutral and far wider-reaching article, and ends up being clickbait here (even if unintentionally).<p>"PyTorch is dead. Long live JAX." conveys exactly what the article about, and is a much better title.
I wish dex-lang [1] had gotten more traction. It’s JAX without the limitations that come from being a Python DSL.
But ML researchers apparently don’t want to touch anything that doesn’t look exactly like Python.<p>[1]: <a href="https://github.com/google-research/dex-lang">https://github.com/google-research/dex-lang</a>
PyTorch is the javascript of ML. sadly "worse is better" software has better survival characteristics even when there is consensus that technology X is theoretically better
I think a lot of the commenters here are being rather unfair.<p>PyTorch has better adoption / network effects. JAX has stronger underlying abstractions.<p>I use both. I like both :)
i like pytorch because all the academia release their code with it<p>ive never even heard of jax nor will i have the skills to use it<p>i literally just want to know two things: 1) how much vram 2) how to run it on pytorch
One aspect of jax that’s rarely touched on is browser stuff. Completely aside from deep learning, it’s straightforward to compile jax to a graphics shader you can call in js, which in this insane world is actually my preferred way to put numerical computing or linear algebra code on a web page.
The best thing about PyTorch is that we aren't stuck with Python, rather we can enjoy it alongside proper performance, by using the Java and C++ API surfaces instead.
Are modern NN's really just static functions, and are they going to continue to be in the future?<p>KV caching is directly in conflict with a purely functional approach.
The article misses multi-modal thing. Which is the future. Sure they can be considered a separate things, like today. But that's probably not the best approach. Support from framework may include partial training, easy components swap, intermediate data caching, dynamic architecture, automatic work balance and scaling.
From the article is seems that JAX is a non-starter for me as they don't have support for any kind of acceleration on Windows proper, and only experimental in WSL.
> I believe that all infrastructure built on Torch is just a huge pile of technical debt, that will haunt the field for a long, long time.<p>... from the company that pioneered the approach with tensorflow. I've worked with worse ML frameworks, but they're by now pretty obscure; i cannot remember (and i am very happy about it) the last time i saw MXNet in the wild, for example. You'll still find Caffe on some embedded systems, but you can mostly sidestep it.
Jax is well designed? That's nice. The only thing that matters is adoption. You can make run this title when Jax's adoption surpasses PyTorch. How does someone using _python_ not understand this?
> I’ve personally known researchers who set the seeds in the wrong file at the wrong place and they weren’t even used by torch at all - instead, were just silently ignored, thus invalidating all their experiments. (That researcher was me)<p>Some <i>assert</i>-ing won't hurt you. Seriously. It might even help keeping your sanity.
A more accurate title for the OP would be "I hope and wish PyTorch were dead, so Jax could become the standard."<p>Leaving aside the fact that PyTorch's ecosystem is 10x to 100x larger, depending on how one measures it, PyTorch's biggest advantage, in my experience, is that it can be picked up quickly by developers who are new to it. Jax, despite its superiority, or maybe because of it, can not be picked up quickly.<p>Equinox does a great job of making Jax accessible, but Jax's functional approach is in practice more difficult to learn than PyTorch's object-oriented one.
My main reason to avoid Jax is Google. Google doesn't provide good support even for things you pay them for. They do things because they want to, to get their internal promotions, irrespective of their customers or the impact on them.