comparison.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. """Database functions that do comparisons or type conversions."""
  2. from django.db import NotSupportedError
  3. from django.db.models.expressions import Func, Value
  4. from django.db.models.fields.json import JSONField
  5. from django.utils.regex_helper import _lazy_re_compile
  6. class Cast(Func):
  7. """Coerce an expression to a new field type."""
  8. function = 'CAST'
  9. template = '%(function)s(%(expressions)s AS %(db_type)s)'
  10. def __init__(self, expression, output_field):
  11. super().__init__(expression, output_field=output_field)
  12. def as_sql(self, compiler, connection, **extra_context):
  13. extra_context['db_type'] = self.output_field.cast_db_type(connection)
  14. return super().as_sql(compiler, connection, **extra_context)
  15. def as_sqlite(self, compiler, connection, **extra_context):
  16. db_type = self.output_field.db_type(connection)
  17. if db_type in {'datetime', 'time'}:
  18. # Use strftime as datetime/time don't keep fractional seconds.
  19. template = 'strftime(%%s, %(expressions)s)'
  20. sql, params = super().as_sql(compiler, connection, template=template, **extra_context)
  21. format_string = '%H:%M:%f' if db_type == 'time' else '%Y-%m-%d %H:%M:%f'
  22. params.insert(0, format_string)
  23. return sql, params
  24. elif db_type == 'date':
  25. template = 'date(%(expressions)s)'
  26. return super().as_sql(compiler, connection, template=template, **extra_context)
  27. return self.as_sql(compiler, connection, **extra_context)
  28. def as_mysql(self, compiler, connection, **extra_context):
  29. template = None
  30. output_type = self.output_field.get_internal_type()
  31. # MySQL doesn't support explicit cast to float.
  32. if output_type == 'FloatField':
  33. template = '(%(expressions)s + 0.0)'
  34. # MariaDB doesn't support explicit cast to JSON.
  35. elif output_type == 'JSONField' and connection.mysql_is_mariadb:
  36. template = "JSON_EXTRACT(%(expressions)s, '$')"
  37. return self.as_sql(compiler, connection, template=template, **extra_context)
  38. def as_oracle(self, compiler, connection, **extra_context):
  39. if self.output_field.get_internal_type() == 'JSONField':
  40. # Oracle doesn't support explicit cast to JSON.
  41. template = "JSON_QUERY(%(expressions)s, '$')"
  42. return super().as_sql(compiler, connection, template=template, **extra_context)
  43. return self.as_sql(compiler, connection, **extra_context)
  44. class Coalesce(Func):
  45. """Return, from left to right, the first non-null expression."""
  46. function = 'COALESCE'
  47. def __init__(self, *expressions, **extra):
  48. if len(expressions) < 2:
  49. raise ValueError('Coalesce must take at least two expressions')
  50. super().__init__(*expressions, **extra)
  51. def as_oracle(self, compiler, connection, **extra_context):
  52. # Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2),
  53. # so convert all fields to NCLOB when that type is expected.
  54. if self.output_field.get_internal_type() == 'TextField':
  55. clone = self.copy()
  56. clone.set_source_expressions([
  57. Func(expression, function='TO_NCLOB') for expression in self.get_source_expressions()
  58. ])
  59. return super(Coalesce, clone).as_sql(compiler, connection, **extra_context)
  60. return self.as_sql(compiler, connection, **extra_context)
  61. class Collate(Func):
  62. function = 'COLLATE'
  63. template = '%(expressions)s %(function)s %(collation)s'
  64. # Inspired from https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
  65. collation_re = _lazy_re_compile(r'^[\w\-]+$')
  66. def __init__(self, expression, collation):
  67. if not (collation and self.collation_re.match(collation)):
  68. raise ValueError('Invalid collation name: %r.' % collation)
  69. self.collation = collation
  70. super().__init__(expression)
  71. def as_sql(self, compiler, connection, **extra_context):
  72. extra_context.setdefault('collation', connection.ops.quote_name(self.collation))
  73. return super().as_sql(compiler, connection, **extra_context)
  74. class Greatest(Func):
  75. """
  76. Return the maximum expression.
  77. If any expression is null the return value is database-specific:
  78. On PostgreSQL, the maximum not-null expression is returned.
  79. On MySQL, Oracle, and SQLite, if any expression is null, null is returned.
  80. """
  81. function = 'GREATEST'
  82. def __init__(self, *expressions, **extra):
  83. if len(expressions) < 2:
  84. raise ValueError('Greatest must take at least two expressions')
  85. super().__init__(*expressions, **extra)
  86. def as_sqlite(self, compiler, connection, **extra_context):
  87. """Use the MAX function on SQLite."""
  88. return super().as_sqlite(compiler, connection, function='MAX', **extra_context)
  89. class JSONObject(Func):
  90. function = 'JSON_OBJECT'
  91. output_field = JSONField()
  92. def __init__(self, **fields):
  93. expressions = []
  94. for key, value in fields.items():
  95. expressions.extend((Value(key), value))
  96. super().__init__(*expressions)
  97. def as_sql(self, compiler, connection, **extra_context):
  98. if not connection.features.has_json_object_function:
  99. raise NotSupportedError(
  100. 'JSONObject() is not supported on this database backend.'
  101. )
  102. return super().as_sql(compiler, connection, **extra_context)
  103. def as_postgresql(self, compiler, connection, **extra_context):
  104. return self.as_sql(
  105. compiler,
  106. connection,
  107. function='JSONB_BUILD_OBJECT',
  108. **extra_context,
  109. )
  110. def as_oracle(self, compiler, connection, **extra_context):
  111. class ArgJoiner:
  112. def join(self, args):
  113. args = [' VALUE '.join(arg) for arg in zip(args[::2], args[1::2])]
  114. return ', '.join(args)
  115. return self.as_sql(
  116. compiler,
  117. connection,
  118. arg_joiner=ArgJoiner(),
  119. template='%(function)s(%(expressions)s RETURNING CLOB)',
  120. **extra_context,
  121. )
  122. class Least(Func):
  123. """
  124. Return the minimum expression.
  125. If any expression is null the return value is database-specific:
  126. On PostgreSQL, return the minimum not-null expression.
  127. On MySQL, Oracle, and SQLite, if any expression is null, return null.
  128. """
  129. function = 'LEAST'
  130. def __init__(self, *expressions, **extra):
  131. if len(expressions) < 2:
  132. raise ValueError('Least must take at least two expressions')
  133. super().__init__(*expressions, **extra)
  134. def as_sqlite(self, compiler, connection, **extra_context):
  135. """Use the MIN function on SQLite."""
  136. return super().as_sqlite(compiler, connection, function='MIN', **extra_context)
  137. class NullIf(Func):
  138. function = 'NULLIF'
  139. arity = 2
  140. def as_oracle(self, compiler, connection, **extra_context):
  141. expression1 = self.get_source_expressions()[0]
  142. if isinstance(expression1, Value) and expression1.value is None:
  143. raise ValueError('Oracle does not allow Value(None) for expression1.')
  144. return super().as_sql(compiler, connection, **extra_context)