Preferences

fxtentacle parent
Isn’t fusing ops at a fine-grained level also the core benefit of JAX over TensorFlow? How does this work compare to JAX?

zhihaojia
JAX's operator fusion (https://apxml.com/courses/advanced-jax/chapter-2-optimizing-...) can fuse a few local operators (e.g., matmul and elementwise computation) into a single kernel. But JAX's approach cannot fuse an entire LLM with hundreds of operators into a single kernel because many operators involve loop transformations.

MPK takes a different approach where instead of incrementally fusing local operators, it decomposes operators into a task graph and builds a runtime system within a single kernel to execute all tasks specified in the task graph.

This item has no comments currently.