Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1900,6 +1900,15 @@ def can_return_none(self, type: TypeInfo, attr_name: str) -> bool:
return is_subtype(NoneType(), node.type.ret_type)
return False

def _replace_callable_return_type(self, tp: Type, replacement: Type) -> ProperType:
"""Replace the return type of a callable or overloaded type with replacement."""
ptp = get_proper_type(tp)
if isinstance(ptp, CallableType):
return ptp.copy_modified(ret_type=replacement)
if isinstance(ptp, Overloaded):
return Overloaded([c.copy_modified(ret_type=replacement) for c in ptp.items])
return ptp

def analyze_type_type_callee(self, item: ProperType, context: Context) -> Type:
"""Analyze the callee X in X(...) where X is Type[item].

Expand Down Expand Up @@ -1932,10 +1941,25 @@ def analyze_type_type_callee(self, item: ProperType, context: Context) -> Type:
# with typevar.
callee = self.analyze_type_type_callee(get_proper_type(item.upper_bound), context)
callee = get_proper_type(callee)
# When the TypeVar has a union bound, use the bound as the return type
# instead of the TypeVar itself, since calling the constructor should
# return an instance of one of the union members.
return_type = item
upper_bound = get_proper_type(item.upper_bound)
if isinstance(upper_bound, UnionType):
return_type = upper_bound
if isinstance(callee, CallableType):
callee = callee.copy_modified(ret_type=item)
callee = callee.copy_modified(ret_type=return_type)
elif isinstance(callee, Overloaded):
callee = Overloaded([c.copy_modified(ret_type=item) for c in callee.items])
callee = Overloaded([c.copy_modified(ret_type=return_type) for c in callee.items])
elif isinstance(callee, UnionType):
callee = UnionType(
[
self._replace_callable_return_type(tp, return_type)
for tp in callee.relevant_items()
],
callee.line,
)
return callee
# We support Type of namedtuples but not of tuples in general
if isinstance(item, TupleType) and tuple_fallback(item).type.fullname != "builtins.tuple":
Expand Down
9 changes: 9 additions & 0 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2311,6 +2311,15 @@ def type_object(self) -> mypy.nodes.TypeInfo:
ret = get_proper_type(self.ret_type)
if isinstance(ret, TypeVarType):
ret = get_proper_type(ret.upper_bound)
if isinstance(ret, UnionType):
# When the TypeVar has a union bound, pick the first item's
# fallback. This is only used for is_protocol checks, which
# are not applicable to union-bound typevars.
first = get_proper_type(ret.items[0])
if isinstance(first, Instance):
ret = first
else:
ret = self.fallback
if isinstance(ret, TupleType):
ret = ret.partial_fallback
if isinstance(ret, TypedDictType):
Expand Down
Loading