ranges.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. import datetime
  2. import json
  3. from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange, Range
  4. from django.contrib.postgres import forms, lookups
  5. from django.db import models
  6. from django.db.models.lookups import PostgresOperatorLookup
  7. from .utils import AttributeSetter
  8. __all__ = [
  9. 'RangeField', 'IntegerRangeField', 'BigIntegerRangeField',
  10. 'DecimalRangeField', 'DateTimeRangeField', 'DateRangeField',
  11. 'RangeBoundary', 'RangeOperators',
  12. ]
  13. class RangeBoundary(models.Expression):
  14. """A class that represents range boundaries."""
  15. def __init__(self, inclusive_lower=True, inclusive_upper=False):
  16. self.lower = '[' if inclusive_lower else '('
  17. self.upper = ']' if inclusive_upper else ')'
  18. def as_sql(self, compiler, connection):
  19. return "'%s%s'" % (self.lower, self.upper), []
  20. class RangeOperators:
  21. # https://www.postgresql.org/docs/current/functions-range.html#RANGE-OPERATORS-TABLE
  22. EQUAL = '='
  23. NOT_EQUAL = '<>'
  24. CONTAINS = '@>'
  25. CONTAINED_BY = '<@'
  26. OVERLAPS = '&&'
  27. FULLY_LT = '<<'
  28. FULLY_GT = '>>'
  29. NOT_LT = '&>'
  30. NOT_GT = '&<'
  31. ADJACENT_TO = '-|-'
  32. class RangeField(models.Field):
  33. empty_strings_allowed = False
  34. def __init__(self, *args, **kwargs):
  35. # Initializing base_field here ensures that its model matches the model for self.
  36. if hasattr(self, 'base_field'):
  37. self.base_field = self.base_field()
  38. super().__init__(*args, **kwargs)
  39. @property
  40. def model(self):
  41. try:
  42. return self.__dict__['model']
  43. except KeyError:
  44. raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__)
  45. @model.setter
  46. def model(self, model):
  47. self.__dict__['model'] = model
  48. self.base_field.model = model
  49. @classmethod
  50. def _choices_is_value(cls, value):
  51. return isinstance(value, (list, tuple)) or super()._choices_is_value(value)
  52. def get_prep_value(self, value):
  53. if value is None:
  54. return None
  55. elif isinstance(value, Range):
  56. return value
  57. elif isinstance(value, (list, tuple)):
  58. return self.range_type(value[0], value[1])
  59. return value
  60. def to_python(self, value):
  61. if isinstance(value, str):
  62. # Assume we're deserializing
  63. vals = json.loads(value)
  64. for end in ('lower', 'upper'):
  65. if end in vals:
  66. vals[end] = self.base_field.to_python(vals[end])
  67. value = self.range_type(**vals)
  68. elif isinstance(value, (list, tuple)):
  69. value = self.range_type(value[0], value[1])
  70. return value
  71. def set_attributes_from_name(self, name):
  72. super().set_attributes_from_name(name)
  73. self.base_field.set_attributes_from_name(name)
  74. def value_to_string(self, obj):
  75. value = self.value_from_object(obj)
  76. if value is None:
  77. return None
  78. if value.isempty:
  79. return json.dumps({"empty": True})
  80. base_field = self.base_field
  81. result = {"bounds": value._bounds}
  82. for end in ('lower', 'upper'):
  83. val = getattr(value, end)
  84. if val is None:
  85. result[end] = None
  86. else:
  87. obj = AttributeSetter(base_field.attname, val)
  88. result[end] = base_field.value_to_string(obj)
  89. return json.dumps(result)
  90. def formfield(self, **kwargs):
  91. kwargs.setdefault('form_class', self.form_field)
  92. return super().formfield(**kwargs)
  93. class IntegerRangeField(RangeField):
  94. base_field = models.IntegerField
  95. range_type = NumericRange
  96. form_field = forms.IntegerRangeField
  97. def db_type(self, connection):
  98. return 'int4range'
  99. class BigIntegerRangeField(RangeField):
  100. base_field = models.BigIntegerField
  101. range_type = NumericRange
  102. form_field = forms.IntegerRangeField
  103. def db_type(self, connection):
  104. return 'int8range'
  105. class DecimalRangeField(RangeField):
  106. base_field = models.DecimalField
  107. range_type = NumericRange
  108. form_field = forms.DecimalRangeField
  109. def db_type(self, connection):
  110. return 'numrange'
  111. class DateTimeRangeField(RangeField):
  112. base_field = models.DateTimeField
  113. range_type = DateTimeTZRange
  114. form_field = forms.DateTimeRangeField
  115. def db_type(self, connection):
  116. return 'tstzrange'
  117. class DateRangeField(RangeField):
  118. base_field = models.DateField
  119. range_type = DateRange
  120. form_field = forms.DateRangeField
  121. def db_type(self, connection):
  122. return 'daterange'
  123. RangeField.register_lookup(lookups.DataContains)
  124. RangeField.register_lookup(lookups.ContainedBy)
  125. RangeField.register_lookup(lookups.Overlap)
  126. class DateTimeRangeContains(PostgresOperatorLookup):
  127. """
  128. Lookup for Date/DateTimeRange containment to cast the rhs to the correct
  129. type.
  130. """
  131. lookup_name = 'contains'
  132. postgres_operator = RangeOperators.CONTAINS
  133. def process_rhs(self, compiler, connection):
  134. # Transform rhs value for db lookup.
  135. if isinstance(self.rhs, datetime.date):
  136. value = models.Value(self.rhs)
  137. self.rhs = value.resolve_expression(compiler.query)
  138. return super().process_rhs(compiler, connection)
  139. def as_postgresql(self, compiler, connection):
  140. sql, params = super().as_postgresql(compiler, connection)
  141. # Cast the rhs if needed.
  142. cast_sql = ''
  143. if (
  144. isinstance(self.rhs, models.Expression) and
  145. self.rhs._output_field_or_none and
  146. # Skip cast if rhs has a matching range type.
  147. not isinstance(self.rhs._output_field_or_none, self.lhs.output_field.__class__)
  148. ):
  149. cast_internal_type = self.lhs.output_field.base_field.get_internal_type()
  150. cast_sql = '::{}'.format(connection.data_types.get(cast_internal_type))
  151. return '%s%s' % (sql, cast_sql), params
  152. DateRangeField.register_lookup(DateTimeRangeContains)
  153. DateTimeRangeField.register_lookup(DateTimeRangeContains)
  154. class RangeContainedBy(PostgresOperatorLookup):
  155. lookup_name = 'contained_by'
  156. type_mapping = {
  157. 'smallint': 'int4range',
  158. 'integer': 'int4range',
  159. 'bigint': 'int8range',
  160. 'double precision': 'numrange',
  161. 'numeric': 'numrange',
  162. 'date': 'daterange',
  163. 'timestamp with time zone': 'tstzrange',
  164. }
  165. postgres_operator = RangeOperators.CONTAINED_BY
  166. def process_rhs(self, compiler, connection):
  167. rhs, rhs_params = super().process_rhs(compiler, connection)
  168. # Ignore precision for DecimalFields.
  169. db_type = self.lhs.output_field.cast_db_type(connection).split('(')[0]
  170. cast_type = self.type_mapping[db_type]
  171. return '%s::%s' % (rhs, cast_type), rhs_params
  172. def process_lhs(self, compiler, connection):
  173. lhs, lhs_params = super().process_lhs(compiler, connection)
  174. if isinstance(self.lhs.output_field, models.FloatField):
  175. lhs = '%s::numeric' % lhs
  176. elif isinstance(self.lhs.output_field, models.SmallIntegerField):
  177. lhs = '%s::integer' % lhs
  178. return lhs, lhs_params
  179. def get_prep_lookup(self):
  180. return RangeField().get_prep_value(self.rhs)
  181. models.DateField.register_lookup(RangeContainedBy)
  182. models.DateTimeField.register_lookup(RangeContainedBy)
  183. models.IntegerField.register_lookup(RangeContainedBy)
  184. models.FloatField.register_lookup(RangeContainedBy)
  185. models.DecimalField.register_lookup(RangeContainedBy)
  186. @RangeField.register_lookup
  187. class FullyLessThan(PostgresOperatorLookup):
  188. lookup_name = 'fully_lt'
  189. postgres_operator = RangeOperators.FULLY_LT
  190. @RangeField.register_lookup
  191. class FullGreaterThan(PostgresOperatorLookup):
  192. lookup_name = 'fully_gt'
  193. postgres_operator = RangeOperators.FULLY_GT
  194. @RangeField.register_lookup
  195. class NotLessThan(PostgresOperatorLookup):
  196. lookup_name = 'not_lt'
  197. postgres_operator = RangeOperators.NOT_LT
  198. @RangeField.register_lookup
  199. class NotGreaterThan(PostgresOperatorLookup):
  200. lookup_name = 'not_gt'
  201. postgres_operator = RangeOperators.NOT_GT
  202. @RangeField.register_lookup
  203. class AdjacentToLookup(PostgresOperatorLookup):
  204. lookup_name = 'adjacent_to'
  205. postgres_operator = RangeOperators.ADJACENT_TO
  206. @RangeField.register_lookup
  207. class RangeStartsWith(models.Transform):
  208. lookup_name = 'startswith'
  209. function = 'lower'
  210. @property
  211. def output_field(self):
  212. return self.lhs.output_field.base_field
  213. @RangeField.register_lookup
  214. class RangeEndsWith(models.Transform):
  215. lookup_name = 'endswith'
  216. function = 'upper'
  217. @property
  218. def output_field(self):
  219. return self.lhs.output_field.base_field
  220. @RangeField.register_lookup
  221. class IsEmpty(models.Transform):
  222. lookup_name = 'isempty'
  223. function = 'isempty'
  224. output_field = models.BooleanField()
  225. @RangeField.register_lookup
  226. class LowerInclusive(models.Transform):
  227. lookup_name = 'lower_inc'
  228. function = 'LOWER_INC'
  229. output_field = models.BooleanField()
  230. @RangeField.register_lookup
  231. class LowerInfinite(models.Transform):
  232. lookup_name = 'lower_inf'
  233. function = 'LOWER_INF'
  234. output_field = models.BooleanField()
  235. @RangeField.register_lookup
  236. class UpperInclusive(models.Transform):
  237. lookup_name = 'upper_inc'
  238. function = 'UPPER_INC'
  239. output_field = models.BooleanField()
  240. @RangeField.register_lookup
  241. class UpperInfinite(models.Transform):
  242. lookup_name = 'upper_inf'
  243. function = 'UPPER_INF'
  244. output_field = models.BooleanField()