Hacker News new | past | comments | ask | show | jobs | submit login
Penzai: JAX research toolkit for building, editing, and visualizing neural nets (github.com/google-deepmind)
261 points by mccoyb 12 days ago | hide | past | favorite | 51 comments





I like JAX, and find most of the core functionality as an "accelerated NumPy" great. Ecosystem fragmentation and difficulties in interop make adopting JAX hard though.

There's too much fragmentation within the JAX NN library space, which penzai isn't helping with. I wish everyone using JAX could agree on a single set of libraries for NN, optimization, and data loading.

PyTorch code can't be called, meaning a lot of reimplementation in JAX is needed when extending and iterating on prior works, which is the case for most of research. Custom CUDA kernels are a bit fiddly too, I haven't been able to bring Gaussian Splatting to JAX yet.


I'm curious what interop difficulties you've run into in JAX? In my experience, the JAX ecosystem is quite modular and most JAX libraries work pretty well together. Penzai's core visualization tooling should work for most JAX NN libraries out of the box, and Penzai's neural net components are compatible with existing JAX optimization libraries (like Optax) and data loaders (like tfds/seqio or grain).

(Interop with PyTorch seems more difficult, of course!)


It's mostly an ecosystem thing, being unable to use existing methods. In my experience, research goes something like

1. Milestone paper introducing novel method is published with green-field implementation

2. Bunch of papers extend milestone paper with brown-field implementation

3. Goto 1

Most things in 1 are written in PyTorch, meaning 2 also has to be in PyTorch. I know this isn't JAX's fault, but I don't think JAX's philosophy to stay unopinionated and low-level is helping. Seems like the community agreeing on a single set of DL libraries around JAX will help it gain some momentum.


That's my experience as well. PyTorch dominates the ecosystem.

Which is a shame, because JAX's approach is superior.[a]

---

[a] In my experience, anytime I've have to do anything in PyTorch that isn't well supported out-of-the-box, I've quickly found myself tinkering with Triton, which usually becomes... very frustrating. Meanwhile, JAX offers decent parallelization of anything I write in plain Python, plus really nice primitives like jax.lax.while_loop, jax.lax.associative_scan, jax.lax.select, etc. And yet, I keep using PyTorch... because of the ecosystem.


The best is not always popular. JAX idea is very like Erlang programming language.

> The best is not always popular.

I agree. Network effects routinely overpower better technology.


Another issue I've personally faced is debugging - although I am saying this from my experience from more than a yr ago, and maybe things are better now. I have used it mostly for optimization and the error messages aren't helpful.

I’ve only been reading through the docs for a few moments, but I’m pleasantly surprised to find they the authors are using effect handlers to handle effectful computations in ML models. I was in the process of translating a model from torch to Jax using Equinox, this makes me think penzai could be a better choice.

I was just reading this too! I think it's a really interesting choice in the design space.

So to elucidate this a little bit, the trade-off is that this is now incompatible with e.g. `jax.grad` or `lax.scan`: you can't compose things in the order `discharge_effect(jax.grad(your_model_here))`, or put an effectful `lax.scan` inside your forward pass, etc. The effect-discharging process only knows how to handle traversing pytree structures. (And they do mention this at the end of their docs.)

This kind of thing was actually something I explicitly considered later on in Equinox, but in part decided against as I couldn't see a way to make that work either. The goal of Equinox was always absolute compatibility with arbitrary JAX code.

Now, none of that should be taken as a bash at Penzai! They've made a different set of trade-offs, and if the above incompatibility doesn't affect your goals then indeed their effect system is incredibly elegant, so certainly give it a try. (Seriously, it's been pretty cool to see the release of Penzai, which explicitly acknowledges how much it's inspired by Equinox.)


Author of Penzai here! In idiomatic Penzai usage, you should always discharge all effects before running your model. While it's true you can't do `discharge_effect(jax.grad(your_model_here))`, you can still do `jax.grad(discharge_effect(your_model_here))`, which is probably what you meant to do anyway in most cases. Once you've wrapped your model in a handler layer, it has a pure interface again, which makes it fully compatible with all arbitrary JAX transformations. The intended use of effects is as an internal helper to simplify plumbing of values into and out of layers, not as something that affects the top-level interface of using the model!

