Jax: jax-v0.10.0 Release

Release date:
April 16, 2026
Previous version:
jax-v0.9.2 (released March 18, 2026)
Magnitude:
62,264 Diff Delta
Contributors:
59 total committers
Data confidence:
Commits:

591 Commits in this Release

Ordered by the degree to which they evolved the repo in this version.

Authored December 2, 2025
Authored April 3, 2026
Authored March 24, 2026
Authored March 13, 2026
Authored March 26, 2026
Authored March 19, 2026
Authored March 19, 2026
Authored April 10, 2026
Authored April 13, 2026
Authored March 25, 2026

Top Contributors in jax-v0.10.0

hawkinsp
superbobry
mattjj
a-googler
yashk2810
danielsuo
allanrenucci
bchetioui
rdyro
levskaya

Directory Browser for jax-v0.10.0

All files are compared to previous version, jax-v0.9.2. Click here to browse diffs between other versions.

Loading File Browser...

Release Notes Published

  • New features:

    • Added ResizeMethod.CUBIC_PYTORCH to jax.image.resize to match PyTorch's bicubic resize (#15768).
    • We now support differentiation of jax.lax.linalg.qr for wide matrices and when full_matrices is True.
    • LAPACK operations are now parallelized along the batch dimension on CPU.
    • Added perturb_singular argument to jax.lax.linalg.tridiagonal_solve to handle singular matrices by perturbing near-zero pivots in the LU decomposition. This is useful for solving numerically singular systems when computing eigenvectors by inverse iteration.
    • jax.scipy.linalg.eigh_tridiagonal now supports computing eigenvectors on CPU and GPU.
    • Added the jax.numpy.ndarray.byteswap method.
  • Breaking changes:

    • PartitionSpec objects no longer report themselves to be equal to tuples. Convert tuples to PartitionSpec objects before testing equality.
    • The .vma property has been removed from jax.core.ShapedArray. Use .manual_axis_type.varying instead.
    • JAX CPU devices now report their names as cpu:0, cpu:1, etc. instead of TFRT_CPU_0, TFRT_CPU_1.
    • The config state jax_pmap_shmap_merge has been removed. jax.pmap will now always use the new implementation that wraps jax.jit(jax.shard_map). Please see https://docs.jax.dev/en/latest/migrate_pmap.html for more information.
    • jax.device_put_sharded and jax.device_put_replicated have been removed from the public API and now raise an AttributeError when accessed. Please see https://docs.jax.dev/en/latest/migrate_pmap.html#drop-in-replacements for drop-in replacements.
    • The C++ pmap infrastructure has been removed. The following public APIs are no longer available:
    • jax.sharding.PmapSharding
    • From jaxlib.xla_extension: PmapFunction, pmap, NoSharding, Chunked, Unstacked, ShardedAxis, Replicated, ShardingSpec.
    • From jax.interpreters.pxla: MapTracer, PmapExecutable, parallel_callable, shard_args, xla_pmap_p, Chunked, NoSharding, Replicated, ShardedAxis, ShardingSpec, Unstacked, spec_to_indices.
    • The deprecated keyword arguments a, a_min, and a_max to jax.numpy.clip have been removed.
    • Functions jax.numpy.hstack, jax.numpy.vstack, jax.numpy.dstack, jax.numpy.column_stack, jax.numpy.atleast_1d, jax.numpy.atleast_2d, and jax.numpy.atleast_3d no longer accept non-ArrayLike inputs. Doing so previously issued a DeprecationWarning.
    • jax.scipy.stats.rankdata now returns floating point values in all cases, following a similar change in the SciPy 1.18 release.
  • Deprecations:

    • A number of internal APIs in jax.core have been newly deprecated and some have been moved to jax.extend.core. These include CallPrimitive, DebugInfo, DropVar, Effect, Effects, InconclusiveDimensionOperation, JaxprTypeError, check_jaxpr, concrete_or_error, find_top_trace, gensym, get_opaque_trace_state, jaxprs_in_params, new_jaxpr_eqn, no_effects, nonempty_axis_env_DO_NOT_USE, primal_dtype_to_tangent_dtype, unsafe_am_i_under_a_jit_DO_NOT_USE, unsafe_am_i_under_a_vmap_DO_NOT_USE, unsafe_get_axis_names_DO_NOT_USE, valid_jaxtype, JaxprPpContext, JaxprPpSettings, OutputType, abstract_token, aval_mapping_handlers, call, concretization_function_error, custom_typechecks, is_concrete, is_constant_dim, is_constant_shape, literalable_types, no_axis_name, pytype_aval_mappings, and trace_ctx.
  • Changes:

    • The minimum supported SciPy version is now 1.14.
    • vma parameter of jax.ShapeDtypeStruct has been replaced with manual_axis_type: jax.sharding.ManualAxisType. The .vma property has been replaced with .manual_axis_type.varying.
    • Removed experimental jax.experimental.custom_dce.custom_dce
    • jax.scipy.linalg.cho_solve, jax.scipy.linalg.lu_solve, and jax.scipy.linalg.solve_triangular now show a deprecation warning for batched 1D solves with b.ndim > 1. In the future these will be treated as batched 2D solves.
    • Added a new version 10 for the jax.export serialization format. This is an optimization for when there are multiple occurrences of the same abstract value, abstract mesh, or sharding.
  • Bug fixes:

    • Fixed a bug that led to differing output between CPU and GPU for non-symmetric multidimensional IRFFTs (#29325).
    • Fixed an error when tiny matrices were passed to jax.lax.linalg.tridiagonal_solve on GPU (#32487).
    • Fixed a bug in jax.scipy.fft.dctn and idctn where axes=None incorrectly defaulted to all axes when s was specified, instead of the last len(s) axes to match SciPy behavior (#29426).
    • Fixed a bug where calling jax.distributed.initialize() on a GCE TPU Managed Instance Group raised an IndexError (#36593). When jax.distributed.initialize() is called on a GCE VM, it uses the GCE metadata server to learn the addresses of all participating tasks. The format of this metadata on Managed Instance Groups was not a format JAX expected, leading to the exception. We now parse this format correctly.