Word is that internally at Google, among a few teams, and then also externally, Trax/Jax are putting up real competition to Tensorflow. Some teams have moved off of tensorflow entirely. Combined with the better research capabilities of PyTorch, the future of tensorflow is not bright. Given that, Tensorflow still provides the highest performance with regards to production usage, and has tons of legacy code strewn throughout the web.<p>I would argue that this is not the fault of Tensorflow, but rather the hazard of being the first implementation in an extremely complex space. Seems like usually there needs to be some sacrificial lamb in software domains. Somewhat like Map/Reduce was quickly replaced by Spark, which has no real competitors.
Is it just me or is there zero explanation to what this actually is?<p>It somehow "helps" me understand deep learning but its tutorial / doc is one python notebook with three cells where some nondescript unknown API is called to train a transformer.<p>Huh?
Note that, in this space, there is also Flax[0] which is also built on top of Jax bringing more deep-learning specific primitives (while not trying to be tensorflow compatible unlike Trax if I understand correctly).<p>[0]: <a href="https://github.com/google-research/flax/tree/prerelease" rel="nofollow">https://github.com/google-research/flax/tree/prerelease</a>
Is this like a layer on top of TensorFlow to make it easier to get started? Is it meant to compete with PyTorch in that respect?<p>I wish the title and description were more clear. They make it sound like a course but it is a library/command-line tool.
I was recently surprised to discover that Jax can't use a TPU's CPU, and that there are no plans to add this to Jax. <a href="https://github.com/google/jax/issues/2108#issuecomment-581541862" rel="nofollow">https://github.com/google/jax/issues/2108#issuecomment-58154...</a><p>A TPU's CPU is <i>the only reason</i> that TPUs are able to get such high performance on MLPerf benchmarks like imagenet resnet training. <a href="https://mlperf.org/training-results-0-6" rel="nofollow">https://mlperf.org/training-results-0-6</a><p>They do infeed processing (image transforms, etc) on the TPU's CPU. Then the results are fed to each TPU core.<p>Without this capability, I don't know how you'd feed the TPUs with data in a timely fashion. It seems like your input will be starved.<p>Hopefully they'll bring jax to parity with tensorflow in this regard soon. Otherwise, given that jax is a serious tensorflow competitor, I'm not sure how the future of TPUs will play out.<p>(If it sounds like this is just a minor feature, consider how it would sound to say "We're selling this car, and it can go fast, but it has no seats." Kind of a crucial feature of a car.)<p>Still, I think this is just a passing issue. There's no way that Google is going to let their TPU fleet languish. Not when they bring in >$1M/yr per TPU pod commitment.
Not sure why one would bother with this. This is a less mature version of PyTorch. And I know there's XLA and stuff, but I've yet to see any major benefit from that for research in particular. A ton of time in DL frameworks is spent in the kernels (which in most practical cases means CUDA/cuDNN) which are hand-optimized far better than anything we'll ever get out of any optimizer.