I like JAX, and find most of the core functionality as an "accelerated NumPy" great.
Ecosystem fragmentation and difficulties in interop make adopting JAX hard though.<p>There's too much fragmentation within the JAX NN library space, which penzai isn't helping with. I wish everyone using JAX could agree on a single set of libraries for NN, optimization, and data loading.<p>PyTorch code can't be called, meaning a lot of reimplementation in JAX is needed when extending and iterating on prior works, which is the case for most of research. Custom CUDA kernels are a bit fiddly too, I haven't been able to bring Gaussian Splatting to JAX yet.
I’ve only been reading through the docs for a few moments, but I’m pleasantly surprised to find they the authors are using effect handlers to handle effectful computations in ML models. I was in the process of translating a model from torch to Jax using Equinox, this makes me think penzai could be a better choice.
I remember pytorch has some pytree capability, no? So is it safe to say that the any-pytree-compatible modules here are already compatible w/ pytorch?
Does anyone know if and how well Penzai can work with Diffrax [1]? I currently use Diffrax + Equinox for scientific machine learning. Penzai looks like an interesting alternative to Equinox.<p>[1]: <a href="https://docs.kidger.site/diffrax/" rel="nofollow">https://docs.kidger.site/diffrax/</a>
I have a small YT channel that teaches JAX bit-by-bit, check it out! <a href="https://www.youtube.com/@TwoMinuteJAX" rel="nofollow">https://www.youtube.com/@TwoMinuteJAX</a>