suggestions.py 37 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046
  1. """Mechanisms for inferring function types based on callsites.
  2. Currently works by collecting all argument types at callsites,
  3. synthesizing a list of possible function types from that, trying them
  4. all, and picking the one with the fewest errors that we think is the
  5. "best".
  6. Can return JSON that pyannotate can use to apply the annotations to code.
  7. There are a bunch of TODOs here:
  8. * Maybe want a way to surface the choices not selected??
  9. * We can generate an exponential number of type suggestions, and probably want
  10. a way to not always need to check them all.
  11. * Our heuristics for what types to try are primitive and not yet
  12. supported by real practice.
  13. * More!
  14. Other things:
  15. * This is super brute force. Could we integrate with the typechecker
  16. more to understand more about what is going on?
  17. * Like something with tracking constraints/unification variables?
  18. * No understanding of type variables at *all*
  19. """
  20. from __future__ import annotations
  21. import itertools
  22. import json
  23. import os
  24. from contextlib import contextmanager
  25. from typing import Callable, Iterator, NamedTuple, TypeVar, cast
  26. from typing_extensions import TypedDict
  27. from mypy.argmap import map_actuals_to_formals
  28. from mypy.build import Graph, State
  29. from mypy.checkexpr import has_any_type
  30. from mypy.find_sources import InvalidSourceList, SourceFinder
  31. from mypy.join import join_type_list
  32. from mypy.meet import meet_type_list
  33. from mypy.modulefinder import PYTHON_EXTENSIONS
  34. from mypy.nodes import (
  35. ARG_STAR,
  36. ARG_STAR2,
  37. ArgKind,
  38. CallExpr,
  39. Decorator,
  40. Expression,
  41. FuncDef,
  42. MypyFile,
  43. RefExpr,
  44. ReturnStmt,
  45. SymbolNode,
  46. SymbolTable,
  47. TypeInfo,
  48. reverse_builtin_aliases,
  49. )
  50. from mypy.options import Options
  51. from mypy.plugin import FunctionContext, MethodContext, Plugin
  52. from mypy.server.update import FineGrainedBuildManager
  53. from mypy.state import state
  54. from mypy.traverser import TraverserVisitor
  55. from mypy.typeops import make_simplified_union
  56. from mypy.types import (
  57. AnyType,
  58. CallableType,
  59. FunctionLike,
  60. Instance,
  61. NoneType,
  62. ProperType,
  63. TupleType,
  64. Type,
  65. TypeAliasType,
  66. TypedDictType,
  67. TypeOfAny,
  68. TypeStrVisitor,
  69. TypeTranslator,
  70. TypeVarType,
  71. UninhabitedType,
  72. UnionType,
  73. get_proper_type,
  74. )
  75. from mypy.types_utils import is_optional, remove_optional
  76. from mypy.util import split_target
  77. class PyAnnotateSignature(TypedDict):
  78. return_type: str
  79. arg_types: list[str]
  80. class Callsite(NamedTuple):
  81. path: str
  82. line: int
  83. arg_kinds: list[list[ArgKind]]
  84. callee_arg_names: list[str | None]
  85. arg_names: list[list[str | None]]
  86. arg_types: list[list[Type]]
  87. class SuggestionPlugin(Plugin):
  88. """Plugin that records all calls to a given target."""
  89. def __init__(self, target: str) -> None:
  90. if target.endswith((".__new__", ".__init__")):
  91. target = target.rsplit(".", 1)[0]
  92. self.target = target
  93. # List of call sites found by dmypy suggest:
  94. # (path, line, <arg kinds>, <arg names>, <arg types>)
  95. self.mystery_hits: list[Callsite] = []
  96. def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None:
  97. if fullname == self.target:
  98. return self.log
  99. else:
  100. return None
  101. def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | None:
  102. if fullname == self.target:
  103. return self.log
  104. else:
  105. return None
  106. def log(self, ctx: FunctionContext | MethodContext) -> Type:
  107. self.mystery_hits.append(
  108. Callsite(
  109. ctx.api.path,
  110. ctx.context.line,
  111. ctx.arg_kinds,
  112. ctx.callee_arg_names,
  113. ctx.arg_names,
  114. ctx.arg_types,
  115. )
  116. )
  117. return ctx.default_return_type
  118. # NOTE: We could make this a bunch faster by implementing a StatementVisitor that skips
  119. # traversing into expressions
  120. class ReturnFinder(TraverserVisitor):
  121. """Visitor for finding all types returned from a function."""
  122. def __init__(self, typemap: dict[Expression, Type]) -> None:
  123. self.typemap = typemap
  124. self.return_types: list[Type] = []
  125. def visit_return_stmt(self, o: ReturnStmt) -> None:
  126. if o.expr is not None and o.expr in self.typemap:
  127. self.return_types.append(self.typemap[o.expr])
  128. def visit_func_def(self, o: FuncDef) -> None:
  129. # Skip nested functions
  130. pass
  131. def get_return_types(typemap: dict[Expression, Type], func: FuncDef) -> list[Type]:
  132. """Find all the types returned by return statements in func."""
  133. finder = ReturnFinder(typemap)
  134. func.body.accept(finder)
  135. return finder.return_types
  136. class ArgUseFinder(TraverserVisitor):
  137. """Visitor for finding all the types of arguments that each arg is passed to.
  138. This is extremely simple minded but might be effective anyways.
  139. """
  140. def __init__(self, func: FuncDef, typemap: dict[Expression, Type]) -> None:
  141. self.typemap = typemap
  142. self.arg_types: dict[SymbolNode, list[Type]] = {arg.variable: [] for arg in func.arguments}
  143. def visit_call_expr(self, o: CallExpr) -> None:
  144. if not any(isinstance(e, RefExpr) and e.node in self.arg_types for e in o.args):
  145. return
  146. typ = get_proper_type(self.typemap.get(o.callee))
  147. if not isinstance(typ, CallableType):
  148. return
  149. formal_to_actual = map_actuals_to_formals(
  150. o.arg_kinds,
  151. o.arg_names,
  152. typ.arg_kinds,
  153. typ.arg_names,
  154. lambda n: AnyType(TypeOfAny.special_form),
  155. )
  156. for i, args in enumerate(formal_to_actual):
  157. for arg_idx in args:
  158. arg = o.args[arg_idx]
  159. if isinstance(arg, RefExpr) and arg.node in self.arg_types:
  160. self.arg_types[arg.node].append(typ.arg_types[i])
  161. def get_arg_uses(typemap: dict[Expression, Type], func: FuncDef) -> list[list[Type]]:
  162. """Find all the types of arguments that each arg is passed to.
  163. For example, given
  164. def foo(x: int) -> None: ...
  165. def bar(x: str) -> None: ...
  166. def test(x, y):
  167. foo(x)
  168. bar(y)
  169. this will return [[int], [str]].
  170. """
  171. finder = ArgUseFinder(func, typemap)
  172. func.body.accept(finder)
  173. return [finder.arg_types[arg.variable] for arg in func.arguments]
  174. class SuggestionFailure(Exception):
  175. pass
  176. def is_explicit_any(typ: AnyType) -> bool:
  177. # Originally I wanted to count as explicit anything derived from an explicit any, but that
  178. # seemed too strict in some testing.
  179. # return (typ.type_of_any == TypeOfAny.explicit
  180. # or (typ.source_any is not None and typ.source_any.type_of_any == TypeOfAny.explicit))
  181. # Important question: what should we do with source_any stuff? Does that count?
  182. # And actually should explicit anys count at all?? Maybe not!
  183. return typ.type_of_any == TypeOfAny.explicit
  184. def is_implicit_any(typ: Type) -> bool:
  185. typ = get_proper_type(typ)
  186. return isinstance(typ, AnyType) and not is_explicit_any(typ)
  187. class SuggestionEngine:
  188. """Engine for finding call sites and suggesting signatures."""
  189. def __init__(
  190. self,
  191. fgmanager: FineGrainedBuildManager,
  192. *,
  193. json: bool,
  194. no_errors: bool = False,
  195. no_any: bool = False,
  196. flex_any: float | None = None,
  197. use_fixme: str | None = None,
  198. max_guesses: int | None = None,
  199. ) -> None:
  200. self.fgmanager = fgmanager
  201. self.manager = fgmanager.manager
  202. self.plugin = self.manager.plugin
  203. self.graph = fgmanager.graph
  204. self.finder = SourceFinder(self.manager.fscache, self.manager.options)
  205. self.give_json = json
  206. self.no_errors = no_errors
  207. self.flex_any = flex_any
  208. if no_any:
  209. self.flex_any = 1.0
  210. self.max_guesses = max_guesses or 64
  211. self.use_fixme = use_fixme
  212. def suggest(self, function: str) -> str:
  213. """Suggest an inferred type for function."""
  214. mod, func_name, node = self.find_node(function)
  215. with self.restore_after(mod):
  216. with self.with_export_types():
  217. suggestion = self.get_suggestion(mod, node)
  218. if self.give_json:
  219. return self.json_suggestion(mod, func_name, node, suggestion)
  220. else:
  221. return self.format_signature(suggestion)
  222. def suggest_callsites(self, function: str) -> str:
  223. """Find a list of call sites of function."""
  224. mod, _, node = self.find_node(function)
  225. with self.restore_after(mod):
  226. callsites, _ = self.get_callsites(node)
  227. return "\n".join(
  228. dedup(
  229. [
  230. f"{path}:{line}: {self.format_args(arg_kinds, arg_names, arg_types)}"
  231. for path, line, arg_kinds, _, arg_names, arg_types in callsites
  232. ]
  233. )
  234. )
  235. @contextmanager
  236. def restore_after(self, module: str) -> Iterator[None]:
  237. """Context manager that reloads a module after executing the body.
  238. This should undo any damage done to the module state while mucking around.
  239. """
  240. try:
  241. yield
  242. finally:
  243. self.reload(self.graph[module])
  244. @contextmanager
  245. def with_export_types(self) -> Iterator[None]:
  246. """Context manager that enables the export_types flag in the body.
  247. This causes type information to be exported into the manager's all_types variable.
  248. """
  249. old = self.manager.options.export_types
  250. self.manager.options.export_types = True
  251. try:
  252. yield
  253. finally:
  254. self.manager.options.export_types = old
  255. def get_trivial_type(self, fdef: FuncDef) -> CallableType:
  256. """Generate a trivial callable type from a func def, with all Anys"""
  257. # The Anys are marked as being from the suggestion engine
  258. # since they need some special treatment (specifically,
  259. # constraint generation ignores them.)
  260. return CallableType(
  261. [AnyType(TypeOfAny.suggestion_engine) for _ in fdef.arg_kinds],
  262. fdef.arg_kinds,
  263. fdef.arg_names,
  264. AnyType(TypeOfAny.suggestion_engine),
  265. self.named_type("builtins.function"),
  266. )
  267. def get_starting_type(self, fdef: FuncDef) -> CallableType:
  268. if isinstance(fdef.type, CallableType):
  269. return make_suggestion_anys(fdef.type)
  270. else:
  271. return self.get_trivial_type(fdef)
  272. def get_args(
  273. self,
  274. is_method: bool,
  275. base: CallableType,
  276. defaults: list[Type | None],
  277. callsites: list[Callsite],
  278. uses: list[list[Type]],
  279. ) -> list[list[Type]]:
  280. """Produce a list of type suggestions for each argument type."""
  281. types: list[list[Type]] = []
  282. for i in range(len(base.arg_kinds)):
  283. # Make self args Any but this will get overridden somewhere in the checker
  284. if i == 0 and is_method:
  285. types.append([AnyType(TypeOfAny.suggestion_engine)])
  286. continue
  287. all_arg_types = []
  288. for call in callsites:
  289. for typ in call.arg_types[i - is_method]:
  290. # Collect all the types except for implicit anys
  291. if not is_implicit_any(typ):
  292. all_arg_types.append(typ)
  293. all_use_types = []
  294. for typ in uses[i]:
  295. # Collect all the types except for implicit anys
  296. if not is_implicit_any(typ):
  297. all_use_types.append(typ)
  298. # Add in any default argument types
  299. default = defaults[i]
  300. if default:
  301. all_arg_types.append(default)
  302. if all_use_types:
  303. all_use_types.append(default)
  304. arg_types = []
  305. if all_arg_types and all(
  306. isinstance(get_proper_type(tp), NoneType) for tp in all_arg_types
  307. ):
  308. arg_types.append(
  309. UnionType.make_union([all_arg_types[0], AnyType(TypeOfAny.explicit)])
  310. )
  311. elif all_arg_types:
  312. arg_types.extend(generate_type_combinations(all_arg_types))
  313. else:
  314. arg_types.append(AnyType(TypeOfAny.explicit))
  315. if all_use_types:
  316. # This is a meet because the type needs to be compatible with all the uses
  317. arg_types.append(meet_type_list(all_use_types))
  318. types.append(arg_types)
  319. return types
  320. def get_default_arg_types(self, fdef: FuncDef) -> list[Type | None]:
  321. return [
  322. self.manager.all_types[arg.initializer] if arg.initializer else None
  323. for arg in fdef.arguments
  324. ]
  325. def get_guesses(
  326. self,
  327. is_method: bool,
  328. base: CallableType,
  329. defaults: list[Type | None],
  330. callsites: list[Callsite],
  331. uses: list[list[Type]],
  332. ) -> list[CallableType]:
  333. """Compute a list of guesses for a function's type.
  334. This focuses just on the argument types, and doesn't change the provided return type.
  335. """
  336. options = self.get_args(is_method, base, defaults, callsites, uses)
  337. # Take the first `max_guesses` guesses.
  338. product = itertools.islice(itertools.product(*options), 0, self.max_guesses)
  339. return [refine_callable(base, base.copy_modified(arg_types=list(x))) for x in product]
  340. def get_callsites(self, func: FuncDef) -> tuple[list[Callsite], list[str]]:
  341. """Find all call sites of a function."""
  342. new_type = self.get_starting_type(func)
  343. collector_plugin = SuggestionPlugin(func.fullname)
  344. self.plugin._plugins.insert(0, collector_plugin)
  345. try:
  346. errors = self.try_type(func, new_type)
  347. finally:
  348. self.plugin._plugins.pop(0)
  349. return collector_plugin.mystery_hits, errors
  350. def filter_options(
  351. self, guesses: list[CallableType], is_method: bool, ignore_return: bool
  352. ) -> list[CallableType]:
  353. """Apply any configured filters to the possible guesses.
  354. Currently the only option is filtering based on Any prevalance."""
  355. return [
  356. t
  357. for t in guesses
  358. if self.flex_any is None
  359. or any_score_callable(t, is_method, ignore_return) >= self.flex_any
  360. ]
  361. def find_best(self, func: FuncDef, guesses: list[CallableType]) -> tuple[CallableType, int]:
  362. """From a list of possible function types, find the best one.
  363. For best, we want the fewest errors, then the best "score" from score_callable.
  364. """
  365. if not guesses:
  366. raise SuggestionFailure("No guesses that match criteria!")
  367. errors = {guess: self.try_type(func, guess) for guess in guesses}
  368. best = min(guesses, key=lambda s: (count_errors(errors[s]), self.score_callable(s)))
  369. return best, count_errors(errors[best])
  370. def get_guesses_from_parent(self, node: FuncDef) -> list[CallableType]:
  371. """Try to get a guess of a method type from a parent class."""
  372. if not node.info:
  373. return []
  374. for parent in node.info.mro[1:]:
  375. pnode = parent.names.get(node.name)
  376. if pnode and isinstance(pnode.node, (FuncDef, Decorator)):
  377. typ = get_proper_type(pnode.node.type)
  378. # FIXME: Doesn't work right with generic tyeps
  379. if isinstance(typ, CallableType) and len(typ.arg_types) == len(node.arguments):
  380. # Return the first thing we find, since it probably doesn't make sense
  381. # to grab things further up in the chain if an earlier parent has it.
  382. return [typ]
  383. return []
  384. def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature:
  385. """Compute a suggestion for a function.
  386. Return the type and whether the first argument should be ignored.
  387. """
  388. graph = self.graph
  389. callsites, orig_errors = self.get_callsites(node)
  390. uses = get_arg_uses(self.manager.all_types, node)
  391. if self.no_errors and orig_errors:
  392. raise SuggestionFailure("Function does not typecheck.")
  393. is_method = bool(node.info) and not node.is_static
  394. with state.strict_optional_set(graph[mod].options.strict_optional):
  395. guesses = self.get_guesses(
  396. is_method,
  397. self.get_starting_type(node),
  398. self.get_default_arg_types(node),
  399. callsites,
  400. uses,
  401. )
  402. guesses += self.get_guesses_from_parent(node)
  403. guesses = self.filter_options(guesses, is_method, ignore_return=True)
  404. best, _ = self.find_best(node, guesses)
  405. # Now try to find the return type!
  406. self.try_type(node, best)
  407. returns = get_return_types(self.manager.all_types, node)
  408. with state.strict_optional_set(graph[mod].options.strict_optional):
  409. if returns:
  410. ret_types = generate_type_combinations(returns)
  411. else:
  412. ret_types = [NoneType()]
  413. guesses = [best.copy_modified(ret_type=refine_type(best.ret_type, t)) for t in ret_types]
  414. guesses = self.filter_options(guesses, is_method, ignore_return=False)
  415. best, errors = self.find_best(node, guesses)
  416. if self.no_errors and errors:
  417. raise SuggestionFailure("No annotation without errors")
  418. return self.pyannotate_signature(mod, is_method, best)
  419. def format_args(
  420. self,
  421. arg_kinds: list[list[ArgKind]],
  422. arg_names: list[list[str | None]],
  423. arg_types: list[list[Type]],
  424. ) -> str:
  425. args: list[str] = []
  426. for i in range(len(arg_types)):
  427. for kind, name, typ in zip(arg_kinds[i], arg_names[i], arg_types[i]):
  428. arg = self.format_type(None, typ)
  429. if kind == ARG_STAR:
  430. arg = "*" + arg
  431. elif kind == ARG_STAR2:
  432. arg = "**" + arg
  433. elif kind.is_named():
  434. if name:
  435. arg = f"{name}={arg}"
  436. args.append(arg)
  437. return f"({', '.join(args)})"
  438. def find_node(self, key: str) -> tuple[str, str, FuncDef]:
  439. """From a target name, return module/target names and the func def.
  440. The 'key' argument can be in one of two formats:
  441. * As the function full name, e.g., package.module.Cls.method
  442. * As the function location as file and line separated by column,
  443. e.g., path/to/file.py:42
  444. """
  445. # TODO: Also return OverloadedFuncDef -- currently these are ignored.
  446. node: SymbolNode | None = None
  447. if ":" in key:
  448. if key.count(":") > 1:
  449. raise SuggestionFailure(
  450. "Malformed location for function: {}. Must be either"
  451. " package.module.Class.method or path/to/file.py:line".format(key)
  452. )
  453. file, line = key.split(":")
  454. if not line.isdigit():
  455. raise SuggestionFailure(f"Line number must be a number. Got {line}")
  456. line_number = int(line)
  457. modname, node = self.find_node_by_file_and_line(file, line_number)
  458. tail = node.fullname[len(modname) + 1 :] # add one to account for '.'
  459. else:
  460. target = split_target(self.fgmanager.graph, key)
  461. if not target:
  462. raise SuggestionFailure(f"Cannot find module for {key}")
  463. modname, tail = target
  464. node = self.find_node_by_module_and_name(modname, tail)
  465. if isinstance(node, Decorator):
  466. node = self.extract_from_decorator(node)
  467. if not node:
  468. raise SuggestionFailure(f"Object {key} is a decorator we can't handle")
  469. if not isinstance(node, FuncDef):
  470. raise SuggestionFailure(f"Object {key} is not a function")
  471. return modname, tail, node
  472. def find_node_by_module_and_name(self, modname: str, tail: str) -> SymbolNode | None:
  473. """Find symbol node by module id and qualified name.
  474. Raise SuggestionFailure if can't find one.
  475. """
  476. tree = self.ensure_loaded(self.fgmanager.graph[modname])
  477. # N.B. This is reimplemented from update's lookup_target
  478. # basically just to produce better error messages.
  479. names: SymbolTable = tree.names
  480. # Look through any classes
  481. components = tail.split(".")
  482. for i, component in enumerate(components[:-1]):
  483. if component not in names:
  484. raise SuggestionFailure(
  485. "Unknown class {}.{}".format(modname, ".".join(components[: i + 1]))
  486. )
  487. node: SymbolNode | None = names[component].node
  488. if not isinstance(node, TypeInfo):
  489. raise SuggestionFailure(
  490. "Object {}.{} is not a class".format(modname, ".".join(components[: i + 1]))
  491. )
  492. names = node.names
  493. # Look for the actual function/method
  494. funcname = components[-1]
  495. if funcname not in names:
  496. key = modname + "." + tail
  497. raise SuggestionFailure(
  498. "Unknown {} {}".format("method" if len(components) > 1 else "function", key)
  499. )
  500. return names[funcname].node
  501. def find_node_by_file_and_line(self, file: str, line: int) -> tuple[str, SymbolNode]:
  502. """Find symbol node by path to file and line number.
  503. Find the first function declared *before or on* the line number.
  504. Return module id and the node found. Raise SuggestionFailure if can't find one.
  505. """
  506. if not any(file.endswith(ext) for ext in PYTHON_EXTENSIONS):
  507. raise SuggestionFailure("Source file is not a Python file")
  508. try:
  509. modname, _ = self.finder.crawl_up(os.path.normpath(file))
  510. except InvalidSourceList as e:
  511. raise SuggestionFailure("Invalid source file name: " + file) from e
  512. if modname not in self.graph:
  513. raise SuggestionFailure("Unknown module: " + modname)
  514. # We must be sure about any edits in this file as this might affect the line numbers.
  515. tree = self.ensure_loaded(self.fgmanager.graph[modname], force=True)
  516. node: SymbolNode | None = None
  517. closest_line: int | None = None
  518. # TODO: Handle nested functions.
  519. for _, sym, _ in tree.local_definitions():
  520. if isinstance(sym.node, (FuncDef, Decorator)):
  521. sym_line = sym.node.line
  522. # TODO: add support for OverloadedFuncDef.
  523. else:
  524. continue
  525. # We want the closest function above the specified line
  526. if sym_line <= line and (closest_line is None or sym_line > closest_line):
  527. closest_line = sym_line
  528. node = sym.node
  529. if not node:
  530. raise SuggestionFailure(f"Cannot find a function at line {line}")
  531. return modname, node
  532. def extract_from_decorator(self, node: Decorator) -> FuncDef | None:
  533. for dec in node.decorators:
  534. typ = None
  535. if isinstance(dec, RefExpr) and isinstance(dec.node, FuncDef):
  536. typ = dec.node.type
  537. elif (
  538. isinstance(dec, CallExpr)
  539. and isinstance(dec.callee, RefExpr)
  540. and isinstance(dec.callee.node, FuncDef)
  541. and isinstance(dec.callee.node.type, CallableType)
  542. ):
  543. typ = get_proper_type(dec.callee.node.type.ret_type)
  544. if not isinstance(typ, FunctionLike):
  545. return None
  546. for ct in typ.items:
  547. if not (
  548. len(ct.arg_types) == 1
  549. and isinstance(ct.arg_types[0], TypeVarType)
  550. and ct.arg_types[0] == ct.ret_type
  551. ):
  552. return None
  553. return node.func
  554. def try_type(self, func: FuncDef, typ: ProperType) -> list[str]:
  555. """Recheck a function while assuming it has type typ.
  556. Return all error messages.
  557. """
  558. old = func.unanalyzed_type
  559. # During reprocessing, unanalyzed_type gets copied to type (by aststrip).
  560. # We set type to None to ensure that the type always changes during
  561. # reprocessing.
  562. func.type = None
  563. func.unanalyzed_type = typ
  564. try:
  565. res = self.fgmanager.trigger(func.fullname)
  566. # if res:
  567. # print('===', typ)
  568. # print('\n'.join(res))
  569. return res
  570. finally:
  571. func.unanalyzed_type = old
  572. def reload(self, state: State) -> list[str]:
  573. """Recheck the module given by state."""
  574. assert state.path is not None
  575. self.fgmanager.flush_cache()
  576. return self.fgmanager.update([(state.id, state.path)], [])
  577. def ensure_loaded(self, state: State, force: bool = False) -> MypyFile:
  578. """Make sure that the module represented by state is fully loaded."""
  579. if not state.tree or state.tree.is_cache_skeleton or force:
  580. self.reload(state)
  581. assert state.tree is not None
  582. return state.tree
  583. def named_type(self, s: str) -> Instance:
  584. return self.manager.semantic_analyzer.named_type(s)
  585. def json_suggestion(
  586. self, mod: str, func_name: str, node: FuncDef, suggestion: PyAnnotateSignature
  587. ) -> str:
  588. """Produce a json blob for a suggestion suitable for application by pyannotate."""
  589. # pyannotate irritatingly drops class names for class and static methods
  590. if node.is_class or node.is_static:
  591. func_name = func_name.split(".", 1)[-1]
  592. # pyannotate works with either paths relative to where the
  593. # module is rooted or with absolute paths. We produce absolute
  594. # paths because it is simpler.
  595. path = os.path.abspath(self.graph[mod].xpath)
  596. obj = {
  597. "signature": suggestion,
  598. "line": node.line,
  599. "path": path,
  600. "func_name": func_name,
  601. "samples": 0,
  602. }
  603. return json.dumps([obj], sort_keys=True)
  604. def pyannotate_signature(
  605. self, cur_module: str | None, is_method: bool, typ: CallableType
  606. ) -> PyAnnotateSignature:
  607. """Format a callable type as a pyannotate dict"""
  608. start = int(is_method)
  609. return {
  610. "arg_types": [self.format_type(cur_module, t) for t in typ.arg_types[start:]],
  611. "return_type": self.format_type(cur_module, typ.ret_type),
  612. }
  613. def format_signature(self, sig: PyAnnotateSignature) -> str:
  614. """Format a callable type in a way suitable as an annotation... kind of"""
  615. return f"({', '.join(sig['arg_types'])}) -> {sig['return_type']}"
  616. def format_type(self, cur_module: str | None, typ: Type) -> str:
  617. if self.use_fixme and isinstance(get_proper_type(typ), AnyType):
  618. return self.use_fixme
  619. return typ.accept(TypeFormatter(cur_module, self.graph, self.manager.options))
  620. def score_type(self, t: Type, arg_pos: bool) -> int:
  621. """Generate a score for a type that we use to pick which type to use.
  622. Lower is better, prefer non-union/non-any types. Don't penalize optionals.
  623. """
  624. t = get_proper_type(t)
  625. if isinstance(t, AnyType):
  626. return 20
  627. if arg_pos and isinstance(t, NoneType):
  628. return 20
  629. if isinstance(t, UnionType):
  630. if any(isinstance(get_proper_type(x), AnyType) for x in t.items):
  631. return 20
  632. if any(has_any_type(x) for x in t.items):
  633. return 15
  634. if not is_optional(t):
  635. return 10
  636. if isinstance(t, CallableType) and (has_any_type(t) or is_tricky_callable(t)):
  637. return 10
  638. return 0
  639. def score_callable(self, t: CallableType) -> int:
  640. return sum(self.score_type(x, arg_pos=True) for x in t.arg_types) + self.score_type(
  641. t.ret_type, arg_pos=False
  642. )
  643. def any_score_type(ut: Type, arg_pos: bool) -> float:
  644. """Generate a very made up number representing the Anyness of a type.
  645. Higher is better, 1.0 is max
  646. """
  647. t = get_proper_type(ut)
  648. if isinstance(t, AnyType) and t.type_of_any != TypeOfAny.suggestion_engine:
  649. return 0
  650. if isinstance(t, NoneType) and arg_pos:
  651. return 0.5
  652. if isinstance(t, UnionType):
  653. if any(isinstance(get_proper_type(x), AnyType) for x in t.items):
  654. return 0.5
  655. if any(has_any_type(x) for x in t.items):
  656. return 0.25
  657. if isinstance(t, CallableType) and is_tricky_callable(t):
  658. return 0.5
  659. if has_any_type(t):
  660. return 0.5
  661. return 1.0
  662. def any_score_callable(t: CallableType, is_method: bool, ignore_return: bool) -> float:
  663. # Ignore the first argument of methods
  664. scores = [any_score_type(x, arg_pos=True) for x in t.arg_types[int(is_method) :]]
  665. # Return type counts twice (since it spreads type information), unless it is
  666. # None in which case it does not count at all. (Though it *does* still count
  667. # if there are no arguments.)
  668. if not isinstance(get_proper_type(t.ret_type), NoneType) or not scores:
  669. ret = 1.0 if ignore_return else any_score_type(t.ret_type, arg_pos=False)
  670. scores += [ret, ret]
  671. return sum(scores) / len(scores)
  672. def is_tricky_callable(t: CallableType) -> bool:
  673. """Is t a callable that we need to put a ... in for syntax reasons?"""
  674. return t.is_ellipsis_args or any(k.is_star() or k.is_named() for k in t.arg_kinds)
  675. class TypeFormatter(TypeStrVisitor):
  676. """Visitor used to format types"""
  677. # TODO: Probably a lot
  678. def __init__(self, module: str | None, graph: Graph, options: Options) -> None:
  679. super().__init__(options=options)
  680. self.module = module
  681. self.graph = graph
  682. def visit_any(self, t: AnyType) -> str:
  683. if t.missing_import_name:
  684. return t.missing_import_name
  685. else:
  686. return "Any"
  687. def visit_instance(self, t: Instance) -> str:
  688. s = t.type.fullname or t.type.name or None
  689. if s is None:
  690. return "<???>"
  691. if s in reverse_builtin_aliases:
  692. s = reverse_builtin_aliases[s]
  693. mod_obj = split_target(self.graph, s)
  694. assert mod_obj
  695. mod, obj = mod_obj
  696. # If a class is imported into the current module, rewrite the reference
  697. # to point to the current module. This helps the annotation tool avoid
  698. # inserting redundant imports when a type has been reexported.
  699. if self.module:
  700. parts = obj.split(".") # need to split the object part if it is a nested class
  701. tree = self.graph[self.module].tree
  702. if tree and parts[0] in tree.names:
  703. mod = self.module
  704. if (mod, obj) == ("builtins", "tuple"):
  705. mod, obj = "typing", "Tuple[" + t.args[0].accept(self) + ", ...]"
  706. elif t.args:
  707. obj += f"[{self.list_str(t.args)}]"
  708. if mod_obj == ("builtins", "unicode"):
  709. return "Text"
  710. elif mod == "builtins":
  711. return obj
  712. else:
  713. delim = "." if "." not in obj else ":"
  714. return mod + delim + obj
  715. def visit_tuple_type(self, t: TupleType) -> str:
  716. if t.partial_fallback and t.partial_fallback.type:
  717. fallback_name = t.partial_fallback.type.fullname
  718. if fallback_name != "builtins.tuple":
  719. return t.partial_fallback.accept(self)
  720. s = self.list_str(t.items)
  721. return f"Tuple[{s}]"
  722. def visit_uninhabited_type(self, t: UninhabitedType) -> str:
  723. return "Any"
  724. def visit_typeddict_type(self, t: TypedDictType) -> str:
  725. return t.fallback.accept(self)
  726. def visit_union_type(self, t: UnionType) -> str:
  727. if len(t.items) == 2 and is_optional(t):
  728. return f"Optional[{remove_optional(t).accept(self)}]"
  729. else:
  730. return super().visit_union_type(t)
  731. def visit_callable_type(self, t: CallableType) -> str:
  732. # TODO: use extended callables?
  733. if is_tricky_callable(t):
  734. arg_str = "..."
  735. else:
  736. # Note: for default arguments, we just assume that they
  737. # are required. This isn't right, but neither is the
  738. # other thing, and I suspect this will produce more better
  739. # results than falling back to `...`
  740. args = [typ.accept(self) for typ in t.arg_types]
  741. arg_str = f"[{', '.join(args)}]"
  742. return f"Callable[{arg_str}, {t.ret_type.accept(self)}]"
  743. TType = TypeVar("TType", bound=Type)
  744. def make_suggestion_anys(t: TType) -> TType:
  745. """Make all anys in the type as coming from the suggestion engine.
  746. This keeps those Anys from influencing constraint generation,
  747. which allows us to do better when refining types.
  748. """
  749. return cast(TType, t.accept(MakeSuggestionAny()))
  750. class MakeSuggestionAny(TypeTranslator):
  751. def visit_any(self, t: AnyType) -> Type:
  752. if not t.missing_import_name:
  753. return t.copy_modified(type_of_any=TypeOfAny.suggestion_engine)
  754. else:
  755. return t
  756. def visit_type_alias_type(self, t: TypeAliasType) -> Type:
  757. return t.copy_modified(args=[a.accept(self) for a in t.args])
  758. def generate_type_combinations(types: list[Type]) -> list[Type]:
  759. """Generate possible combinations of a list of types.
  760. mypy essentially supports two different ways to do this: joining the types
  761. and unioning the types. We try both.
  762. """
  763. joined_type = join_type_list(types)
  764. union_type = make_simplified_union(types)
  765. if joined_type == union_type:
  766. return [joined_type]
  767. else:
  768. return [joined_type, union_type]
  769. def count_errors(msgs: list[str]) -> int:
  770. return len([x for x in msgs if " error: " in x])
  771. def refine_type(ti: Type, si: Type) -> Type:
  772. """Refine `ti` by replacing Anys in it with information taken from `si`
  773. This basically works by, when the types have the same structure,
  774. traversing both of them in parallel and replacing Any on the left
  775. with whatever the type on the right is. If the types don't have the
  776. same structure (or aren't supported), the left type is chosen.
  777. For example:
  778. refine(Any, T) = T, for all T
  779. refine(float, int) = float
  780. refine(List[Any], List[int]) = List[int]
  781. refine(Dict[int, Any], Dict[Any, int]) = Dict[int, int]
  782. refine(Tuple[int, Any], Tuple[Any, int]) = Tuple[int, int]
  783. refine(Callable[[Any], Any], Callable[[int], int]) = Callable[[int], int]
  784. refine(Callable[..., int], Callable[[int, float], Any]) = Callable[[int, float], int]
  785. refine(Optional[Any], int) = Optional[int]
  786. refine(Optional[Any], Optional[int]) = Optional[int]
  787. refine(Optional[Any], Union[int, str]) = Optional[Union[int, str]]
  788. refine(Optional[List[Any]], List[int]) = List[int]
  789. """
  790. t = get_proper_type(ti)
  791. s = get_proper_type(si)
  792. if isinstance(t, AnyType):
  793. # If s is also an Any, we return if it is a missing_import Any
  794. return t if isinstance(s, AnyType) and t.missing_import_name else s
  795. if isinstance(t, Instance) and isinstance(s, Instance) and t.type == s.type:
  796. return t.copy_modified(args=[refine_type(ta, sa) for ta, sa in zip(t.args, s.args)])
  797. if (
  798. isinstance(t, TupleType)
  799. and isinstance(s, TupleType)
  800. and t.partial_fallback == s.partial_fallback
  801. and len(t.items) == len(s.items)
  802. ):
  803. return t.copy_modified(items=[refine_type(ta, sa) for ta, sa in zip(t.items, s.items)])
  804. if isinstance(t, CallableType) and isinstance(s, CallableType):
  805. return refine_callable(t, s)
  806. if isinstance(t, UnionType):
  807. return refine_union(t, s)
  808. # TODO: Refining of builtins.tuple, Type?
  809. return t
  810. def refine_union(t: UnionType, s: ProperType) -> Type:
  811. """Refine a union type based on another type.
  812. This is done by refining every component of the union against the
  813. right hand side type (or every component of its union if it is
  814. one). If an element of the union is successfully refined, we drop it
  815. from the union in favor of the refined versions.
  816. """
  817. # Don't try to do any union refining if the types are already the
  818. # same. This prevents things like refining Optional[Any] against
  819. # itself and producing None.
  820. if t == s:
  821. return t
  822. rhs_items = s.items if isinstance(s, UnionType) else [s]
  823. new_items = []
  824. for lhs in t.items:
  825. refined = False
  826. for rhs in rhs_items:
  827. new = refine_type(lhs, rhs)
  828. if new != lhs:
  829. new_items.append(new)
  830. refined = True
  831. if not refined:
  832. new_items.append(lhs)
  833. # Turn strict optional on when simplifying the union since we
  834. # don't want to drop Nones.
  835. with state.strict_optional_set(True):
  836. return make_simplified_union(new_items)
  837. def refine_callable(t: CallableType, s: CallableType) -> CallableType:
  838. """Refine a callable based on another.
  839. See comments for refine_type.
  840. """
  841. if t.fallback != s.fallback:
  842. return t
  843. if t.is_ellipsis_args and not is_tricky_callable(s):
  844. return s.copy_modified(ret_type=refine_type(t.ret_type, s.ret_type))
  845. if is_tricky_callable(t) or t.arg_kinds != s.arg_kinds:
  846. return t
  847. return t.copy_modified(
  848. arg_types=[refine_type(ta, sa) for ta, sa in zip(t.arg_types, s.arg_types)],
  849. ret_type=refine_type(t.ret_type, s.ret_type),
  850. )
  851. T = TypeVar("T")
  852. def dedup(old: list[T]) -> list[T]:
  853. new: list[T] = []
  854. for x in old:
  855. if x not in new:
  856. new.append(x)
  857. return new