Jax: jax-v0.8.0 Release

Release date:
October 15, 2025
Previous version:
jax-v0.7.2 (released September 15, 2025)
Magnitude:
18,317 Diff Delta
Contributors:
45 total committers
Data confidence:
Commits:

227 Commits in this Release

Ordered by the degree to which they evolved the repo in this version.

Authored October 2, 2025
Authored September 19, 2025
Authored September 30, 2025

Top Contributors in jax-v0.8.0

dimitar-asenov
allanrenucci
justinjfu
danielsuo
jakevdp
jburnim
yueshengys
WindQAQ
basioli-k
brianwa84

Directory Browser for jax-v0.8.0

We haven't yet finished calculating and confirming the files and directories changed in this release. Please check back soon.

Release Notes Published

  • Breaking changes:

    • JAX is changing the default jax.pmap implementation to one implemented in terms of jax.jit and jax.shard_map. jax.pmap is in maintenance mode and we encourage all new code to use jax.shard_map directly. See the migration guide for more information.
    • The auto= parameter of jax.experimental.shard_map.shard_map has been removed. This means that jax.experimental.shard_map.shard_map no longer supports nesting. If you want to nest shard_map calls, please use jax.shard_map.
    • JAX no longer allows passing objects that support __jax_array__ directly to, e.g. jit-ed functions. Call jax.numpy.asarray on them first.
    • jax.numpy.cov is now returns NaN for empty arrays ({jax-issue}#32305), and matches NumPy 2.2 behavior for single-row design matrices ({jax-issue}#32308).
    • JAX no longer accepts Array values where a dtype value is expected. Call .dtype on these values first.
    • The deprecated function jax.interpreters.mlir.custom_call was removed.
    • The jax.util, jax.extend.ffi, and jax.experimental.host_callback modules have been removed. All public APIs within these modules were deprecated and removed in v0.7.0 or earlier.
    • The deprecated symbol jax.custom_derivatives.custom_jvp_call_jaxpr_p was removed.
    • jax.experimental.multihost_utils.process_allgather raises an error when the input is a jax.Array and not fully-addressable and tiled=False. To fix this, pass tiled=True to your process_allgather invocation.
    • from jax.experimental.compilation_cache, the deprecated symbols is_initialized and initialize_cache were removed.
    • The deprecated function jax.interpreters.xla.canonicalize_dtype was removed.
    • jaxlib.hlo_helpers has been removed. Use jax.ffi instead.
    • The option jax_cpu_enable_gloo_collectives has been removed. Use jax_cpu_collectives_implementation instead.
    • The previously-deprecated interpolation argument to jax.numpy.percentile and jax.numpy.quantile has been removed; use method instead.
    • The JAX-internal for_loop primitive was removed. Its functionality, reading from and writing to refs in the loop body, is now directly supported by jax.lax.fori_loop. If you need help updating your code, please file a bug.
    • jax.numpy.trimzeros now errors for non-1D input.
    • The where argument to jax.numpy.sum and other reductions is now required to be boolean. Non-boolean values have resulted in a DeprecationWarning since JAX v0.5.0.
    • The deprecated functions in jax.dlpack, jax.errors, jax.lib.xla_bridge, jax.lib.xla_client, and jax.lib.xla_extension were removed.
    • jax.interpreters.mlir.dense_bool_array was removed. Use MLIR APIs to construct attributes instead.
  • Changes

    • jax.numpy.linalg.eig now returns a namedtuple (with attributes eigenvalues and eigenvectors) instead of a plain tuple.
    • jax.grad and jax.vjp will now round always primals to float32 if float64 mode is not enabled.
    • jax.dlpack.from_dlpack now accepts arrays with non-default layouts, for example, transposed.
    • The default nonsymmetric eigendecomposition on NVIDIA GPUs now uses cusolver. The magma and LAPACK implementations are still available via the new implementation argument to jax.lax.linalg.eig ({jax-issue}#27265). The use_magma argument is now deprecated in favor of implementation.
    • jax.numpy.trim_zeros now follows NumPy 2.2 in supporting multi-dimensional inputs.
  • Deprecations

    • jax.experimental.enable_x64 and jax.experimental.disable_x64 are deprecated in favor of the new non-experimental context manager jax.enable_x64.
    • jax.experimental.shard_map.shard_map is deprecated; going forward use jax.shard_map.
    • jax.experimental.pjit.pjit is deprecated; going forward use jax.jit.