datetime.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. from datetime import datetime
  2. from django.conf import settings
  3. from django.db.models.expressions import Func
  4. from django.db.models.fields import (
  5. DateField, DateTimeField, DurationField, Field, IntegerField, TimeField,
  6. )
  7. from django.db.models.lookups import (
  8. Transform, YearExact, YearGt, YearGte, YearLt, YearLte,
  9. )
  10. from django.utils import timezone
  11. class TimezoneMixin:
  12. tzinfo = None
  13. def get_tzname(self):
  14. # Timezone conversions must happen to the input datetime *before*
  15. # applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
  16. # database as 2016-01-01 01:00:00 +00:00. Any results should be
  17. # based on the input datetime not the stored datetime.
  18. tzname = None
  19. if settings.USE_TZ:
  20. if self.tzinfo is None:
  21. tzname = timezone.get_current_timezone_name()
  22. else:
  23. tzname = timezone._get_timezone_name(self.tzinfo)
  24. return tzname
  25. class Extract(TimezoneMixin, Transform):
  26. lookup_name = None
  27. output_field = IntegerField()
  28. def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
  29. if self.lookup_name is None:
  30. self.lookup_name = lookup_name
  31. if self.lookup_name is None:
  32. raise ValueError('lookup_name must be provided')
  33. self.tzinfo = tzinfo
  34. super().__init__(expression, **extra)
  35. def as_sql(self, compiler, connection):
  36. if not connection.ops.extract_trunc_lookup_pattern.fullmatch(self.lookup_name):
  37. raise ValueError("Invalid lookup_name: %s" % self.lookup_name)
  38. sql, params = compiler.compile(self.lhs)
  39. lhs_output_field = self.lhs.output_field
  40. if isinstance(lhs_output_field, DateTimeField):
  41. tzname = self.get_tzname()
  42. sql = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
  43. elif self.tzinfo is not None:
  44. raise ValueError('tzinfo can only be used with DateTimeField.')
  45. elif isinstance(lhs_output_field, DateField):
  46. sql = connection.ops.date_extract_sql(self.lookup_name, sql)
  47. elif isinstance(lhs_output_field, TimeField):
  48. sql = connection.ops.time_extract_sql(self.lookup_name, sql)
  49. elif isinstance(lhs_output_field, DurationField):
  50. if not connection.features.has_native_duration_field:
  51. raise ValueError('Extract requires native DurationField database support.')
  52. sql = connection.ops.time_extract_sql(self.lookup_name, sql)
  53. else:
  54. # resolve_expression has already validated the output_field so this
  55. # assert should never be hit.
  56. assert False, "Tried to Extract from an invalid type."
  57. return sql, params
  58. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
  59. copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
  60. field = copy.lhs.output_field
  61. if not isinstance(field, (DateField, DateTimeField, TimeField, DurationField)):
  62. raise ValueError(
  63. 'Extract input expression must be DateField, DateTimeField, '
  64. 'TimeField, or DurationField.'
  65. )
  66. # Passing dates to functions expecting datetimes is most likely a mistake.
  67. if type(field) == DateField and copy.lookup_name in ('hour', 'minute', 'second'):
  68. raise ValueError(
  69. "Cannot extract time component '%s' from DateField '%s'. " % (copy.lookup_name, field.name)
  70. )
  71. if (
  72. isinstance(field, DurationField) and
  73. copy.lookup_name in ('year', 'iso_year', 'month', 'week', 'week_day', 'iso_week_day', 'quarter')
  74. ):
  75. raise ValueError(
  76. "Cannot extract component '%s' from DurationField '%s'."
  77. % (copy.lookup_name, field.name)
  78. )
  79. return copy
  80. class ExtractYear(Extract):
  81. lookup_name = 'year'
  82. class ExtractIsoYear(Extract):
  83. """Return the ISO-8601 week-numbering year."""
  84. lookup_name = 'iso_year'
  85. class ExtractMonth(Extract):
  86. lookup_name = 'month'
  87. class ExtractDay(Extract):
  88. lookup_name = 'day'
  89. class ExtractWeek(Extract):
  90. """
  91. Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
  92. week.
  93. """
  94. lookup_name = 'week'
  95. class ExtractWeekDay(Extract):
  96. """
  97. Return Sunday=1 through Saturday=7.
  98. To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
  99. """
  100. lookup_name = 'week_day'
  101. class ExtractIsoWeekDay(Extract):
  102. """Return Monday=1 through Sunday=7, based on ISO-8601."""
  103. lookup_name = 'iso_week_day'
  104. class ExtractQuarter(Extract):
  105. lookup_name = 'quarter'
  106. class ExtractHour(Extract):
  107. lookup_name = 'hour'
  108. class ExtractMinute(Extract):
  109. lookup_name = 'minute'
  110. class ExtractSecond(Extract):
  111. lookup_name = 'second'
  112. DateField.register_lookup(ExtractYear)
  113. DateField.register_lookup(ExtractMonth)
  114. DateField.register_lookup(ExtractDay)
  115. DateField.register_lookup(ExtractWeekDay)
  116. DateField.register_lookup(ExtractIsoWeekDay)
  117. DateField.register_lookup(ExtractWeek)
  118. DateField.register_lookup(ExtractIsoYear)
  119. DateField.register_lookup(ExtractQuarter)
  120. TimeField.register_lookup(ExtractHour)
  121. TimeField.register_lookup(ExtractMinute)
  122. TimeField.register_lookup(ExtractSecond)
  123. DateTimeField.register_lookup(ExtractHour)
  124. DateTimeField.register_lookup(ExtractMinute)
  125. DateTimeField.register_lookup(ExtractSecond)
  126. ExtractYear.register_lookup(YearExact)
  127. ExtractYear.register_lookup(YearGt)
  128. ExtractYear.register_lookup(YearGte)
  129. ExtractYear.register_lookup(YearLt)
  130. ExtractYear.register_lookup(YearLte)
  131. ExtractIsoYear.register_lookup(YearExact)
  132. ExtractIsoYear.register_lookup(YearGt)
  133. ExtractIsoYear.register_lookup(YearGte)
  134. ExtractIsoYear.register_lookup(YearLt)
  135. ExtractIsoYear.register_lookup(YearLte)
  136. class Now(Func):
  137. template = 'CURRENT_TIMESTAMP'
  138. output_field = DateTimeField()
  139. def as_postgresql(self, compiler, connection, **extra_context):
  140. # PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
  141. # transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
  142. # other databases.
  143. return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()', **extra_context)
  144. class TruncBase(TimezoneMixin, Transform):
  145. kind = None
  146. tzinfo = None
  147. def __init__(self, expression, output_field=None, tzinfo=None, is_dst=None, **extra):
  148. self.tzinfo = tzinfo
  149. self.is_dst = is_dst
  150. super().__init__(expression, output_field=output_field, **extra)
  151. def as_sql(self, compiler, connection):
  152. if not connection.ops.extract_trunc_lookup_pattern.fullmatch(self.kind):
  153. raise ValueError("Invalid kind: %s" % self.kind)
  154. inner_sql, inner_params = compiler.compile(self.lhs)
  155. tzname = None
  156. if isinstance(self.lhs.output_field, DateTimeField):
  157. tzname = self.get_tzname()
  158. elif self.tzinfo is not None:
  159. raise ValueError('tzinfo can only be used with DateTimeField.')
  160. if isinstance(self.output_field, DateTimeField):
  161. sql = connection.ops.datetime_trunc_sql(self.kind, inner_sql, tzname)
  162. elif isinstance(self.output_field, DateField):
  163. sql = connection.ops.date_trunc_sql(self.kind, inner_sql, tzname)
  164. elif isinstance(self.output_field, TimeField):
  165. sql = connection.ops.time_trunc_sql(self.kind, inner_sql, tzname)
  166. else:
  167. raise ValueError('Trunc only valid on DateField, TimeField, or DateTimeField.')
  168. return sql, inner_params
  169. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
  170. copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
  171. field = copy.lhs.output_field
  172. # DateTimeField is a subclass of DateField so this works for both.
  173. assert isinstance(field, (DateField, TimeField)), (
  174. "%r isn't a DateField, TimeField, or DateTimeField." % field.name
  175. )
  176. # If self.output_field was None, then accessing the field will trigger
  177. # the resolver to assign it to self.lhs.output_field.
  178. if not isinstance(copy.output_field, (DateField, DateTimeField, TimeField)):
  179. raise ValueError('output_field must be either DateField, TimeField, or DateTimeField')
  180. # Passing dates or times to functions expecting datetimes is most
  181. # likely a mistake.
  182. class_output_field = self.__class__.output_field if isinstance(self.__class__.output_field, Field) else None
  183. output_field = class_output_field or copy.output_field
  184. has_explicit_output_field = class_output_field or field.__class__ is not copy.output_field.__class__
  185. if type(field) == DateField and (
  186. isinstance(output_field, DateTimeField) or copy.kind in ('hour', 'minute', 'second', 'time')):
  187. raise ValueError("Cannot truncate DateField '%s' to %s. " % (
  188. field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
  189. ))
  190. elif isinstance(field, TimeField) and (
  191. isinstance(output_field, DateTimeField) or
  192. copy.kind in ('year', 'quarter', 'month', 'week', 'day', 'date')):
  193. raise ValueError("Cannot truncate TimeField '%s' to %s. " % (
  194. field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
  195. ))
  196. return copy
  197. def convert_value(self, value, expression, connection):
  198. if isinstance(self.output_field, DateTimeField):
  199. if not settings.USE_TZ:
  200. pass
  201. elif value is not None:
  202. value = value.replace(tzinfo=None)
  203. value = timezone.make_aware(value, self.tzinfo, is_dst=self.is_dst)
  204. elif not connection.features.has_zoneinfo_database:
  205. raise ValueError(
  206. 'Database returned an invalid datetime value. Are time '
  207. 'zone definitions for your database installed?'
  208. )
  209. elif isinstance(value, datetime):
  210. if value is None:
  211. pass
  212. elif isinstance(self.output_field, DateField):
  213. value = value.date()
  214. elif isinstance(self.output_field, TimeField):
  215. value = value.time()
  216. return value
  217. class Trunc(TruncBase):
  218. def __init__(self, expression, kind, output_field=None, tzinfo=None, is_dst=None, **extra):
  219. self.kind = kind
  220. super().__init__(
  221. expression, output_field=output_field, tzinfo=tzinfo,
  222. is_dst=is_dst, **extra
  223. )
  224. class TruncYear(TruncBase):
  225. kind = 'year'
  226. class TruncQuarter(TruncBase):
  227. kind = 'quarter'
  228. class TruncMonth(TruncBase):
  229. kind = 'month'
  230. class TruncWeek(TruncBase):
  231. """Truncate to midnight on the Monday of the week."""
  232. kind = 'week'
  233. class TruncDay(TruncBase):
  234. kind = 'day'
  235. class TruncDate(TruncBase):
  236. kind = 'date'
  237. lookup_name = 'date'
  238. output_field = DateField()
  239. def as_sql(self, compiler, connection):
  240. # Cast to date rather than truncate to date.
  241. lhs, lhs_params = compiler.compile(self.lhs)
  242. tzname = self.get_tzname()
  243. sql = connection.ops.datetime_cast_date_sql(lhs, tzname)
  244. return sql, lhs_params
  245. class TruncTime(TruncBase):
  246. kind = 'time'
  247. lookup_name = 'time'
  248. output_field = TimeField()
  249. def as_sql(self, compiler, connection):
  250. # Cast to time rather than truncate to time.
  251. lhs, lhs_params = compiler.compile(self.lhs)
  252. tzname = self.get_tzname()
  253. sql = connection.ops.datetime_cast_time_sql(lhs, tzname)
  254. return sql, lhs_params
  255. class TruncHour(TruncBase):
  256. kind = 'hour'
  257. class TruncMinute(TruncBase):
  258. kind = 'minute'
  259. class TruncSecond(TruncBase):
  260. kind = 'second'
  261. DateTimeField.register_lookup(TruncDate)
  262. DateTimeField.register_lookup(TruncTime)