Jax: jax-v0.7.0 Release

Release date:
July 22, 2025
Previous version:
jax-v0.6.2 (released June 17, 2025)
Magnitude:
58,426 Diff Delta
Contributors:
69 total committers
Data confidence:
Commits:

625 Commits in this Release

Ordered by the degree to which they evolved the repo in this version.

Authored July 11, 2025
Authored June 18, 2025
Authored June 25, 2025

Top Contributors in jax-v0.7.0

apaszke
bchetioui
yashk2810
jakevdp
mattjj
justinjfu
hawkinsp
a-googler
superbobry
dimitar-asenov

Directory Browser for jax-v0.7.0

We haven't yet finished calculating and confirming the files and directories changed in this release. Please check back soon.

Release Notes Published

  • New features:

    • Added jax.P which is an alias for jax.sharding.PartitionSpec.
    • Added jax.tree.reduce_associative.
  • Breaking changes:

    • JAX is migrating from GSPMD to Shardy by default. See the migration guide for more information.
    • JAX autodiff is switching to using direct linearization by default (instead of implementing linearization via JVP and partial eval). See migration guide for more information.
    • jax.stages.OutInfo has been replaced with jax.ShapeDtypeStruct.
    • jax.jit now requires fun to be passed by position, and additional arguments to be passed by keyword. Doing otherwise will result in an error starting in v0.7.x. This raised a DeprecationWarning in v0.6.x.
    • The minimum Python version is now 3.11. 3.11 will remain the minimum supported version until July 2026.
    • Layout API renames:
    • Layout, .layout, .input_layouts and .output_layouts have been renamed to Format, .format, .input_formats and .output_formats
    • DeviceLocalLayout, .device_local_layout have been renamed to Layout and .layout
    • jax.experimental.shard module has been deleted and all the APIs have been moved to the jax.sharding endpoint. So use jax.sharding.reshard, jax.sharding.auto_axes and jax.sharding.explicit_axes instead of their experimental endpoints.
    • lax.infeed and lax.outfeed were removed, after being deprecated in JAX 0.6. The transfer_to_infeed and transfer_from_outfeed methods were also removed the Device objects.
    • The jax.extend.core.primitives.pjit_p primitive has been renamed to jit_p, and its name attribute has changed from "pjit" to "jit". This affects the string representations of jaxprs. The same primitive is no longer exported from the jax.experimental.pjit module.
    • The (undocumented) function jax.extend.backend.add_clear_backends_callback has been removed. Users should use jax.extend.backend.register_backend_cache instead.
  • Deprecations:

    • {obj}jax.dlpack.SUPPORTED_DTYPES is deprecated; please use the new jax.dlpack.is_supported_dtype function.
    • jax.scipy.special.sph_harm has been deprecated following a similar deprecation in SciPy; use jax.scipy.special.sph_harm_y instead.
    • From {mod}jax.interpreters.xla, the previously deprecated symbols abstractify and pytype_aval_mappings have been removed.
    • jax.interpreters.xla.canonicalize_dtype is deprecated. For canonicalizing dtypes, prefer jax.dtypes.canonicalize_dtype. For checking whether an object is a valid jax input, prefer jax.core.valid_jaxtype.
    • From {mod}jax.core, the previously deprecated symbols AxisName, ConcretizationTypeError, axis_frame, call_p, closed_call_p, get_type, trace_state_clean, typematch, and typecheck have been removed.
    • From {mod}jax.lib.xla_client, the previously deprecated symbols DeviceAssignment, get_topology_for_devices, and mlir_api_version have been removed.
    • jax.extend.ffi was removed after being deprecated in v0.5.0. Use {mod}jax.ffi instead.
    • jax.lib.xla_bridge.get_compile_options is deprecated, and replaced by jax.extend.backend.get_compile_options.