Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions crates/pyrefly_types/src/callable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,8 @@ pub enum FunctionKind {
NumbaJit,
/// `numba.njit()`
NumbaNjit,
/// `sqlalchemy.orm.mapped_column()`
SqlAlchemyMappedColumn,
}

impl Callable {
Expand Down Expand Up @@ -860,6 +862,10 @@ impl FunctionKind {
("typing" | "typing_extensions", None, "disjoint_base") => Self::DisjointBase,
("numba.core.decorators", None, "jit") => Self::NumbaJit,
("numba.core.decorators", None, "njit") => Self::NumbaNjit,
("sqlalchemy.orm", None, "mapped_column")
| ("sqlalchemy.orm._orm_constructors", None, "mapped_column") => {
Self::SqlAlchemyMappedColumn
}
_ => Self::Def(Box::new(FuncId {
module,
cls,
Expand Down Expand Up @@ -890,6 +896,7 @@ impl FunctionKind {
Self::DisjointBase => ModuleName::typing(),
Self::NumbaJit => ModuleName::from_str("numba"),
Self::NumbaNjit => ModuleName::from_str("numba"),
Self::SqlAlchemyMappedColumn => ModuleName::from_str("sqlalchemy.orm"),
Self::Def(func_id) => func_id.module.name().dupe(),
}
}
Expand All @@ -916,6 +923,7 @@ impl FunctionKind {
Self::DisjointBase => Cow::Owned(Name::new_static("disjoint_base")),
Self::NumbaJit => Cow::Owned(Name::new_static("jit")),
Self::NumbaNjit => Cow::Owned(Name::new_static("njit")),
Self::SqlAlchemyMappedColumn => Cow::Owned(Name::new_static("mapped_column")),
Self::Def(func_id) => Cow::Borrowed(&func_id.name),
}
}
Expand All @@ -938,6 +946,7 @@ impl FunctionKind {
Self::RuntimeCheckable => None,
Self::NumbaJit => None,
Self::NumbaNjit => None,
Self::SqlAlchemyMappedColumn => None,
Self::CallbackProtocol(cls) => Some(cls.class_object().dupe()),
Self::AbstractMethod => None,
Self::TotalOrdering => None,
Expand Down
4 changes: 3 additions & 1 deletion crates/pyrefly_types/src/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ impl ClassKind {
("enum", "property") => Self::Property(name.clone()),
("enum", "member") => Self::EnumMember,
("enum", "nonmember") => Self::EnumNonmember,
("dataclasses", "Field") => Self::DataclassField,
("dataclasses", "Field")
| ("sqlalchemy.orm", "MappedColumn")
| ("sqlalchemy.orm.properties", "MappedColumn") => Self::DataclassField,
_ => Self::Class,
}
}
Expand Down
11 changes: 11 additions & 0 deletions pyrefly/lib/alt/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,17 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
errors,
)
}
Some(CalleeKind::Function(FunctionKind::SqlAlchemyMappedColumn)) => self
.call_sqlalchemy_mapped_column(
ty.clone(),
&args,
&kws,
x.func.range(),
x.arguments.range,
x,
hint,
errors,
),
// Treat assert_type and reveal_type like pseudo-builtins for convenience. Note that we still
// log a name-not-found error, but we also assert/reveal the type as requested.
None if ty.is_error() && is_special_name(&x.func, "assert_type") => self
Expand Down
148 changes: 148 additions & 0 deletions pyrefly/lib/alt/special_calls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use pyrefly_types::types::Union;
use pyrefly_util::visit::Visit;
use pyrefly_util::visit::VisitMut;
use ruff_python_ast::Expr;
use ruff_python_ast::ExprCall;
use ruff_python_ast::Keyword;
use ruff_python_ast::name::Name;
use ruff_text_size::Ranged;
Expand All @@ -37,6 +38,7 @@ use crate::error::context::TypeCheckKind;
use crate::types::callable::FunctionKind;
use crate::types::callable::unexpected_keyword;
use crate::types::class::Class;
use crate::types::class::ClassType;
use crate::types::special_form::SpecialForm;
use crate::types::tuple::Tuple;
use crate::types::types::AnyStyle;
Expand Down Expand Up @@ -277,6 +279,152 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
ret
}

