Skip to content

Commit a00b5d2

Browse files
committed
Improve typing of OutputIterator
1 parent 050798e commit a00b5d2

File tree

2 files changed

+255
-28
lines changed

2 files changed

+255
-28
lines changed

replicate/use.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,16 @@
1313
Any,
1414
AsyncIterator,
1515
Callable,
16+
Generator,
1617
Generic,
1718
Iterator,
19+
List,
1820
Literal,
1921
Optional,
2022
ParamSpec,
2123
Protocol,
2224
Tuple,
2325
TypeVar,
24-
Union,
2526
cast,
2627
overload,
2728
)
@@ -210,38 +211,38 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any:
210211
return output
211212

212213

213-
class OutputIterator:
214+
class OutputIterator[T]:
214215
"""
215216
An iterator wrapper that handles both regular iteration and string conversion.
216217
Supports both sync and async iteration patterns.
217218
"""
218219

219220
def __init__(
220-
self,
221-
iterator_factory: Callable[[], Iterator[Any]],
222-
async_iterator_factory: Callable[[], AsyncIterator[Any]],
221+
self,
222+
iterator_factory: Callable[[], Iterator[T]],
223+
async_iterator_factory: Callable[[], AsyncIterator[T]],
223224
schema: dict,
224-
*,
225-
is_concatenate: bool
225+
*,
226+
is_concatenate: bool,
226227
) -> None:
227228
self.iterator_factory = iterator_factory
228229
self.async_iterator_factory = async_iterator_factory
229230
self.schema = schema
230231
self.is_concatenate = is_concatenate
231232

232-
def __iter__(self) -> Iterator[Any]:
233+
def __iter__(self) -> Iterator[T]:
233234
"""Iterate over output items synchronously."""
234235
for chunk in self.iterator_factory():
235236
if self.is_concatenate:
236-
yield str(chunk)
237+
yield chunk
237238
else:
238239
yield _process_iterator_item(chunk, self.schema)
239240

240-
async def __aiter__(self) -> AsyncIterator[Any]:
241+
async def __aiter__(self) -> AsyncIterator[T]:
241242
"""Iterate over output items asynchronously."""
242243
async for chunk in self.async_iterator_factory():
243244
if self.is_concatenate:
244-
yield str(chunk)
245+
yield chunk
245246
else:
246247
yield _process_iterator_item(chunk, self.schema)
247248

@@ -252,9 +253,10 @@ def __str__(self) -> str:
252253
else:
253254
return str(list(self.iterator_factory()))
254255

255-
def __await__(self):
256+
def __await__(self) -> Generator[Any, None, List[T] | str]:
256257
"""Make OutputIterator awaitable, returning appropriate result based on concatenate mode."""
257-
async def _collect_result():
258+
259+
async def _collect_result() -> List[T] | str:
258260
if self.is_concatenate:
259261
# For concatenate iterators, return the joined string
260262
segments = []
@@ -267,6 +269,7 @@ async def _collect_result():
267269
async for item in self:
268270
items.append(item)
269271
return items
272+
270273
return _collect_result().__await__()
271274

272275

@@ -341,14 +344,10 @@ class Run[O]:
341344

342345
def output(self) -> O:
343346
"""
344-
Wait for the prediction to complete and return its output.
347+
Return the output. For iterator types, returns immediately without waiting.
348+
For non-iterator types, waits for completion.
345349
"""
346-
self.prediction.wait()
347-
348-
if self.prediction.status == "failed":
349-
raise ModelError(self.prediction)
350-
351-
# Return an OutputIterator for iterator output types (including concatenate iterators)
350+
# Return an OutputIterator immediately for iterator output types
352351
if _has_iterator_output_type(self.schema):
353352
is_concatenate = _has_concatenate_iterator_output_type(self.schema)
354353
return cast(
@@ -361,6 +360,12 @@ def output(self) -> O:
361360
),
362361
)
363362

363+
# For non-iterator types, wait for completion and process output
364+
self.prediction.wait()
365+
366+
if self.prediction.status == "failed":
367+
raise ModelError(self.prediction)
368+
364369
# Process output for file downloads based on schema
365370
return _process_output_with_schema(self.prediction.output, self.schema)
366371

@@ -483,14 +488,10 @@ class AsyncRun[O]:
483488

