Jax: jax-v0.4.34 Release

Release date:
October 2, 2024
Previous version:
jax-v0.4.33 (released September 16, 2024)
Magnitude:
26,451 Diff Delta
Contributors:
54 total committers
Data confidence:
Commits:

326 Commits in this Release

Ordered by the degree to which they evolved the repo in this version.

Authored October 2, 2024
Authored September 24, 2024
Authored September 30, 2024
Authored September 3, 2024
Authored September 30, 2024
Authored August 26, 2024
Authored September 25, 2024
Authored July 8, 2024
Authored September 24, 2024
Authored October 1, 2024

Top Contributors in jax-v0.4.34

hawkinsp
jakevdp
apaszke
dfm
bythew3i
yashk2810
superbobry
rajasekharporeddy
sharadmv
gnecula

Directory Browser for jax-v0.4.34

All files are compared to previous version, jax-v0.4.33. Click here to browse diffs between other versions.

Loading File Browser...

Release Notes Published

  • New Functionality

    • This release includes wheels for Python 3.13. Free-threading mode is not yet supported.
    • jax.errors.JaxRuntimeError has been added as a public alias for the formerly private XlaRuntimeError type.
  • Breaking changes

    • jax_pmap_no_rank_reduction flag is set to True by default.
    • array[0] on a pmap result now introduces a reshape (use array[0:1] instead).
    • The per-shard shape (accessable via jax_array.addressable_shards or jax_array.addressable_data(0)) now has a leading (1, ...). Update code that directly accesses shards accordingly. The rank of the per-shard-shape now matches that of the global shape which is the same behavior as jit. This avoids costly reshapes when passing results from pmap into jit.
    • jax.experimental.host_callback has been deprecated since March 2024, with JAX version 0.4.26. Now we set the default value of the --jax_host_callback_legacy configuration value to True, which means that if your code uses jax.experimental.host_callback APIs, those API calls will be implemented in terms of the new jax.experimental.io_callback API. If this breaks your code, for a very limited time, you can set the --jax_host_callback_legacy to True. Soon we will remove that configuration option, so you should instead transition to using the new JAX callback APIs. See #20385 for a discussion.
  • Deprecations

    • In jax.numpy.trim_zeros, non-arraylike arguments or arraylike arguments with ndim != 1 are now deprecated, and in the future will result in an error.
    • Internal pretty-printing tools jax.core.pp_* have been removed, after being deprecated in JAX v0.4.30.
    • jax.lib.xla_client.Device is deprecated; use jax.Device instead.
    • jax.lib.xla_client.XlaRuntimeError has been deprecated. Use jax.errors.JaxRuntimeError instead.
  • Deletion:

    • jax.xla_computation is deleted. It has been 3 months since its deprecation in 0.4.30 JAX release. Please use the AOT APIs to get the same functionality as jax.xla_computation.
    • jax.xla_computation(fn)(*args, **kwargs) can be replaced with jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo').
    • You can also use .out_info property of jax.stages.Lowered to get the output information (like tree structure, shape and dtype).
    • For cross-backend lowering, you can replace jax.xla_computation(fn, backend='tpu')(*args, **kwargs) with jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo').
    • jax.ShapeDtypeStruct no longer accepts the named_shape argument. The argument was only used by xmap which was removed in 0.4.31.
    • jax.tree.map(f, None, non-None), which previously emitted a DeprecationWarning, now raises an error. None is only a tree-prefix of itself. To preserve the current behavior, you can ask jax.tree.map to treat None as a leaf value by writing: jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None).
    • jax.sharding.XLACompatibleSharding has been removed. Please use jax.sharding.Sharding.
  • Bug fixes

    • Fixed a bug where jax.numpy.cumsum would produce incorrect outputs if a non-boolean input was provided and dtype=bool was specified.
    • Edit implementation of jax.numpy.ldexp to get correct gradient.