aaron brooks

I Was Wrong to Sleep on JAX

ยท jax , pytorch , training

I was slow to give JAX a fair shot.

Some of that was honest inertia. I had spent a lot of time in PyTorch, lived through the usual distributed training bruises, and built enough around FSDP that switching mental models sounded like a tax I did not need to pay.

Some of it was pride wearing a practical hat.

I am glad I got over it.

PyTorch still feels like home

PyTorch is where I learned to move fast with models. It is direct. It debugs well enough. The ecosystem is enormous. When something breaks at the edges, there is usually a GitHub issue, a forum thread, or somebody nearby who has already stepped on the same rake.

That matters in production. Familiar tools are not a weakness.

But familiarity can turn into a fence. You start judging other systems by how quickly they let you do the PyTorch-shaped thing you already had in mind.

That is not a fair test.

JAX makes the shape of computation harder to ignore

The thing I came to appreciate is that JAX forces more of the computation model into the open.

jit, vmap, pmap, pjit, and explicit arrays across devices can feel fussy at first. Then the fussy parts start paying rent. You are pushed to think about pure functions, shapes, compilation boundaries, and how data moves.

That is not free. Compilation surprises are real. Error messages can send you walking around the block. The ecosystem does not always hand you the same batteries PyTorch does.

Still, I like tools that make hidden costs visible. JAX does that.

The distributed story changed my mind

The part that made me stop being stubborn was distributed training.

With PyTorch and FSDP, I often feel like I am managing a powerful machine with a lot of knobs hidden behind curtains. It can work very well. It can also produce failure modes where you end up spelunking through wrapping policy, parameter sharding, optimizer state, and checkpoint behavior.

JAX did not make distributed training easy. Nothing worth running at scale is easy. But it made some of the layout and partitioning questions feel more explicit. More inspectable. Less like I was negotiating with a pile of implicit state.

That is worth something.

The rule I took from it

I am not interested in framework religion. That stuff is not worth a hill of beans when the run is failing and the GPUs are billing.

PyTorch is still a great tool. JAX is not magic. Both can cut you.

The lesson for me was simpler: if a tool makes the real constraints clearer, it deserves a fair trial. Even if it makes you slower for a week. Even if it pokes at your habits.

Sometimes the thing that feels like friction is the tool showing you where the system was already complicated.