Skip to content

jax-ml/jax-triton

Folders and files

NameName
Last commit message
Last commit date
May 1, 2025
May 1, 2025
Mar 31, 2025
May 2, 2025
Mar 14, 2025
Oct 29, 2022
Sep 16, 2024
Aug 1, 2022
Aug 1, 2022
Mar 14, 2025
Jan 9, 2023
May 2, 2025

Repository files navigation

jax-triton

PyPI version

The jax-triton repository contains integrations between JAX and Triton.

Documentation can be found here.

This is not an officially supported Google product.

Quickstart

The main function of interest is jax_triton.triton_call for applying Triton functions to JAX arrays, including inside jax.jit-compiled functions. For example, we can define a kernel from the Triton tutorial:

import triton
import triton.language as tl


@triton.jit
def add_kernel(
    x_ptr,
    y_ptr,
    length,
    output_ptr,
    block_size: tl.constexpr,
):
  """Adds two vectors."""
  pid = tl.program_id(axis=0)
  block_start = pid * block_size
  offsets = block_start + tl.arange(0, block_size)
  mask = offsets < length
  x = tl.load(x_ptr + offsets, mask=mask)
  y = tl.load(y_ptr + offsets, mask=mask)
  output = x + y
  tl.store(output_ptr + offsets, output, mask=mask)

Then we can apply it to JAX arrays using jax_triton.triton_call:

import jax
import jax.numpy as jnp
import jax_triton as jt

def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
  out_shape = jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)
  block_size = 8
  return jt.triton_call(
      x,
      y,
      x.size,
      kernel=add_kernel,
      out_shape=out_shape,
      grid=(x.size // block_size,),
      block_size=block_size)

x_val = jnp.arange(8)
y_val = jnp.arange(8, 16)
print(add(x_val, y_val))
print(jax.jit(add)(x_val, y_val))

See the examples directory, especially fused_attention.py and the fused attention ipynb.

Installation

$ pip install jax-triton

Make sure you have a CUDA-compatible jax installed. For example you could run:

$ pip install "jax[cuda12]"

jax-triton currently requires building the latest version of triton from source.

Development

To develop jax-triton, you can clone the repo with:

$ git clone https://github.com/jax-ml/jax-triton.git

and do an editable install with:

$ cd jax-triton
$ pip install -e .

To run the jax-triton tests, you'll need pytest:

$ pip install pytest
$ pytest tests/