(As an example of this, the GemmaTransformer example model uses the SideInput effect internally to do attention masking. But it exposes a pure functional interface by using a handler internally, so you can call it anywhere you could call an Equinox model, and you shouldn't have to think about the effect system at all as a user of the model.)

It's not clear to me what the semantics of ordinary JAX transformations like `lax.scan` should be if the model has side effects. But if you don't have any effects in your model, or if you've explicitly handled them already, then it's perfectly fine to use `lax.scan`. This is similar to how it works in ordinary JAX; if you try to do a `lax.scan` over a function that mutates Python state, you'll probably hit an error or get something unexpected. But if you mutate Python state internally inside `lax.scan`, it works fine.

I'll also note that adding support for higher-order layer combinators (like "layer scan") is something that's on the roadmap! The goal would be to support some of the fancier features of libraries like Flax when you need them, while still admitting a simple purely-functional mental model when you don't.


Thanks! This is one of the more experimental design choices I made in designing Penzai, but so far I've found it to be quite useful.

The effect system does come with a few sharp edges at the moment if you want to use JAX transformations inside the forward pass of your model (see my reply to Patrick), but I'm hoping to make it more flexible as time goes on. (Figuring out how effect systems should compose with function transformations is a bit nontrivial!)

Please let me know if you run into any issues using Penzai for your model! (Also, most of Penzai's visualization and patching utilities should work with Equinox too, so you shouldn't necessarily need to fully commit to either one.)


This something I’ve thought about in the past, since I messed around with trying to add monads to JAX - I think you made the right call with effect handlers. You might want to take a look at what Koka does, that was the best implementation of effect handlers the last time I checked.

I remember pytorch has some pytree capability, no? So is it safe to say that the any-pytree-compatible modules here are already compatible w/ pytorch?

Author here! I didn't know PyTorch had its own pytree system. It looks like it's separate from JAX's pytree registry, though, so Penzai's tooling probably won't work with PyTorch out of the box.

I implemented Jax’s pytrees in pure python. You can use it with whatever you want. https://github.com/shawwn/pytreez

The readme is a todo, but the tests are complete. They’re the same that Jax itself uses, but zero dependencies. https://github.com/shawwn/pytreez/blob/master/tests/test_pyt...

The concept is simple. The hard part is cross pollination. Suppose you wanted to literally use Jax pytrees with PyTorch. Now you’ll have to import Jax, or my library, and register your modules with it. But anything else that ever uses pytrees need to use the same pytree library, because the registry (the thing that keeps track of pytree compatible classes) is in the library you choose. They don’t share registries.

A better way of phrasing it is that if you use a jax-style pytree interface, it should work with any other pytree library. But to my knowledge, the only pytree library besides Jax itself is mine here, and only I use it. So when you ask if pytree-compatible modules are compatible with PyTorch, it’s equivalent to asking whether PyTorch projects use jax, and the answer tends to be no.

EDIT: perhaps I’m outdated. OP says that PyTorch has pytree functionality now. https://news.ycombinator.com/item?id=40109662 I guess yet again I was ahead of the times by a couple years; happy to see other ecosystems catch up. Hopefully seeing a simple implementation will clarify the tradeoffs.

The best approach for a universal pytree library would be to assume that any class with tree_flatten and tree_unflatten methods are pytreeable, and not require those classes to be explicitly registered. That way you don’t have to worry whether you’re using Jax or PyTorch pytrees. But I gave up trying to make library-agnostic ML modules; in practice it’s better just to choose Jax or PyTorch and be done with it, since making PyTorch modules run in Jax automatically (and vice versa) is a fool’s errand (I was the fool, and it was an errand) for many reasons, not the least of which is that Jax builds an explicit computation graph via jax.jit, a feature PyTorch has only recently (and reluctantly) embraced. But of course, that means if you pick the wrong ecosystem, you’ll miss out on the best tools — hello React vs Vue, or Unreal Engine vs Unity, or dozens of other examples.


There are a couple more such libraries. One was inside tensorflow (nest) and then extracted into the standalone dm-tree: https://github.com/deepmind/tree

Or also: https://github.com/metaopt/optree

I think ideally you would try to use mostly standard types (dict, list, tuple, etc) which are supported by all those libraries in mostly the same way, so it's easy to switch.

You have to be careful in some of the small differences though. E.g. what basic types are supported (e.g. dataclass, namedtuple, other derived instances from dict, tuple, etc), or how None is handled.


Does anyone know if and how well Penzai can work with Diffrax [1]? I currently use Diffrax + Equinox for scientific machine learning. Penzai looks like an interesting alternative to Equinox.

[1]: https://docs.kidger.site/diffrax/


Not sure on the specific combination, but since everything in Jax is functionally pure it's generally really easy to compose libraries. E.g. I've written code which embedded a flax model inside a haiku model without much effort.

IIUC then penzai is (deliberately) sacrificing support for higher-order operations like `lax.{while_loop, scan, cond}` or `diffrax.diffeqsolve`, in return for some of the other new features it is trying out (treescope, effects).

So it's slightly more framework-y than Equinox and will not be completely compatible with arbitrary JAX code. However I have already had a collaborator demonstrate that as long as you don't use any higher-order operations, then treescope will actually work out-of-the-box with Equinox modules!

So I think the answer to your question is "sort of":

* As long as you only try to inspect things that are happening outside of your `diffrax.diffeqsolve` then you should be good to go. And moreover can probably do this simply by using e.g. Penzai's treescope directly alongside your existing Equinox code, without needing to move things over wholesale.

* But anything inside probably isn't supported + if I understand their setup correctly can never be supported. (Not bashing Penzai there, which I think genuinely looks excellent -- I think it's just fundamentally tricky at a technical level.)


