TE
TechEcho
Home24h TopNewestBestAskShowJobs
GitHubTwitter
Home

TechEcho

A tech news platform built with Next.js, providing global tech news and discussions.

GitHubTwitter

Home

HomeNewestBestAskShowJobs

Resources

HackerNews APIOriginal HackerNewsNext.js

© 2025 TechEcho. All rights reserved.

Pytrees

132 pointsby f_devdalmost 2 years ago

5 comments

time_to_smilealmost 2 years ago
For those curious what the big deal is here: PyTrees make it wildly easier to take derivatives with respect to parameters involving a complex structure. This makes it much easier to organize code for non-trivial models.<p>As an example: if you want to implement logistic regression in JAX, you need to optimize the weights. This is easy enough since this can be modeled as a single value, a matrix of weights. If you want to model a 2 layer MLP, now you have to use 2 matrices of weights (at least). You could treat this as two parameters to your function (which makes the derivative more complicated to manage) or you could concatenate the weights and split them up, etc. Annoying, but managable.<p>When you get to something like a diffusion model you now need to manage parameters for a variety of different, quite complex, models. It really helps if you can keep track of all these parameters in whatever data structure you like, but also trivially just call &quot;grad&quot; with regard to these and get your models derivative with respect to its parameters.<p>Pytrees make this incredibly simple, and is a major quality of life improvement in automatic differentiation.
albertzeyeralmost 2 years ago
There is also the standalone library &quot;tree&quot; from DeepMind: <a href="https:&#x2F;&#x2F;github.com&#x2F;deepmind&#x2F;tree">https:&#x2F;&#x2F;github.com&#x2F;deepmind&#x2F;tree</a><p>It provides similar functionality but is standalone and does not depend on JAX, TF or anything else.
iNicalmost 2 years ago
JAX&#x27;s use of pytrees is great! They implemented a lot of useful utility functions, namely `tree_map`, that makes working with these objects easy and intuitive. I recommend looking at their neural network example library &quot;stax&quot;.
mccoybalmost 2 years ago
One curious thing I discovered a few months ago: you can sort of hack higher-order functions into JAX by defining “Pytree closures” which introspect on normal closures, and pull out the JAX tracer data from the closure environment (and put it back in, when tracing is required) —- and this works! You can pass these Pytree closures in and out of JIT boundaries, etc.<p>I believe JAX has a utility for this somewhere, can’t quite remember what this is called.<p>I typically think of JAX as quite restrictive — but I think the reality is that the only real limit on expressivity is that you can’t dynamically allocate inside of unbounded control flow (e.g. creating new allocations inside of a while loop).
评论 #36031757 未加载
patrickkidgeralmost 2 years ago
Shameless advert -- Equinox is a neural network library for JAX based entirely around pytrees:<p><a href="https:&#x2F;&#x2F;github.com&#x2F;patrick-kidger&#x2F;equinox">https:&#x2F;&#x2F;github.com&#x2F;patrick-kidger&#x2F;equinox</a><p>(Now on 1.1k stars so it&#x27;s achieved some popularity!)<p>This makes model-building elegant (IMO), without any new abstractions to learn. Quite a PyTorch-like experience overall.
评论 #36033235 未加载