Jax might be faster than Pytorch, I don’t know. I’m talking about TF. When I switched from TF to Pytorch 3 years ago, I got no slowdown on any of computer vision models at the time. And I remember looking at a couple of independent benchmarks which also showed them to be roughly the same in speed.
It really depends on what your model is doing. For a time, sequence models were easier to do with Pytorch than TF (due to control flow). On the efficiency side, for vanilla CV models, I also did not observe major differences last time I looked, but when I started to do lots of things in parallel, multi-gpu training, heavy data augmentation, I think TF has some well-engineered capabilities that are not matched yet.
> The #1 problem with PyTorch is that it’s great if you want to use one videocard for training
Incorrect information so confidently stated here. Tons of research papers that use more than one GPU for training, not sure what you're referring to? Standard DDP works fine, for starters.
> how much better it is when you have actual control over which parts of your program are JITed
Can you elaborate? What’s the advantage of controlling which parts are jitted?
The #1 problem with PyTorch is that it’s great if you want to use one videocard for training. Facebook has completely failed to support research scientists that want to do more than this.
It’s no secret that I’m a jax fanboy. But I drink the koolaid because it tastes better than anyone else’s. PyTorch is gonna have a rude wake up call in about… oh, four years. They’ll wake up and hear everyone else comparing them to tensorflow, and it won’t be for the rosy reasons they currently enjoy. PyTorch devs are living in the dark ages without even realizing how much better it is when you have actual control over which parts of your program are JITed, along with an actual execution graph that you can walk and macroexpand lisp-style.
https://jax.readthedocs.io/en/latest/autodidax.html should be required reading for every ML dev, and I can hardly get anyone to look at it. Sometimes I wonder if people just don’t see the steamroller coming for PyTorch. Probably — jax still reads to outsiders as a toy.