From 78cd51537ea2c0f92fb92a11bdcf5198e910637e Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Sat, 20 Dec 2025 20:54:10 +0900 Subject: [PATCH 1/3] fix --- crates/pyrefly_types/src/callable.rs | 9 +++ crates/pyrefly_types/src/class.rs | 4 +- pyrefly/lib/alt/call.rs | 11 +++ pyrefly/lib/alt/special_calls.rs | 114 +++++++++++++++++++++++++++ pyrefly/lib/export/special.rs | 8 ++ pyrefly/lib/test/mod.rs | 1 + pyrefly/lib/test/sqlalchemy.rs | 112 ++++++++++++++++++++++++++ 7 files changed, 258 insertions(+), 1 deletion(-) create mode 100644 pyrefly/lib/test/sqlalchemy.rs diff --git a/crates/pyrefly_types/src/callable.rs b/crates/pyrefly_types/src/callable.rs index 5764e6a51d..6297331ccf 100644 --- a/crates/pyrefly_types/src/callable.rs +++ b/crates/pyrefly_types/src/callable.rs @@ -513,6 +513,8 @@ pub enum FunctionKind { NumbaJit, /// `numba.njit()` NumbaNjit, + /// `sqlalchemy.orm.mapped_column()` + SqlAlchemyMappedColumn, } impl Callable { @@ -860,6 +862,10 @@ impl FunctionKind { ("typing" | "typing_extensions", None, "disjoint_base") => Self::DisjointBase, ("numba.core.decorators", None, "jit") => Self::NumbaJit, ("numba.core.decorators", None, "njit") => Self::NumbaNjit, + ("sqlalchemy.orm", None, "mapped_column") + | ("sqlalchemy.orm._orm_constructors", None, "mapped_column") => { + Self::SqlAlchemyMappedColumn + } _ => Self::Def(Box::new(FuncId { module, cls, @@ -890,6 +896,7 @@ impl FunctionKind { Self::DisjointBase => ModuleName::typing(), Self::NumbaJit => ModuleName::from_str("numba"), Self::NumbaNjit => ModuleName::from_str("numba"), + Self::SqlAlchemyMappedColumn => ModuleName::from_str("sqlalchemy.orm"), Self::Def(func_id) => func_id.module.name().dupe(), } } @@ -916,6 +923,7 @@ impl FunctionKind { Self::DisjointBase => Cow::Owned(Name::new_static("disjoint_base")), Self::NumbaJit => Cow::Owned(Name::new_static("jit")), Self::NumbaNjit => Cow::Owned(Name::new_static("njit")), + Self::SqlAlchemyMappedColumn => Cow::Owned(Name::new_static("mapped_column")), Self::Def(func_id) => Cow::Borrowed(&func_id.name), } } @@ -938,6 +946,7 @@ impl FunctionKind { Self::RuntimeCheckable => None, Self::NumbaJit => None, Self::NumbaNjit => None, + Self::SqlAlchemyMappedColumn => None, Self::CallbackProtocol(cls) => Some(cls.class_object().dupe()), Self::AbstractMethod => None, Self::TotalOrdering => None, diff --git a/crates/pyrefly_types/src/class.rs b/crates/pyrefly_types/src/class.rs index 1a13b0cf9c..4dcd8569b1 100644 --- a/crates/pyrefly_types/src/class.rs +++ b/crates/pyrefly_types/src/class.rs @@ -170,7 +170,9 @@ impl ClassKind { ("enum", "property") => Self::Property(name.clone()), ("enum", "member") => Self::EnumMember, ("enum", "nonmember") => Self::EnumNonmember, - ("dataclasses", "Field") => Self::DataclassField, + ("dataclasses", "Field") + | ("sqlalchemy.orm", "MappedColumn") + | ("sqlalchemy.orm.properties", "MappedColumn") => Self::DataclassField, _ => Self::Class, } } diff --git a/pyrefly/lib/alt/call.rs b/pyrefly/lib/alt/call.rs index d0262717dc..a0c175c435 100644 --- a/pyrefly/lib/alt/call.rs +++ b/pyrefly/lib/alt/call.rs @@ -1308,6 +1308,17 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { errors, ) } + Some(CalleeKind::Function(FunctionKind::SqlAlchemyMappedColumn)) => self + .call_sqlalchemy_mapped_column( + ty.clone(), + &args, + &kws, + x.func.range(), + x.arguments.range, + x, + hint, + errors, + ), // Treat assert_type and reveal_type like pseudo-builtins for convenience. Note that we still // log a name-not-found error, but we also assert/reveal the type as requested. None if ty.is_error() && is_special_name(&x.func, "assert_type") => self diff --git a/pyrefly/lib/alt/special_calls.rs b/pyrefly/lib/alt/special_calls.rs index 174f397108..aa7c7aafee 100644 --- a/pyrefly/lib/alt/special_calls.rs +++ b/pyrefly/lib/alt/special_calls.rs @@ -16,6 +16,7 @@ use pyrefly_types::types::Union; use pyrefly_util::visit::Visit; use pyrefly_util::visit::VisitMut; use ruff_python_ast::Expr; +use ruff_python_ast::ExprCall; use ruff_python_ast::Keyword; use ruff_python_ast::name::Name; use ruff_text_size::Ranged; @@ -37,6 +38,7 @@ use crate::error::context::TypeCheckKind; use crate::types::callable::FunctionKind; use crate::types::callable::unexpected_keyword; use crate::types::class::Class; +use crate::types::class::ClassType; use crate::types::special_form::SpecialForm; use crate::types::tuple::Tuple; use crate::types::types::AnyStyle; @@ -277,6 +279,118 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ret } + pub fn call_sqlalchemy_mapped_column( + &self, + callee_ty: Type, + args: &[CallArg], + keywords: &[CallKeyword], + callee_range: TextRange, + arg_range: TextRange, + call: &ExprCall, + hint: Option, + errors: &ErrorCollector, + ) -> Type { + let ret = self.freeform_call_infer( + callee_ty, + args, + keywords, + callee_range, + arg_range, + hint, + errors, + ); + let Some(python_type) = self.sqlalchemy_mapped_column_python_type(call) else { + return ret; + }; + self.apply_sqlalchemy_mapped_python_type(ret, python_type) + } + + fn sqlalchemy_mapped_column_python_type(&self, call: &ExprCall) -> Option { + // mapped_column's first two positional arguments correspond to the name/type. + for expr in call.arguments.args.iter().take(2) { + if let Some(ty) = self.python_type_from_type_engine_expr(expr) { + return Some(ty); + } + } + for keyword in &call.arguments.keywords { + if let Some(arg) = &keyword.arg + && matches!(arg.as_str(), "__type_pos" | "type_") + && let Some(ty) = self.python_type_from_type_engine_expr(&keyword.value) + { + return Some(ty); + } + } + None + } + + fn python_type_from_type_engine_expr(&self, expr: &Expr) -> Option { + let ty = self.expr_infer(expr, &self.error_swallower()); + self.python_type_from_type_engine_type(&ty) + } + + fn python_type_from_type_engine_type(&self, ty: &Type) -> Option { + match ty { + Type::ClassType(cls) => self.python_type_from_type_engine_class(cls), + Type::ClassDef(cls) => { + let inst = self.instantiate(cls); + self.python_type_from_type_engine_type(&inst) + } + Type::Type(inner) => self.python_type_from_type_engine_type(inner), + Type::TypeAlias(alias) => self.python_type_from_type_engine_type(&alias.as_type()), + Type::Union(u) => { + let mut inferred = None; + for member in &u.members { + if let Some(member_ty) = self.python_type_from_type_engine_type(member) { + match &inferred { + Some(existing) if existing != &member_ty => return None, + None => inferred = Some(member_ty), + _ => {} + } + } + } + inferred + } + _ => None, + } + } + + fn python_type_from_type_engine_class(&self, cls: &ClassType) -> Option { + if Self::is_sqlalchemy_type_engine_class(cls.class_object()) { + return cls.targs().as_slice().first().cloned(); + } + let bases = self.get_base_types_for_class(cls.class_object()); + for base in bases.iter() { + if let Some(ty) = self.python_type_from_type_engine_class(base) { + return Some(ty); + } + } + None + } + + fn is_sqlalchemy_type_engine_class(class: &Class) -> bool { + class.has_toplevel_qname("sqlalchemy.sql.type_api", "TypeEngine") + } + + fn apply_sqlalchemy_mapped_python_type(&self, mut ty: Type, python_type: Type) -> Type { + ty.visit_mut(&mut |inner| { + if let Type::ClassType(class_type) = inner + && Self::is_sqlalchemy_mapped_class(class_type.class_object()) + { + if let Some(slot) = class_type.targs_mut().as_mut().get_mut(0) { + *slot = python_type.clone(); + } + } + }); + ty + } + + fn is_sqlalchemy_mapped_class(class: &Class) -> bool { + class.has_toplevel_qname("sqlalchemy.orm.base", "Mapped") + || class.has_toplevel_qname("sqlalchemy.orm.base", "_MappedAnnotationBase") + || class.has_toplevel_qname("sqlalchemy.orm.base", "_DeclarativeMapped") + || class.has_toplevel_qname("sqlalchemy.orm.properties", "MappedColumn") + } + pub fn call_isinstance( &self, obj: &Expr, diff --git a/pyrefly/lib/export/special.rs b/pyrefly/lib/export/special.rs index 35fe7fad5d..0167ea3df5 100644 --- a/pyrefly/lib/export/special.rs +++ b/pyrefly/lib/export/special.rs @@ -71,6 +71,7 @@ pub enum SpecialExport { BuiltinsFrozenset, BuiltinsFloat, Deprecated, + SqlAlchemyMappedColumn, } impl SpecialExport { @@ -133,6 +134,7 @@ impl SpecialExport { "frozenset" => Some(Self::BuiltinsFrozenset), "float" => Some(Self::BuiltinsFloat), "deprecated" => Some(Self::Deprecated), + "mapped_column" => Some(Self::SqlAlchemyMappedColumn), _ => None, } } @@ -199,6 +201,12 @@ impl SpecialExport { ), Self::PytestNoReturn => matches!(m.as_str(), "pytest"), Self::Deprecated => matches!(m.as_str(), "warnings" | "typing_extensions"), + Self::SqlAlchemyMappedColumn => { + matches!( + m.as_str(), + "sqlalchemy.orm" | "sqlalchemy.orm._orm_constructors" + ) + } } } diff --git a/pyrefly/lib/test/mod.rs b/pyrefly/lib/test/mod.rs index 1db0ae1cff..b1cd169ca0 100644 --- a/pyrefly/lib/test/mod.rs +++ b/pyrefly/lib/test/mod.rs @@ -58,6 +58,7 @@ mod returns; mod scope; mod semantic_syntax_errors; mod simple; +mod sqlalchemy; mod state; mod subscript_narrow; mod suppression; diff --git a/pyrefly/lib/test/sqlalchemy.rs b/pyrefly/lib/test/sqlalchemy.rs new file mode 100644 index 0000000000..11f4ccff27 --- /dev/null +++ b/pyrefly/lib/test/sqlalchemy.rs @@ -0,0 +1,112 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +use crate::test::util::TestEnv; +use crate::testcase; + +fn sqlalchemy_env() -> TestEnv { + let mut env = TestEnv::new(); + env.add_with_path( + "sqlalchemy", + "sqlalchemy/__init__.py", + "from . import orm\n", + ); + env.add_with_path( + "sqlalchemy.sql", + "sqlalchemy/sql/__init__.py", + "from . import sqltypes\nfrom . import type_api\n", + ); + env.add_with_path( + "sqlalchemy.sql.type_api", + "sqlalchemy/sql/type_api.py", + r#" +from typing import Generic, TypeVar + +_T = TypeVar("_T") + +class TypeEngine(Generic[_T]): + ... +"#, + ); + env.add_with_path( + "sqlalchemy.sql.sqltypes", + "sqlalchemy/sql/sqltypes.py", + r#" +from .type_api import TypeEngine + +class String(TypeEngine[str]): + ... + +class Integer(TypeEngine[int]): + ... +"#, + ); + env.add_with_path( + "sqlalchemy.orm.base", + "sqlalchemy/orm/base.py", + r#" +from typing import Generic, TypeVar + +_T = TypeVar("_T") + +class Mapped(Generic[_T]): + ... + +class _MappedAnnotationBase(Mapped[_T]): + ... + +class _DeclarativeMapped(_MappedAnnotationBase[_T]): + ... +"#, + ); + env.add_with_path( + "sqlalchemy.orm.properties", + "sqlalchemy/orm/properties.py", + r#" +from typing import Generic, TypeVar +from .base import _DeclarativeMapped + +_T = TypeVar("_T") + +class MappedColumn(_DeclarativeMapped[_T]): + def __init__(self) -> None: + ... +"#, + ); + env.add_with_path( + "sqlalchemy.orm", + "sqlalchemy/orm/__init__.py", + r#" +from typing import Any +from sqlalchemy.sql.type_api import TypeEngine +from .properties import MappedColumn + +__all__ = ["MappedColumn", "mapped_column"] + +def mapped_column(type_: TypeEngine[Any] | type[TypeEngine[Any]] | None = None) -> MappedColumn[Any]: + return MappedColumn() +"#, + ); + env +} + +testcase!( + test_sqlalchemy_mapped_column_infers_type, + sqlalchemy_env(), + r#" +from typing import reveal_type +from sqlalchemy.orm import mapped_column +from sqlalchemy.sql.sqltypes import Integer, String + +class Model: + name = mapped_column(String()) + quantity = mapped_column(Integer) + +reveal_type(Model.name) # E: revealed type: MappedColumn[str] +reveal_type(Model.quantity) # E: revealed type: MappedColumn[int] +"#, +); From fc8df2684df6b692c4cc1db931574ef39c22927f Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Sat, 20 Dec 2025 21:35:47 +0900 Subject: [PATCH 2/3] check nullable --- pyrefly/lib/alt/special_calls.rs | 36 +++++++++++++++++++++++++++++++- pyrefly/lib/test/sqlalchemy.rs | 14 ++++++++++--- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/pyrefly/lib/alt/special_calls.rs b/pyrefly/lib/alt/special_calls.rs index aa7c7aafee..23ef0a63f6 100644 --- a/pyrefly/lib/alt/special_calls.rs +++ b/pyrefly/lib/alt/special_calls.rs @@ -299,9 +299,12 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { hint, errors, ); - let Some(python_type) = self.sqlalchemy_mapped_column_python_type(call) else { + let Some(mut python_type) = self.sqlalchemy_mapped_column_python_type(call) else { return ret; }; + if self.sqlalchemy_mapped_column_is_nullable(call) { + python_type = Type::optional(python_type); + } self.apply_sqlalchemy_mapped_python_type(ret, python_type) } @@ -323,6 +326,37 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { None } + fn sqlalchemy_mapped_column_is_nullable(&self, call: &ExprCall) -> bool { + let mut nullable = true; + let mut primary_key = false; + for keyword in &call.arguments.keywords { + let Some(arg) = &keyword.arg else { + continue; + }; + match arg.as_str() { + "nullable" => { + if let Some(value) = Self::expr_bool_literal(&keyword.value) { + nullable = value; + } + } + "primary_key" => { + if let Some(value) = Self::expr_bool_literal(&keyword.value) { + primary_key = value; + } + } + _ => {} + } + } + if primary_key { false } else { nullable } + } + + fn expr_bool_literal(expr: &Expr) -> Option { + match expr { + Expr::BooleanLiteral(value) => Some(value.value), + _ => None, + } + } + fn python_type_from_type_engine_expr(&self, expr: &Expr) -> Option { let ty = self.expr_infer(expr, &self.error_swallower()); self.python_type_from_type_engine_type(&ty) diff --git a/pyrefly/lib/test/sqlalchemy.rs b/pyrefly/lib/test/sqlalchemy.rs index 11f4ccff27..4863e8160e 100644 --- a/pyrefly/lib/test/sqlalchemy.rs +++ b/pyrefly/lib/test/sqlalchemy.rs @@ -87,7 +87,11 @@ from .properties import MappedColumn __all__ = ["MappedColumn", "mapped_column"] -def mapped_column(type_: TypeEngine[Any] | type[TypeEngine[Any]] | None = None) -> MappedColumn[Any]: +def mapped_column( + type_: TypeEngine[Any] | type[TypeEngine[Any]] | None = None, + *args: Any, + **kw: Any, +) -> MappedColumn[Any]: return MappedColumn() "#, ); @@ -105,8 +109,12 @@ from sqlalchemy.sql.sqltypes import Integer, String class Model: name = mapped_column(String()) quantity = mapped_column(Integer) + sku = mapped_column(String(), nullable=False) + pk = mapped_column(Integer, primary_key=True) -reveal_type(Model.name) # E: revealed type: MappedColumn[str] -reveal_type(Model.quantity) # E: revealed type: MappedColumn[int] +reveal_type(Model.name) # E: revealed type: MappedColumn[str | None] +reveal_type(Model.quantity) # E: revealed type: MappedColumn[int | None] +reveal_type(Model.sku) # E: revealed type: MappedColumn[str] +reveal_type(Model.pk) # E: revealed type: MappedColumn[int] "#, ); From 3131bc10ab064f7ae2412b4fbf2be50878494344 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Sat, 20 Dec 2025 22:12:51 +0900 Subject: [PATCH 3/3] Update pyrefly/lib/alt/special_calls.rs Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- pyrefly/lib/alt/special_calls.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrefly/lib/alt/special_calls.rs b/pyrefly/lib/alt/special_calls.rs index 23ef0a63f6..7cb7497b1d 100644 --- a/pyrefly/lib/alt/special_calls.rs +++ b/pyrefly/lib/alt/special_calls.rs @@ -309,7 +309,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } fn sqlalchemy_mapped_column_python_type(&self, call: &ExprCall) -> Option { - // mapped_column's first two positional arguments correspond to the name/type. + // Check up to the first two positional arguments for a SQLAlchemy TypeEngine-derived column type. for expr in call.arguments.args.iter().take(2) { if let Some(ty) = self.python_type_from_type_engine_expr(expr) { return Some(ty);