Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions docs/relations/many-to-many.md
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,33 @@ class Student(ormar.Model):
)
```

## Making through relation columns non-nullable

By default the auto-generated foreign key columns on the through table are
nullable (matching SQLAlchemy's default for non primary-key columns). This can
be overridden per column with:

* `through_relation_nullable` - controls nullability of the column pointing to
the model where `ManyToMany` is declared (the owner side). Defaults to `True`.
* `through_reverse_relation_nullable` - controls nullability of the column
pointing to the target model. Defaults to `True`.

Set either (or both) to `False` when you want the database to enforce that a
through row always references both sides.

```python
class Student(ormar.Model):
ormar_config = base_ormar_config.copy()

id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
courses = ormar.ManyToMany(
Course,
through_relation_nullable=False,
through_reverse_relation_nullable=False,
)
```

## Through Fields

The through field is auto added to the reverse side of the relation.
Expand Down
6 changes: 6 additions & 0 deletions ormar/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ def __init__(self, **kwargs: Any) -> None:
self.through_reverse_foreign_key_name: Optional[str] = kwargs.pop(
"through_reverse_foreign_key_name", None
)
self.through_relation_nullable: bool = kwargs.pop(
"through_relation_nullable", True
)
self.through_reverse_relation_nullable: bool = kwargs.pop(
"through_reverse_relation_nullable", True
)

self.skip_reverse: bool = kwargs.pop("skip_reverse", False)
self.skip_field: bool = kwargs.pop("skip_field", False)
Expand Down
7 changes: 7 additions & 0 deletions ormar/fields/many_to_many.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ def ManyToMany( # type: ignore
"through_reverse_foreign_key_name", None
)

through_relation_nullable = kwargs.pop("through_relation_nullable", True)
through_reverse_relation_nullable = kwargs.pop(
"through_reverse_relation_nullable", True
)

if through is not None and through.__class__ != ForwardRef:
forbid_through_relations(cast(type["Model"], through))

Expand Down Expand Up @@ -178,6 +183,8 @@ def ManyToMany( # type: ignore
through_reverse_relation_name=through_reverse_relation_name,
through_foreign_key_name=through_foreign_key_name,
through_reverse_foreign_key_name=through_reverse_foreign_key_name,
through_relation_nullable=through_relation_nullable,
through_reverse_relation_nullable=through_reverse_relation_nullable,
)

Field = type("ManyToMany", (ManyToManyField, BaseField), {})
Expand Down
2 changes: 2 additions & 0 deletions ormar/models/helpers/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def register_reverse_model_fields(model_field: "ForeignKeyField") -> None:
skip_field=model_field.skip_reverse,
through_relation_name=model_field.through_reverse_relation_name,
through_reverse_relation_name=model_field.through_relation_name,
through_relation_nullable=model_field.through_reverse_relation_nullable,
through_reverse_relation_nullable=model_field.through_relation_nullable,
)
# register foreign keys on through model
model_field = cast("ManyToManyField", model_field)
Expand Down
6 changes: 6 additions & 0 deletions ormar/models/helpers/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,14 @@ def adjust_through_many_to_many_model(model_field: "ManyToManyField") -> None:
model_field=model_field,
field_name=parent_name,
foreign_key_name=model_field.through_reverse_foreign_key_name,
nullable=model_field.through_reverse_relation_nullable,
)
create_and_append_m2m_fk(
model=model_field.owner,
model_field=model_field,
field_name=child_name,
foreign_key_name=model_field.through_foreign_key_name,
nullable=model_field.through_relation_nullable,
)

create_pydantic_field(parent_name, model_field.to, model_field)
Expand All @@ -68,6 +70,7 @@ def create_and_append_m2m_fk(
model_field: "ManyToManyField",
field_name: str,
foreign_key_name: Optional[str] = None,
nullable: bool = True,
) -> None:
"""
Registers sqlalchemy Column with sqlalchemy.ForeignKey leading to the model.
Expand All @@ -83,6 +86,8 @@ def create_and_append_m2m_fk(
:type model_field: ManyToManyField field
:param foreign_key_name: optional override for the generated FK constraint name.
:type foreign_key_name: Optional[str]
:param nullable: whether the created column is nullable.
:type nullable: bool
"""
pk_alias = model.get_column_alias(model.ormar_config.pkname)
pk_column = next(
Expand All @@ -104,6 +109,7 @@ def create_and_append_m2m_fk(
onupdate="CASCADE",
name=foreign_key_name or default_name,
),
nullable=nullable,
)
model_field.through.ormar_config.columns.append(column)
model_field.through.ormar_config.table.append_column(column, replace_existing=True)
Expand Down
128 changes: 128 additions & 0 deletions tests/test_relations/test_through_relation_nullable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import pytest
import sqlalchemy

import ormar
from tests.lifespan import init_tests
from tests.settings import create_config

base_ormar_config = create_config()


class Subject(ormar.Model):
ormar_config = base_ormar_config.copy(tablename="trnn_subjects")

id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)


class Student(ormar.Model):
ormar_config = base_ormar_config.copy(tablename="trnn_students")

id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
subjects = ormar.ManyToMany(
Subject,
through_foreign_key_name="fk_trnn_stu_subj",
through_reverse_foreign_key_name="fk_trnn_stu_subj_rev",
)


class StudentStrict(ormar.Model):
ormar_config = base_ormar_config.copy(tablename="trnn_students_strict")

id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
subjects = ormar.ManyToMany(
Subject,
through_relation_nullable=False,
through_reverse_relation_nullable=False,
through_foreign_key_name="fk_trnn_stu_strict_subj",
through_reverse_foreign_key_name="fk_trnn_stu_strict_subj_rev",
)


class StudentOwnerNotNull(ormar.Model):
ormar_config = base_ormar_config.copy(tablename="trnn_students_owner_not_null")

id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
subjects = ormar.ManyToMany(
Subject,
through_relation_nullable=False,
through_foreign_key_name="fk_trnn_stu_owner_subj",
through_reverse_foreign_key_name="fk_trnn_stu_owner_subj_rev",
)


class StudentReverseNotNull(ormar.Model):
ormar_config = base_ormar_config.copy(tablename="trnn_students_reverse_not_null")

id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
subjects = ormar.ManyToMany(
Subject,
through_reverse_relation_nullable=False,
through_foreign_key_name="fk_trnn_stu_rev_subj",
through_reverse_foreign_key_name="fk_trnn_stu_rev_subj_rev",
)


create_test_database = init_tests(base_ormar_config)


def _through_table(model, field_name):
return model.ormar_config.model_fields[field_name].through.ormar_config.table


def test_default_through_columns_are_nullable():
table = _through_table(Student, "subjects")
assert table.c["student"].nullable is True
assert table.c["subject"].nullable is True


def test_through_relation_nullable_false_sets_owner_column_not_null():
table = _through_table(StudentOwnerNotNull, "subjects")
assert table.c["studentownernotnull"].nullable is False
assert table.c["subject"].nullable is True


def test_through_reverse_relation_nullable_false_sets_target_column_not_null():
table = _through_table(StudentReverseNotNull, "subjects")
assert table.c["studentreversenotnull"].nullable is True
assert table.c["subject"].nullable is False


def test_both_through_columns_can_be_not_null():
table = _through_table(StudentStrict, "subjects")
assert table.c["studentstrict"].nullable is False
assert table.c["subject"].nullable is False


@pytest.mark.asyncio
async def test_m2m_with_non_nullable_through_columns_works_at_runtime():
async with base_ormar_config.database:
async with base_ormar_config.database.transaction(force_rollback=True):
subject = await Subject.objects.create(name="math")
student = await StudentStrict.objects.create(name="Alice")
await student.subjects.add(subject)

fetched = await StudentStrict.objects.select_related("subjects").get(
id=student.id
)
assert len(fetched.subjects) == 1
assert fetched.subjects[0].name == "math"


@pytest.mark.asyncio
async def test_insert_null_into_non_nullable_through_column_fails():
async with base_ormar_config.database:
subject = await Subject.objects.create(name="chemistry")
through_table = _through_table(StudentStrict, "subjects")

async with base_ormar_config.database.connection() as conn:
with pytest.raises(sqlalchemy.exc.IntegrityError):
await conn.execute(
through_table.insert().values(
studentstrict=None, subject=subject.id
)
)