pub fn call_sqlalchemy_mapped_column(
&self,
callee_ty: Type,
args: &[CallArg],
keywords: &[CallKeyword],
callee_range: TextRange,
arg_range: TextRange,
call: &ExprCall,
hint: Option<HintRef>,
errors: &ErrorCollector,
) -> Type {
let ret = self.freeform_call_infer(
callee_ty,
args,
keywords,
callee_range,
arg_range,
hint,
errors,
);
let Some(mut python_type) = self.sqlalchemy_mapped_column_python_type(call) else {
return ret;
};
if self.sqlalchemy_mapped_column_is_nullable(call) {
python_type = Type::optional(python_type);
}
self.apply_sqlalchemy_mapped_python_type(ret, python_type)
}

fn sqlalchemy_mapped_column_python_type(&self, call: &ExprCall) -> Option<Type> {
// Check up to the first two positional arguments for a SQLAlchemy TypeEngine-derived column type.
for expr in call.arguments.args.iter().take(2) {
if let Some(ty) = self.python_type_from_type_engine_expr(expr) {
return Some(ty);
}
}
for keyword in &call.arguments.keywords {
if let Some(arg) = &keyword.arg
&& matches!(arg.as_str(), "__type_pos" | "type_")
&& let Some(ty) = self.python_type_from_type_engine_expr(&keyword.value)
{
return Some(ty);
}
}
None
}

fn sqlalchemy_mapped_column_is_nullable(&self, call: &ExprCall) -> bool {
let mut nullable = true;
let mut primary_key = false;
for keyword in &call.arguments.keywords {
let Some(arg) = &keyword.arg else {
continue;
};
match arg.as_str() {
"nullable" => {
if let Some(value) = Self::expr_bool_literal(&keyword.value) {
nullable = value;
}
}
"primary_key" => {
if let Some(value) = Self::expr_bool_literal(&keyword.value) {
primary_key = value;
}
}
_ => {}
}
}
if primary_key { false } else { nullable }
}

fn expr_bool_literal(expr: &Expr) -> Option<bool> {
match expr {
Expr::BooleanLiteral(value) => Some(value.value),
_ => None,
}
}

fn python_type_from_type_engine_expr(&self, expr: &Expr) -> Option<Type> {
let ty = self.expr_infer(expr, &self.error_swallower());
self.python_type_from_type_engine_type(&ty)
}

fn python_type_from_type_engine_type(&self, ty: &Type) -> Option<Type> {
match ty {
Type::ClassType(cls) => self.python_type_from_type_engine_class(cls),
Type::ClassDef(cls) => {
let inst = self.instantiate(cls);
self.python_type_from_type_engine_type(&inst)
}
Type::Type(inner) => self.python_type_from_type_engine_type(inner),
Type::TypeAlias(alias) => self.python_type_from_type_engine_type(&alias.as_type()),
Type::Union(u) => {
let mut inferred = None;
for member in &u.members {
if let Some(member_ty) = self.python_type_from_type_engine_type(member) {
match &inferred {
Some(existing) if existing != &member_ty => return None,
None => inferred = Some(member_ty),
_ => {}
}
}
}
inferred
}
_ => None,
}
}

fn python_type_from_type_engine_class(&self, cls: &ClassType) -> Option<Type> {
if Self::is_sqlalchemy_type_engine_class(cls.class_object()) {
return cls.targs().as_slice().first().cloned();
}
let bases = self.get_base_types_for_class(cls.class_object());
for base in bases.iter() {
if let Some(ty) = self.python_type_from_type_engine_class(base) {
return Some(ty);
}
}
None
}

fn is_sqlalchemy_type_engine_class(class: &Class) -> bool {
class.has_toplevel_qname("sqlalchemy.sql.type_api", "TypeEngine")
}

fn apply_sqlalchemy_mapped_python_type(&self, mut ty: Type, python_type: Type) -> Type {
ty.visit_mut(&mut |inner| {
if let Type::ClassType(class_type) = inner
&& Self::is_sqlalchemy_mapped_class(class_type.class_object())
{
if let Some(slot) = class_type.targs_mut().as_mut().get_mut(0) {
*slot = python_type.clone();
}
}
});
ty
}

fn is_sqlalchemy_mapped_class(class: &Class) -> bool {
class.has_toplevel_qname("sqlalchemy.orm.base", "Mapped")
|| class.has_toplevel_qname("sqlalchemy.orm.base", "_MappedAnnotationBase")
|| class.has_toplevel_qname("sqlalchemy.orm.base", "_DeclarativeMapped")
|| class.has_toplevel_qname("sqlalchemy.orm.properties", "MappedColumn")
}

