Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions backend/utils/input_sanitizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import logging
import re
from typing import NoReturn

from rest_framework.serializers import ValidationError

logger = logging.getLogger(__name__)

# Pattern to detect HTML/script tags (closed tags and unclosed tags starting with a letter)
# The second alternative catches unclosed tags like "<script" or "<img src=x" that could
# be completed by adjacent content in non-React rendering contexts (emails, PDFs, logs)
Expand All @@ -26,14 +30,34 @@
EVENT_HANDLER_PATTERN = re.compile(rf"\bon({_DOM_EVENTS})\s*=", re.IGNORECASE)


def _reject(field_name: str, reason: str, message: str) -> NoReturn:
logger.warning(
"input_validation_rejected",
extra={"field": field_name, "reason": reason},
)
raise ValidationError(message)


def validate_no_html_tags(value: str, field_name: str = "This field") -> str:
"""Reject values containing HTML/script tags."""
if HTML_TAG_PATTERN.search(value):
raise ValidationError(f"{field_name} must not contain HTML or script tags.")
_reject(
field_name,
"html_tag",
f"{field_name} must not contain HTML or script tags.",
)
if JS_PROTOCOL_PATTERN.search(value):
raise ValidationError(f"{field_name} must not contain dangerous URI protocols.")
_reject(
field_name,
"js_protocol",
f"{field_name} must not contain dangerous URI protocols.",
)
if EVENT_HANDLER_PATTERN.search(value):
raise ValidationError(f"{field_name} must not contain event handler attributes.")
_reject(
field_name,
"event_handler",
f"{field_name} must not contain event handler attributes.",
)
return value


Expand Down
15 changes: 15 additions & 0 deletions backend/utils/serializer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from utils.serializer.integrity_error_mixin import IntegrityErrorMixin
from utils.serializer.sanitization import (
HyperlinkedModelSerializer,
ModelSerializer,
SanitizedSerializerMixin,
Serializer,
)

__all__ = [
"HyperlinkedModelSerializer",
"IntegrityErrorMixin",
"ModelSerializer",
"SanitizedSerializerMixin",
"Serializer",
]
55 changes: 55 additions & 0 deletions backend/utils/serializer/sanitization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Sanitized serializer base classes.

`SanitizedSerializerMixin` attaches `validate_no_html_tags` to every writable
`CharField` on the serializer, so new serializers are protected from stored
XSS without per-field wiring. Opt out per-serializer via
`Meta.html_safe_fields = (...)` for fields that legitimately accept HTML-like
content (e.g. prompt text, regex literals).

Use the pre-mixed `ModelSerializer` / `Serializer` / `HyperlinkedModelSerializer`
classes instead of importing from `rest_framework` directly:

from utils.serializer import ModelSerializer

class FooSerializer(ModelSerializer):
class Meta:
model = Foo
fields = ["name", "description", "prompt"]
html_safe_fields = ("prompt",) # opt-out
"""

from functools import partial

from rest_framework import serializers as drf

from utils.input_sanitizer import validate_no_html_tags


class SanitizedSerializerMixin:
"""Attach `validate_no_html_tags` to every writable CharField."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
meta = getattr(self, "Meta", None)
exempt = set(getattr(meta, "html_safe_fields", ()) or ())
for name, field in self.fields.items():
if name in exempt or field.read_only:
continue
if isinstance(field, drf.CharField):
# partial binds the field name at iteration time, avoiding the
# late-binding closure trap of a bare lambda.
field.validators.append(partial(validate_no_html_tags, field_name=name))


class ModelSerializer(SanitizedSerializerMixin, drf.ModelSerializer):
pass


class Serializer(SanitizedSerializerMixin, drf.Serializer):
pass


class HyperlinkedModelSerializer(
SanitizedSerializerMixin, drf.HyperlinkedModelSerializer
):
pass
113 changes: 113 additions & 0 deletions backend/utils/tests/test_sanitized_serializer_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import pytest
from rest_framework import serializers as drf
from rest_framework.exceptions import ValidationError

from utils.serializer import (
HyperlinkedModelSerializer,
ModelSerializer,
SanitizedSerializerMixin,
Serializer,
)


class PlainSerializer(Serializer):
name = drf.CharField(max_length=100)
description = drf.CharField(max_length=500, allow_blank=True)
count = drf.IntegerField(required=False)


class WithOptOutSerializer(Serializer):
name = drf.CharField(max_length=100)
prompt = drf.CharField()

class Meta:
html_safe_fields = ("prompt",)


class WithReadOnlySerializer(Serializer):
name = drf.CharField(max_length=100)
computed = drf.CharField(read_only=True)


class TestSanitizedSerializerMixin:
def test_rejects_html_in_writable_charfield(self):
s = PlainSerializer(data={"name": "<script>alert(1)</script>", "description": ""})
assert not s.is_valid()
assert "name" in s.errors

def test_rejects_html_in_description(self):
s = PlainSerializer(data={"name": "ok", "description": "<img onerror=x>"})
assert not s.is_valid()
assert "description" in s.errors

def test_clean_input_passes(self):
s = PlainSerializer(data={"name": "My Workflow", "description": "Plain text."})
assert s.is_valid(), s.errors

def test_does_not_touch_non_charfield(self):
s = PlainSerializer(data={"name": "ok", "description": "", "count": 42})
assert s.is_valid(), s.errors

def test_each_field_gets_its_own_validator(self):
"""Default-arg closure capture: the field_name in the error must match the offender."""
s = PlainSerializer(data={"name": "ok", "description": "<x>"})
assert not s.is_valid()
assert "description" in s.errors
msg = str(s.errors["description"][0])
assert "description" in msg.lower()

def test_html_safe_fields_opts_out(self):
s = WithOptOutSerializer(
data={"name": "ok", "prompt": "<thinking>step 1</thinking>"}
)
assert s.is_valid(), s.errors

def test_html_safe_fields_does_not_leak_to_other_fields(self):
s = WithOptOutSerializer(
data={"name": "<script>", "prompt": "<thinking>step 1</thinking>"}
)
assert not s.is_valid()
assert "name" in s.errors

def test_read_only_field_is_naturally_exempt(self):
s = WithReadOnlySerializer(data={"name": "ok"})
assert s.is_valid(), s.errors

def test_missing_meta_does_not_break(self):
class NoMetaSerializer(Serializer):
name = drf.CharField()

s = NoMetaSerializer(data={"name": "ok"})
assert s.is_valid(), s.errors

def test_meta_without_html_safe_fields_does_not_break(self):
class MetaWithoutAttrSerializer(Serializer):
name = drf.CharField()

class Meta:
pass

s = MetaWithoutAttrSerializer(data={"name": "<script>"})
assert not s.is_valid()
assert "name" in s.errors

def test_pre_mixed_classes_inherit_mixin(self):
assert issubclass(ModelSerializer, SanitizedSerializerMixin)
assert issubclass(Serializer, SanitizedSerializerMixin)
assert issubclass(HyperlinkedModelSerializer, SanitizedSerializerMixin)

def test_rejects_javascript_protocol(self):
s = PlainSerializer(data={"name": "javascript:alert(1)", "description": ""})
assert not s.is_valid()
assert "name" in s.errors

def test_rejects_event_handler(self):
s = PlainSerializer(data={"name": "onclick=alert(1)", "description": ""})
assert not s.is_valid()
assert "name" in s.errors

def test_validation_raises_validation_error_type(self):
"""End-to-end: raise_exception path surfaces a DRF ValidationError."""
s = PlainSerializer(data={"name": "<script>", "description": ""})
with pytest.raises(ValidationError):
s.is_valid(raise_exception=True)