484489
async def output(self) -> O:
485490
"""
486-
Wait for the prediction to complete and return its output asynchronously.
491+
Return the output. For iterator types, returns immediately without waiting.
492+
For non-iterator types, waits for completion.
487493
"""
488-
await self.prediction.async_wait()
489-
490-
if self.prediction.status == "failed":
491-
raise ModelError(self.prediction)
492-
493-
# Return an OutputIterator for iterator output types (including concatenate iterators)
494+
# Return an OutputIterator immediately for iterator output types
494495
if _has_iterator_output_type(self.schema):
495496
is_concatenate = _has_concatenate_iterator_output_type(self.schema)
496497
return cast(
@@ -503,6 +504,12 @@ async def output(self) -> O:
503504
),
504505
)
505506

507+
# For non-iterator types, wait for completion and process output
508+
await self.prediction.async_wait()
509+
510+
if self.prediction.status == "failed":
511+
raise ModelError(self.prediction)
512+
506513
# Process output for file downloads based on schema
507514
return _process_output_with_schema(self.prediction.output, self.schema)
508515

tests/test_use.py

Lines changed: 221 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ async def test_use_function_create_method(client_mode):
345345
run = hotdog_detector.create(prompt="hello world")
346346

347347
# Assert that run is a Run object with a prediction
348-
from replicate.use import Run, AsyncRun
348+
from replicate.use import AsyncRun, Run
349349

350350
if client_mode == ClientMode.ASYNC:
351351
assert isinstance(run, AsyncRun)
@@ -621,6 +621,226 @@ async def async_iterator():
621621
assert str(result) == "['Hello', ' ', 'World']" # str() gives list representation
622622

623623

624+
@pytest.mark.asyncio
625+
@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC])
626+
@respx.mock
627+
async def test_iterator_output_returns_immediately(client_mode):
628+
"""Test that OutputIterator is returned immediately without waiting for completion."""
629+
mock_model_endpoints(
630+
versions=[
631+
create_mock_version(
632+
{
633+
"openapi_schema": {
634+
"components": {
635+
"schemas": {
636+
"Output": {
637+
"type": "array",
638+
"items": {"type": "string"},
639+
"x-cog-array-type": "iterator",
640+
"x-cog-array-display": "concatenate",
641+
}
642+
}
643+
}
644+
}
645+
}
646+
)
647+
]
648+
)
649+
650+
# Mock prediction that starts as processing (not completed)
651+
mock_prediction_endpoints(
652+
predictions=[
653+
create_mock_prediction({"status": "processing", "output": []}),
654+
create_mock_prediction({"status": "processing", "output": ["Hello"]}),
655+
create_mock_prediction(
656+
{"status": "succeeded", "output": ["Hello", " ", "World"]}
657+
),
658+
]
659+
)
660+
661+
# Call use with "acme/hotdog-detector"
662+
hotdog_detector = replicate.use(
663+
"acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC
664+
)
665+
666+
# Get the output iterator - this should return immediately even though prediction is processing
667+
if client_mode == ClientMode.ASYNC:
668+
run = await hotdog_detector.create(prompt="hello world")
669+
output_iterator = await run.output()
670+
else:
671+
run = hotdog_detector.create(prompt="hello world")
672+
output_iterator = run.output()
673+
674+
# Assert that we get an OutputIterator immediately (without waiting for completion)
675+
from replicate.use import OutputIterator
676+
677+
assert isinstance(output_iterator, OutputIterator)
678+
679+
# Verify the prediction is still processing when we get the iterator
680+
assert run.prediction.status == "processing"
681+
682+
683+
@pytest.mark.asyncio
684+
@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC])
685+
@respx.mock
686+
async def test_streaming_output_yields_incrementally(client_mode):
687+
"""Test that OutputIterator yields results incrementally during polling."""
688+
mock_model_endpoints(
689+
versions=[
690+
create_mock_version(
691+
{
692+
"openapi_schema": {
693+
"components": {
694+
"schemas": {
695+
"Output": {
696+
"type": "array",
697+
"items": {"type": "string"},
698+
"x-cog-array-type": "iterator",
699+
"x-cog-array-display": "concatenate",
700+
}
701+
}
702+
}
703+
}
704+
}
705+
)
706+
]
707+
)
708+
709+
# Create a prediction that will be polled multiple times
710+
prediction_id = "pred123"
711+
712+
# Mock the initial prediction creation
713+
initial_prediction = create_mock_prediction(
714+
{"id": prediction_id, "status": "processing", "output": []},
715+
prediction_id=prediction_id,
716+
)
717+
718+
if client_mode == ClientMode.ASYNC:
719+
respx.post("https://api.replicate.com/v1/predictions").mock(
720+
return_value=httpx.Response(201, json=initial_prediction)
721+
)
722+
else:
723+
respx.post("https://api.replicate.com/v1/predictions").mock(
724+
return_value=httpx.Response(201, json=initial_prediction)
725+
)
726+
727+
# Mock incremental polling responses - each poll returns more data
728+
poll_responses = [
729+
create_mock_prediction(
730+
{"status": "processing", "output": ["Hello"]}, prediction_id=prediction_id
731+
),
732+
create_mock_prediction(
733+
{"status": "processing", "output": ["Hello", " "]},
734+
prediction_id=prediction_id,
735+
),
736+
create_mock_prediction(
737+
{"status": "processing", "output": ["Hello", " ", "streaming"]},
738+
prediction_id=prediction_id,
739+
),
740+
create_mock_prediction(
741+
{"status": "processing", "output": ["Hello", " ", "streaming", " "]},
742+
prediction_id=prediction_id,
743+
),
744+
create_mock_prediction(
745+
{
746+
"status": "succeeded",
747+
"output": ["Hello", " ", "streaming", " ", "world!"],
748+
},
749+
prediction_id=prediction_id,
750+
),
751+
]
752+
753+
# Mock the polling endpoint to return different responses in sequence
754+
respx.get(f"https://api.replicate.com/v1/predictions/{prediction_id}").mock(
755+
side_effect=[httpx.Response(200, json=resp) for resp in poll_responses]
756+
)
757+
758+
# Call use with "acme/hotdog-detector"
759+
hotdog_detector = replicate.use(
760+
"acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC
761+
)
762+
763+
# Get the output iterator immediately
764+
if client_mode == ClientMode.ASYNC:
765+
run = await hotdog_detector.create(prompt="hello world", use_async=True)
766+
output_iterator = await run.output()
767+
else:
768+
run = hotdog_detector.create(prompt="hello world")
769+
output_iterator = run.output()
770+
771+
# Assert that we get an OutputIterator immediately
772+
from replicate.use import OutputIterator
773+
774+
assert isinstance(output_iterator, OutputIterator)
775+
776+
# Track when we receive each item to verify incremental delivery
777+
collected_items = []
778+
779+
if client_mode == ClientMode.ASYNC:
780+
async for item in output_iterator:
781+
collected_items.append(item)
782+
# Break after we get some incremental results to verify polling works
783+
if len(collected_items) >= 3:
784+
break
785+
else:
786+
for item in output_iterator:
787+
collected_items.append(item)
788+
# Break after we get some incremental results to verify polling works
789+
if len(collected_items) >= 3:
790+
break
791+
792+
# Verify we got incremental streaming results
793+
assert len(collected_items) >= 3
794+
# The items should be the concatenated string parts from the incremental output
795+
result = "".join(collected_items)
796+
assert "Hello" in result # Should contain the first part we streamed
797+
798+
799+
@pytest.mark.asyncio
800+
@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC])
801+
@respx.mock
802+
async def test_non_streaming_output_waits_for_completion(client_mode):
803+
"""Test that non-iterator outputs still wait for completion."""
804+
mock_model_endpoints(
805+
versions=[
806+
create_mock_version(
807+
{
808+
"openapi_schema": {
809+
"components": {
810+
"schemas": {
811+
"Output": {"type": "string"} # Non-iterator output
812+
}
813+
}
814+
}
815+
}
816+
)
817+
]
818+
)
819+
820+
mock_prediction_endpoints(
821+
predictions=[
822+
create_mock_prediction({"status": "processing", "output": None}),
823+
create_mock_prediction({"status": "succeeded", "output": "Final result"}),
824+
]
825+
)
826+
827+
# Call use with "acme/hotdog-detector"
828+
hotdog_detector = replicate.use(
829+
"acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC
830+
)
831+
832+
# For non-iterator output, this should wait for completion
833+
if client_mode == ClientMode.ASYNC:
834+
run = await hotdog_detector.create(prompt="hello world")
835+
output = await run.output()
836+
else:
837+
run = hotdog_detector.create(prompt="hello world")
838+
output = run.output()
839+
840+
# Should get the final result directly
841+
assert output == "Final result"
842+
843+
624844
@pytest.mark.asyncio
625845
@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC])
626846
@respx.mock

0 commit comments

Comments
 (0)