diff --git a/backend/utils/input_sanitizer.py b/backend/utils/input_sanitizer.py index 36772727e6..8ac4b305eb 100644 --- a/backend/utils/input_sanitizer.py +++ b/backend/utils/input_sanitizer.py @@ -1,7 +1,11 @@ +import logging import re +from typing import NoReturn from rest_framework.serializers import ValidationError +logger = logging.getLogger(__name__) + # Pattern to detect HTML/script tags (closed tags and unclosed tags starting with a letter) # The second alternative catches unclosed tags like " NoReturn: + logger.warning( + "input_validation_rejected", + extra={"field": field_name, "reason": reason}, + ) + raise ValidationError(message) + + def validate_no_html_tags(value: str, field_name: str = "This field") -> str: """Reject values containing HTML/script tags.""" if HTML_TAG_PATTERN.search(value): - raise ValidationError(f"{field_name} must not contain HTML or script tags.") + _reject( + field_name, + "html_tag", + f"{field_name} must not contain HTML or script tags.", + ) if JS_PROTOCOL_PATTERN.search(value): - raise ValidationError(f"{field_name} must not contain dangerous URI protocols.") + _reject( + field_name, + "js_protocol", + f"{field_name} must not contain dangerous URI protocols.", + ) if EVENT_HANDLER_PATTERN.search(value): - raise ValidationError(f"{field_name} must not contain event handler attributes.") + _reject( + field_name, + "event_handler", + f"{field_name} must not contain event handler attributes.", + ) return value diff --git a/backend/utils/serializer/__init__.py b/backend/utils/serializer/__init__.py new file mode 100644 index 0000000000..0f2b61acb5 --- /dev/null +++ b/backend/utils/serializer/__init__.py @@ -0,0 +1,15 @@ +from utils.serializer.integrity_error_mixin import IntegrityErrorMixin +from utils.serializer.sanitization import ( + HyperlinkedModelSerializer, + ModelSerializer, + SanitizedSerializerMixin, + Serializer, +) + +__all__ = [ + "HyperlinkedModelSerializer", + "IntegrityErrorMixin", + "ModelSerializer", + "SanitizedSerializerMixin", + "Serializer", +] diff --git a/backend/utils/serializer/sanitization.py b/backend/utils/serializer/sanitization.py new file mode 100644 index 0000000000..3c53ed73d4 --- /dev/null +++ b/backend/utils/serializer/sanitization.py @@ -0,0 +1,55 @@ +"""Sanitized serializer base classes. + +`SanitizedSerializerMixin` attaches `validate_no_html_tags` to every writable +`CharField` on the serializer, so new serializers are protected from stored +XSS without per-field wiring. Opt out per-serializer via +`Meta.html_safe_fields = (...)` for fields that legitimately accept HTML-like +content (e.g. prompt text, regex literals). + +Use the pre-mixed `ModelSerializer` / `Serializer` / `HyperlinkedModelSerializer` +classes instead of importing from `rest_framework` directly: + + from utils.serializer import ModelSerializer + + class FooSerializer(ModelSerializer): + class Meta: + model = Foo + fields = ["name", "description", "prompt"] + html_safe_fields = ("prompt",) # opt-out +""" + +from functools import partial + +from rest_framework import serializers as drf + +from utils.input_sanitizer import validate_no_html_tags + + +class SanitizedSerializerMixin: + """Attach `validate_no_html_tags` to every writable CharField.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + meta = getattr(self, "Meta", None) + exempt = set(getattr(meta, "html_safe_fields", ()) or ()) + for name, field in self.fields.items(): + if name in exempt or field.read_only: + continue + if isinstance(field, drf.CharField): + # partial binds the field name at iteration time, avoiding the + # late-binding closure trap of a bare lambda. + field.validators.append(partial(validate_no_html_tags, field_name=name)) + + +class ModelSerializer(SanitizedSerializerMixin, drf.ModelSerializer): + pass + + +class Serializer(SanitizedSerializerMixin, drf.Serializer): + pass + + +class HyperlinkedModelSerializer( + SanitizedSerializerMixin, drf.HyperlinkedModelSerializer +): + pass diff --git a/backend/utils/tests/test_sanitized_serializer_mixin.py b/backend/utils/tests/test_sanitized_serializer_mixin.py new file mode 100644 index 0000000000..20f4833a2a --- /dev/null +++ b/backend/utils/tests/test_sanitized_serializer_mixin.py @@ -0,0 +1,113 @@ +import pytest +from rest_framework import serializers as drf +from rest_framework.exceptions import ValidationError + +from utils.serializer import ( + HyperlinkedModelSerializer, + ModelSerializer, + SanitizedSerializerMixin, + Serializer, +) + + +class PlainSerializer(Serializer): + name = drf.CharField(max_length=100) + description = drf.CharField(max_length=500, allow_blank=True) + count = drf.IntegerField(required=False) + + +class WithOptOutSerializer(Serializer): + name = drf.CharField(max_length=100) + prompt = drf.CharField() + + class Meta: + html_safe_fields = ("prompt",) + + +class WithReadOnlySerializer(Serializer): + name = drf.CharField(max_length=100) + computed = drf.CharField(read_only=True) + + +class TestSanitizedSerializerMixin: + def test_rejects_html_in_writable_charfield(self): + s = PlainSerializer(data={"name": "", "description": ""}) + assert not s.is_valid() + assert "name" in s.errors + + def test_rejects_html_in_description(self): + s = PlainSerializer(data={"name": "ok", "description": ""}) + assert not s.is_valid() + assert "description" in s.errors + + def test_clean_input_passes(self): + s = PlainSerializer(data={"name": "My Workflow", "description": "Plain text."}) + assert s.is_valid(), s.errors + + def test_does_not_touch_non_charfield(self): + s = PlainSerializer(data={"name": "ok", "description": "", "count": 42}) + assert s.is_valid(), s.errors + + def test_each_field_gets_its_own_validator(self): + """Default-arg closure capture: the field_name in the error must match the offender.""" + s = PlainSerializer(data={"name": "ok", "description": ""}) + assert not s.is_valid() + assert "description" in s.errors + msg = str(s.errors["description"][0]) + assert "description" in msg.lower() + + def test_html_safe_fields_opts_out(self): + s = WithOptOutSerializer( + data={"name": "ok", "prompt": "step 1"} + ) + assert s.is_valid(), s.errors + + def test_html_safe_fields_does_not_leak_to_other_fields(self): + s = WithOptOutSerializer( + data={"name": "