diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index a173025cda29..ec421e6da238 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1900,6 +1900,15 @@ def can_return_none(self, type: TypeInfo, attr_name: str) -> bool: return is_subtype(NoneType(), node.type.ret_type) return False + def _replace_callable_return_type(self, tp: Type, replacement: Type) -> ProperType: + """Replace the return type of a callable or overloaded type with replacement.""" + ptp = get_proper_type(tp) + if isinstance(ptp, CallableType): + return ptp.copy_modified(ret_type=replacement) + if isinstance(ptp, Overloaded): + return Overloaded([c.copy_modified(ret_type=replacement) for c in ptp.items]) + return ptp + def analyze_type_type_callee(self, item: ProperType, context: Context) -> Type: """Analyze the callee X in X(...) where X is Type[item]. @@ -1932,10 +1941,25 @@ def analyze_type_type_callee(self, item: ProperType, context: Context) -> Type: # with typevar. callee = self.analyze_type_type_callee(get_proper_type(item.upper_bound), context) callee = get_proper_type(callee) + # When the TypeVar has a union bound, use the bound as the return type + # instead of the TypeVar itself, since calling the constructor should + # return an instance of one of the union members. + return_type = item + upper_bound = get_proper_type(item.upper_bound) + if isinstance(upper_bound, UnionType): + return_type = upper_bound if isinstance(callee, CallableType): - callee = callee.copy_modified(ret_type=item) + callee = callee.copy_modified(ret_type=return_type) elif isinstance(callee, Overloaded): - callee = Overloaded([c.copy_modified(ret_type=item) for c in callee.items]) + callee = Overloaded([c.copy_modified(ret_type=return_type) for c in callee.items]) + elif isinstance(callee, UnionType): + callee = UnionType( + [ + self._replace_callable_return_type(tp, return_type) + for tp in callee.relevant_items() + ], + callee.line, + ) return callee # We support Type of namedtuples but not of tuples in general if isinstance(item, TupleType) and tuple_fallback(item).type.fullname != "builtins.tuple": diff --git a/mypy/types.py b/mypy/types.py index d4ed728f4c9b..1ab06a338631 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -2311,6 +2311,15 @@ def type_object(self) -> mypy.nodes.TypeInfo: ret = get_proper_type(self.ret_type) if isinstance(ret, TypeVarType): ret = get_proper_type(ret.upper_bound) + if isinstance(ret, UnionType): + # When the TypeVar has a union bound, pick the first item's + # fallback. This is only used for is_protocol checks, which + # are not applicable to union-bound typevars. + first = get_proper_type(ret.items[0]) + if isinstance(first, Instance): + ret = first + else: + ret = self.fallback if isinstance(ret, TupleType): ret = ret.partial_fallback if isinstance(ret, TypedDictType):