Skip to content

Commit e02faab

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Replace references to jax.readthedocs.io with docs.jax.dev.
PiperOrigin-RevId: 745156931
1 parent b8353d1 commit e02faab

File tree

112 files changed

+323
-323
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

112 files changed

+323
-323
lines changed

CHANGELOG.md

+34-34
Large diffs are not rendered by default.

CONTRIBUTING.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# Contributing to JAX
22

33
For information on how to contribute to JAX, see
4-
[Contributing to JAX](https://jax.readthedocs.io/en/latest/contributing.html)
4+
[Contributing to JAX](https://docs.jax.dev/en/latest/contributing.html)

README.md

+30-30
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
| [**Transformations**](#transformations)
1212
| [**Install guide**](#installation)
1313
| [**Neural net libraries**](#neural-network-libraries)
14-
| [**Change logs**](https://jax.readthedocs.io/en/latest/changelog.html)
15-
| [**Reference docs**](https://jax.readthedocs.io/en/latest/)
14+
| [**Change logs**](https://docs.jax.dev/en/latest/changelog.html)
15+
| [**Reference docs**](https://docs.jax.dev/en/latest/)
1616

1717

1818
## What is JAX?
@@ -48,7 +48,7 @@ are instances of such transformations. Others are
4848
parallel programming of multiple accelerators, with more to come.
4949

5050
This is a research project, not an official Google product. Expect
51-
[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
51+
[sharp edges](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html).
5252
Please help by trying it out, [reporting
5353
bugs](https://github.com/jax-ml/jax/issues), and letting us know what you
5454
think!
@@ -83,15 +83,15 @@ perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example gra
8383
## Quickstart: Colab in the Cloud
8484
Jump right in using a notebook in your browser, connected to a Google Cloud GPU.
8585
Here are some starter notebooks:
86-
- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/quickstart.html)
86+
- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://docs.jax.dev/en/latest/quickstart.html)
8787
- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)
8888

8989
**JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU
9090
Colabs](https://github.com/jax-ml/jax/tree/main/cloud_tpu_colabs).
9191

9292
For a deeper dive into JAX:
93-
- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
94-
- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)
93+
- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)
94+
- [Common gotchas and sharp edges](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html)
9595
- See the [full list of
9696
notebooks](https://github.com/jax-ml/jax/tree/main/docs/notebooks).
9797

@@ -105,7 +105,7 @@ Here are four transformations of primary interest: `grad`, `jit`, `vmap`, and
105105

106106
JAX has roughly the same API as [Autograd](https://github.com/hips/autograd).
107107
The most popular function is
108-
[`grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad)
108+
[`grad`](https://docs.jax.dev/en/latest/jax.html#jax.grad)
109109
for reverse-mode gradients:
110110

111111
```python
@@ -129,13 +129,13 @@ print(grad(grad(grad(tanh)))(1.0))
129129
```
130130

131131
For more advanced autodiff, you can use
132-
[`jax.vjp`](https://jax.readthedocs.io/en/latest/jax.html#jax.vjp) for
132+
[`jax.vjp`](https://docs.jax.dev/en/latest/jax.html#jax.vjp) for
133133
reverse-mode vector-Jacobian products and
134-
[`jax.jvp`](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) for
134+
[`jax.jvp`](https://docs.jax.dev/en/latest/jax.html#jax.jvp) for
135135
forward-mode Jacobian-vector products. The two can be composed arbitrarily with
136136
one another, and with other JAX transformations. Here's one way to compose those
137137
to make a function that efficiently computes [full Hessian
138-
matrices](https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html#jax.hessian):
138+
matrices](https://docs.jax.dev/en/latest/_autosummary/jax.hessian.html#jax.hessian):
139139

140140
```python
141141
from jax import jit, jacfwd, jacrev
@@ -160,15 +160,15 @@ print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated)
160160
```
161161

162162
See the [reference docs on automatic
163-
differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
163+
differentiation](https://docs.jax.dev/en/latest/jax.html#automatic-differentiation)
164164
and the [JAX Autodiff
165-
Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
165+
Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)
166166
for more.
167167

168168
### Compilation with `jit`
169169

170170
You can use XLA to compile your functions end-to-end with
171-
[`jit`](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit),
171+
[`jit`](https://docs.jax.dev/en/latest/jax.html#just-in-time-compilation-jit),
172172
used either as an `@jit` decorator or as a higher-order function.
173173

174174
```python
@@ -189,12 +189,12 @@ You can mix `jit` and `grad` and any other JAX transformation however you like.
189189

190190
Using `jit` puts constraints on the kind of Python control flow
191191
the function can use; see
192-
the tutorial on [Control Flow and Logical Operators with JIT](https://jax.readthedocs.io/en/latest/control-flow.html)
192+
the tutorial on [Control Flow and Logical Operators with JIT](https://docs.jax.dev/en/latest/control-flow.html)
193193
for more.
194194

195195
### Auto-vectorization with `vmap`
196196

197-
[`vmap`](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) is
197+
[`vmap`](https://docs.jax.dev/en/latest/jax.html#vectorization-vmap) is
198198
the vectorizing map.
199199
It has the familiar semantics of mapping a function along array axes, but
200200
instead of keeping the loop on the outside, it pushes the loop down into a
@@ -259,7 +259,7 @@ differentiation for fast Jacobian and Hessian matrix calculations in
259259
### SPMD programming with `pmap`
260260

261261
For parallel programming of multiple accelerators, like multiple GPUs, use
262-
[`pmap`](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap).
262+
[`pmap`](https://docs.jax.dev/en/latest/jax.html#parallelization-pmap).
263263
With `pmap` you write single-program multiple-data (SPMD) programs, including
264264
fast parallel collective communication operations. Applying `pmap` will mean
265265
that the function you write is compiled by XLA (similarly to `jit`), then
@@ -284,7 +284,7 @@ print(pmap(jnp.mean)(result))
284284
```
285285

286286
In addition to expressing pure maps, you can use fast [collective communication
287-
operations](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators)
287+
operations](https://docs.jax.dev/en/latest/jax.lax.html#parallel-operators)
288288
between devices:
289289

290290
```python
@@ -341,20 +341,20 @@ for more.
341341

342342
For a more thorough survey of current gotchas, with examples and explanations,
343343
we highly recommend reading the [Gotchas
344-
Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
344+
Notebook](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html).
345345
Some standouts:
346346

347347
1. JAX transformations only work on [pure functions](https://en.wikipedia.org/wiki/Pure_function), which don't have side-effects and respect [referential transparency](https://en.wikipedia.org/wiki/Referential_transparency) (i.e. object identity testing with `is` isn't preserved). If you use a JAX transformation on an impure Python function, you might see an error like `Exception: Can't lift Traced...` or `Exception: Different traces at same level`.
348348
1. [In-place mutating updates of
349-
arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically.
349+
arrays](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://docs.jax.dev/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically.
350350
1. [Random numbers are
351-
different](https://jax.readthedocs.io/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md).
351+
different](https://docs.jax.dev/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md).
352352
1. If you're looking for [convolution
353-
operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html),
353+
operators](https://docs.jax.dev/en/latest/notebooks/convolutions.html),
354354
they're in the `jax.lax` package.
355355
1. JAX enforces single-precision (32-bit, e.g. `float32`) values by default, and
356356
[to enable
357-
double-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision)
357+
double-precision](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision)
358358
(64-bit, e.g. `float64`) one needs to set the `jax_enable_x64` variable at
359359
startup (or set the environment variable `JAX_ENABLE_X64=True`).
360360
On TPU, JAX uses 32-bit values by default for everything _except_ internal
@@ -368,14 +368,14 @@ Some standouts:
368368
and NumPy types aren't preserved, namely `np.add(1, np.array([2],
369369
np.float32)).dtype` is `float64` rather than `float32`.
370370
1. Some transformations, like `jit`, [constrain how you can use Python control
371-
flow](https://jax.readthedocs.io/en/latest/control-flow.html).
371+
flow](https://docs.jax.dev/en/latest/control-flow.html).
372372
You'll always get loud errors if something goes wrong. You might have to use
373373
[`jit`'s `static_argnums`
374-
parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit),
374+
parameter](https://docs.jax.dev/en/latest/jax.html#just-in-time-compilation-jit),
375375
[structured control flow
376-
primitives](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators)
376+
primitives](https://docs.jax.dev/en/latest/jax.lax.html#control-flow-operators)
377377
like
378-
[`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan),
378+
[`lax.scan`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan),
379379
or just use `jit` on smaller subfunctions.
380380

381381
## Installation
@@ -403,7 +403,7 @@ Some standouts:
403403
| Mac GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |
404404
| Intel GPU | Follow [Intel's instructions](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md). |
405405

406-
See [the documentation](https://jax.readthedocs.io/en/latest/installation.html)
406+
See [the documentation](https://docs.jax.dev/en/latest/installation.html)
407407
for information on alternative installation strategies. These include compiling
408408
from source, installing with Docker, using other versions of CUDA, a
409409
community-supported conda build, and answers to some frequently-asked questions.
@@ -417,7 +417,7 @@ for training neural networks in JAX. If you want a fully featured library for ne
417417
training with examples and how-to guides, try
418418
[Flax](https://github.com/google/flax) and its [documentation site](https://flax.readthedocs.io/en/latest/nnx/index.html).
419419

420-
Check out the [JAX Ecosystem section](https://jax.readthedocs.io/en/latest/#ecosystem)
420+
Check out the [JAX Ecosystem section](https://docs.jax.dev/en/latest/#ecosystem)
421421
on the JAX documentation site for a list of JAX-based network libraries, which includes
422422
[Optax](https://github.com/deepmind/optax) for gradient processing and
423423
optimization, [chex](https://github.com/deepmind/chex) for reliable code and testing, and
@@ -452,7 +452,7 @@ paper.
452452
## Reference documentation
453453

454454
For details about the JAX API, see the
455-
[reference documentation](https://jax.readthedocs.io/).
455+
[reference documentation](https://docs.jax.dev/).
456456

457457
For getting started as a JAX developer, see the
458-
[developer documentation](https://jax.readthedocs.io/en/latest/developer.html).
458+
[developer documentation](https://docs.jax.dev/en/latest/developer.html).

cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@
225225
"* Jacobian pre-accumulation for elementwise operations (like `gelu`)\n",
226226
"\n",
227227
"\n",
228-
"For much more, see the [JAX Autodiff Cookbook (Part 1)](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)."
228+
"For much more, see the [JAX Autodiff Cookbook (Part 1)](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)."
229229
]
230230
},
231231
{

cloud_tpu_colabs/JAX_demo.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@
315315
"* Jacobian pre-accumulation for elementwise operations (like `gelu`)\n",
316316
"\n",
317317
"\n",
318-
"For much more, see the [JAX Autodiff Cookbook (Part 1)](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)."
318+
"For much more, see the [JAX Autodiff Cookbook (Part 1)](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)."
319319
]
320320
},
321321
{

cloud_tpu_colabs/Pmap_Cookbook.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
"id": "2e_06-OAJNyi"
6060
},
6161
"source": [
62-
"A basic starting point is expressing parallel maps with [`pmap`](https://jax.readthedocs.io/en/latest/jax.html#jax.pmap):"
62+
"A basic starting point is expressing parallel maps with [`pmap`](https://docs.jax.dev/en/latest/jax.html#jax.pmap):"
6363
]
6464
},
6565
{
@@ -407,7 +407,7 @@
407407
"source": [
408408
"When writing nested `pmap` functions in the decorator style, axis names are resolved according to lexical scoping.\n",
409409
"\n",
410-
"Check [the JAX reference documentation](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) for a complete list of the parallel operators. More are being added!\n",
410+
"Check [the JAX reference documentation](https://docs.jax.dev/en/latest/jax.lax.html#parallel-operators) for a complete list of the parallel operators. More are being added!\n",
411411
"\n",
412412
"Here's how to use `lax.ppermute` to implement a simple halo exchange for a [Rule 30](https://en.wikipedia.org/wiki/Rule_30) simulation:"
413413
]

cloud_tpu_colabs/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ The same JAX code that runs on CPU and GPU can also be run on TPU. Cloud TPUs
44
have the advantage of quickly giving you access to multiple TPU accelerators,
55
including in [Colab](https://research.google.com/colaboratory/). All of the
66
example notebooks here use
7-
[`jax.pmap`](https://jax.readthedocs.io/en/latest/jax.html#jax.pmap) to run JAX
7+
[`jax.pmap`](https://docs.jax.dev/en/latest/jax.html#jax.pmap) to run JAX
88
computation across multiple TPU cores from Colab. You can also run the same code
99
directly on a [Cloud TPU
1010
VM](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm).

docs/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
To rebuild the documentation,
2-
see [Update Documentation](https://jax.readthedocs.io/en/latest/developer.html#update-documentation).
2+
see [Update Documentation](https://docs.jax.dev/en/latest/developer.html#update-documentation).

docs/about.md

+8-8
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ technology stack](#components). First, we design the `jax` module
1919
to be
2020
[composable](https://github.com/jax-ml/jax?tab=readme-ov-file#transformations)
2121
and
22-
[extensible](https://jax.readthedocs.io/en/latest/jax.extend.html), so
22+
[extensible](https://docs.jax.dev/en/latest/jax.extend.html), so
2323
that a wide variety of domain-specific libraries can thrive outside of
2424
it in a decentralized manner. Second, we lean heavily on a modular
2525
backend stack (compiler and runtime) to target different
@@ -42,10 +42,10 @@ scale.
4242
JAX's day-to-day development takes place in the open on GitHub, using
4343
pull requests, the issue tracker, discussions, and [JAX Enhancement
4444
Proposals
45-
(JEPs)](https://jax.readthedocs.io/en/latest/jep/index.html). Reading
45+
(JEPs)](https://docs.jax.dev/en/latest/jep/index.html). Reading
4646
and participating in these is a good way to get involved. We also
4747
maintain [developer
48-
notes](https://jax.readthedocs.io/en/latest/contributor_guide.html)
48+
notes](https://docs.jax.dev/en/latest/contributor_guide.html)
4949
that cover JAX's internal design.
5050

5151
The JAX core team determines whether to accept changes and
@@ -56,7 +56,7 @@ intricate decision structure over time (e.g. with designated area
5656
owners) if/when it becomes useful to do so.
5757

5858
For more see [contributing to
59-
JAX](https://jax.readthedocs.io/en/latest/contributing.html).
59+
JAX](https://docs.jax.dev/en/latest/contributing.html).
6060

6161
(components)=
6262
## A modular stack
@@ -71,7 +71,7 @@ and (b) an advancing hardware landscape, we lean heavily on
7171
While the JAX core library focuses on the fundamentals, we want to
7272
encourage domain-specific libraries and tools to be built on top of
7373
JAX. Indeed, [many
74-
libraries](https://jax.readthedocs.io/en/latest/#ecosystem) have
74+
libraries](https://docs.jax.dev/en/latest/#ecosystem) have
7575
emerged around JAX to offer higher-level features and extensions.
7676

7777
How do we encourage such decentralized development? We guide it with
@@ -80,11 +80,11 @@ building blocks (e.g. numerical primitives, NumPy operations, arrays,
8080
and transformations), encouraging auxiliary libraries to develop
8181
utilities as needed for their domain. In addition, JAX exposes a
8282
handful of more advanced APIs for
83-
[customization](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html)
83+
[customization](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html)
8484
and
85-
[extensibility](https://jax.readthedocs.io/en/latest/jax.extend.html). Libraries
85+
[extensibility](https://docs.jax.dev/en/latest/jax.extend.html). Libraries
8686
can [lean on these
87-
APIs](https://jax.readthedocs.io/en/latest/building_on_jax.html) in
87+
APIs](https://docs.jax.dev/en/latest/building_on_jax.html) in
8888
order to use JAX as an internal means of implementation, to integrate
8989
more with its transformations like autodiff, and more.
9090

docs/advanced-autodiff.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,7 @@ There are two ways to define differentiation rules in JAX:
876876
1. Using {func}`jax.custom_jvp` and {func}`jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and
877877
2. Defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.
878878

879-
This notebook is about #1. To read instead about #2, refer to the [notebook on adding primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html).
879+
This notebook is about #1. To read instead about #2, refer to the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html).
880880

881881

882882
### TL;DR: Custom JVPs with {func}`jax.custom_jvp`
@@ -1608,7 +1608,7 @@ Array(-0.91113025, dtype=float32)
16081608

16091609
#### Working with `list` / `tuple` / `dict` containers (and other pytrees)
16101610

1611-
You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints.
1611+
You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://docs.jax.dev/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints.
16121612

16131613
Here's a contrived example with {func}`jax.custom_jvp`:
16141614

docs/aot.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ are arrays, JAX does the following in order:
2626
carries out this specialization by a process that we call
2727
_tracing_. During tracing, JAX stages the specialization of `F` to
2828
a jaxpr, which is a function in the [Jaxpr intermediate
29-
language](https://jax.readthedocs.io/en/latest/jaxpr.html).
29+
language](https://docs.jax.dev/en/latest/jaxpr.html).
3030

3131
2. **Lower** this specialized, staged-out computation to the XLA compiler's
3232
input language, StableHLO.

0 commit comments

Comments
 (0)