diff --git a/crates/pyrefly_types/src/callable.rs b/crates/pyrefly_types/src/callable.rs index 5764e6a51..6297331cc 100644 --- a/crates/pyrefly_types/src/callable.rs +++ b/crates/pyrefly_types/src/callable.rs @@ -513,6 +513,8 @@ pub enum FunctionKind { NumbaJit, /// `numba.njit()` NumbaNjit, + /// `sqlalchemy.orm.mapped_column()` + SqlAlchemyMappedColumn, } impl Callable { @@ -860,6 +862,10 @@ impl FunctionKind { ("typing" | "typing_extensions", None, "disjoint_base") => Self::DisjointBase, ("numba.core.decorators", None, "jit") => Self::NumbaJit, ("numba.core.decorators", None, "njit") => Self::NumbaNjit, + ("sqlalchemy.orm", None, "mapped_column") + | ("sqlalchemy.orm._orm_constructors", None, "mapped_column") => { + Self::SqlAlchemyMappedColumn + } _ => Self::Def(Box::new(FuncId { module, cls, @@ -890,6 +896,7 @@ impl FunctionKind { Self::DisjointBase => ModuleName::typing(), Self::NumbaJit => ModuleName::from_str("numba"), Self::NumbaNjit => ModuleName::from_str("numba"), + Self::SqlAlchemyMappedColumn => ModuleName::from_str("sqlalchemy.orm"), Self::Def(func_id) => func_id.module.name().dupe(), } } @@ -916,6 +923,7 @@ impl FunctionKind { Self::DisjointBase => Cow::Owned(Name::new_static("disjoint_base")), Self::NumbaJit => Cow::Owned(Name::new_static("jit")), Self::NumbaNjit => Cow::Owned(Name::new_static("njit")), + Self::SqlAlchemyMappedColumn => Cow::Owned(Name::new_static("mapped_column")), Self::Def(func_id) => Cow::Borrowed(&func_id.name), } } @@ -938,6 +946,7 @@ impl FunctionKind { Self::RuntimeCheckable => None, Self::NumbaJit => None, Self::NumbaNjit => None, + Self::SqlAlchemyMappedColumn => None, Self::CallbackProtocol(cls) => Some(cls.class_object().dupe()), Self::AbstractMethod => None, Self::TotalOrdering => None, diff --git a/crates/pyrefly_types/src/class.rs b/crates/pyrefly_types/src/class.rs index 1a13b0cf9..4dcd8569b 100644 --- a/crates/pyrefly_types/src/class.rs +++ b/crates/pyrefly_types/src/class.rs @@ -170,7 +170,9 @@ impl ClassKind { ("enum", "property") => Self::Property(name.clone()), ("enum", "member") => Self::EnumMember, ("enum", "nonmember") => Self::EnumNonmember, - ("dataclasses", "Field") => Self::DataclassField, + ("dataclasses", "Field") + | ("sqlalchemy.orm", "MappedColumn") + | ("sqlalchemy.orm.properties", "MappedColumn") => Self::DataclassField, _ => Self::Class, } } diff --git a/pyrefly/lib/alt/call.rs b/pyrefly/lib/alt/call.rs index d0262717d..a0c175c43 100644 --- a/pyrefly/lib/alt/call.rs +++ b/pyrefly/lib/alt/call.rs @@ -1308,6 +1308,17 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { errors, ) } + Some(CalleeKind::Function(FunctionKind::SqlAlchemyMappedColumn)) => self + .call_sqlalchemy_mapped_column( + ty.clone(), + &args, + &kws, + x.func.range(), + x.arguments.range, + x, + hint, + errors, + ), // Treat assert_type and reveal_type like pseudo-builtins for convenience. Note that we still // log a name-not-found error, but we also assert/reveal the type as requested. None if ty.is_error() && is_special_name(&x.func, "assert_type") => self diff --git a/pyrefly/lib/alt/special_calls.rs b/pyrefly/lib/alt/special_calls.rs index 174f39710..7cb7497b1 100644 --- a/pyrefly/lib/alt/special_calls.rs +++ b/pyrefly/lib/alt/special_calls.rs @@ -16,6 +16,7 @@ use pyrefly_types::types::Union; use pyrefly_util::visit::Visit; use pyrefly_util::visit::VisitMut; use ruff_python_ast::Expr; +use ruff_python_ast::ExprCall; use ruff_python_ast::Keyword; use ruff_python_ast::name::Name; use ruff_text_size::Ranged; @@ -37,6 +38,7 @@ use crate::error::context::TypeCheckKind; use crate::types::callable::FunctionKind; use crate::types::callable::unexpected_keyword; use crate::types::class::Class; +use crate::types::class::ClassType; use crate::types::special_form::SpecialForm; use crate::types::tuple::Tuple; use crate::types::types::AnyStyle; @@ -277,6 +279,152 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ret } + pub fn call_sqlalchemy_mapped_column( + &self, + callee_ty: Type, + args: &[CallArg], + keywords: &[CallKeyword], + callee_range: TextRange, + arg_range: TextRange, + call: &ExprCall, + hint: Option, + errors: &ErrorCollector, + ) -> Type { + let ret = self.freeform_call_infer( + callee_ty, + args, + keywords, + callee_range, + arg_range, + hint, + errors, + ); + let Some(mut python_type) = self.sqlalchemy_mapped_column_python_type(call) else { + return ret; + }; + if self.sqlalchemy_mapped_column_is_nullable(call) { + python_type = Type::optional(python_type); + } + self.apply_sqlalchemy_mapped_python_type(ret, python_type) + } + + fn sqlalchemy_mapped_column_python_type(&self, call: &ExprCall) -> Option { + // Check up to the first two positional arguments for a SQLAlchemy TypeEngine-derived column type. + for expr in call.arguments.args.iter().take(2) { + if let Some(ty) = self.python_type_from_type_engine_expr(expr) { + return Some(ty); + } + } + for keyword in &call.arguments.keywords { + if let Some(arg) = &keyword.arg + && matches!(arg.as_str(), "__type_pos" | "type_") + && let Some(ty) = self.python_type_from_type_engine_expr(&keyword.value) + { + return Some(ty); + } + } + None + } + + fn sqlalchemy_mapped_column_is_nullable(&self, call: &ExprCall) -> bool { + let mut nullable = true; + let mut primary_key = false; + for keyword in &call.arguments.keywords { + let Some(arg) = &keyword.arg else { + continue; + }; + match arg.as_str() { + "nullable" => { + if let Some(value) = Self::expr_bool_literal(&keyword.value) { + nullable = value; + } + } + "primary_key" => { + if let Some(value) = Self::expr_bool_literal(&keyword.value) { + primary_key = value; + } + } + _ => {} + } + } + if primary_key { false } else { nullable } + } + + fn expr_bool_literal(expr: &Expr) -> Option { + match expr { + Expr::BooleanLiteral(value) => Some(value.value), + _ => None, + } + } + + fn python_type_from_type_engine_expr(&self, expr: &Expr) -> Option { + let ty = self.expr_infer(expr, &self.error_swallower()); + self.python_type_from_type_engine_type(&ty) + } + + fn python_type_from_type_engine_type(&self, ty: &Type) -> Option { + match ty { + Type::ClassType(cls) => self.python_type_from_type_engine_class(cls), + Type::ClassDef(cls) => { + let inst = self.instantiate(cls); + self.python_type_from_type_engine_type(&inst) + } + Type::Type(inner) => self.python_type_from_type_engine_type(inner), + Type::TypeAlias(alias) => self.python_type_from_type_engine_type(&alias.as_type()), + Type::Union(u) => { + let mut inferred = None; + for member in &u.members { + if let Some(member_ty) = self.python_type_from_type_engine_type(member) { + match &inferred { + Some(existing) if existing != &member_ty => return None, + None => inferred = Some(member_ty), + _ => {} + } + } + } + inferred + } + _ => None, + } + } + + fn python_type_from_type_engine_class(&self, cls: &ClassType) -> Option { + if Self::is_sqlalchemy_type_engine_class(cls.class_object()) { + return cls.targs().as_slice().first().cloned(); + } + let bases = self.get_base_types_for_class(cls.class_object()); + for base in bases.iter() { + if let Some(ty) = self.python_type_from_type_engine_class(base) { + return Some(ty); + } + } + None + } + + fn is_sqlalchemy_type_engine_class(class: &Class) -> bool { + class.has_toplevel_qname("sqlalchemy.sql.type_api", "TypeEngine") + } + + fn apply_sqlalchemy_mapped_python_type(&self, mut ty: Type, python_type: Type) -> Type { + ty.visit_mut(&mut |inner| { + if let Type::ClassType(class_type) = inner + && Self::is_sqlalchemy_mapped_class(class_type.class_object()) + { + if let Some(slot) = class_type.targs_mut().as_mut().get_mut(0) { + *slot = python_type.clone(); + } + } + }); + ty + } + + fn is_sqlalchemy_mapped_class(class: &Class) -> bool { + class.has_toplevel_qname("sqlalchemy.orm.base", "Mapped") + || class.has_toplevel_qname("sqlalchemy.orm.base", "_MappedAnnotationBase") + || class.has_toplevel_qname("sqlalchemy.orm.base", "_DeclarativeMapped") + || class.has_toplevel_qname("sqlalchemy.orm.properties", "MappedColumn") + } + pub fn call_isinstance( &self, obj: &Expr, diff --git a/pyrefly/lib/export/special.rs b/pyrefly/lib/export/special.rs index 35fe7fad5..0167ea3df 100644 --- a/pyrefly/lib/export/special.rs +++ b/pyrefly/lib/export/special.rs @@ -71,6 +71,7 @@ pub enum SpecialExport { BuiltinsFrozenset, BuiltinsFloat, Deprecated, + SqlAlchemyMappedColumn, } impl SpecialExport { @@ -133,6 +134,7 @@ impl SpecialExport { "frozenset" => Some(Self::BuiltinsFrozenset), "float" => Some(Self::BuiltinsFloat), "deprecated" => Some(Self::Deprecated), + "mapped_column" => Some(Self::SqlAlchemyMappedColumn), _ => None, } } @@ -199,6 +201,12 @@ impl SpecialExport { ), Self::PytestNoReturn => matches!(m.as_str(), "pytest"), Self::Deprecated => matches!(m.as_str(), "warnings" | "typing_extensions"), + Self::SqlAlchemyMappedColumn => { + matches!( + m.as_str(), + "sqlalchemy.orm" | "sqlalchemy.orm._orm_constructors" + ) + } } } diff --git a/pyrefly/lib/test/mod.rs b/pyrefly/lib/test/mod.rs index 1db0ae1cf..b1cd169ca 100644 --- a/pyrefly/lib/test/mod.rs +++ b/pyrefly/lib/test/mod.rs @@ -58,6 +58,7 @@ mod returns; mod scope; mod semantic_syntax_errors; mod simple; +mod sqlalchemy; mod state; mod subscript_narrow; mod suppression; diff --git a/pyrefly/lib/test/sqlalchemy.rs b/pyrefly/lib/test/sqlalchemy.rs new file mode 100644 index 000000000..4863e8160 --- /dev/null +++ b/pyrefly/lib/test/sqlalchemy.rs @@ -0,0 +1,120 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +use crate::test::util::TestEnv; +use crate::testcase; + +fn sqlalchemy_env() -> TestEnv { + let mut env = TestEnv::new(); + env.add_with_path( + "sqlalchemy", + "sqlalchemy/__init__.py", + "from . import orm\n", + ); + env.add_with_path( + "sqlalchemy.sql", + "sqlalchemy/sql/__init__.py", + "from . import sqltypes\nfrom . import type_api\n", + ); + env.add_with_path( + "sqlalchemy.sql.type_api", + "sqlalchemy/sql/type_api.py", + r#" +from typing import Generic, TypeVar + +_T = TypeVar("_T") + +class TypeEngine(Generic[_T]): + ... +"#, + ); + env.add_with_path( + "sqlalchemy.sql.sqltypes", + "sqlalchemy/sql/sqltypes.py", + r#" +from .type_api import TypeEngine + +class String(TypeEngine[str]): + ... + +class Integer(TypeEngine[int]): + ... +"#, + ); + env.add_with_path( + "sqlalchemy.orm.base", + "sqlalchemy/orm/base.py", + r#" +from typing import Generic, TypeVar + +_T = TypeVar("_T") + +class Mapped(Generic[_T]): + ... + +class _MappedAnnotationBase(Mapped[_T]): + ... + +class _DeclarativeMapped(_MappedAnnotationBase[_T]): + ... +"#, + ); + env.add_with_path( + "sqlalchemy.orm.properties", + "sqlalchemy/orm/properties.py", + r#" +from typing import Generic, TypeVar +from .base import _DeclarativeMapped + +_T = TypeVar("_T") + +class MappedColumn(_DeclarativeMapped[_T]): + def __init__(self) -> None: + ... +"#, + ); + env.add_with_path( + "sqlalchemy.orm", + "sqlalchemy/orm/__init__.py", + r#" +from typing import Any +from sqlalchemy.sql.type_api import TypeEngine +from .properties import MappedColumn + +__all__ = ["MappedColumn", "mapped_column"] + +def mapped_column( + type_: TypeEngine[Any] | type[TypeEngine[Any]] | None = None, + *args: Any, + **kw: Any, +) -> MappedColumn[Any]: + return MappedColumn() +"#, + ); + env +} + +testcase!( + test_sqlalchemy_mapped_column_infers_type, + sqlalchemy_env(), + r#" +from typing import reveal_type +from sqlalchemy.orm import mapped_column +from sqlalchemy.sql.sqltypes import Integer, String + +class Model: + name = mapped_column(String()) + quantity = mapped_column(Integer) + sku = mapped_column(String(), nullable=False) + pk = mapped_column(Integer, primary_key=True) + +reveal_type(Model.name) # E: revealed type: MappedColumn[str | None] +reveal_type(Model.quantity) # E: revealed type: MappedColumn[int | None] +reveal_type(Model.sku) # E: revealed type: MappedColumn[str] +reveal_type(Model.pk) # E: revealed type: MappedColumn[int] +"#, +);