Preferences

Interesting. Thanks for you input. I already tried to adhere to the JAX paradigm as laid out in the documentation so I already have a fully static graph.

I would test how much of the total flop capability of the hardware you are using. Take the first order terms of your model and estimate how many flops you need per data point (a good guide is 6*param for training if you mostly have large multiplies and nonlinearity/norm layers) and then calculate the real time performance for a given data size input vs the actual expected theoretical max perfomance for the given GPU (eg 1e15 FLOPs/s for bfloat16 per H100 or H200 GPU). If you are already over 50% it is unlikely you can have big gains without very considerable effort, and most likely simple jax or pytorch are not sufficient at that point. If you are at the 2–20% range there are probably some low hanging fruit left and the closer you are to using only 1% the easier it is to see dramatic gains.

This item has no comments currently.

Keyboard Shortcuts

Story Lists

j
Next story
k
Previous story
Shift+j
Last story
Shift+k
First story
o Enter
Go to story URL
c
Go to comments
u
Go to author

Navigation

Shift+t
Go to top stories
Shift+n
Go to new stories
Shift+b
Go to best stories
Shift+a
Go to Ask HN
Shift+s
Go to Show HN

Miscellaneous

?
Show this modal