You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
bugs](https://github.com/jax-ml/jax/issues), and letting us know what you
54
54
think!
@@ -83,15 +83,15 @@ perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example gra
83
83
## Quickstart: Colab in the Cloud
84
84
Jump right in using a notebook in your browser, connected to a Google Cloud GPU.
85
85
Here are some starter notebooks:
86
-
-[The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/quickstart.html)
86
+
-[The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://docs.jax.dev/en/latest/quickstart.html)
87
87
-[Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)
88
88
89
89
**JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU
-[The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
94
-
-[Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)
93
+
-[The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)
94
+
-[Common gotchas and sharp edges](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html)
1. JAX transformations only work on [pure functions](https://en.wikipedia.org/wiki/Pure_function), which don't have side-effects and respect [referential transparency](https://en.wikipedia.org/wiki/Referential_transparency) (i.e. object identity testing with `is` isn't preserved). If you use a JAX transformation on an impure Python function, you might see an error like `Exception: Can't lift Traced...` or `Exception: Different traces at same level`.
348
348
1.[In-place mutating updates of
349
-
arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically.
349
+
arrays](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://docs.jax.dev/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically.
350
350
1.[Random numbers are
351
-
different](https://jax.readthedocs.io/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md).
351
+
different](https://docs.jax.dev/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md).
Copy file name to clipboardExpand all lines: cloud_tpu_colabs/Pmap_Cookbook.ipynb
+2-2
Original file line number
Diff line number
Diff line change
@@ -59,7 +59,7 @@
59
59
"id": "2e_06-OAJNyi"
60
60
},
61
61
"source": [
62
-
"A basic starting point is expressing parallel maps with [`pmap`](https://jax.readthedocs.io/en/latest/jax.html#jax.pmap):"
62
+
"A basic starting point is expressing parallel maps with [`pmap`](https://docs.jax.dev/en/latest/jax.html#jax.pmap):"
63
63
]
64
64
},
65
65
{
@@ -407,7 +407,7 @@
407
407
"source": [
408
408
"When writing nested `pmap` functions in the decorator style, axis names are resolved according to lexical scoping.\n",
409
409
"\n",
410
-
"Check [the JAX reference documentation](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) for a complete list of the parallel operators. More are being added!\n",
410
+
"Check [the JAX reference documentation](https://docs.jax.dev/en/latest/jax.lax.html#parallel-operators) for a complete list of the parallel operators. More are being added!\n",
411
411
"\n",
412
412
"Here's how to use `lax.ppermute` to implement a simple halo exchange for a [Rule 30](https://en.wikipedia.org/wiki/Rule_30) simulation:"
Copy file name to clipboardExpand all lines: docs/advanced-autodiff.md
+2-2
Original file line number
Diff line number
Diff line change
@@ -876,7 +876,7 @@ There are two ways to define differentiation rules in JAX:
876
876
1. Using {func}`jax.custom_jvp` and {func}`jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and
877
877
2. Defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.
878
878
879
-
This notebook is about #1. To read instead about #2, refer to the [notebook on adding primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html).
879
+
This notebook is about #1. To read instead about #2, refer to the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html).
880
880
881
881
882
882
### TL;DR: Custom JVPs with {func}`jax.custom_jvp`
#### Working with `list` / `tuple` / `dict` containers (and other pytrees)
1610
1610
1611
-
You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints.
1611
+
You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://docs.jax.dev/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints.
1612
1612
1613
1613
Here's a contrived example with {func}`jax.custom_jvp`:
0 commit comments