Jax: jax-v0.4.13 Release

Release date:
June 22, 2023
Previous version:
jax-v0.4.13-rc (released June 22, 2023)
Magnitude:
0 Diff Delta
Contributors:
0 total committers
Data confidence:
Commits:

Top Contributors in jax-v0.4.13

Could not determine top contributors for this release.

Directory Browser for jax-v0.4.13

We haven't yet finished calculating and confirming the files and directories changed in this release. Please check back soon.

Release Notes Published

NOTE: This is the last JAX release that will include Python 3.8 support

  • Changes

    • jax.jit now allows None to be passed to in_shardings and out_shardings. The semantics are as follows:
      • For in_shardings, JAX will mark is as replicated but this behavior can change in the future.
      • For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.
    • jax.experimental.pjit.pjit also allows None to be passed to in_shardings and out_shardings. The semantics are as follows:
    • If the mesh context manager is not provided, JAX has the freedom to choose whatever sharding it wants.
      • For in_shardings, JAX will mark is as replicated but this behavior can change in the future.
      • For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.
    • If the mesh context manager is provided, None will imply that the value will be replicated on all devices of the mesh.
    • Executable.cost_analysis() works on Cloud TPU
    • Added a warning if a non-allowlisted jaxlib plugin is in use.
    • Added jax.tree_util.tree_leaves_with_path.
  • Bug fixes

    • Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel is named cudnn89 instead of cudnn88.
  • Deprecations

    • The native_serialization_strict_checks parameter to {func}jax.experimental.jax2tf.convert is deprecated in favor of the new native_serializaation_disabled_checks ({jax-issue}#16347).