Author of Penzai here. I think the answer is a bit more nuanced (and closer to "yes") than this:

- If you want to use the treescope pretty-printer or the pz.select tree manipulation utility, those should work out-of-the-box with both Equinox and Diffrax. Penzai's utilities are designed to be as modular as possible (we explicitly try not to be "frameworky") so they support arbitrary JAX pytrees; if you run into any problems with this please file an issue!

- If you want to call a Penzai model inside `diffrax.diffeqsolve`, that should also be fully supported out of the box. Penzai models expose a pure functional interface when called, so you should be able to call a Penzai model anywhere that you'd call an Equinox model. From the perspective of the model user, you should be able to think of the effect system as an implementation detail. Again, if you run into problems here, please file an issue.

- If you want to write your own Penzai layer that uses `diffrax.diffeqsolve` internally, that should also work. You can put arbitrary logic inside a Penzai layer as long as it's pure.

- The specific thing that is not currently fully supported is: (1) defining a higher-order Penzai combinator layer that uses `diffrax.diffeqsolve` internally, (2) and having that layer run one of its sublayers inside the `diffrax.diffeqsolve` function, (3) while simultaneously having that internal sublayer use an effect (like random numbers, state, or parameter sharing), (4) where the handler for that effect is placed outside of the combinator layer. This is because the temporary effect implementation node that gets inserted while a handler is running isn't a JAX array type, so you'll get a JAX error when you try to pass it through a function transformation.

This last case is something I'd like to support as well, but I still need to figure out what the semantics of it should be. (E.g. what does it even mean to solve a differential equation that has a local state variable in it?) I think having side effects inside a transformed function is fundamentally hard to get right!


I have a small YT channel that teaches JAX bit-by-bit, check it out! https://www.youtube.com/@TwoMinuteJAX

Looks great, but outside Google I do not personally know anyone who uses Jax, and I work in this space.

Not at Google but currently using Jax to leverage TPUs, because AWS's GPU pricing is eye-gougingly expensive. For the lower-end A10 GPUs, the price-per-gpu for a 4 GPU machine is 1.5x the price for a 1 GPU machine, and the price-per-gpu for a 8 GPU machine is 2x the price of a 1 GPU machine! If you want a A100 or H100, the only option is renting an 8 GPU instance. With properly TPU-optimised code you get something like 30-50% cost saving on GCP TPUs compared to AWS (and I say that as someone who otherwise doesn't like Google as a company and would prefer to avoid GCP if there wasn't such a significant cost advantage).

I use it for GPU accelerated signal processing. It really delivers on the promise of "Numpy but for GPU" better than all competing libraries out there.

We've built our startup from scratch on JAX, selling text-to-image model finetuning, and it's given us a consistent edge not only in terms of pure performance but also in terms of "dollars per unit of work"

Is that gain from TPU usage or something else?

Mostly from the tight JAX-TPU integration yeah

Isn't JAX the most widely used framework in the GenAI space? Most companies there use it -- Cohere, Anthropic, CharacterAI, xAI, Midjourney etc.

most of the GenAI players use both PyTorch and JAX, depending on the hardware they are running on. Character, Anthro, Midjourney, etc. are dual shops (they use both). xAI only uses JAX afaik.

just guessing that tech leadership at all of those traces back to Google somehow

Jax trends on papers with code:

https://paperswithcode.com/trends


Was gonna ask "What's that MindSpore thing that seems to be taking the research world by storm?" but I Googled and it's apparently Huawei's open-source AI framework. 1% to 7% market share in 2 years is nothing to sneeze at - that's growth rates similar to Chrome or Facebook in their heyday.

