Jax: jax-v0.2.27 Release

Release date:
January 18, 2022
Previous version:
jax-v0.2.26 (released December 8, 2021)
Magnitude:
11,661 Diff Delta
Contributors:
23 total committers
Data confidence:
Commits:

108 Commits in this Release

Ordered by the degree to which they evolved the repo in this version.

Authored December 16, 2021
Authored December 16, 2021
Authored January 18, 2022
Authored January 11, 2022
Authored December 31, 2021
Authored January 8, 2022
Authored January 9, 2022
Authored January 4, 2022
Authored December 20, 2021
Authored December 31, 2021
Authored December 3, 2021

Top Contributors in jax-v0.2.27

hawkinsp
mattjj
jakevdp
froystig
yashk2810
gnecula
LenaMartens
tomhennigan
a-googler
ukoxyz

Directory Browser for jax-v0.2.27

We haven't yet finished calculating and confirming the files and directories changed in this release. Please check back soon.

Release Notes Published

  • GitHub commits.

  • Breaking changes:

    • Support for NumPy 1.18 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
    • The host_callback primitives have been simplified to drop the special autodiff handling for hcb.id_tap and id_print. From now on, only the primals are tapped. The old behavior can be obtained (for a limited time) by setting the JAX_HOST_CALLBACK_AD_TRANSFORMS environment variable, or the --flax_host_callback_ad_transforms flag. Additionally, added documentation for how to implement the old behavior using JAX custom AD APIs ({jax-issue}#8678).
    • Sorting now matches the behavior of NumPy for 0.0 and NaN regardless of the bit representation. In particular, 0.0 and -0.0 are now treated as equivalent, where previously -0.0 was treated as less than 0.0. Additionally all NaN representations are now treated as equivalent and sorted to the end of the array. Previously negative NaN values were sorted to the front of the array, and NaN values with different internal bit representations were not treated as equivalent, and were sorted according to those bit patterns ({jax- issue}#9178).
    • {func}jax.numpy.unique now treats NaN values in the same way as np.unique in NumPy versions 1.21 and newer: at most one NaN value will appear in the uniquified output ({jax-issue}9184).
  • Bug fixes:

    • host_callback now supports ad_checkpoint.checkpoint ({jax-issue}#8907).
  • New features:

    • add jax.block_until_ready ({jax-issue}`#8941)
    • Added a new debugging flag/environment variable JAX_DUMP_IR_TO=/path. If set, JAX dumps the MHLO/HLO IR it generates for each computation to a file under the given path.
    • Added jax.ensure_compile_time_eval to the public api ({jax-issue}#7987).
    • jax2tf now supports a flag jax2tf_associative_scan_reductions to change the lowering for associative reductions, e.g., jnp.cumsum, to behave like JAX on CPU and GPU (to use an associative scan). See the jax2tf README for more details ({jax-issue}#9189).