For those curious what the big deal is here: PyTrees make it wildly easier to take derivatives with respect to parameters involving a complex structure. This makes it much easier to organize code for non-trivial models.<p>As an example: if you want to implement logistic regression in JAX, you need to optimize the weights. This is easy enough since this can be modeled as a single value, a matrix of weights. If you want to model a 2 layer MLP, now you have to use 2 matrices of weights (at least). You could treat this as two parameters to your function (which makes the derivative more complicated to manage) or you could concatenate the weights and split them up, etc. Annoying, but managable.<p>When you get to something like a diffusion model you now need to manage parameters for a variety of different, quite complex, models. It really helps if you can keep track of all these parameters in whatever data structure you like, but also trivially just call "grad" with regard to these and get your models derivative with respect to its parameters.<p>Pytrees make this incredibly simple, and is a major quality of life improvement in automatic differentiation.
There is also the standalone library "tree" from DeepMind: <a href="https://github.com/deepmind/tree">https://github.com/deepmind/tree</a><p>It provides similar functionality but is standalone and does not depend on JAX, TF or anything else.
JAX's use of pytrees is great! They implemented a lot of useful utility functions, namely `tree_map`, that makes working with these objects easy and intuitive. I recommend looking at their neural network example library "stax".
One curious thing I discovered a few months ago: you can sort of hack higher-order functions into JAX by defining “Pytree closures” which introspect on normal closures, and pull out the JAX tracer data from the closure environment (and put it back in, when tracing is required) —- and this works! You can pass these Pytree closures in and out of JIT boundaries, etc.<p>I believe JAX has a utility for this somewhere, can’t quite remember what this is called.<p>I typically think of JAX as quite restrictive — but I think the reality is that the only real limit on expressivity is that you can’t dynamically allocate inside of unbounded control flow (e.g. creating new allocations inside of a while loop).
Shameless advert -- Equinox is a neural network library for JAX based entirely around pytrees:<p><a href="https://github.com/patrick-kidger/equinox">https://github.com/patrick-kidger/equinox</a><p>(Now on 1.1k stars so it's achieved some popularity!)<p>This makes model-building elegant (IMO), without any new abstractions to learn. Quite a PyTorch-like experience overall.