Maxtext: A simple, performant and scalable Jax LLM

(github.com)

115 points | by zerojames 11 days ago

2 comments

  • arjvik 11 days ago
    What are people's thoughts on how this compares to:

    - EasyLM [1] - Levanter [2] - T5X [3] - and more?

    [1]: https://github.com/young-geng/EasyLM [2]: https://github.com/stanford-crfm/levanter [3]: https://github.com/google-research/t5x

    Asking because I have worked extensively on training a large model on a TPU cluster, and started with Levanter, then tried MaxText, and finally ended up on EasyLM. My thoughts are:

    - Levanter is well intentioned but is unproven and lacking in features. For instance, their sharding is odd in that it requires embedding dimension to be a multiple of the number of devices, so I can't test using a model with embedding dimension 768 on a 512-device pod. Lost confidence in Levanter after finding some glaring correctness bugs (and helping get them fixed). Also, while I'm a huge fan of Equinox's approach, it's sadly underdeveloped (for instance, there's no way to specify non-default weight initialization strategies without manually doing model surgery to set weights).

    - MaxText was just very difficult to work with. We felt like we were fighting against it every time we needed to change something because we would be digging through numerous needless layers of abstraction. My favorite was after one long day of debugging, I found a function who's only purpose was to pass its arguments to another function untouched; this function's only purpose was to pass its arguments untouched to a new, third function, that then slightly changed them and passed them to a fourth function that did the work.

    - EasyLM is, as the name says, easy. But on a deeper dive, the sharding functionality seems to be underdeveloped. What they call "FSDP" is not necessarily true FSDP, it's literally just a certain axis that the JAX mesh is being sharded around that happens to shard some data axes and some model weight axes.

    I'm still searching for a "perfect" JAX LLM codebase - any pointers?

    • logicchains 11 days ago
      >MaxText was just very difficult to work with. We felt like we were fighting against it every time we needed to change something because we would be digging through numerous needless layers of abstraction. My favorite was after one long day of debugging, I found a function who's only purpose was to pass its arguments to another function untouched; this function's only purpose was to pass its arguments untouched to a new, third function, that then slightly changed them and passed them to a fourth function that did the work

      Some of this complexity may be necessary for achieving optimal performance in Jax. E.g. extra indirection to avoid the compiler making some bad fusion decision, or multiple calls so something can be marked as static for the jit in the outer call. As far as I'm aware MaxText is the only public Jax codebase that's demonstrated scaling to models with 100s of billions of weights. I've just started evaluating it and it seems to scale better than the Torch implementation I was using previously (even on GPU). Most of the abstraction seems to have a reason behind it (at least for me since I'm making some modifications to the vanilla model, which is easier when the components are less tightly coupled).

      • gallabytes 11 days ago
        > Some of this complexity may be necessary for achieving optimal performance in Jax. E.g. extra indirection to avoid the compiler making some bad fusion decision, or multiple calls so something can be marked as static for the jit in the outer call

        certainly some of it is but not the lion's share - I have a much simpler (private) codebase which scales pretty similarly afaict.

        the complexity of Maxtext feels more Serious Engineering ™ flavored, following Best Practices.

    • bionhoward 11 days ago
      Is t5x an encoder/decoder architecture?

      Some more general options.

      The Flax ecosystem

      https://github.com/google/flax?tab=readme-ov-file

      or dm-haiku

      https://github.com/google-deepmind/dm-haiku

      were some of the best developed communities in the Jax AI field

      Perhaps the “trax” repo? https://github.com/google/trax

      Some HF examples https://github.com/huggingface/transformers/tree/main/exampl...

      Sadly it seems much of the work is proprietary these days, but one example could be Grok-1, if you customize the details. https://github.com/xai-org/grok-1/blob/main/run.py

      • terafo 11 days ago
        t5 is an architecture, t5x is a framework for training models that was created with that architecture in mind, but can be used to train other architectures, including decoder-only ones(there is one in examples).
        • ma2rten 11 days ago
          t5x was used to train PaLM 1.
  • ubj 11 days ago
    This might be a tangent, but why does JAX only support the saving / serialization of AOT compilation executables for TPU [1]? It would be great to have the ability to save compiled functions and not have to JIT compile something every time you restart a session.

    (Julia has had this problem too, but they've made great progress on caching JIT compiled functions to reduce latency.)

    [1]: https://github.com/google/maxtext?tab=readme-ov-file#ahead-o...