I recently started using JAX for some ion-optics work in accelerator physics. I have found it very very good. The autodiff stuff is magical for doing optimisation work, but even just as a compiled-numpy, I have found it very easy to get highly performant code. For reference, I previously tried roughly the same thing in “numba”, and wasn’t able to get anywhere near the same performance as JAX, even running on the CPU, which I understand is JAX’s weakest backend. By and large I have just written basically idiomatic Python/numpy code — sprinkled a few “vmap”s and “scan”s around, and got great results. I’m very pleased with JAX.