From fa3b9fa8454ac1eb0df7ba788309a1dd123b96ad Mon Sep 17 00:00:00 2001 From: gabriel Date: Thu, 5 Mar 2026 10:46:39 +0100 Subject: [PATCH 1/4] mvp infer schema --- dataframely/__init__.py | 2 + dataframely/_generate_schema.py | 247 +++++++++++++++++++++++++++++ tests/test_infer_schema.py | 272 ++++++++++++++++++++++++++++++++ 3 files changed, 521 insertions(+) create mode 100644 dataframely/_generate_schema.py create mode 100644 tests/test_infer_schema.py diff --git a/dataframely/__init__.py b/dataframely/__init__.py index 399f9711..ca2b5866 100644 --- a/dataframely/__init__.py +++ b/dataframely/__init__.py @@ -12,6 +12,7 @@ from . import random from ._filter import filter +from ._generate_schema import infer_schema from ._rule import rule from ._typing import DataFrame, LazyFrame, Validation from .collection import ( @@ -78,6 +79,7 @@ "deserialize_schema", "read_parquet_metadata_schema", "read_parquet_metadata_collection", + "infer_schema", "Any", "Binary", "Bool", diff --git a/dataframely/_generate_schema.py b/dataframely/_generate_schema.py new file mode 100644 index 00000000..4fea1704 --- /dev/null +++ b/dataframely/_generate_schema.py @@ -0,0 +1,247 @@ +# Copyright (c) QuantCo 2025-2026 +# SPDX-License-Identifier: BSD-3-Clause +"""Infer schema from a Polars DataFrame.""" + +from __future__ import annotations + +import keyword +import re +from typing import TYPE_CHECKING, Literal, overload + +import polars as pl + +if TYPE_CHECKING: + from dataframely.schema import Schema + + +@overload +def infer_schema( + df: pl.DataFrame, + schema_name: str = ..., + *, + return_type: None = ..., +) -> None: ... + + +@overload +def infer_schema( + df: pl.DataFrame, + schema_name: str = ..., + *, + return_type: Literal["string"], +) -> str: ... + + +@overload +def infer_schema( + df: pl.DataFrame, + schema_name: str = ..., + *, + return_type: Literal["schema"], +) -> type[Schema]: ... + + +def infer_schema( + df: pl.DataFrame, + schema_name: str = "InferredSchema", + *, + return_type: Literal["string", "schema"] | None = None, +) -> str | type[Schema] | None: + """Infer a dataframely schema from a Polars DataFrame. + + This function inspects a DataFrame's schema and generates a corresponding + dataframely Schema. It can print the schema code, return it as a string, + or return an actual Schema class. + + Args: + df: The Polars DataFrame to infer the schema from. + schema_name: The name for the generated schema class. + return_type: Controls the return format: + + - ``None`` (default): Print the schema code to stdout, return ``None``. + - ``"string"``: Return the schema code as a string. + - ``"schema"``: Return an actual Schema class. + + Returns: + Depends on ``return_type``: + + - ``None``: Returns ``None`` (prints to stdout). + - ``"string"``: Returns the schema code as a string. + - ``"schema"``: Returns a Schema class that can be used directly. + + Example: + >>> import polars as pl + >>> import dataframely as dy + >>> df = pl.DataFrame({ + ... "name": ["Alice", "Bob"], + ... "age": [25, 30], + ... "score": [95.5, None], + ... }) + >>> dy.infer_schema(df, "PersonSchema") + class PersonSchema(dy.Schema): + name = dy.String() + age = dy.Int64() + score = dy.Float64(nullable=True) + >>> schema = dy.infer_schema(df, "PersonSchema", return_type="schema") + >>> schema.is_valid(df) + True + """ + code = _generate_schema_code(df, schema_name) + + if return_type is None: + print(code) # noqa: T201 + return None + if return_type == "string": + return code + if return_type == "schema": + import dataframely as dy + + namespace: dict = {"dy": dy} + exec(code, namespace) # noqa: S102 + return namespace[schema_name] + + msg = f"Invalid return_type: {return_type!r}" + raise ValueError(msg) + + +def _generate_schema_code(df: pl.DataFrame, schema_name: str) -> str: + """Generate schema code string from a DataFrame.""" + lines = [f"class {schema_name}(dy.Schema):"] + + for col_name, series in df.to_dict().items(): + if _is_valid_identifier(col_name): + attr_name = col_name + alias = None + else: + attr_name = _make_valid_identifier(col_name) + alias = col_name + col_code = _dtype_to_column_code(series, alias=alias) + lines.append(f" {attr_name} = {col_code}") + + return "\n".join(lines) + + +def _is_valid_identifier(name: str) -> bool: + """Check if a string is a valid Python identifier and not a keyword.""" + return name.isidentifier() and not keyword.iskeyword(name) + + +def _make_valid_identifier(name: str) -> str: + """Convert a string to a valid Python identifier.""" + # Replace invalid characters with underscores + result = re.sub(r"[^a-zA-Z0-9_]", "_", name) + # Ensure it doesn't start with a digit + if result and result[0].isdigit(): + result = "_" + result + # Ensure it's not empty + if not result: + result = "_column" + # Handle keywords + if keyword.iskeyword(result): + result = result + "_" + return result + + +def _format_args(*args: str, nullable: bool = False, alias: str | None = None) -> str: + """Format arguments for column constructor.""" + all_args = list(args) + if nullable: + all_args.insert(0, "nullable=True") + if alias: + all_args.insert(0, f'alias="{alias}"') + return ", ".join(all_args) + + +def _dtype_to_column_code(series: pl.Series, *, alias: str | None = None) -> str: + """Convert a Polars Series to dataframely column constructor code.""" + dtype = series.dtype + nullable = series.null_count() > 0 + + # Simple types + if dtype == pl.Boolean(): + return f"dy.Bool({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.Int8(): + return f"dy.Int8({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.Int16(): + return f"dy.Int16({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.Int32(): + return f"dy.Int32({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.Int64(): + return f"dy.Int64({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.UInt8(): + return f"dy.UInt8({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.UInt16(): + return f"dy.UInt16({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.UInt32(): + return f"dy.UInt32({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.UInt64(): + return f"dy.UInt64({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.Float32(): + return f"dy.Float32({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.Float64(): + return f"dy.Float64({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.String(): + return f"dy.String({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.Binary(): + return f"dy.Binary({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.Date(): + return f"dy.Date({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.Time(): + return f"dy.Time({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.Null(): + return f"dy.Any({_format_args(alias=alias)})" + if dtype == pl.Object(): + return f"dy.Object({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.Categorical(): + return f"dy.Categorical({_format_args(nullable=nullable, alias=alias)})" + + # Datetime with parameters + if isinstance(dtype, pl.Datetime): + args = [] + if dtype.time_zone is not None: + args.append(f'time_zone="{dtype.time_zone}"') + if dtype.time_unit != "us": # us is the default + args.append(f'time_unit="{dtype.time_unit}"') + return f"dy.Datetime({_format_args(*args, nullable=nullable, alias=alias)})" + + # Duration with time_unit + if isinstance(dtype, pl.Duration): + return f"dy.Duration({_format_args(nullable=nullable, alias=alias)})" + + # Decimal with precision and scale + if isinstance(dtype, pl.Decimal): + args = [] + if dtype.precision is not None: + args.append(f"precision={dtype.precision}") + if dtype.scale != 0: + args.append(f"scale={dtype.scale}") + return f"dy.Decimal({_format_args(*args, nullable=nullable, alias=alias)})" + + # Enum with categories + if isinstance(dtype, pl.Enum): + categories = dtype.categories.to_list() + return ( + f"dy.Enum({_format_args(repr(categories), nullable=nullable, alias=alias)})" + ) + + # List with inner type + if isinstance(dtype, pl.List): + inner_code = _dtype_to_column_code(series.explode()) + return f"dy.List({_format_args(inner_code, nullable=nullable, alias=alias)})" + + # Array with inner type and shape + if isinstance(dtype, pl.Array): + inner_code = _dtype_to_column_code(series.explode()) + return f"dy.Array({_format_args(inner_code, f'shape={dtype.size}', nullable=nullable, alias=alias)})" + + # Struct with fields + if isinstance(dtype, pl.Struct): + fields_parts = [] + for field in dtype.fields: + field_code = _dtype_to_column_code(series.struct.field(field.name)) + fields_parts.append(f'"{field.name}": {field_code}') + fields_dict = "{" + ", ".join(fields_parts) + "}" + return f"dy.Struct({_format_args(fields_dict, nullable=nullable, alias=alias)})" + + # Fallback for unknown types + return f"dy.Any({_format_args(alias=alias)}) # Unknown dtype: {dtype}" diff --git a/tests/test_infer_schema.py b/tests/test_infer_schema.py new file mode 100644 index 00000000..2dfec525 --- /dev/null +++ b/tests/test_infer_schema.py @@ -0,0 +1,272 @@ +# Copyright (c) QuantCo 2025-2026 +# SPDX-License-Identifier: BSD-3-Clause + +import datetime +import textwrap + +import polars as pl + +import dataframely as dy + + +class TestInferSchema: + def test_basic_types(self) -> None: + df = pl.DataFrame( + { + "int_col": [1, 2, 3], + "float_col": [1.0, 2.0, 3.0], + "str_col": ["a", "b", "c"], + "bool_col": [True, False, True], + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="BasicSchema") + expected = textwrap.dedent("""\ + class BasicSchema(dy.Schema): + int_col = dy.Int64() + float_col = dy.Float64() + str_col = dy.String() + bool_col = dy.Bool()""") + assert result == expected + + def test_nullable_detection(self) -> None: + df = pl.DataFrame( + { + "nullable_int": [1, None, 3], + "non_nullable_int": [1, 2, 3], + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="NullableSchema") + expected = textwrap.dedent("""\ + class NullableSchema(dy.Schema): + nullable_int = dy.Int64(nullable=True) + non_nullable_int = dy.Int64()""") + assert result == expected + + def test_datetime_types(self) -> None: + df = pl.DataFrame( + { + "date_col": [datetime.date(2024, 1, 1)], + "time_col": [datetime.time(12, 0, 0)], + "datetime_col": [datetime.datetime(2024, 1, 1, 12, 0, 0)], + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="DatetimeSchema") + expected = textwrap.dedent("""\ + class DatetimeSchema(dy.Schema): + date_col = dy.Date() + time_col = dy.Time() + datetime_col = dy.Datetime()""") + assert result == expected + + def test_datetime_with_timezone(self) -> None: + df = pl.DataFrame( + { + "utc_time": pl.Series( + [datetime.datetime(2024, 1, 1)] + ).dt.replace_time_zone("UTC"), + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="TzSchema") + expected = textwrap.dedent("""\ + class TzSchema(dy.Schema): + utc_time = dy.Datetime(time_zone="UTC")""") + assert result == expected + + def test_enum_type(self) -> None: + df = pl.DataFrame( + { + "status": pl.Series(["active", "pending"]).cast( + pl.Enum(["active", "pending", "inactive"]) + ), + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="EnumSchema") + expected = textwrap.dedent("""\ + class EnumSchema(dy.Schema): + status = dy.Enum(['active', 'pending', 'inactive'])""") + assert result == expected + + def test_decimal_type(self) -> None: + df = pl.DataFrame( + { + "amount": pl.Series(["10.50"]).cast(pl.Decimal(precision=10, scale=2)), + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="DecimalSchema") + expected = textwrap.dedent("""\ + class DecimalSchema(dy.Schema): + amount = dy.Decimal(precision=10, scale=2)""") + assert result == expected + + def test_list_type(self) -> None: + df = pl.DataFrame( + { + "tags": [["a", "b"], ["c"]], + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="ListSchema") + expected = textwrap.dedent("""\ + class ListSchema(dy.Schema): + tags = dy.List(dy.String())""") + assert result == expected + + def test_struct_type(self) -> None: + df = pl.DataFrame( + { + "metadata": [{"key": "value"}, {"key": "other"}], + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="StructSchema") + expected = textwrap.dedent("""\ + class StructSchema(dy.Schema): + metadata = dy.Struct({"key": dy.String()})""") + assert result == expected + + def test_list_with_nullable_inner(self) -> None: + df = pl.DataFrame({"names": [["Alice"], [None]]}) + result = dy.infer_schema( + df, return_type="string", schema_name="ListNullableInnerSchema" + ) + expected = textwrap.dedent("""\ + class ListNullableInnerSchema(dy.Schema): + names = dy.List(dy.String(nullable=True))""") + assert result == expected + + def test_struct_with_nullable_field(self) -> None: + df = pl.DataFrame({"data": [{"key": "value"}, {"key": None}]}) + result = dy.infer_schema( + df, return_type="string", schema_name="StructNullableFieldSchema" + ) + expected = textwrap.dedent("""\ + class StructNullableFieldSchema(dy.Schema): + data = dy.Struct({"key": dy.String(nullable=True)})""") + assert result == expected + + def test_array_type(self) -> None: + df = pl.DataFrame({"vector": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]}).cast( + {"vector": pl.Array(pl.Float64(), 3)} + ) + result = dy.infer_schema(df, return_type="string", schema_name="ArraySchema") + expected = textwrap.dedent("""\ + class ArraySchema(dy.Schema): + vector = dy.Array(dy.Float64(), shape=3)""") + assert result == expected + + def test_invalid_identifier(self) -> None: + df = pl.DataFrame( + { + "123invalid": ["test"], + } + ) + result = dy.infer_schema( + df, return_type="string", schema_name="InvalidIdSchema" + ) + expected = textwrap.dedent("""\ + class InvalidIdSchema(dy.Schema): + _123invalid = dy.String(alias="123invalid")""") + assert result == expected + + def test_python_keyword(self) -> None: + df = pl.DataFrame( + { + "class": ["test"], + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="KeywordSchema") + expected = textwrap.dedent("""\ + class KeywordSchema(dy.Schema): + class_ = dy.String(alias="class")""") + assert result == expected + + def test_all_integer_types(self) -> None: + df = pl.DataFrame( + { + "i8": pl.Series([1], dtype=pl.Int8), + "i16": pl.Series([1], dtype=pl.Int16), + "i32": pl.Series([1], dtype=pl.Int32), + "i64": pl.Series([1], dtype=pl.Int64), + "u8": pl.Series([1], dtype=pl.UInt8), + "u16": pl.Series([1], dtype=pl.UInt16), + "u32": pl.Series([1], dtype=pl.UInt32), + "u64": pl.Series([1], dtype=pl.UInt64), + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="IntSchema") + assert "dy.Int8()" in result + assert "dy.Int16()" in result + assert "dy.Int32()" in result + assert "dy.Int64()" in result + assert "dy.UInt8()" in result + assert "dy.UInt16()" in result + assert "dy.UInt32()" in result + assert "dy.UInt64()" in result + + def test_float_types(self) -> None: + df = pl.DataFrame( + { + "f32": pl.Series([1.0], dtype=pl.Float32), + "f64": pl.Series([1.0], dtype=pl.Float64), + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="FloatSchema") + assert "dy.Float32()" in result + assert "dy.Float64()" in result + + +class TestInferSchemaReturnsSchema: + """Test that return_type='schema' produces working schemas.""" + + def test_inferred_schema_validates_dataframe(self) -> None: + """Verify inferred schema validates the original dataframe.""" + dataframes = [ + # Basic types + pl.DataFrame( + { + "int_col": [1, 2, 3], + "float_col": [1.0, 2.0, 3.0], + "str_col": ["a", "b", "c"], + "bool_col": [True, False, True], + } + ), + # Nullable + pl.DataFrame({"nullable_int": [1, None, 3], "non_nullable_int": [1, 2, 3]}), + # Datetime types + pl.DataFrame( + { + "date_col": [datetime.date(2024, 1, 1)], + "time_col": [datetime.time(12, 0, 0)], + "datetime_col": [datetime.datetime(2024, 1, 1, 12, 0, 0)], + } + ), + # Enum + pl.DataFrame( + { + "status": pl.Series(["active", "pending"]).cast( + pl.Enum(["active", "pending", "inactive"]) + ) + } + ), + # List and struct + pl.DataFrame({"tags": [["a", "b"], ["c"]]}), + pl.DataFrame({"metadata": [{"key": "value"}]}), + # Array + pl.DataFrame({"vector": [[1.0, 2.0, 3.0]]}).cast( + {"vector": pl.Array(pl.Float64(), 3)} + ), + # Invalid identifiers and keywords + pl.DataFrame({"123invalid": ["test"], "class": ["test"]}), + # Decimal + pl.DataFrame( + {"amount": pl.Series(["10.50"]).cast(pl.Decimal(precision=10, scale=2))} + ), + # Nested types + pl.DataFrame({"nested_list": [[["a", "b"]]]}), + pl.DataFrame({"nested_struct": [{"outer": {"inner": "value"}}]}), + # Nullable inner types + pl.DataFrame({"list_with_nulls": [["a"], [None]]}), + pl.DataFrame({"struct_with_nulls": [{"key": "value"}, {"key": None}]}), + ] + + for i, df in enumerate(dataframes): + schema = dy.infer_schema(df, f"Schema{i}", return_type="schema") + assert schema.is_valid(df), f"Schema{i} failed for {df.schema}" From 6c19bfab2ff5a006e2d99400297160d8e5e8dc3e Mon Sep 17 00:00:00 2001 From: gabriel Date: Thu, 5 Mar 2026 11:20:21 +0100 Subject: [PATCH 2/4] increase code coverage --- dataframely/_generate_schema.py | 2 +- tests/test_infer_schema.py | 110 ++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 1 deletion(-) diff --git a/dataframely/_generate_schema.py b/dataframely/_generate_schema.py index 4fea1704..0ce7c316 100644 --- a/dataframely/_generate_schema.py +++ b/dataframely/_generate_schema.py @@ -43,7 +43,7 @@ def infer_schema( def infer_schema( df: pl.DataFrame, - schema_name: str = "InferredSchema", + schema_name: str = "Schema", *, return_type: Literal["string", "schema"] | None = None, ) -> str | type[Schema] | None: diff --git a/tests/test_infer_schema.py b/tests/test_infer_schema.py index 2dfec525..e57aa950 100644 --- a/tests/test_infer_schema.py +++ b/tests/test_infer_schema.py @@ -5,6 +5,7 @@ import textwrap import polars as pl +import pytest import dataframely as dy @@ -213,6 +214,115 @@ def test_float_types(self) -> None: assert "dy.Float64()" in result +class TestInferSchemaReturnTypes: + """Test the different return_type options.""" + + def test_return_type_none_prints_to_stdout( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + df = pl.DataFrame({"col": [1, 2, 3]}) + result = dy.infer_schema(df, "TestSchema") + assert result is None + captured = capsys.readouterr() + assert "class TestSchema(dy.Schema):" in captured.out + assert "col = dy.Int64()" in captured.out + + def test_return_type_string(self) -> None: + df = pl.DataFrame({"col": [1, 2, 3]}) + result = dy.infer_schema(df, "TestSchema", return_type="string") + assert isinstance(result, str) + assert "class TestSchema(dy.Schema):" in result + + def test_return_type_schema(self) -> None: + df = pl.DataFrame({"col": [1, 2, 3]}) + schema = dy.infer_schema(df, "TestSchema", return_type="schema") + assert schema.is_valid(df) + + def test_invalid_return_type_raises_error(self) -> None: + df = pl.DataFrame({"col": [1]}) + with pytest.raises(ValueError, match="Invalid return_type"): + dy.infer_schema(df, "Test", return_type="invalid") # type: ignore[call-overload] + + def test_default_schema_name(self) -> None: + df = pl.DataFrame({"col": [1]}) + result = dy.infer_schema(df, return_type="string") + assert "class Schema(dy.Schema):" in result + + +class TestSpecialTypes: + """Test special column types.""" + + def test_binary_type(self) -> None: + df = pl.DataFrame({"data": pl.Series([b"hello"], dtype=pl.Binary)}) + result = dy.infer_schema(df, return_type="string", schema_name="BinarySchema") + assert "dy.Binary()" in result + + def test_null_type(self) -> None: + df = pl.DataFrame({"null_col": pl.Series([None, None], dtype=pl.Null)}) + result = dy.infer_schema(df, return_type="string", schema_name="NullSchema") + assert "dy.Any()" in result + + def test_object_type(self) -> None: + df = pl.DataFrame({"obj": pl.Series([object()], dtype=pl.Object)}) + result = dy.infer_schema(df, return_type="string", schema_name="ObjectSchema") + assert "dy.Object()" in result + + def test_categorical_type(self) -> None: + df = pl.DataFrame({"cat": pl.Series(["a", "b"]).cast(pl.Categorical())}) + result = dy.infer_schema(df, return_type="string", schema_name="CatSchema") + assert "dy.Categorical()" in result + + def test_duration_type(self) -> None: + df = pl.DataFrame( + {"dur": pl.Series([datetime.timedelta(days=1)], dtype=pl.Duration)} + ) + result = dy.infer_schema(df, return_type="string", schema_name="DurSchema") + assert "dy.Duration()" in result + + def test_datetime_with_time_unit_ms(self) -> None: + df = pl.DataFrame( + {"dt": pl.Series([datetime.datetime(2024, 1, 1)]).cast(pl.Datetime("ms"))} + ) + result = dy.infer_schema(df, return_type="string", schema_name="DtSchema") + assert 'time_unit="ms"' in result + + def test_datetime_with_time_unit_ns(self) -> None: + df = pl.DataFrame( + {"dt": pl.Series([datetime.datetime(2024, 1, 1)]).cast(pl.Datetime("ns"))} + ) + result = dy.infer_schema(df, return_type="string", schema_name="DtSchema") + assert 'time_unit="ns"' in result + + def test_decimal_without_scale(self) -> None: + df = pl.DataFrame( + {"amount": pl.Series(["10"]).cast(pl.Decimal(precision=5, scale=0))} + ) + result = dy.infer_schema(df, return_type="string", schema_name="DecSchema") + assert "precision=5" in result + assert "scale=" not in result + + +class TestMakeValidIdentifier: + """Test edge cases of _make_valid_identifier.""" + + def test_column_with_special_chars_replaced(self) -> None: + df = pl.DataFrame({"!!!": ["test"]}) + result = dy.infer_schema(df, return_type="string", schema_name="SpecialSchema") + assert '___ = dy.String(alias="!!!")' in result + + def test_column_empty_after_sanitization(self) -> None: + # Empty string column name results in _column fallback + df = pl.DataFrame({"": ["test"]}) + result = dy.infer_schema(df, return_type="string", schema_name="EmptySchema") + # Empty string alias is not included (falsy), but _column is generated + assert "_column = dy.String()" in result + + def test_column_with_spaces(self) -> None: + df = pl.DataFrame({"col name": ["test"]}) + result = dy.infer_schema(df, return_type="string", schema_name="SpaceSchema") + assert 'col_name = dy.String(alias="col name")' in result + + class TestInferSchemaReturnsSchema: """Test that return_type='schema' produces working schemas.""" From f0e07fb14426abd8ea6120d0fe2a90d70c9ec59a Mon Sep 17 00:00:00 2001 From: gabriel Date: Thu, 5 Mar 2026 11:36:07 +0100 Subject: [PATCH 3/4] copilot --- dataframely/_generate_schema.py | 11 +++++++++-- docs/api/schema/index.rst | 1 + docs/api/schema/inference.rst | 9 +++++++++ tests/test_infer_schema.py | 28 +++++++++++++++++----------- 4 files changed, 36 insertions(+), 13 deletions(-) create mode 100644 docs/api/schema/inference.rst diff --git a/dataframely/_generate_schema.py b/dataframely/_generate_schema.py index 0ce7c316..68a55017 100644 --- a/dataframely/_generate_schema.py +++ b/dataframely/_generate_schema.py @@ -85,7 +85,14 @@ class PersonSchema(dy.Schema): >>> schema = dy.infer_schema(df, "PersonSchema", return_type="schema") >>> schema.is_valid(df) True + + Raises: + ValueError: If ``schema_name`` is not a valid Python identifier. """ + if not schema_name.isidentifier(): + msg = f"schema_name must be a valid Python identifier, got {schema_name!r}" + raise ValueError(msg) + code = _generate_schema_code(df, schema_name) if return_type is None: @@ -146,9 +153,9 @@ def _format_args(*args: str, nullable: bool = False, alias: str | None = None) - """Format arguments for column constructor.""" all_args = list(args) if nullable: - all_args.insert(0, "nullable=True") + all_args.append("nullable=True") if alias: - all_args.insert(0, f'alias="{alias}"') + all_args.append(f'alias="{alias}"') return ", ".join(all_args) diff --git a/docs/api/schema/index.rst b/docs/api/schema/index.rst index 77e03239..5ed25baa 100644 --- a/docs/api/schema/index.rst +++ b/docs/api/schema/index.rst @@ -9,6 +9,7 @@ Schema validation io generation + inference conversion metadata diff --git a/docs/api/schema/inference.rst b/docs/api/schema/inference.rst new file mode 100644 index 00000000..29d335e1 --- /dev/null +++ b/docs/api/schema/inference.rst @@ -0,0 +1,9 @@ +========= +Inference +========= + +.. currentmodule:: dataframely +.. autosummary:: + :toctree: _gen/ + + infer_schema diff --git a/tests/test_infer_schema.py b/tests/test_infer_schema.py index e57aa950..62dbd72a 100644 --- a/tests/test_infer_schema.py +++ b/tests/test_infer_schema.py @@ -243,6 +243,13 @@ def test_invalid_return_type_raises_error(self) -> None: with pytest.raises(ValueError, match="Invalid return_type"): dy.infer_schema(df, "Test", return_type="invalid") # type: ignore[call-overload] + def test_invalid_schema_name_raises_error(self) -> None: + df = pl.DataFrame({"col": [1]}) + with pytest.raises( + ValueError, match="schema_name must be a valid Python identifier" + ): + dy.infer_schema(df, "Invalid Name") + def test_default_schema_name(self) -> None: df = pl.DataFrame({"col": [1]}) result = dy.infer_schema(df, return_type="string") @@ -324,11 +331,9 @@ def test_column_with_spaces(self) -> None: class TestInferSchemaReturnsSchema: - """Test that return_type='schema' produces working schemas.""" - - def test_inferred_schema_validates_dataframe(self) -> None: - """Verify inferred schema validates the original dataframe.""" - dataframes = [ + @pytest.mark.parametrize( + "df", + [ # Basic types pl.DataFrame( { @@ -356,8 +361,9 @@ def test_inferred_schema_validates_dataframe(self) -> None: ) } ), - # List and struct + # List pl.DataFrame({"tags": [["a", "b"], ["c"]]}), + # Struct pl.DataFrame({"metadata": [{"key": "value"}]}), # Array pl.DataFrame({"vector": [[1.0, 2.0, 3.0]]}).cast( @@ -375,8 +381,8 @@ def test_inferred_schema_validates_dataframe(self) -> None: # Nullable inner types pl.DataFrame({"list_with_nulls": [["a"], [None]]}), pl.DataFrame({"struct_with_nulls": [{"key": "value"}, {"key": None}]}), - ] - - for i, df in enumerate(dataframes): - schema = dy.infer_schema(df, f"Schema{i}", return_type="schema") - assert schema.is_valid(df), f"Schema{i} failed for {df.schema}" + ], + ) + def test_inferred_schema_validates_dataframe(self, df: pl.DataFrame) -> None: + schema = dy.infer_schema(df, "TestSchema", return_type="schema") + assert schema.is_valid(df) From 7ee32cf436d6a752b24b77675bed69ed6832c08c Mon Sep 17 00:00:00 2001 From: gabriel Date: Thu, 5 Mar 2026 13:46:59 +0100 Subject: [PATCH 4/4] pragma: no cover --- dataframely/_generate_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataframely/_generate_schema.py b/dataframely/_generate_schema.py index 68a55017..934c6575 100644 --- a/dataframely/_generate_schema.py +++ b/dataframely/_generate_schema.py @@ -251,4 +251,4 @@ def _dtype_to_column_code(series: pl.Series, *, alias: str | None = None) -> str return f"dy.Struct({_format_args(fields_dict, nullable=nullable, alias=alias)})" # Fallback for unknown types - return f"dy.Any({_format_args(alias=alias)}) # Unknown dtype: {dtype}" + return f"dy.Any({_format_args(alias=alias)}) # Unknown dtype: {dtype}" # pragma: no cover