fields.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from astroid import MANAGER, AstroidImportError, inference_tip, nodes
  2. from astroid.nodes import scoped_nodes
  3. from pylint_django import utils
  4. _STR_FIELDS = (
  5. "CharField",
  6. "SlugField",
  7. "URLField",
  8. "TextField",
  9. "EmailField",
  10. "CommaSeparatedIntegerField",
  11. "FilePathField",
  12. "GenericIPAddressField",
  13. "IPAddressField",
  14. "RegexField",
  15. "SlugField",
  16. )
  17. _INT_FIELDS = (
  18. "IntegerField",
  19. "SmallIntegerField",
  20. "BigIntegerField",
  21. "PositiveIntegerField",
  22. "PositiveSmallIntegerField",
  23. )
  24. _BOOL_FIELDS = ("BooleanField", "NullBooleanField")
  25. _RANGE_FIELDS = (
  26. "RangeField",
  27. "IntegerRangeField",
  28. "BigIntegerRangeField",
  29. "FloatRangeField",
  30. "DateTimeRangeField",
  31. "DateRangeField",
  32. )
  33. def is_model_field(cls):
  34. return cls.qname().startswith("django.db.models.fields") or cls.qname().startswith("django.contrib.postgres.fields")
  35. def is_form_field(cls):
  36. return cls.qname().startswith("django.forms.fields")
  37. def is_model_or_form_field(cls):
  38. return is_model_field(cls) or is_form_field(cls)
  39. def apply_type_shim(cls, _context=None): # noqa
  40. if cls.name in _STR_FIELDS:
  41. base_nodes = scoped_nodes.builtin_lookup("str")
  42. elif cls.name in _INT_FIELDS:
  43. base_nodes = scoped_nodes.builtin_lookup("int")
  44. elif cls.name in _BOOL_FIELDS:
  45. base_nodes = scoped_nodes.builtin_lookup("bool")
  46. elif cls.name == "FloatField":
  47. base_nodes = scoped_nodes.builtin_lookup("float")
  48. elif cls.name == "DecimalField":
  49. try:
  50. base_nodes = MANAGER.ast_from_module_name("_decimal").lookup("Decimal")
  51. except AstroidImportError:
  52. base_nodes = MANAGER.ast_from_module_name("_pydecimal").lookup("Decimal")
  53. elif cls.name in ("SplitDateTimeField", "DateTimeField"):
  54. base_nodes = MANAGER.ast_from_module_name("datetime").lookup("datetime")
  55. elif cls.name == "TimeField":
  56. base_nodes = MANAGER.ast_from_module_name("datetime").lookup("time")
  57. elif cls.name == "DateField":
  58. base_nodes = MANAGER.ast_from_module_name("datetime").lookup("date")
  59. elif cls.name == "DurationField":
  60. base_nodes = MANAGER.ast_from_module_name("datetime").lookup("timedelta")
  61. elif cls.name == "UUIDField":
  62. base_nodes = MANAGER.ast_from_module_name("uuid").lookup("UUID")
  63. elif cls.name == "ManyToManyField":
  64. base_nodes = MANAGER.ast_from_module_name("django.db.models.query").lookup("QuerySet")
  65. elif cls.name in ("ImageField", "FileField"):
  66. base_nodes = MANAGER.ast_from_module_name("django.core.files.base").lookup("File")
  67. elif cls.name == "ArrayField":
  68. base_nodes = scoped_nodes.builtin_lookup("list")
  69. elif cls.name in ("HStoreField", "JSONField"):
  70. base_nodes = scoped_nodes.builtin_lookup("dict")
  71. elif cls.name in _RANGE_FIELDS:
  72. base_nodes = MANAGER.ast_from_module_name("psycopg2._range").lookup("Range")
  73. else:
  74. return iter([cls])
  75. # XXX: for some reason, with python3, this particular line triggers a
  76. # check in the StdlibChecker for deprecated methods; one of these nodes
  77. # is an ImportFrom which has no qname() method, causing the checker
  78. # to die...
  79. if utils.PY3:
  80. base_nodes = [_valid_base_node(n, _context) for n in base_nodes[1]]
  81. base_nodes = [n for n in base_nodes if n]
  82. else:
  83. base_nodes = list(base_nodes[1])
  84. return iter([cls] + base_nodes)
  85. def _valid_base_node(node, context):
  86. """Attempts to convert `node` to a valid base node, returns None if it cannot."""
  87. if isinstance(node, nodes.AssignAttr):
  88. inferred = next(node.parent.value.infer(context), None)
  89. if inferred and isinstance(node, nodes.ClassDef):
  90. return inferred
  91. return None
  92. if isinstance(node, nodes.ImportFrom):
  93. return None
  94. return node
  95. def add_transforms(manager):
  96. manager.register_transform(nodes.ClassDef, inference_tip(apply_type_shim), is_model_or_form_field)