singledispatch.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. from __future__ import annotations
  2. from typing import Final, NamedTuple, Sequence, TypeVar, Union
  3. from typing_extensions import TypeAlias as _TypeAlias
  4. from mypy.messages import format_type
  5. from mypy.nodes import ARG_POS, Argument, Block, ClassDef, Context, SymbolTable, TypeInfo, Var
  6. from mypy.options import Options
  7. from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext, MethodSigContext
  8. from mypy.plugins.common import add_method_to_class
  9. from mypy.subtypes import is_subtype
  10. from mypy.types import (
  11. AnyType,
  12. CallableType,
  13. FunctionLike,
  14. Instance,
  15. NoneType,
  16. Overloaded,
  17. Type,
  18. TypeOfAny,
  19. get_proper_type,
  20. )
  21. class SingledispatchTypeVars(NamedTuple):
  22. return_type: Type
  23. fallback: CallableType
  24. class RegisterCallableInfo(NamedTuple):
  25. register_type: Type
  26. singledispatch_obj: Instance
  27. SINGLEDISPATCH_TYPE: Final = "functools._SingleDispatchCallable"
  28. SINGLEDISPATCH_REGISTER_METHOD: Final = f"{SINGLEDISPATCH_TYPE}.register"
  29. SINGLEDISPATCH_CALLABLE_CALL_METHOD: Final = f"{SINGLEDISPATCH_TYPE}.__call__"
  30. def get_singledispatch_info(typ: Instance) -> SingledispatchTypeVars | None:
  31. if len(typ.args) == 2:
  32. return SingledispatchTypeVars(*typ.args) # type: ignore[arg-type]
  33. return None
  34. T = TypeVar("T")
  35. def get_first_arg(args: list[list[T]]) -> T | None:
  36. """Get the element that corresponds to the first argument passed to the function"""
  37. if args and args[0]:
  38. return args[0][0]
  39. return None
  40. REGISTER_RETURN_CLASS: Final = "_SingleDispatchRegisterCallable"
  41. REGISTER_CALLABLE_CALL_METHOD: Final = f"functools.{REGISTER_RETURN_CLASS}.__call__"
  42. def make_fake_register_class_instance(
  43. api: CheckerPluginInterface, type_args: Sequence[Type]
  44. ) -> Instance:
  45. defn = ClassDef(REGISTER_RETURN_CLASS, Block([]))
  46. defn.fullname = f"functools.{REGISTER_RETURN_CLASS}"
  47. info = TypeInfo(SymbolTable(), defn, "functools")
  48. obj_type = api.named_generic_type("builtins.object", []).type
  49. info.bases = [Instance(obj_type, [])]
  50. info.mro = [info, obj_type]
  51. defn.info = info
  52. func_arg = Argument(Var("name"), AnyType(TypeOfAny.implementation_artifact), None, ARG_POS)
  53. add_method_to_class(api, defn, "__call__", [func_arg], NoneType())
  54. return Instance(info, type_args)
  55. PluginContext: _TypeAlias = Union[FunctionContext, MethodContext]
  56. def fail(ctx: PluginContext, msg: str, context: Context | None) -> None:
  57. """Emit an error message.
  58. This tries to emit an error message at the location specified by `context`, falling back to the
  59. location specified by `ctx.context`. This is helpful when the only context information about
  60. where you want to put the error message may be None (like it is for `CallableType.definition`)
  61. and falling back to the location of the calling function is fine."""
  62. # TODO: figure out if there is some more reliable way of getting context information, so this
  63. # function isn't necessary
  64. if context is not None:
  65. err_context = context
  66. else:
  67. err_context = ctx.context
  68. ctx.api.fail(msg, err_context)
  69. def create_singledispatch_function_callback(ctx: FunctionContext) -> Type:
  70. """Called for functools.singledispatch"""
  71. func_type = get_proper_type(get_first_arg(ctx.arg_types))
  72. if isinstance(func_type, CallableType):
  73. if len(func_type.arg_kinds) < 1:
  74. fail(
  75. ctx, "Singledispatch function requires at least one argument", func_type.definition
  76. )
  77. return ctx.default_return_type
  78. elif not func_type.arg_kinds[0].is_positional(star=True):
  79. fail(
  80. ctx,
  81. "First argument to singledispatch function must be a positional argument",
  82. func_type.definition,
  83. )
  84. return ctx.default_return_type
  85. # singledispatch returns an instance of functools._SingleDispatchCallable according to
  86. # typeshed
  87. singledispatch_obj = get_proper_type(ctx.default_return_type)
  88. assert isinstance(singledispatch_obj, Instance)
  89. singledispatch_obj.args += (func_type,)
  90. return ctx.default_return_type
  91. def singledispatch_register_callback(ctx: MethodContext) -> Type:
  92. """Called for functools._SingleDispatchCallable.register"""
  93. assert isinstance(ctx.type, Instance)
  94. # TODO: check that there's only one argument
  95. first_arg_type = get_proper_type(get_first_arg(ctx.arg_types))
  96. if isinstance(first_arg_type, (CallableType, Overloaded)) and first_arg_type.is_type_obj():
  97. # HACK: We received a class as an argument to register. We need to be able
  98. # to access the function that register is being applied to, and the typeshed definition
  99. # of register has it return a generic Callable, so we create a new
  100. # SingleDispatchRegisterCallable class, define a __call__ method, and then add a
  101. # plugin hook for that.
  102. # is_subtype doesn't work when the right type is Overloaded, so we need the
  103. # actual type
  104. register_type = first_arg_type.items[0].ret_type
  105. type_args = RegisterCallableInfo(register_type, ctx.type)
  106. register_callable = make_fake_register_class_instance(ctx.api, type_args)
  107. return register_callable
  108. elif isinstance(first_arg_type, CallableType):
  109. # TODO: do more checking for registered functions
  110. register_function(ctx, ctx.type, first_arg_type, ctx.api.options)
  111. # The typeshed stubs for register say that the function returned is Callable[..., T], even
  112. # though the function returned is the same as the one passed in. We return the type of the
  113. # function so that mypy can properly type check cases where the registered function is used
  114. # directly (instead of through singledispatch)
  115. return first_arg_type
  116. # fallback in case we don't recognize the arguments
  117. return ctx.default_return_type
  118. def register_function(
  119. ctx: PluginContext,
  120. singledispatch_obj: Instance,
  121. func: Type,
  122. options: Options,
  123. register_arg: Type | None = None,
  124. ) -> None:
  125. """Register a function"""
  126. func = get_proper_type(func)
  127. if not isinstance(func, CallableType):
  128. return
  129. metadata = get_singledispatch_info(singledispatch_obj)
  130. if metadata is None:
  131. # if we never added the fallback to the type variables, we already reported an error, so
  132. # just don't do anything here
  133. return
  134. dispatch_type = get_dispatch_type(func, register_arg)
  135. if dispatch_type is None:
  136. # TODO: report an error here that singledispatch requires at least one argument
  137. # (might want to do the error reporting in get_dispatch_type)
  138. return
  139. fallback = metadata.fallback
  140. fallback_dispatch_type = fallback.arg_types[0]
  141. if not is_subtype(dispatch_type, fallback_dispatch_type):
  142. fail(
  143. ctx,
  144. "Dispatch type {} must be subtype of fallback function first argument {}".format(
  145. format_type(dispatch_type, options), format_type(fallback_dispatch_type, options)
  146. ),
  147. func.definition,
  148. )
  149. return
  150. return
  151. def get_dispatch_type(func: CallableType, register_arg: Type | None) -> Type | None:
  152. if register_arg is not None:
  153. return register_arg
  154. if func.arg_types:
  155. return func.arg_types[0]
  156. return None
  157. def call_singledispatch_function_after_register_argument(ctx: MethodContext) -> Type:
  158. """Called on the function after passing a type to register"""
  159. register_callable = ctx.type
  160. if isinstance(register_callable, Instance):
  161. type_args = RegisterCallableInfo(*register_callable.args) # type: ignore[arg-type]
  162. func = get_first_arg(ctx.arg_types)
  163. if func is not None:
  164. register_function(
  165. ctx, type_args.singledispatch_obj, func, ctx.api.options, type_args.register_type
  166. )
  167. # see call to register_function in the callback for register
  168. return func
  169. return ctx.default_return_type
  170. def call_singledispatch_function_callback(ctx: MethodSigContext) -> FunctionLike:
  171. """Called for functools._SingleDispatchCallable.__call__"""
  172. if not isinstance(ctx.type, Instance):
  173. return ctx.default_signature
  174. metadata = get_singledispatch_info(ctx.type)
  175. if metadata is None:
  176. return ctx.default_signature
  177. return metadata.fallback