operations.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. from django.contrib.postgres.signals import (
  2. get_citext_oids, get_hstore_oids, register_type_handlers,
  3. )
  4. from django.db import NotSupportedError, router
  5. from django.db.migrations import AddIndex, RemoveIndex
  6. from django.db.migrations.operations.base import Operation
  7. class CreateExtension(Operation):
  8. reversible = True
  9. def __init__(self, name):
  10. self.name = name
  11. def state_forwards(self, app_label, state):
  12. pass
  13. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  14. if (
  15. schema_editor.connection.vendor != 'postgresql' or
  16. not router.allow_migrate(schema_editor.connection.alias, app_label)
  17. ):
  18. return
  19. if not self.extension_exists(schema_editor, self.name):
  20. schema_editor.execute(
  21. 'CREATE EXTENSION IF NOT EXISTS %s' % schema_editor.quote_name(self.name)
  22. )
  23. # Clear cached, stale oids.
  24. get_hstore_oids.cache_clear()
  25. get_citext_oids.cache_clear()
  26. # Registering new type handlers cannot be done before the extension is
  27. # installed, otherwise a subsequent data migration would use the same
  28. # connection.
  29. register_type_handlers(schema_editor.connection)
  30. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  31. if not router.allow_migrate(schema_editor.connection.alias, app_label):
  32. return
  33. if self.extension_exists(schema_editor, self.name):
  34. schema_editor.execute(
  35. 'DROP EXTENSION IF EXISTS %s' % schema_editor.quote_name(self.name)
  36. )
  37. # Clear cached, stale oids.
  38. get_hstore_oids.cache_clear()
  39. get_citext_oids.cache_clear()
  40. def extension_exists(self, schema_editor, extension):
  41. with schema_editor.connection.cursor() as cursor:
  42. cursor.execute(
  43. 'SELECT 1 FROM pg_extension WHERE extname = %s',
  44. [extension],
  45. )
  46. return bool(cursor.fetchone())
  47. def describe(self):
  48. return "Creates extension %s" % self.name
  49. @property
  50. def migration_name_fragment(self):
  51. return 'create_extension_%s' % self.name
  52. class BloomExtension(CreateExtension):
  53. def __init__(self):
  54. self.name = 'bloom'
  55. class BtreeGinExtension(CreateExtension):
  56. def __init__(self):
  57. self.name = 'btree_gin'
  58. class BtreeGistExtension(CreateExtension):
  59. def __init__(self):
  60. self.name = 'btree_gist'
  61. class CITextExtension(CreateExtension):
  62. def __init__(self):
  63. self.name = 'citext'
  64. class CryptoExtension(CreateExtension):
  65. def __init__(self):
  66. self.name = 'pgcrypto'
  67. class HStoreExtension(CreateExtension):
  68. def __init__(self):
  69. self.name = 'hstore'
  70. class TrigramExtension(CreateExtension):
  71. def __init__(self):
  72. self.name = 'pg_trgm'
  73. class UnaccentExtension(CreateExtension):
  74. def __init__(self):
  75. self.name = 'unaccent'
  76. class NotInTransactionMixin:
  77. def _ensure_not_in_transaction(self, schema_editor):
  78. if schema_editor.connection.in_atomic_block:
  79. raise NotSupportedError(
  80. 'The %s operation cannot be executed inside a transaction '
  81. '(set atomic = False on the migration).'
  82. % self.__class__.__name__
  83. )
  84. class AddIndexConcurrently(NotInTransactionMixin, AddIndex):
  85. """Create an index using PostgreSQL's CREATE INDEX CONCURRENTLY syntax."""
  86. atomic = False
  87. def describe(self):
  88. return 'Concurrently create index %s on field(s) %s of model %s' % (
  89. self.index.name,
  90. ', '.join(self.index.fields),
  91. self.model_name,
  92. )
  93. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  94. self._ensure_not_in_transaction(schema_editor)
  95. model = to_state.apps.get_model(app_label, self.model_name)
  96. if self.allow_migrate_model(schema_editor.connection.alias, model):
  97. schema_editor.add_index(model, self.index, concurrently=True)
  98. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  99. self._ensure_not_in_transaction(schema_editor)
  100. model = from_state.apps.get_model(app_label, self.model_name)
  101. if self.allow_migrate_model(schema_editor.connection.alias, model):
  102. schema_editor.remove_index(model, self.index, concurrently=True)
  103. class RemoveIndexConcurrently(NotInTransactionMixin, RemoveIndex):
  104. """Remove an index using PostgreSQL's DROP INDEX CONCURRENTLY syntax."""
  105. atomic = False
  106. def describe(self):
  107. return 'Concurrently remove index %s from %s' % (self.name, self.model_name)
  108. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  109. self._ensure_not_in_transaction(schema_editor)
  110. model = from_state.apps.get_model(app_label, self.model_name)
  111. if self.allow_migrate_model(schema_editor.connection.alias, model):
  112. from_model_state = from_state.models[app_label, self.model_name_lower]
  113. index = from_model_state.get_index_by_name(self.name)
  114. schema_editor.remove_index(model, index, concurrently=True)
  115. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  116. self._ensure_not_in_transaction(schema_editor)
  117. model = to_state.apps.get_model(app_label, self.model_name)
  118. if self.allow_migrate_model(schema_editor.connection.alias, model):
  119. to_model_state = to_state.models[app_label, self.model_name_lower]
  120. index = to_model_state.get_index_by_name(self.name)
  121. schema_editor.add_index(model, index, concurrently=True)
  122. class CollationOperation(Operation):
  123. def __init__(self, name, locale, *, provider='libc', deterministic=True):
  124. self.name = name
  125. self.locale = locale
  126. self.provider = provider
  127. self.deterministic = deterministic
  128. def state_forwards(self, app_label, state):
  129. pass
  130. def deconstruct(self):
  131. kwargs = {'name': self.name, 'locale': self.locale}
  132. if self.provider and self.provider != 'libc':
  133. kwargs['provider'] = self.provider
  134. if self.deterministic is False:
  135. kwargs['deterministic'] = self.deterministic
  136. return (
  137. self.__class__.__qualname__,
  138. [],
  139. kwargs,
  140. )
  141. def create_collation(self, schema_editor):
  142. if (
  143. self.deterministic is False and
  144. not schema_editor.connection.features.supports_non_deterministic_collations
  145. ):
  146. raise NotSupportedError(
  147. 'Non-deterministic collations require PostgreSQL 12+.'
  148. )
  149. if (
  150. self.provider != 'libc' and
  151. not schema_editor.connection.features.supports_alternate_collation_providers
  152. ):
  153. raise NotSupportedError('Non-libc providers require PostgreSQL 10+.')
  154. args = {'locale': schema_editor.quote_name(self.locale)}
  155. if self.provider != 'libc':
  156. args['provider'] = schema_editor.quote_name(self.provider)
  157. if self.deterministic is False:
  158. args['deterministic'] = 'false'
  159. schema_editor.execute('CREATE COLLATION %(name)s (%(args)s)' % {
  160. 'name': schema_editor.quote_name(self.name),
  161. 'args': ', '.join(f'{option}={value}' for option, value in args.items()),
  162. })
  163. def remove_collation(self, schema_editor):
  164. schema_editor.execute(
  165. 'DROP COLLATION %s' % schema_editor.quote_name(self.name),
  166. )
  167. class CreateCollation(CollationOperation):
  168. """Create a collation."""
  169. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  170. if (
  171. schema_editor.connection.vendor != 'postgresql' or
  172. not router.allow_migrate(schema_editor.connection.alias, app_label)
  173. ):
  174. return
  175. self.create_collation(schema_editor)
  176. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  177. if not router.allow_migrate(schema_editor.connection.alias, app_label):
  178. return
  179. self.remove_collation(schema_editor)
  180. def describe(self):
  181. return f'Create collation {self.name}'
  182. @property
  183. def migration_name_fragment(self):
  184. return 'create_collation_%s' % self.name.lower()
  185. class RemoveCollation(CollationOperation):
  186. """Remove a collation."""
  187. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  188. if (
  189. schema_editor.connection.vendor != 'postgresql' or
  190. not router.allow_migrate(schema_editor.connection.alias, app_label)
  191. ):
  192. return
  193. self.remove_collation(schema_editor)
  194. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  195. if not router.allow_migrate(schema_editor.connection.alias, app_label):
  196. return
  197. self.create_collation(schema_editor)
  198. def describe(self):
  199. return f'Remove collation {self.name}'
  200. @property
  201. def migration_name_fragment(self):
  202. return 'remove_collation_%s' % self.name.lower()