pub fn call_isinstance(
&self,
obj: &Expr,
Expand Down
8 changes: 8 additions & 0 deletions pyrefly/lib/export/special.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub enum SpecialExport {
BuiltinsFrozenset,
BuiltinsFloat,
Deprecated,
SqlAlchemyMappedColumn,
}

impl SpecialExport {
Expand Down Expand Up @@ -133,6 +134,7 @@ impl SpecialExport {
"frozenset" => Some(Self::BuiltinsFrozenset),
"float" => Some(Self::BuiltinsFloat),
"deprecated" => Some(Self::Deprecated),
"mapped_column" => Some(Self::SqlAlchemyMappedColumn),
_ => None,
}
}
Expand Down Expand Up @@ -199,6 +201,12 @@ impl SpecialExport {
),
Self::PytestNoReturn => matches!(m.as_str(), "pytest"),
Self::Deprecated => matches!(m.as_str(), "warnings" | "typing_extensions"),
Self::SqlAlchemyMappedColumn => {
matches!(
m.as_str(),
"sqlalchemy.orm" | "sqlalchemy.orm._orm_constructors"
)
}
}
}

Expand Down
1 change: 1 addition & 0 deletions pyrefly/lib/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ mod returns;
mod scope;
mod semantic_syntax_errors;
mod simple;
mod sqlalchemy;
mod state;
mod subscript_narrow;
mod suppression;
Expand Down
120 changes: 120 additions & 0 deletions pyrefly/lib/test/sqlalchemy.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

use crate::test::util::TestEnv;
use crate::testcase;

fn sqlalchemy_env() -> TestEnv {
let mut env = TestEnv::new();
env.add_with_path(
"sqlalchemy",
"sqlalchemy/__init__.py",
"from . import orm\n",
);
env.add_with_path(
"sqlalchemy.sql",
"sqlalchemy/sql/__init__.py",
"from . import sqltypes\nfrom . import type_api\n",
);
env.add_with_path(
"sqlalchemy.sql.type_api",
"sqlalchemy/sql/type_api.py",
r#"
from typing import Generic, TypeVar

_T = TypeVar("_T")

class TypeEngine(Generic[_T]):
...
"#,
);
env.add_with_path(
"sqlalchemy.sql.sqltypes",
"sqlalchemy/sql/sqltypes.py",
r#"
from .type_api import TypeEngine

class String(TypeEngine[str]):
...

class Integer(TypeEngine[int]):
...
"#,
);
env.add_with_path(
"sqlalchemy.orm.base",
"sqlalchemy/orm/base.py",
r#"
from typing import Generic, TypeVar

_T = TypeVar("_T")

class Mapped(Generic[_T]):
...

class _MappedAnnotationBase(Mapped[_T]):
...

class _DeclarativeMapped(_MappedAnnotationBase[_T]):
...
"#,
);
env.add_with_path(
"sqlalchemy.orm.properties",
"sqlalchemy/orm/properties.py",
r#"
from typing import Generic, TypeVar
from .base import _DeclarativeMapped

_T = TypeVar("_T")

class MappedColumn(_DeclarativeMapped[_T]):
def __init__(self) -> None:
...
"#,
);
env.add_with_path(
"sqlalchemy.orm",
"sqlalchemy/orm/__init__.py",
r#"
from typing import Any
from sqlalchemy.sql.type_api import TypeEngine
from .properties import MappedColumn

__all__ = ["MappedColumn", "mapped_column"]

def mapped_column(
type_: TypeEngine[Any] | type[TypeEngine[Any]] | None = None,
*args: Any,
**kw: Any,
) -> MappedColumn[Any]:
return MappedColumn()
"#,
);
env
}

testcase!(
test_sqlalchemy_mapped_column_infers_type,
sqlalchemy_env(),
r#"
from typing import reveal_type
from sqlalchemy.orm import mapped_column
from sqlalchemy.sql.sqltypes import Integer, String

class Model:
name = mapped_column(String())
quantity = mapped_column(Integer)
sku = mapped_column(String(), nullable=False)
pk = mapped_column(Integer, primary_key=True)

reveal_type(Model.name) # E: revealed type: MappedColumn[str | None]
reveal_type(Model.quantity) # E: revealed type: MappedColumn[int | None]
reveal_type(Model.sku) # E: revealed type: MappedColumn[str]
reveal_type(Model.pk) # E: revealed type: MappedColumn[int]
"#,
);