Jax: jax-v0.4.1 Release

Release date:
December 13, 2022
Previous version:
jax-v0.4.1-rc (released December 13, 2022)
Magnitude:
0 Diff Delta
Contributors:
0 total committers
Data confidence:
Commits:

Top Contributors in jax-v0.4.1

Could not determine top contributors for this release.

Directory Browser for jax-v0.4.1

We haven't yet finished calculating and confirming the files and directories changed in this release. Please check back soon.

Release Notes Published

  • Changes
    • Support for Python 3.7 has been dropped, in accordance with JAX's {ref}version-support-policy.
    • We introduce jax.Array which is a unified array type that subsumes DeviceArray, ShardedDeviceArray, and GlobalDeviceArray types in JAX. The jax.Array type helps make parallelism a core feature of JAX, simplifies and unifies JAX internals, and allows us to unify jit and pjit. jax.Array has been enabled by default in JAX 0.4 and makes some breaking change to the pjit API. The jax.Array migration guide can help you migrate your codebase to jax.Array. You can also look at the Distributed arrays and automatic parallelization tutorial to understand the new concepts.
    • PartitionSpec and Mesh are now out of experimental. The new API endpoints are jax.sharding.PartitionSpec and jax.sharding.Mesh. jax.experimental.maps.Mesh and jax.experimental.PartitionSpec are deprecated and will be removed in 3 months.
    • with_sharding_constraints new public endpoint is jax.lax.with_sharding_constraint.
    • If using ABSL flags together with jax.config, the ABSL flag values are no longer read or written after the JAX configuration options are initially populated from the ABSL flags. This change improves performance of reading jax.config options, which are used pervasively in JAX.
    • The jax2tf.call_tf function now uses for TF lowering the first TF device of the same platform as used by the embedding JAX computation. Before, it was using the 0th device for the JAX-default backend.
    • A number of jax.numpy functions now have their arguments marked as positional-only, matching NumPy.
    • jnp.msort is now deprecated, following the deprecation of np.msort in numpy 1.24. It will be removed in a future release, in accordance with the {ref}api-compatibility policy. It can be replaced with jnp.sort(a, axis=0).