I love jax, but I wish this kind of thing in python was easier to debug. The stacktraces you get using these patterns in this particular language can be wild.<p>For example, an unjitted function’s error can be easier to debug in jax than in anything else, while the same error in the jitted function can be harder to debug in jax than in anything else. But it’s similar with vmap, pmap, grad, and the rest of the transformations. Debugging gets nuts quickly.<p>But I don’t think there’s any way around that with these kinds of transformations in python, is there?