It's telling that Huawei-backed MindSpore can go from 1% to 7% in 2 years, while Google-backed Jax is stuck at 2-3%. Contrary to popular narrative in the Western world, Chinese dominance is alive and well.


>It's telling that Huawei-backed MindSpore can go from 1% to 7% in 2 years, while Google-backed Jax is stuck at 2-3%. Contrary to popular narrative in the Western world, Chinese dominance is alive and well.

MindSpore has an advantage there because of its integrated support for Huawei's Ascend 910B, the only Chinese GPU that comes close to matching the A100. Given the US banned export of A100 and H100s to China, this creates artificial demand for the Ascend 910B chips and the MindSpore framework that utilises them.


No, mindspore rises because of the chip embargo

No one is going to use stuff that one day is cut off supply.

This is one signal why Huawei was listed by Nvidia as competitor in 4 out of 5 categories of areas, in nvidia's earnings


Its meteoric rise started well before the chip embargo. I've looked into it, it liberally borrows ideas from other frameworks, both PyTorch and Jax, and adds some of its own. You lose some of the conceptual purity, but it makes up for it in practical usability, assuming it works as it says on the tin, which it may or may not. PyTorch also has support for Ascend as far as I can tell https://github.com/Ascend/pytorch, so that support does not necessarily explain MindSpore's relative success. Why MindSpore is rising so rapidly is not entirely clear to me. Could be something as simple as preferring a domestic alternative that is adequate to the task and has better documentation in Chinese. Could be cost of compute. Could be both. Nowadays, however, I do agree that the various embargoes would help it (as well as Huawei) a great deal. As a side note I wish Huawei could export its silicon to the West. I bet that'd result in dramatically cheaper compute.

This data might just be unreliable. It had a weird spike in Dec 2021 that looks unusual compared to all the other frameworks.

China publishes a looooootttttt of papers. A lot of it is careerist crap.

To be fair, a lot of US papers are also crap, but Chinese crap research is on another level. There's a reason a lot of top US researchers are Chinese - there's brain drain going on.


When I looked into a random sampling of these uses, my impression was that it was a common kind of project in China to take a common paper (or another repo) and implement it in Mindspore. That accounted for the vast majority of the implementations.

Note that most of Jax’s minuscule share is Google.

I’m in academia and I use jax because it’s closest to translate maths to code.

Same, Jax is extremely popular with the applied math/modeling crowd.

I use it all the time, and there's also a few classes at my uni that use Jax. It's really great for experimentation and research, you can do a lot of things in Jax you just can't in, say, PyTorch.

Like what?

Anytime you want to make something GPU accelerated that doesn't fit as standard operations on tensors. For example, I often write RL environments in Jax, which is something you can't do in PyTorch. There's also things you can do in PyTorch but that would be far more difficult, for example an efficient implementation of MCTS.

I also used Jax a lot for differential equations, not even sure how I would do that with PyTorch.

Basically, Torch is a lot more like a specialization of Numpy for neural networks, while Jax feels a lot more like if you could just write CUDA as Python, and also get the Jacobians (jacs! jax!) and jvp for free (of everything, you can even differentiate hyperparameters through your optimizer which is crazy).

At the end, when you're doing fundamental research and coming up with something new, I think Jax is just better. If all I had to do was implementation, then I would be a happy PyTorch user.


A small addendum: the only people I know who uses Jax are people who work at Google, or people who had a big GCP grant and needed to use TPUs as a result.

That's cool -- but wouldn't it be more constructive to discuss "the ideas" in this package anyways?

For instance, it would be interesting to discern if the design of PyTorch (and their modules) preclude or admit the same sort of visualization tooling? If you have expertise in PyTorch, perhaps you could help answer this sort of question?

JAX's Pytrees are like "immutable structs, with array leaves" -- does PyTorch have a similar concept?


> does PyTorch have a similar concept

of course https://github.com/pytorch/pytorch/blob/main/torch/utils/_py...


Idk if you need that immutability actually. You could probably reconstruct enough to do this kind of viz from the autograd graph, or capture the graph and intermediates in the forward pass using hooks. My hunch is it should be doable.

If JAX had affine_grid() and grid_sample(), I'd be using it instead of PyTorch for my current project.

it would be great if we can have intelligent tools for building neural networks in pytorch.

would a comprehensive object construction platform with schema support and the ability to hookup to a compiler (ie turn object data to code for instance) be a useful tool in this domain?

ex: https://www.youtube.com/watch?v=fPnD6I9w84c

I am the developer, happy to answer questions.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: