@@ -345,7 +345,7 @@ async def test_use_function_create_method(client_mode):
345
345
run = hotdog_detector .create (prompt = "hello world" )
346
346
347
347
# Assert that run is a Run object with a prediction
348
- from replicate .use import Run , AsyncRun
348
+ from replicate .use import AsyncRun , Run
349
349
350
350
if client_mode == ClientMode .ASYNC :
351
351
assert isinstance (run , AsyncRun )
@@ -621,6 +621,226 @@ async def async_iterator():
621
621
assert str (result ) == "['Hello', ' ', 'World']" # str() gives list representation
622
622
623
623
624
+ @pytest .mark .asyncio
625
+ @pytest .mark .parametrize ("client_mode" , [ClientMode .DEFAULT , ClientMode .ASYNC ])
626
+ @respx .mock
627
+ async def test_iterator_output_returns_immediately (client_mode ):
628
+ """Test that OutputIterator is returned immediately without waiting for completion."""
629
+ mock_model_endpoints (
630
+ versions = [
631
+ create_mock_version (
632
+ {
633
+ "openapi_schema" : {
634
+ "components" : {
635
+ "schemas" : {
636
+ "Output" : {
637
+ "type" : "array" ,
638
+ "items" : {"type" : "string" },
639
+ "x-cog-array-type" : "iterator" ,
640
+ "x-cog-array-display" : "concatenate" ,
641
+ }
642
+ }
643
+ }
644
+ }
645
+ }
646
+ )
647
+ ]
648
+ )
649
+
650
+ # Mock prediction that starts as processing (not completed)
651
+ mock_prediction_endpoints (
652
+ predictions = [
653
+ create_mock_prediction ({"status" : "processing" , "output" : []}),
654
+ create_mock_prediction ({"status" : "processing" , "output" : ["Hello" ]}),
655
+ create_mock_prediction (
656
+ {"status" : "succeeded" , "output" : ["Hello" , " " , "World" ]}
657
+ ),
658
+ ]
659
+ )
660
+
661
+ # Call use with "acme/hotdog-detector"
662
+ hotdog_detector = replicate .use (
663
+ "acme/hotdog-detector" , use_async = client_mode == ClientMode .ASYNC
664
+ )
665
+
666
+ # Get the output iterator - this should return immediately even though prediction is processing
667
+ if client_mode == ClientMode .ASYNC :
668
+ run = await hotdog_detector .create (prompt = "hello world" )
669
+ output_iterator = await run .output ()
670
+ else :
671
+ run = hotdog_detector .create (prompt = "hello world" )
672
+ output_iterator = run .output ()
673
+
674
+ # Assert that we get an OutputIterator immediately (without waiting for completion)
675
+ from replicate .use import OutputIterator
676
+
677
+ assert isinstance (output_iterator , OutputIterator )
678
+
679
+ # Verify the prediction is still processing when we get the iterator
680
+ assert run .prediction .status == "processing"
681
+
682
+
683
+ @pytest .mark .asyncio
684
+ @pytest .mark .parametrize ("client_mode" , [ClientMode .DEFAULT , ClientMode .ASYNC ])
685
+ @respx .mock
686
+ async def test_streaming_output_yields_incrementally (client_mode ):
687
+ """Test that OutputIterator yields results incrementally during polling."""
688
+ mock_model_endpoints (
689
+ versions = [
690
+ create_mock_version (
691
+ {
692
+ "openapi_schema" : {
693
+ "components" : {
694
+ "schemas" : {
695
+ "Output" : {
696
+ "type" : "array" ,
697
+ "items" : {"type" : "string" },
698
+ "x-cog-array-type" : "iterator" ,
699
+ "x-cog-array-display" : "concatenate" ,
700
+ }
701
+ }
702
+ }
703
+ }
704
+ }
705
+ )
706
+ ]
707
+ )
708
+
709
+ # Create a prediction that will be polled multiple times
710
+ prediction_id = "pred123"
711
+
712
+ # Mock the initial prediction creation
713
+ initial_prediction = create_mock_prediction (
714
+ {"id" : prediction_id , "status" : "processing" , "output" : []},
715
+ prediction_id = prediction_id ,
716
+ )
717
+
718
+ if client_mode == ClientMode .ASYNC :
719
+ respx .post ("https://api.replicate.com/v1/predictions" ).mock (
720
+ return_value = httpx .Response (201 , json = initial_prediction )
721
+ )
722
+ else :
723
+ respx .post ("https://api.replicate.com/v1/predictions" ).mock (
724
+ return_value = httpx .Response (201 , json = initial_prediction )
725
+ )
726
+
727
+ # Mock incremental polling responses - each poll returns more data
728
+ poll_responses = [
729
+ create_mock_prediction (
730
+ {"status" : "processing" , "output" : ["Hello" ]}, prediction_id = prediction_id
731
+ ),
732
+ create_mock_prediction (
733
+ {"status" : "processing" , "output" : ["Hello" , " " ]},
734
+ prediction_id = prediction_id ,
735
+ ),
736
+ create_mock_prediction (
737
+ {"status" : "processing" , "output" : ["Hello" , " " , "streaming" ]},
738
+ prediction_id = prediction_id ,
739
+ ),
740
+ create_mock_prediction (
741
+ {"status" : "processing" , "output" : ["Hello" , " " , "streaming" , " " ]},
742
+ prediction_id = prediction_id ,
743
+ ),
744
+ create_mock_prediction (
745
+ {
746
+ "status" : "succeeded" ,
747
+ "output" : ["Hello" , " " , "streaming" , " " , "world!" ],
748
+ },
749
+ prediction_id = prediction_id ,
750
+ ),
751
+ ]
752
+
753
+ # Mock the polling endpoint to return different responses in sequence
754
+ respx .get (f"https://api.replicate.com/v1/predictions/{ prediction_id } " ).mock (
755
+ side_effect = [httpx .Response (200 , json = resp ) for resp in poll_responses ]
756
+ )
757
+
758
+ # Call use with "acme/hotdog-detector"
759
+ hotdog_detector = replicate .use (
760
+ "acme/hotdog-detector" , use_async = client_mode == ClientMode .ASYNC
761
+ )
762
+
763
+ # Get the output iterator immediately
764
+ if client_mode == ClientMode .ASYNC :
765
+ run = await hotdog_detector .create (prompt = "hello world" , use_async = True )
766
+ output_iterator = await run .output ()
767
+ else :
768
+ run = hotdog_detector .create (prompt = "hello world" )
769
+ output_iterator = run .output ()
770
+
771
+ # Assert that we get an OutputIterator immediately
772
+ from replicate .use import OutputIterator
773
+
774
+ assert isinstance (output_iterator , OutputIterator )
775
+
776
+ # Track when we receive each item to verify incremental delivery
777
+ collected_items = []
778
+
779
+ if client_mode == ClientMode .ASYNC :
780
+ async for item in output_iterator :
781
+ collected_items .append (item )
782
+ # Break after we get some incremental results to verify polling works
783
+ if len (collected_items ) >= 3 :
784
+ break
785
+ else :
786
+ for item in output_iterator :
787
+ collected_items .append (item )
788
+ # Break after we get some incremental results to verify polling works
789
+ if len (collected_items ) >= 3 :
790
+ break
791
+
792
+ # Verify we got incremental streaming results
793
+ assert len (collected_items ) >= 3
794
+ # The items should be the concatenated string parts from the incremental output
795
+ result = "" .join (collected_items )
796
+ assert "Hello" in result # Should contain the first part we streamed
797
+
798
+
799
+ @pytest .mark .asyncio
800
+ @pytest .mark .parametrize ("client_mode" , [ClientMode .DEFAULT , ClientMode .ASYNC ])
801
+ @respx .mock
802
+ async def test_non_streaming_output_waits_for_completion (client_mode ):
803
+ """Test that non-iterator outputs still wait for completion."""
804
+ mock_model_endpoints (
805
+ versions = [
806
+ create_mock_version (
807
+ {
808
+ "openapi_schema" : {
809
+ "components" : {
810
+ "schemas" : {
811
+ "Output" : {"type" : "string" } # Non-iterator output
812
+ }
813
+ }
814
+ }
815
+ }
816
+ )
817
+ ]
818
+ )
819
+
820
+ mock_prediction_endpoints (
821
+ predictions = [
822
+ create_mock_prediction ({"status" : "processing" , "output" : None }),
823
+ create_mock_prediction ({"status" : "succeeded" , "output" : "Final result" }),
824
+ ]
825
+ )
826
+
827
+ # Call use with "acme/hotdog-detector"
828
+ hotdog_detector = replicate .use (
829
+ "acme/hotdog-detector" , use_async = client_mode == ClientMode .ASYNC
830
+ )
831
+
832
+ # For non-iterator output, this should wait for completion
833
+ if client_mode == ClientMode .ASYNC :
834
+ run = await hotdog_detector .create (prompt = "hello world" )
835
+ output = await run .output ()
836
+ else :
837
+ run = hotdog_detector .create (prompt = "hello world" )
838
+ output = run .output ()
839
+
840
+ # Should get the final result directly
841
+ assert output == "Final result"
842
+
843
+
624
844
@pytest .mark .asyncio
625
845
@pytest .mark .parametrize ("client_mode" , [ClientMode .DEFAULT , ClientMode .ASYNC ])
626
846
@respx .mock
0 commit comments