Skip to content

Add experimental replicate.use() function #438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
e4398b9
Add initial test for `use()` functionality
aron Jun 2, 2025
587ed82
Add unit tests for existing `use()` functionality
aron Jun 2, 2025
70e1da7
Fix warning for missing pytest setting
aron Jun 2, 2025
607150f
Remove unused ignores
aron Jun 2, 2025
736604b
Refactor tests to be easier to work with
aron Jun 2, 2025
40c97a7
Support conversion of file outputs into Path in use()
aron Jun 2, 2025
196aef0
Add support for returning an iterator from use()
aron Jun 2, 2025
9a293e6
Fix bug in output_iterator when prediction is terminal
aron Jun 2, 2025
83d8fad
Update OutputIterator to use polling implementation
aron Jun 2, 2025
f017857
Ensure OutputIterator objects are converted into strings
aron Jun 2, 2025
35eb88b
Implement PathProxy as a way to defer download of file data
aron Jun 2, 2025
8d85629
Skip downloading files passed directly into other models in use()
aron Jun 2, 2025
20a37d1
Add get_url_path() helper to get underlying URL for a PathProxy object
aron Jun 2, 2025
bae5dc8
Export the `use` function
aron Jun 2, 2025
ae1589f
Document the `use()` functionality
aron Jun 2, 2025
82e40ce
Linting
aron Jun 2, 2025
bad0ce4
Rework the async test variant to give better test names
aron Jun 2, 2025
35e66dc
Fix typing of Function create()
aron Jun 2, 2025
639f234
Add support for typing use() function
aron Jun 3, 2025
65a89d3
Clean up tests
aron Jun 3, 2025
b79a5cd
Clean up prediction fixtures
aron Jun 3, 2025
80ce4e5
Remove redundant fixture data
aron Jun 3, 2025
bc5d7d8
Speed up test runs by using REPLICATE_POLL_INTERVAL
aron Jun 3, 2025
4111e82
Actually use PathProxy
aron Jun 3, 2025
e8acdb2
Use new URLPath instead of PathProxy
aron Jun 3, 2025
2df34ed
Silence warning when using cog.current_scope()
aron Jun 3, 2025
c982d53
Add asyncio support to `use()` function
aron Jun 4, 2025
050798e
Document asyncio mode for `use()`
aron Jun 4, 2025
a00b5d2
Improve typing of OutputIterator
aron Jun 4, 2025
83793a0
Correctly resolve OutputIterator when passed to `create()`
aron Jun 4, 2025
2afd364
URLPath.__str__() uses __fspath__()
aron Jun 4, 2025
dd64e91
Implement use(ref, streaming=True) to return iterators
aron Jun 4, 2025
cd12cf4
Correctly handle concatenated output when not streaming
aron Jun 4, 2025
1185b7b
Remove useless comments
aron Jun 4, 2025
f160fef
Implement `streaming` argument for `use()`
aron Jun 4, 2025
57bab3e
Fix lint errors
aron Jun 4, 2025
adb4fa7
Remove top-level restrictions
aron Jun 4, 2025
3b5200b
Clean up linting issues
aron Jun 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 216 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,222 @@ replicate = Client(
> Never hardcode authentication credentials like API tokens into your code.
> Instead, pass them as environment variables when running your program.

## Experimental `use()` interface

The latest versions of `replicate >= 1.0.8` include a new experimental `use()` function that is intended to make running a model closer to calling a function rather than an API request.

Some key differences to `replicate.run()`.

1. You "import" the model using the `use()` syntax, after that you call the model like a function.
2. The output type matches the model definition.
3. Baked in support for streaming for all models.
4. File outputs will be represented as `PathLike` objects and downloaded to disk when used*.

> [!NOTE]
> \* We've replaced the `FileOutput` implementation with `Path` objects. However to avoid unnecessary downloading of files until they are needed we've implemented a `PathProxy` class that will defer the download until the first time the object is used. If you need the underlying URL of the `Path` object you can use the `get_path_url(path: Path) -> str` helper.

### Examples

To use a model:

> [!IMPORTANT]
> For now `use()` MUST be called in the top level module scope. We may relax this in future.

```py
import replicate

flux_dev = replicate.use("black-forest-labs/flux-dev")
outputs = flux_dev(prompt="a cat wearing an amusing hat")

for output in outputs:
print(output) # Path(/tmp/output.webp)
```

Models that implement iterators will return the output of the completed run as a list unless explicitly streaming (see Streaming section below). Language models that define `x-cog-iterator-display: concatenate` will return strings:

```py
claude = replicate.use("anthropic/claude-4-sonnet")

output = claude(prompt="Give me a recipe for tasty smashed avocado on sourdough toast that could feed all of California.")

print(output) # "Here's a recipe to feed all of California (about 39 million people)! ..."
```

You can pass the results of one model directly into another:

```py
import replicate

flux_dev = replicate.use("black-forest-labs/flux-dev")
claude = replicate.use("anthropic/claude-4-sonnet")

images = flux_dev(prompt="a cat wearing an amusing hat")

result = claude(prompt="describe this image for me", image=images[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I pass images[0] here, what's happening under the hood? Is it using the existing replicate.delivery HTTPS URL of the output file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, file outputs are now a thin wrapper around Path (see PathProxy) that will defer the download of the remote file to disk until the first time they are used.

To get the underlying URL, there is a get_path_url() helper that will return the remote URL of the file.

This is then used internally so that the file is never downloaded in cases like this where an output is just passed as an input to another model.

Feedback on this would be greatly appreciated. This and the OutputIterator are the two most complex pieces of this PR.


print(str(result)) # "This shows an image of a cat wearing a hat ..."
```

To create an individual prediction that has not yet resolved, use the `create()` method:

```
claude = replicate.use("anthropic/claude-4-sonnet")

prediction = claude.create(prompt="Give me a recipe for tasty smashed avocado on sourdough toast that could feed all of California.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prediction = claude.create(...)

☝🏼 I like it!


prediction.logs() # get current logs (WIP)

prediction.output() # get the output
```

### Streaming

Many models, particularly large language models (LLMs), will yield partial results as the model is running. To consume outputs from these models as they run you can pass the `streaming` argument to `use()`:

```py
claude = replicate.use("anthropic/claude-4-sonnet", streaming=True)

output = claude(prompt="Give me a recipe for tasty smashed avocado on sourdough toast that could feed all of California.")

for chunk in output:
print(chunk) # "Here's a recipe ", "to feed all", " of California"
```

### Downloading file outputs

Output files are provided as Python [os.PathLike](https://docs.python.org/3.12/library/os.html#os.PathLike) objects. These are supported by most of the Python standard library like `open()` and `Path`, as well as third-party libraries like `pillow` and `ffmpeg-python`.

The first time the file is accessed it will be downloaded to a temporary directory on disk ready for use.

Here's an example of how to use the `pillow` package to convert file outputs:

```py
import replicate
from PIL import Image

flux_dev = replicate.use("black-forest-labs/flux-dev")

images = flux_dev(prompt="a cat wearing an amusing hat")
for i, path in enumerate(images):
with Image.open(path) as img:
img.save(f"./output_{i}.png", format="PNG")
```

For libraries that do not support `Path` or `PathLike` instances you can use `open()` as you would with any other file. For example to use `requests` to upload the file to a different location:

```py
import replicate
import requests

flux_dev = replicate.use("black-forest-labs/flux-dev")

images = flux_dev(prompt="a cat wearing an amusing hat")
for path in images:
with open(path, "rb") as f:
r = requests.post("https://api.example.com/upload", files={"file": f})
```

### Accessing outputs as HTTPS URLs

If you do not need to download the output to disk. You can access the underlying URL for a Path object returned from a model call by using the `get_path_url()` helper.

```py
import replicate
from replicate import get_url_path

flux_dev = replicate.use("black-forest-labs/flux-dev")
outputs = flux_dev(prompt="a cat wearing an amusing hat")

for output in outputs:
print(get_url_path(output)) # "https://replicate.delivery/xyz"
```

### Async Mode

By default `use()` will return a function instance with a sync interface. You can pass `use_async=True` to have it return an `AsyncFunction` that provides an async interface.

```py
import asyncio
import replicate

async def main():
flux_dev = replicate.use("black-forest-labs/flux-dev", use_async=True)
outputs = await flux_dev(prompt="a cat wearing an amusing hat")

for output in outputs:
print(Path(output))

asyncio.run(main())
```

When used in streaming mode then an `AsyncIterator` will be returned.

```py
import asyncio
import replicate

async def main():
claude = replicate.use("anthropic/claude-3.5-haiku", streaming=True, use_async=True)
output = await claude(prompt="say hello")

# Stream the response as it comes in.
async for token in output:
print(token)

# Wait until model has completed. This will return either a `list` or a `str` depending
# on whether the model uses AsyncIterator or ConcatenateAsyncIterator. You can check this
# on the model schema by looking for `x-cog-display: concatenate`.
print(await output)

asyncio.run(main())
```

### Typing

By default `use()` knows nothing about the interface of the model. To provide a better developer experience we provide two methods to add type annotations to the function returned by the `use()` helper.

**1. Provide a function signature**

The use method accepts a function signature as an additional `hint` keyword argument. When provided it will use this signature for the `model()` and `model.create()` functions.

```py
# Flux takes a required prompt string and optional image and seed.
def hint(*, prompt: str, image: Path | None = None, seed: int | None = None) -> str: ...

flux_dev = use("black-forest-labs/flux-dev", hint=hint)
output1 = flux_dev() # will warn that `prompt` is missing
output2 = flux_dev(prompt="str") # output2 will be typed as `str`
```

**2. Provide a class**

The second method requires creating a callable class with a `name` field. The name will be used as the function reference when passed to `use()`.

```py
class FluxDev:
name = "black-forest-labs/flux-dev"

def __call__( self, *, prompt: str, image: Path | None = None, seed: int | None = None ) -> str: ...

flux_dev = use(FluxDev)
output1 = flux_dev() # will warn that `prompt` is missing
output2 = flux_dev(prompt="str") # output2 will be typed as `str`
```

> [!WARNING]
> Currently the typing system doesn't correctly support the `streaming` flag for models that return lists or use iterators. We're working on improvements here.

In future we hope to provide tooling to generate and provide these models as packages to make working with them easier. For now you may wish to create your own.

### TODO

There are several key things still outstanding:

1. Support for streaming text when available (rather than polling)
2. Support for streaming files when available (rather than polling)
3. Support for cleaning up downloaded files.
4. Support for streaming logs using `OutputIterator`.

## Development

See [CONTRIBUTING.md](CONTRIBUTING.md)
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dev-dependencies = [

[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
testpaths = "tests/"

[tool.setuptools]
Expand Down Expand Up @@ -73,8 +74,6 @@ ignore = [
"ANN001", # Missing type annotation for function argument
"ANN002", # Missing type annotation for `*args`
"ANN003", # Missing type annotation for `**kwargs`
"ANN101", # Missing type annotation for self in method
"ANN102", # Missing type annotation for cls in classmethod
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed in {name}
"W191", # Indentation contains tabs
"UP037", # Remove quotes from type annotation
Expand Down
22 changes: 22 additions & 0 deletions replicate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,28 @@
from replicate.client import Client
from replicate.pagination import async_paginate as _async_paginate
from replicate.pagination import paginate as _paginate
from replicate.use import get_path_url, use

__all__ = [
"Client",
"use",
"run",
"async_run",
"stream",
"async_stream",
"paginate",
"async_paginate",
"collections",
"deployments",
"files",
"hardware",
"models",
"predictions",
"trainings",
"webhooks",
"default_client",
"get_path_url",
]

default_client = Client()

Expand Down
5 changes: 5 additions & 0 deletions replicate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,11 @@ def _get_api_token_from_environment() -> Optional[str]:
"""Get API token from cog current scope if available, otherwise from environment."""
try:
import cog # noqa: I001 # pyright: ignore [reportMissingImports]
import warnings

warnings.filterwarnings(
"ignore", message="current_scope", category=cog.ExperimentalFeatureWarning
)

for key, value in cog.current_scope().context.items():
if key.upper() == "REPLICATE_API_TOKEN":
Expand Down
11 changes: 11 additions & 0 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,11 @@ def output_iterator(self) -> Iterator[Any]:
"""
Return an iterator of the prediction output.
"""
if (
self.status in ["succeeded", "failed", "canceled"]
and self.output is not None
):
yield from self.output

# TODO: check output is list
previous_output = self.output or []
Expand All @@ -270,6 +275,12 @@ async def async_output_iterator(self) -> AsyncIterator[Any]:
"""
Return an asynchronous iterator of the prediction output.
"""
if (
self.status in ["succeeded", "failed", "canceled"]
and self.output is not None
):
for item in self.output:
yield item

# TODO: check output is list
previous_output = self.output or []
Expand Down
4 changes: 2 additions & 2 deletions replicate/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ def version_has_no_array_type(cog_version: str) -> Optional[bool]:

def make_schema_backwards_compatible(
schema: dict,
cog_version: str,
cog_version: str | None,
) -> dict:
"""A place to add backwards compatibility logic for our openapi schema"""

# If the top-level output is an array, assume it is an iterator in old versions which didn't have an array type
if version_has_no_array_type(cog_version):
if cog_version and version_has_no_array_type(cog_version):
output = schema["components"]["schemas"]["Output"]
if output.get("type") == "array":
output["x-cog-array-type"] = "iterator"
Expand Down
Loading