Jax: jax-v0.6.0 Release

Release date:
April 22, 2025
Previous version:
jax-v0.5.3 (released March 19, 2025)
Magnitude:
57,866 Diff Delta
Contributors:
81 total committers
Data confidence:
Commits:

675 Commits in this Release

Ordered by the degree to which they evolved the repo in this version.

Authored April 16, 2025
Authored March 25, 2025
Authored April 11, 2025
Authored April 11, 2025
Authored March 20, 2025
Authored April 13, 2025
Authored April 9, 2025
Authored February 7, 2025
Authored March 24, 2025
Authored April 10, 2025
Authored March 5, 2025
Authored March 19, 2025
Authored March 19, 2025
Authored March 21, 2025
Authored April 9, 2025

Top Contributors in jax-v0.6.0

yashk2810
hawkinsp
dfm
superbobry
apaszke
a-googler
jakevdp
mattjj
justinjfu
danielsuo

Directory Browser for jax-v0.6.0

All files are compared to previous version, jax-v0.5.3. Click here to browse diffs between other versions.

Loading File Browser...

Release Notes Published

  • Breaking changes

    • jax.numpy.array no longer accepts None. This behavior was deprecated since November 2023 and is now removed.
    • Removed the config.jax_data_dependent_tracing_fallback config option, which was added temporarily in v0.4.36 to allow users to opt out of the new "stackless" tracing machinery.
    • Removed the config.jax_eager_pmap config option.
    • Disallow the calling of lower and trace AOT APIs on the result of jax.jit if there have been subsequent wrappers applied. Previously this worked, but silently ignored the wrappers. The workaround is to apply jax.jit last among the wrappers, and similarly for jax.pmap. See #27873.
    • The cuda12_pip extra for jax has been removed; use pip install jax[cuda12] instead.
  • Changes

    • The minimum CuDNN version is v9.8.
    • JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain supported.
    • JAX package extras are now updated to use dash instead of underscore to align with PEP 685. For instance, if you were previously using pip install jax[cuda12_local] to install JAX, run pip install jax[cuda12-local] instead.
    • jax.jit now requires fun to be passed by position, and additional arguments to be passed by keyword. Doing otherwise will result in a DeprecationWarning in v0.6.X, and an error in starting in v0.7.X.
  • Deprecations

    • jax.tree_util.build_tree is deprecated. Use jax.tree.unflatten instead.
    • Implemented host callback handlers for CPU and GPU devices using XLA's FFI and removed existing CPU/GPU handlers using XLA's custom call.
    • All APIs in jax.lib.xla_extension are now deprecated.
    • jax.interpreters.mlir.hlo and jax.interpreters.mlir.func_dialect, which were accidental exports, have been removed. If needed, they are available from jax.extend.mlir.
    • jax.interpreters.mlir.custom_call is deprecated. The APIs provided by jax.ffi should be used instead.
    • The deprecated use of jax.ffi.ffi_call with inline arguments is no longer supported. jax.ffi.ffi_call now unconditionally returns a callable.
    • The following exports in jax.lib.xla_client are deprecated: get_topology_for_devices, heap_profile, mlir_api_version, Client, CompileOptions, DeviceAssignment, Frame, HloSharding, OpSharding, Traceback.
    • The following internal APIs in jax.util are deprecated: HashableFunction, as_hashable_function, cache, safe_map, safe_zip, split_dict, split_list, split_list_checked, split_merge, subvals, toposort, unzip2, wrap_name, and wraps.
    • jax.dlpack.to_dlpack has been deprecated. You can usually pass a JAX Array directly to the from_dlpack function of another framework. If you need the functionality of to_dlpack, use the __dlpack__ attribute of an array.
    • jax.lax.infeed, jax.lax.infeed_p, jax.lax.outfeed, and jax.lax.outfeed_p are deprecated and will be removed in JAX v0.7.0.
    • Several previously-deprecated APIs have been removed, including:
    • From jax.lib.xla_client: ArrayImpl, FftType, PaddingType, PrimitiveType, XlaBuilder, dtype_to_etype, ops, register_custom_call_target, shape_from_pyval, Shape, XlaComputation.
    • From jax.lib.xla_extension: ArrayImpl, XlaRuntimeError.
    • From jax: jax.treedef_is_leaf, jax.tree_flatten, jax.tree_map, jax.tree_leaves, jax.tree_structure, jax.tree_transpose, and jax.tree_unflatten. Replacements can be found in jax.tree or jax.tree_util.
    • From jax.core: AxisSize, ClosedJaxpr, EvalTrace, InDBIdx, InputType, Jaxpr, JaxprEqn, Literal, MapPrimitive, OpaqueTraceState, OutDBIdx, Primitive, Token, TRACER_LEAK_DEBUGGER_WARNING, Var, concrete_aval, dedup_referents, escaped_tracer_error, extend_axis_env_nd, full_lower, get_referent, jaxpr_as_fun, join_effects, lattice_join, leaked_tracer_error, maybe_find_leaked_tracers, raise_to_shaped, raise_to_shaped_mappings, reset_trace_state, str_eqn_compact, substitute_vars_in_output_ty, typecompat, and used_axis_names_jaxpr. Most have no public replacement, though a few are available at jax.extend.core.
    • The vectorized argument to jax.pure_callback and jax.ffi.ffi_call. Use the vmap_method parameter instead.