Jax: jax-v0.9.0 Release

Release date:
January 20, 2026
Previous version:
jax-v0.8.3 (released January 28, 2026)
Magnitude:
0 Diff Delta
Contributors:
0 total committers
Data confidence:
Commits:

Top Contributors in jax-v0.9.0

Could not determine top contributors for this release.

Directory Browser for jax-v0.9.0

All files are compared to previous version, jax-v0.8.3. Click here to browse diffs between other versions.

Loading File Browser...

Release Notes Published

  • New features:

    • Added jax.thread_guard, a context manager that detects when devices are used by multiple threads in multi-controller JAX.
  • Bug fixes:

    • Fixed a workspace size calculation error for pivoted QR (magma_zgeqp3_gpu) in MAGMA 2.9.0 when using use_magma=True and pivoting=True. (#34145).
  • Deprecations:

    • The flag jax_collectives_common_channel_id was removed.
    • The jax_pmap_no_rank_reduction config state has been removed. The no-rank-reduction behavior is now the only supported behavior: a jax.pmapped function f sees inputs of the same rank as the input to jax.pmap(f). For example, if jax.pmap(f) receives shape (8, 128) on 8 devices, then f receives shape (1, 128).
    • Setting the jax_pmap_shmap_merge config state is deprecated in JAX v0.9.0 and will be removed in JAX v0.10.0.
    • jax.numpy.fix is deprecated, anticipating the deprecation of numpy.fix in NumPy v2.5.0. jax.numpy.trunc is a drop-in replacement.
  • Changes:

    • jax.export now supports explicit sharding. This required a new export serialization format version that includes the NamedSharding, including the abstract mesh, and the partition spec. As part of this change we have added a restriction in the use of exported modules: when calling them the abstract mesh must match the one used at export time, including the axis names. Previously, only the number of the devices mattered.