ddl_references.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. """
  2. Helpers to manipulate deferred DDL statements that might need to be adjusted or
  3. discarded within when executing a migration.
  4. """
  5. from copy import deepcopy
  6. class Reference:
  7. """Base class that defines the reference interface."""
  8. def references_table(self, table):
  9. """
  10. Return whether or not this instance references the specified table.
  11. """
  12. return False
  13. def references_column(self, table, column):
  14. """
  15. Return whether or not this instance references the specified column.
  16. """
  17. return False
  18. def rename_table_references(self, old_table, new_table):
  19. """
  20. Rename all references to the old_name to the new_table.
  21. """
  22. pass
  23. def rename_column_references(self, table, old_column, new_column):
  24. """
  25. Rename all references to the old_column to the new_column.
  26. """
  27. pass
  28. def __repr__(self):
  29. return '<%s %r>' % (self.__class__.__name__, str(self))
  30. def __str__(self):
  31. raise NotImplementedError('Subclasses must define how they should be converted to string.')
  32. class Table(Reference):
  33. """Hold a reference to a table."""
  34. def __init__(self, table, quote_name):
  35. self.table = table
  36. self.quote_name = quote_name
  37. def references_table(self, table):
  38. return self.table == table
  39. def rename_table_references(self, old_table, new_table):
  40. if self.table == old_table:
  41. self.table = new_table
  42. def __str__(self):
  43. return self.quote_name(self.table)
  44. class TableColumns(Table):
  45. """Base class for references to multiple columns of a table."""
  46. def __init__(self, table, columns):
  47. self.table = table
  48. self.columns = columns
  49. def references_column(self, table, column):
  50. return self.table == table and column in self.columns
  51. def rename_column_references(self, table, old_column, new_column):
  52. if self.table == table:
  53. for index, column in enumerate(self.columns):
  54. if column == old_column:
  55. self.columns[index] = new_column
  56. class Columns(TableColumns):
  57. """Hold a reference to one or many columns."""
  58. def __init__(self, table, columns, quote_name, col_suffixes=()):
  59. self.quote_name = quote_name
  60. self.col_suffixes = col_suffixes
  61. super().__init__(table, columns)
  62. def __str__(self):
  63. def col_str(column, idx):
  64. col = self.quote_name(column)
  65. try:
  66. suffix = self.col_suffixes[idx]
  67. if suffix:
  68. col = '{} {}'.format(col, suffix)
  69. except IndexError:
  70. pass
  71. return col
  72. return ', '.join(col_str(column, idx) for idx, column in enumerate(self.columns))
  73. class IndexName(TableColumns):
  74. """Hold a reference to an index name."""
  75. def __init__(self, table, columns, suffix, create_index_name):
  76. self.suffix = suffix
  77. self.create_index_name = create_index_name
  78. super().__init__(table, columns)
  79. def __str__(self):
  80. return self.create_index_name(self.table, self.columns, self.suffix)
  81. class IndexColumns(Columns):
  82. def __init__(self, table, columns, quote_name, col_suffixes=(), opclasses=()):
  83. self.opclasses = opclasses
  84. super().__init__(table, columns, quote_name, col_suffixes)
  85. def __str__(self):
  86. def col_str(column, idx):
  87. # Index.__init__() guarantees that self.opclasses is the same
  88. # length as self.columns.
  89. col = '{} {}'.format(self.quote_name(column), self.opclasses[idx])
  90. try:
  91. suffix = self.col_suffixes[idx]
  92. if suffix:
  93. col = '{} {}'.format(col, suffix)
  94. except IndexError:
  95. pass
  96. return col
  97. return ', '.join(col_str(column, idx) for idx, column in enumerate(self.columns))
  98. class ForeignKeyName(TableColumns):
  99. """Hold a reference to a foreign key name."""
  100. def __init__(self, from_table, from_columns, to_table, to_columns, suffix_template, create_fk_name):
  101. self.to_reference = TableColumns(to_table, to_columns)
  102. self.suffix_template = suffix_template
  103. self.create_fk_name = create_fk_name
  104. super().__init__(from_table, from_columns,)
  105. def references_table(self, table):
  106. return super().references_table(table) or self.to_reference.references_table(table)
  107. def references_column(self, table, column):
  108. return (
  109. super().references_column(table, column) or
  110. self.to_reference.references_column(table, column)
  111. )
  112. def rename_table_references(self, old_table, new_table):
  113. super().rename_table_references(old_table, new_table)
  114. self.to_reference.rename_table_references(old_table, new_table)
  115. def rename_column_references(self, table, old_column, new_column):
  116. super().rename_column_references(table, old_column, new_column)
  117. self.to_reference.rename_column_references(table, old_column, new_column)
  118. def __str__(self):
  119. suffix = self.suffix_template % {
  120. 'to_table': self.to_reference.table,
  121. 'to_column': self.to_reference.columns[0],
  122. }
  123. return self.create_fk_name(self.table, self.columns, suffix)
  124. class Statement(Reference):
  125. """
  126. Statement template and formatting parameters container.
  127. Allows keeping a reference to a statement without interpolating identifiers
  128. that might have to be adjusted if they're referencing a table or column
  129. that is removed
  130. """
  131. def __init__(self, template, **parts):
  132. self.template = template
  133. self.parts = parts
  134. def references_table(self, table):
  135. return any(
  136. hasattr(part, 'references_table') and part.references_table(table)
  137. for part in self.parts.values()
  138. )
  139. def references_column(self, table, column):
  140. return any(
  141. hasattr(part, 'references_column') and part.references_column(table, column)
  142. for part in self.parts.values()
  143. )
  144. def rename_table_references(self, old_table, new_table):
  145. for part in self.parts.values():
  146. if hasattr(part, 'rename_table_references'):
  147. part.rename_table_references(old_table, new_table)
  148. def rename_column_references(self, table, old_column, new_column):
  149. for part in self.parts.values():
  150. if hasattr(part, 'rename_column_references'):
  151. part.rename_column_references(table, old_column, new_column)
  152. def __str__(self):
  153. return self.template % self.parts
  154. class Expressions(TableColumns):
  155. def __init__(self, table, expressions, compiler, quote_value):
  156. self.compiler = compiler
  157. self.expressions = expressions
  158. self.quote_value = quote_value
  159. columns = [col.target.column for col in self.compiler.query._gen_cols([self.expressions])]
  160. super().__init__(table, columns)
  161. def rename_table_references(self, old_table, new_table):
  162. if self.table != old_table:
  163. return
  164. self.expressions = self.expressions.relabeled_clone({old_table: new_table})
  165. super().rename_table_references(old_table, new_table)
  166. def rename_column_references(self, table, old_column, new_column):
  167. if self.table != table:
  168. return
  169. expressions = deepcopy(self.expressions)
  170. self.columns = []
  171. for col in self.compiler.query._gen_cols([expressions]):
  172. if col.target.column == old_column:
  173. col.target.column = new_column
  174. self.columns.append(col.target.column)
  175. self.expressions = expressions
  176. def __str__(self):
  177. sql, params = self.compiler.compile(self.expressions)
  178. params = map(self.quote_value, params)
  179. return sql % tuple(params)