From ec8bf5417bb654cbc3c3bb0a998bf4742dfef81b Mon Sep 17 00:00:00 2001 From: Zeke Sikelianos Date: Fri, 24 Oct 2025 14:35:25 -0700 Subject: [PATCH 1/3] feat: filter None values from inputs before API requests This PR filters out None-valued inputs from all prediction and training creation methods before making API requests. This preserves legacy behavior and prevents potential API errors when users pass None values. Changes: - Add filter_none_values() utility function to recursively remove None values from dictionaries - Update encode_json() and async_encode_json() to filter None values when processing dicts - Apply filtering to deployments.predictions.create() and trainings.create() methods - Add comprehensive test suite for None filtering functionality Fixes https://linear.app/replicate/issue/DP-737 --- src/replicate/lib/_files.py | 26 ++++ .../resources/deployments/predictions.py | 5 +- src/replicate/resources/trainings.py | 5 +- tests/test_filter_none_values.py | 136 ++++++++++++++++++ 4 files changed, 168 insertions(+), 4 deletions(-) create mode 100644 tests/test_filter_none_values.py diff --git a/src/replicate/lib/_files.py b/src/replicate/lib/_files.py index ad49a4c..cd353a9 100644 --- a/src/replicate/lib/_files.py +++ b/src/replicate/lib/_files.py @@ -19,6 +19,26 @@ FileEncodingStrategy = Literal["base64", "url"] +def filter_none_values(obj: Any) -> Any: # noqa: ANN401 + """ + Recursively filter out None values from dictionaries. + + This preserves the legacy behavior where None-valued inputs are removed + before making API requests. + + Args: + obj: The object to filter. + + Returns: + The object with None values removed from all nested dictionaries. + """ + if isinstance(obj, dict): + return {key: filter_none_values(value) for key, value in obj.items() if value is not None} + if isinstance(obj, (list, tuple)): + return type(obj)(filter_none_values(item) for item in obj) + return obj + + try: import numpy as np # type: ignore @@ -35,12 +55,15 @@ def encode_json( ) -> Any: # noqa: ANN401 """ Return a JSON-compatible version of the object. + + None values are filtered out from dictionaries to prevent API errors. """ if isinstance(obj, dict): return { key: encode_json(value, client, file_encoding_strategy) for key, value in obj.items() # type: ignore + if value is not None } # type: ignore if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)): return [encode_json(value, client, file_encoding_strategy) for value in obj] # type: ignore @@ -70,12 +93,15 @@ async def async_encode_json( ) -> Any: # noqa: ANN401 """ Asynchronously return a JSON-compatible version of the object. + + None values are filtered out from dictionaries to prevent API errors. """ if isinstance(obj, dict): return { key: (await async_encode_json(value, client, file_encoding_strategy)) for key, value in obj.items() # type: ignore + if value is not None } # type: ignore if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)): return [ diff --git a/src/replicate/resources/deployments/predictions.py b/src/replicate/resources/deployments/predictions.py index aa22e7d..0f4bac7 100644 --- a/src/replicate/resources/deployments/predictions.py +++ b/src/replicate/resources/deployments/predictions.py @@ -17,6 +17,7 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) +from ...lib._files import filter_none_values from ..._base_client import make_request_options from ...types.prediction import Prediction from ...types.deployments import prediction_create_params @@ -176,7 +177,7 @@ def create( f"/deployments/{deployment_owner}/{deployment_name}/predictions", body=maybe_transform( { - "input": input, + "input": filter_none_values(input), "stream": stream, "webhook": webhook, "webhook_events_filter": webhook_events_filter, @@ -342,7 +343,7 @@ async def create( f"/deployments/{deployment_owner}/{deployment_name}/predictions", body=await async_maybe_transform( { - "input": input, + "input": filter_none_values(input), "stream": stream, "webhook": webhook, "webhook_events_filter": webhook_events_filter, diff --git a/src/replicate/resources/trainings.py b/src/replicate/resources/trainings.py index 51ab733..7f9abbd 100644 --- a/src/replicate/resources/trainings.py +++ b/src/replicate/resources/trainings.py @@ -18,6 +18,7 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) +from ..lib._files import filter_none_values from ..pagination import SyncCursorURLPage, AsyncCursorURLPage from .._base_client import AsyncPaginator, make_request_options from ..types.training_get_response import TrainingGetResponse @@ -187,7 +188,7 @@ def create( body=maybe_transform( { "destination": destination, - "input": input, + "input": filter_none_values(input), "webhook": webhook, "webhook_events_filter": webhook_events_filter, }, @@ -573,7 +574,7 @@ async def create( body=await async_maybe_transform( { "destination": destination, - "input": input, + "input": filter_none_values(input), "webhook": webhook, "webhook_events_filter": webhook_events_filter, }, diff --git a/tests/test_filter_none_values.py b/tests/test_filter_none_values.py new file mode 100644 index 0000000..313ba5a --- /dev/null +++ b/tests/test_filter_none_values.py @@ -0,0 +1,136 @@ +from replicate import Replicate +from replicate.lib._files import encode_json, async_encode_json, filter_none_values + + +def test_filter_none_values_simple_dict(): + """Test that None values are filtered from a simple dictionary.""" + input_dict = {"prompt": "banana", "seed": None, "width": 512} + result = filter_none_values(input_dict) + assert result == {"prompt": "banana", "width": 512} + assert "seed" not in result + + +def test_filter_none_values_nested_dict(): + """Test that None values are filtered from nested dictionaries.""" + input_dict = { + "prompt": "banana", + "config": {"seed": None, "temperature": 0.8, "iterations": None}, + "width": 512, + } + result = filter_none_values(input_dict) + assert result == { + "prompt": "banana", + "config": {"temperature": 0.8}, + "width": 512, + } + assert "seed" not in result["config"] + assert "iterations" not in result["config"] + + +def test_filter_none_values_all_none(): + """Test that a dict with all None values returns an empty dict.""" + input_dict = {"seed": None, "temperature": None} + result = filter_none_values(input_dict) + assert result == {} + + +def test_filter_none_values_empty_dict(): + """Test that an empty dict returns an empty dict.""" + input_dict = {} + result = filter_none_values(input_dict) + assert result == {} + + +def test_filter_none_values_with_list(): + """Test that lists are preserved and None values in dicts within lists are filtered.""" + input_dict = { + "prompts": ["banana", "apple"], + "seeds": [None, 42, None], + "config": {"value": None}, + } + result = filter_none_values(input_dict) + # None values in lists are preserved + assert result == { + "prompts": ["banana", "apple"], + "seeds": [None, 42, None], + "config": {}, + } + + +def test_filter_none_values_with_tuple(): + """Test that tuples are preserved.""" + input_dict = {"coords": (1, None, 3)} + result = filter_none_values(input_dict) + # Tuples are preserved as-is + assert result == {"coords": (1, None, 3)} + + +def test_filter_none_values_non_dict(): + """Test that non-dict values are returned as-is.""" + assert filter_none_values("string") == "string" + assert filter_none_values(42) == 42 + assert filter_none_values(None) is None + assert filter_none_values([1, 2, 3]) == [1, 2, 3] + + +def test_encode_json_filters_none(client: Replicate): + """Test that encode_json filters None values from dicts.""" + input_dict = {"prompt": "banana", "seed": None, "width": 512} + result = encode_json(input_dict, client) + assert result == {"prompt": "banana", "width": 512} + assert "seed" not in result + + +def test_encode_json_nested_none_filtering(client: Replicate): + """Test that encode_json recursively filters None values.""" + input_dict = { + "prompt": "banana", + "config": {"seed": None, "temperature": 0.8}, + "metadata": {"user": "test", "session": None}, + } + result = encode_json(input_dict, client) + assert result == { + "prompt": "banana", + "config": {"temperature": 0.8}, + "metadata": {"user": "test"}, + } + + +async def test_async_encode_json_filters_none(async_client): + """Test that async_encode_json filters None values from dicts.""" + input_dict = {"prompt": "banana", "seed": None, "width": 512} + result = await async_encode_json(input_dict, async_client) + assert result == {"prompt": "banana", "width": 512} + assert "seed" not in result + + +async def test_async_encode_json_nested_none_filtering(async_client): + """Test that async_encode_json recursively filters None values.""" + input_dict = { + "prompt": "banana", + "config": {"seed": None, "temperature": 0.8}, + "metadata": {"user": "test", "session": None}, + } + result = await async_encode_json(input_dict, async_client) + assert result == { + "prompt": "banana", + "config": {"temperature": 0.8}, + "metadata": {"user": "test"}, + } + + +def test_encode_json_preserves_false_and_zero(client: Replicate): + """Test that False and 0 are not filtered out.""" + input_dict = { + "prompt": "banana", + "seed": 0, + "enable_feature": False, + "iterations": None, + } + result = encode_json(input_dict, client) + assert result == { + "prompt": "banana", + "seed": 0, + "enable_feature": False, + } + assert "iterations" not in result From 3e4fdc9053f1f225c1ba2c790da6f2b7a9ed1104 Mon Sep 17 00:00:00 2001 From: Zeke Sikelianos Date: Fri, 24 Oct 2025 14:37:35 -0700 Subject: [PATCH 2/3] fix: add type ignores for pyright compatibility --- src/replicate/lib/_files.py | 8 ++++++-- tests/test_filter_none_values.py | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/replicate/lib/_files.py b/src/replicate/lib/_files.py index cd353a9..04a17c3 100644 --- a/src/replicate/lib/_files.py +++ b/src/replicate/lib/_files.py @@ -33,9 +33,13 @@ def filter_none_values(obj: Any) -> Any: # noqa: ANN401 The object with None values removed from all nested dictionaries. """ if isinstance(obj, dict): - return {key: filter_none_values(value) for key, value in obj.items() if value is not None} + return { + key: filter_none_values(value) + for key, value in obj.items() # type: ignore[misc] + if value is not None + } if isinstance(obj, (list, tuple)): - return type(obj)(filter_none_values(item) for item in obj) + return type(obj)(filter_none_values(item) for item in obj) # type: ignore[arg-type, misc] return obj diff --git a/tests/test_filter_none_values.py b/tests/test_filter_none_values.py index 313ba5a..492f9d0 100644 --- a/tests/test_filter_none_values.py +++ b/tests/test_filter_none_values.py @@ -96,7 +96,7 @@ def test_encode_json_nested_none_filtering(client: Replicate): } -async def test_async_encode_json_filters_none(async_client): +async def test_async_encode_json_filters_none(async_client): # type: ignore[no-untyped-def] """Test that async_encode_json filters None values from dicts.""" input_dict = {"prompt": "banana", "seed": None, "width": 512} result = await async_encode_json(input_dict, async_client) @@ -104,7 +104,7 @@ async def test_async_encode_json_filters_none(async_client): assert "seed" not in result -async def test_async_encode_json_nested_none_filtering(async_client): +async def test_async_encode_json_nested_none_filtering(async_client): # type: ignore[no-untyped-def] """Test that async_encode_json recursively filters None values.""" input_dict = { "prompt": "banana", From 1267126e915b120d71d6809564ebb93fe110ba74 Mon Sep 17 00:00:00 2001 From: Zeke Sikelianos Date: Fri, 24 Oct 2025 14:39:36 -0700 Subject: [PATCH 3/3] fix: add arg-type ignores for async_client --- tests/test_filter_none_values.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_filter_none_values.py b/tests/test_filter_none_values.py index 492f9d0..afabfd0 100644 --- a/tests/test_filter_none_values.py +++ b/tests/test_filter_none_values.py @@ -99,7 +99,7 @@ def test_encode_json_nested_none_filtering(client: Replicate): async def test_async_encode_json_filters_none(async_client): # type: ignore[no-untyped-def] """Test that async_encode_json filters None values from dicts.""" input_dict = {"prompt": "banana", "seed": None, "width": 512} - result = await async_encode_json(input_dict, async_client) + result = await async_encode_json(input_dict, async_client) # type: ignore[arg-type] assert result == {"prompt": "banana", "width": 512} assert "seed" not in result @@ -111,7 +111,7 @@ async def test_async_encode_json_nested_none_filtering(async_client): # type: i "config": {"seed": None, "temperature": 0.8}, "metadata": {"user": "test", "session": None}, } - result = await async_encode_json(input_dict, async_client) + result = await async_encode_json(input_dict, async_client) # type: ignore[arg-type] assert result == { "prompt": "banana", "config": {"temperature": 0.8},