diff --git a/py2k/utils/__init__.py b/py2k/utils/__init__.py index 0127128..60c265b 100644 --- a/py2k/utils/__init__.py +++ b/py2k/utils/__init__.py @@ -18,6 +18,13 @@ import pydantic +def _nullable_type(avro_type): + if isinstance(avro_type, list) and 'null' in avro_type: + return avro_type + + return ['null', avro_type] + + def process_properties(schema: Dict[str, Any]) -> Dict[str, Any]: """Processes a Pydantic generated schema to a confluent compliant schema @@ -46,7 +53,10 @@ def process_properties(schema: Dict[str, Any]) -> Dict[str, Any]: json_type = value['type'] value['type'] = json2avro_types.get(json_type, json_type) - if value.get('default') is not None: + if 'default' in value and value['default'] is None: + value['type'] = _nullable_type(value['type']) + + elif value.get('default') is not None: default = value['default'] value['default'] = python2avro_defaults.get(default, default) @@ -82,5 +92,7 @@ def update_optional_schema( for field, optional_field in product(_schema['fields'], optionals): if field['name'] == optional_field: - field.update({'type': ['null', field['type']]}) + field.update({'type': _nullable_type(field['type'])}) + if model.__fields__[optional_field].default is None: + field.setdefault('default', None) return _schema diff --git a/tests/test_schema.py b/tests/test_schema.py index dc799b9..72769f8 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -98,3 +98,17 @@ def test_field_with_default(python_value, avro_value): default_value = schema['fields'][0]['default'] assert default_value == avro_value + + +def test_field_with_null_default_is_nullable(): + MyRecord = create_model('MyRecord', + a=(bool, None), + __base__=KafkaRecord) + + schema = MyRecord().schema() + + assert schema['fields'][0] == { + 'type': ['null', 'boolean'], + 'default': None, + 'name': 'a' + }