indexes.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. from django.db.backends.utils import names_digest, split_identifier
  2. from django.db.models.expressions import Col, ExpressionList, F, Func, OrderBy
  3. from django.db.models.functions import Collate
  4. from django.db.models.query_utils import Q
  5. from django.db.models.sql import Query
  6. from django.utils.functional import partition
  7. __all__ = ['Index']
  8. class Index:
  9. suffix = 'idx'
  10. # The max length of the name of the index (restricted to 30 for
  11. # cross-database compatibility with Oracle)
  12. max_name_length = 30
  13. def __init__(
  14. self,
  15. *expressions,
  16. fields=(),
  17. name=None,
  18. db_tablespace=None,
  19. opclasses=(),
  20. condition=None,
  21. include=None,
  22. ):
  23. if opclasses and not name:
  24. raise ValueError('An index must be named to use opclasses.')
  25. if not isinstance(condition, (type(None), Q)):
  26. raise ValueError('Index.condition must be a Q instance.')
  27. if condition and not name:
  28. raise ValueError('An index must be named to use condition.')
  29. if not isinstance(fields, (list, tuple)):
  30. raise ValueError('Index.fields must be a list or tuple.')
  31. if not isinstance(opclasses, (list, tuple)):
  32. raise ValueError('Index.opclasses must be a list or tuple.')
  33. if not expressions and not fields:
  34. raise ValueError(
  35. 'At least one field or expression is required to define an '
  36. 'index.'
  37. )
  38. if expressions and fields:
  39. raise ValueError(
  40. 'Index.fields and expressions are mutually exclusive.',
  41. )
  42. if expressions and not name:
  43. raise ValueError('An index must be named to use expressions.')
  44. if expressions and opclasses:
  45. raise ValueError(
  46. 'Index.opclasses cannot be used with expressions. Use '
  47. 'django.contrib.postgres.indexes.OpClass() instead.'
  48. )
  49. if opclasses and len(fields) != len(opclasses):
  50. raise ValueError('Index.fields and Index.opclasses must have the same number of elements.')
  51. if fields and not all(isinstance(field, str) for field in fields):
  52. raise ValueError('Index.fields must contain only strings with field names.')
  53. if include and not name:
  54. raise ValueError('A covering index must be named.')
  55. if not isinstance(include, (type(None), list, tuple)):
  56. raise ValueError('Index.include must be a list or tuple.')
  57. self.fields = list(fields)
  58. # A list of 2-tuple with the field name and ordering ('' or 'DESC').
  59. self.fields_orders = [
  60. (field_name[1:], 'DESC') if field_name.startswith('-') else (field_name, '')
  61. for field_name in self.fields
  62. ]
  63. self.name = name or ''
  64. self.db_tablespace = db_tablespace
  65. self.opclasses = opclasses
  66. self.condition = condition
  67. self.include = tuple(include) if include else ()
  68. self.expressions = tuple(
  69. F(expression) if isinstance(expression, str) else expression
  70. for expression in expressions
  71. )
  72. @property
  73. def contains_expressions(self):
  74. return bool(self.expressions)
  75. def _get_condition_sql(self, model, schema_editor):
  76. if self.condition is None:
  77. return None
  78. query = Query(model=model, alias_cols=False)
  79. where = query.build_where(self.condition)
  80. compiler = query.get_compiler(connection=schema_editor.connection)
  81. sql, params = where.as_sql(compiler, schema_editor.connection)
  82. return sql % tuple(schema_editor.quote_value(p) for p in params)
  83. def create_sql(self, model, schema_editor, using='', **kwargs):
  84. include = [model._meta.get_field(field_name).column for field_name in self.include]
  85. condition = self._get_condition_sql(model, schema_editor)
  86. if self.expressions:
  87. index_expressions = []
  88. for expression in self.expressions:
  89. index_expression = IndexExpression(expression)
  90. index_expression.set_wrapper_classes(schema_editor.connection)
  91. index_expressions.append(index_expression)
  92. expressions = ExpressionList(*index_expressions).resolve_expression(
  93. Query(model, alias_cols=False),
  94. )
  95. fields = None
  96. col_suffixes = None
  97. else:
  98. fields = [
  99. model._meta.get_field(field_name)
  100. for field_name, _ in self.fields_orders
  101. ]
  102. col_suffixes = [order[1] for order in self.fields_orders]
  103. expressions = None
  104. return schema_editor._create_index_sql(
  105. model, fields=fields, name=self.name, using=using,
  106. db_tablespace=self.db_tablespace, col_suffixes=col_suffixes,
  107. opclasses=self.opclasses, condition=condition, include=include,
  108. expressions=expressions, **kwargs,
  109. )
  110. def remove_sql(self, model, schema_editor, **kwargs):
  111. return schema_editor._delete_index_sql(model, self.name, **kwargs)
  112. def deconstruct(self):
  113. path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__)
  114. path = path.replace('django.db.models.indexes', 'django.db.models')
  115. kwargs = {'name': self.name}
  116. if self.fields:
  117. kwargs['fields'] = self.fields
  118. if self.db_tablespace is not None:
  119. kwargs['db_tablespace'] = self.db_tablespace
  120. if self.opclasses:
  121. kwargs['opclasses'] = self.opclasses
  122. if self.condition:
  123. kwargs['condition'] = self.condition
  124. if self.include:
  125. kwargs['include'] = self.include
  126. return (path, self.expressions, kwargs)
  127. def clone(self):
  128. """Create a copy of this Index."""
  129. _, args, kwargs = self.deconstruct()
  130. return self.__class__(*args, **kwargs)
  131. def set_name_with_model(self, model):
  132. """
  133. Generate a unique name for the index.
  134. The name is divided into 3 parts - table name (12 chars), field name
  135. (8 chars) and unique hash + suffix (10 chars). Each part is made to
  136. fit its size by truncating the excess length.
  137. """
  138. _, table_name = split_identifier(model._meta.db_table)
  139. column_names = [model._meta.get_field(field_name).column for field_name, order in self.fields_orders]
  140. column_names_with_order = [
  141. (('-%s' if order else '%s') % column_name)
  142. for column_name, (field_name, order) in zip(column_names, self.fields_orders)
  143. ]
  144. # The length of the parts of the name is based on the default max
  145. # length of 30 characters.
  146. hash_data = [table_name] + column_names_with_order + [self.suffix]
  147. self.name = '%s_%s_%s' % (
  148. table_name[:11],
  149. column_names[0][:7],
  150. '%s_%s' % (names_digest(*hash_data, length=6), self.suffix),
  151. )
  152. assert len(self.name) <= self.max_name_length, (
  153. 'Index too long for multiple database support. Is self.suffix '
  154. 'longer than 3 characters?'
  155. )
  156. if self.name[0] == '_' or self.name[0].isdigit():
  157. self.name = 'D%s' % self.name[1:]
  158. def __repr__(self):
  159. return '<%s:%s%s%s%s%s>' % (
  160. self.__class__.__name__,
  161. '' if not self.fields else " fields='%s'" % ', '.join(self.fields),
  162. '' if not self.expressions else " expressions='%s'" % ', '.join([
  163. str(expression) for expression in self.expressions
  164. ]),
  165. '' if self.condition is None else ' condition=%s' % self.condition,
  166. '' if not self.include else " include='%s'" % ', '.join(self.include),
  167. '' if not self.opclasses else " opclasses='%s'" % ', '.join(self.opclasses),
  168. )
  169. def __eq__(self, other):
  170. if self.__class__ == other.__class__:
  171. return self.deconstruct() == other.deconstruct()
  172. return NotImplemented
  173. class IndexExpression(Func):
  174. """Order and wrap expressions for CREATE INDEX statements."""
  175. template = '%(expressions)s'
  176. wrapper_classes = (OrderBy, Collate)
  177. def set_wrapper_classes(self, connection=None):
  178. # Some databases (e.g. MySQL) treats COLLATE as an indexed expression.
  179. if connection and connection.features.collate_as_index_expression:
  180. self.wrapper_classes = tuple([
  181. wrapper_cls
  182. for wrapper_cls in self.wrapper_classes
  183. if wrapper_cls is not Collate
  184. ])
  185. @classmethod
  186. def register_wrappers(cls, *wrapper_classes):
  187. cls.wrapper_classes = wrapper_classes
  188. def resolve_expression(
  189. self,
  190. query=None,
  191. allow_joins=True,
  192. reuse=None,
  193. summarize=False,
  194. for_save=False,
  195. ):
  196. expressions = list(self.flatten())
  197. # Split expressions and wrappers.
  198. index_expressions, wrappers = partition(
  199. lambda e: isinstance(e, self.wrapper_classes),
  200. expressions,
  201. )
  202. wrapper_types = [type(wrapper) for wrapper in wrappers]
  203. if len(wrapper_types) != len(set(wrapper_types)):
  204. raise ValueError(
  205. "Multiple references to %s can't be used in an indexed "
  206. "expression." % ', '.join([
  207. wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes
  208. ])
  209. )
  210. if expressions[1:len(wrappers) + 1] != wrappers:
  211. raise ValueError(
  212. '%s must be topmost expressions in an indexed expression.'
  213. % ', '.join([
  214. wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes
  215. ])
  216. )
  217. # Wrap expressions in parentheses if they are not column references.
  218. root_expression = index_expressions[1]
  219. resolve_root_expression = root_expression.resolve_expression(
  220. query,
  221. allow_joins,
  222. reuse,
  223. summarize,
  224. for_save,
  225. )
  226. if not isinstance(resolve_root_expression, Col):
  227. root_expression = Func(root_expression, template='(%(expressions)s)')
  228. if wrappers:
  229. # Order wrappers and set their expressions.
  230. wrappers = sorted(
  231. wrappers,
  232. key=lambda w: self.wrapper_classes.index(type(w)),
  233. )
  234. wrappers = [wrapper.copy() for wrapper in wrappers]
  235. for i, wrapper in enumerate(wrappers[:-1]):
  236. wrapper.set_source_expressions([wrappers[i + 1]])
  237. # Set the root expression on the deepest wrapper.
  238. wrappers[-1].set_source_expressions([root_expression])
  239. self.set_source_expressions([wrappers[0]])
  240. else:
  241. # Use the root expression, if there are no wrappers.
  242. self.set_source_expressions([root_expression])
  243. return super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
  244. def as_sqlite(self, compiler, connection, **extra_context):
  245. # Casting to numeric is unnecessary.
  246. return self.as_sql(compiler, connection, **extra_context)