fields.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. from django.core.exceptions import FieldDoesNotExist
  2. from django.db.models import NOT_PROVIDED
  3. from django.utils.functional import cached_property
  4. from .base import Operation
  5. from .utils import field_is_referenced, field_references, get_references
  6. class FieldOperation(Operation):
  7. def __init__(self, model_name, name, field=None):
  8. self.model_name = model_name
  9. self.name = name
  10. self.field = field
  11. @cached_property
  12. def model_name_lower(self):
  13. return self.model_name.lower()
  14. @cached_property
  15. def name_lower(self):
  16. return self.name.lower()
  17. def is_same_model_operation(self, operation):
  18. return self.model_name_lower == operation.model_name_lower
  19. def is_same_field_operation(self, operation):
  20. return self.is_same_model_operation(operation) and self.name_lower == operation.name_lower
  21. def references_model(self, name, app_label):
  22. name_lower = name.lower()
  23. if name_lower == self.model_name_lower:
  24. return True
  25. if self.field:
  26. return bool(field_references(
  27. (app_label, self.model_name_lower), self.field, (app_label, name_lower)
  28. ))
  29. return False
  30. def references_field(self, model_name, name, app_label):
  31. model_name_lower = model_name.lower()
  32. # Check if this operation locally references the field.
  33. if model_name_lower == self.model_name_lower:
  34. if name == self.name:
  35. return True
  36. elif self.field and hasattr(self.field, 'from_fields') and name in self.field.from_fields:
  37. return True
  38. # Check if this operation remotely references the field.
  39. if self.field is None:
  40. return False
  41. return bool(field_references(
  42. (app_label, self.model_name_lower),
  43. self.field,
  44. (app_label, model_name_lower),
  45. name,
  46. ))
  47. def reduce(self, operation, app_label):
  48. return (
  49. super().reduce(operation, app_label) or
  50. not operation.references_field(self.model_name, self.name, app_label)
  51. )
  52. class AddField(FieldOperation):
  53. """Add a field to a model."""
  54. def __init__(self, model_name, name, field, preserve_default=True):
  55. self.preserve_default = preserve_default
  56. super().__init__(model_name, name, field)
  57. def deconstruct(self):
  58. kwargs = {
  59. 'model_name': self.model_name,
  60. 'name': self.name,
  61. 'field': self.field,
  62. }
  63. if self.preserve_default is not True:
  64. kwargs['preserve_default'] = self.preserve_default
  65. return (
  66. self.__class__.__name__,
  67. [],
  68. kwargs
  69. )
  70. def state_forwards(self, app_label, state):
  71. # If preserve default is off, don't use the default for future state
  72. if not self.preserve_default:
  73. field = self.field.clone()
  74. field.default = NOT_PROVIDED
  75. else:
  76. field = self.field
  77. state.models[app_label, self.model_name_lower].fields[self.name] = field
  78. # Delay rendering of relationships if it's not a relational field
  79. delay = not field.is_relation
  80. state.reload_model(app_label, self.model_name_lower, delay=delay)
  81. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  82. to_model = to_state.apps.get_model(app_label, self.model_name)
  83. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  84. from_model = from_state.apps.get_model(app_label, self.model_name)
  85. field = to_model._meta.get_field(self.name)
  86. if not self.preserve_default:
  87. field.default = self.field.default
  88. schema_editor.add_field(
  89. from_model,
  90. field,
  91. )
  92. if not self.preserve_default:
  93. field.default = NOT_PROVIDED
  94. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  95. from_model = from_state.apps.get_model(app_label, self.model_name)
  96. if self.allow_migrate_model(schema_editor.connection.alias, from_model):
  97. schema_editor.remove_field(from_model, from_model._meta.get_field(self.name))
  98. def describe(self):
  99. return "Add field %s to %s" % (self.name, self.model_name)
  100. @property
  101. def migration_name_fragment(self):
  102. return '%s_%s' % (self.model_name_lower, self.name_lower)
  103. def reduce(self, operation, app_label):
  104. if isinstance(operation, FieldOperation) and self.is_same_field_operation(operation):
  105. if isinstance(operation, AlterField):
  106. return [
  107. AddField(
  108. model_name=self.model_name,
  109. name=operation.name,
  110. field=operation.field,
  111. ),
  112. ]
  113. elif isinstance(operation, RemoveField):
  114. return []
  115. elif isinstance(operation, RenameField):
  116. return [
  117. AddField(
  118. model_name=self.model_name,
  119. name=operation.new_name,
  120. field=self.field,
  121. ),
  122. ]
  123. return super().reduce(operation, app_label)
  124. class RemoveField(FieldOperation):
  125. """Remove a field from a model."""
  126. def deconstruct(self):
  127. kwargs = {
  128. 'model_name': self.model_name,
  129. 'name': self.name,
  130. }
  131. return (
  132. self.__class__.__name__,
  133. [],
  134. kwargs
  135. )
  136. def state_forwards(self, app_label, state):
  137. model_state = state.models[app_label, self.model_name_lower]
  138. old_field = model_state.fields.pop(self.name)
  139. # Delay rendering of relationships if it's not a relational field
  140. delay = not old_field.is_relation
  141. state.reload_model(app_label, self.model_name_lower, delay=delay)
  142. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  143. from_model = from_state.apps.get_model(app_label, self.model_name)
  144. if self.allow_migrate_model(schema_editor.connection.alias, from_model):
  145. schema_editor.remove_field(from_model, from_model._meta.get_field(self.name))
  146. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  147. to_model = to_state.apps.get_model(app_label, self.model_name)
  148. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  149. from_model = from_state.apps.get_model(app_label, self.model_name)
  150. schema_editor.add_field(from_model, to_model._meta.get_field(self.name))
  151. def describe(self):
  152. return "Remove field %s from %s" % (self.name, self.model_name)
  153. @property
  154. def migration_name_fragment(self):
  155. return 'remove_%s_%s' % (self.model_name_lower, self.name_lower)
  156. def reduce(self, operation, app_label):
  157. from .models import DeleteModel
  158. if isinstance(operation, DeleteModel) and operation.name_lower == self.model_name_lower:
  159. return [operation]
  160. return super().reduce(operation, app_label)
  161. class AlterField(FieldOperation):
  162. """
  163. Alter a field's database column (e.g. null, max_length) to the provided
  164. new field.
  165. """
  166. def __init__(self, model_name, name, field, preserve_default=True):
  167. self.preserve_default = preserve_default
  168. super().__init__(model_name, name, field)
  169. def deconstruct(self):
  170. kwargs = {
  171. 'model_name': self.model_name,
  172. 'name': self.name,
  173. 'field': self.field,
  174. }
  175. if self.preserve_default is not True:
  176. kwargs['preserve_default'] = self.preserve_default
  177. return (
  178. self.__class__.__name__,
  179. [],
  180. kwargs
  181. )
  182. def state_forwards(self, app_label, state):
  183. if not self.preserve_default:
  184. field = self.field.clone()
  185. field.default = NOT_PROVIDED
  186. else:
  187. field = self.field
  188. model_state = state.models[app_label, self.model_name_lower]
  189. model_state.fields[self.name] = field
  190. # TODO: investigate if old relational fields must be reloaded or if it's
  191. # sufficient if the new field is (#27737).
  192. # Delay rendering of relationships if it's not a relational field and
  193. # not referenced by a foreign key.
  194. delay = (
  195. not field.is_relation and
  196. not field_is_referenced(
  197. state, (app_label, self.model_name_lower), (self.name, field),
  198. )
  199. )
  200. state.reload_model(app_label, self.model_name_lower, delay=delay)
  201. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  202. to_model = to_state.apps.get_model(app_label, self.model_name)
  203. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  204. from_model = from_state.apps.get_model(app_label, self.model_name)
  205. from_field = from_model._meta.get_field(self.name)
  206. to_field = to_model._meta.get_field(self.name)
  207. if not self.preserve_default:
  208. to_field.default = self.field.default
  209. schema_editor.alter_field(from_model, from_field, to_field)
  210. if not self.preserve_default:
  211. to_field.default = NOT_PROVIDED
  212. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  213. self.database_forwards(app_label, schema_editor, from_state, to_state)
  214. def describe(self):
  215. return "Alter field %s on %s" % (self.name, self.model_name)
  216. @property
  217. def migration_name_fragment(self):
  218. return 'alter_%s_%s' % (self.model_name_lower, self.name_lower)
  219. def reduce(self, operation, app_label):
  220. if isinstance(operation, RemoveField) and self.is_same_field_operation(operation):
  221. return [operation]
  222. elif isinstance(operation, RenameField) and self.is_same_field_operation(operation):
  223. return [
  224. operation,
  225. AlterField(
  226. model_name=self.model_name,
  227. name=operation.new_name,
  228. field=self.field,
  229. ),
  230. ]
  231. return super().reduce(operation, app_label)
  232. class RenameField(FieldOperation):
  233. """Rename a field on the model. Might affect db_column too."""
  234. def __init__(self, model_name, old_name, new_name):
  235. self.old_name = old_name
  236. self.new_name = new_name
  237. super().__init__(model_name, old_name)
  238. @cached_property
  239. def old_name_lower(self):
  240. return self.old_name.lower()
  241. @cached_property
  242. def new_name_lower(self):
  243. return self.new_name.lower()
  244. def deconstruct(self):
  245. kwargs = {
  246. 'model_name': self.model_name,
  247. 'old_name': self.old_name,
  248. 'new_name': self.new_name,
  249. }
  250. return (
  251. self.__class__.__name__,
  252. [],
  253. kwargs
  254. )
  255. def state_forwards(self, app_label, state):
  256. model_state = state.models[app_label, self.model_name_lower]
  257. # Rename the field
  258. fields = model_state.fields
  259. try:
  260. found = fields.pop(self.old_name)
  261. except KeyError:
  262. raise FieldDoesNotExist(
  263. "%s.%s has no field named '%s'" % (app_label, self.model_name, self.old_name)
  264. )
  265. fields[self.new_name] = found
  266. for field in fields.values():
  267. # Fix from_fields to refer to the new field.
  268. from_fields = getattr(field, 'from_fields', None)
  269. if from_fields:
  270. field.from_fields = tuple([
  271. self.new_name if from_field_name == self.old_name else from_field_name
  272. for from_field_name in from_fields
  273. ])
  274. # Fix index/unique_together to refer to the new field
  275. options = model_state.options
  276. for option in ('index_together', 'unique_together'):
  277. if option in options:
  278. options[option] = [
  279. [self.new_name if n == self.old_name else n for n in together]
  280. for together in options[option]
  281. ]
  282. # Fix to_fields to refer to the new field.
  283. delay = True
  284. references = get_references(
  285. state, (app_label, self.model_name_lower), (self.old_name, found),
  286. )
  287. for *_, field, reference in references:
  288. delay = False
  289. if reference.to:
  290. remote_field, to_fields = reference.to
  291. if getattr(remote_field, 'field_name', None) == self.old_name:
  292. remote_field.field_name = self.new_name
  293. if to_fields:
  294. field.to_fields = tuple([
  295. self.new_name if to_field_name == self.old_name else to_field_name
  296. for to_field_name in to_fields
  297. ])
  298. state.reload_model(app_label, self.model_name_lower, delay=delay)
  299. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  300. to_model = to_state.apps.get_model(app_label, self.model_name)
  301. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  302. from_model = from_state.apps.get_model(app_label, self.model_name)
  303. schema_editor.alter_field(
  304. from_model,
  305. from_model._meta.get_field(self.old_name),
  306. to_model._meta.get_field(self.new_name),
  307. )
  308. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  309. to_model = to_state.apps.get_model(app_label, self.model_name)
  310. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  311. from_model = from_state.apps.get_model(app_label, self.model_name)
  312. schema_editor.alter_field(
  313. from_model,
  314. from_model._meta.get_field(self.new_name),
  315. to_model._meta.get_field(self.old_name),
  316. )
  317. def describe(self):
  318. return "Rename field %s on %s to %s" % (self.old_name, self.model_name, self.new_name)
  319. @property
  320. def migration_name_fragment(self):
  321. return 'rename_%s_%s_%s' % (
  322. self.old_name_lower,
  323. self.model_name_lower,
  324. self.new_name_lower,
  325. )
  326. def references_field(self, model_name, name, app_label):
  327. return self.references_model(model_name, app_label) and (
  328. name.lower() == self.old_name_lower or
  329. name.lower() == self.new_name_lower
  330. )
  331. def reduce(self, operation, app_label):
  332. if (isinstance(operation, RenameField) and
  333. self.is_same_model_operation(operation) and
  334. self.new_name_lower == operation.old_name_lower):
  335. return [
  336. RenameField(
  337. self.model_name,
  338. self.old_name,
  339. operation.new_name,
  340. ),
  341. ]
  342. # Skip `FieldOperation.reduce` as we want to run `references_field`
  343. # against self.new_name.
  344. return (
  345. super(FieldOperation, self).reduce(operation, app_label) or
  346. not operation.references_field(self.model_name, self.new_name, app_label)
  347. )