array.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. import json
  2. from django.contrib.postgres import lookups
  3. from django.contrib.postgres.forms import SimpleArrayField
  4. from django.contrib.postgres.validators import ArrayMaxLengthValidator
  5. from django.core import checks, exceptions
  6. from django.db.models import Field, Func, IntegerField, Transform, Value
  7. from django.db.models.fields.mixins import CheckFieldDefaultMixin
  8. from django.db.models.lookups import Exact, In
  9. from django.utils.translation import gettext_lazy as _
  10. from ..utils import prefix_validation_error
  11. from .utils import AttributeSetter
  12. __all__ = ['ArrayField']
  13. class ArrayField(CheckFieldDefaultMixin, Field):
  14. empty_strings_allowed = False
  15. default_error_messages = {
  16. 'item_invalid': _('Item %(nth)s in the array did not validate:'),
  17. 'nested_array_mismatch': _('Nested arrays must have the same length.'),
  18. }
  19. _default_hint = ('list', '[]')
  20. def __init__(self, base_field, size=None, **kwargs):
  21. self.base_field = base_field
  22. self.size = size
  23. if self.size:
  24. self.default_validators = [*self.default_validators, ArrayMaxLengthValidator(self.size)]
  25. # For performance, only add a from_db_value() method if the base field
  26. # implements it.
  27. if hasattr(self.base_field, 'from_db_value'):
  28. self.from_db_value = self._from_db_value
  29. super().__init__(**kwargs)
  30. @property
  31. def model(self):
  32. try:
  33. return self.__dict__['model']
  34. except KeyError:
  35. raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__)
  36. @model.setter
  37. def model(self, model):
  38. self.__dict__['model'] = model
  39. self.base_field.model = model
  40. @classmethod
  41. def _choices_is_value(cls, value):
  42. return isinstance(value, (list, tuple)) or super()._choices_is_value(value)
  43. def check(self, **kwargs):
  44. errors = super().check(**kwargs)
  45. if self.base_field.remote_field:
  46. errors.append(
  47. checks.Error(
  48. 'Base field for array cannot be a related field.',
  49. obj=self,
  50. id='postgres.E002'
  51. )
  52. )
  53. else:
  54. # Remove the field name checks as they are not needed here.
  55. base_errors = self.base_field.check()
  56. if base_errors:
  57. messages = '\n '.join('%s (%s)' % (error.msg, error.id) for error in base_errors)
  58. errors.append(
  59. checks.Error(
  60. 'Base field for array has errors:\n %s' % messages,
  61. obj=self,
  62. id='postgres.E001'
  63. )
  64. )
  65. return errors
  66. def set_attributes_from_name(self, name):
  67. super().set_attributes_from_name(name)
  68. self.base_field.set_attributes_from_name(name)
  69. @property
  70. def description(self):
  71. return 'Array of %s' % self.base_field.description
  72. def db_type(self, connection):
  73. size = self.size or ''
  74. return '%s[%s]' % (self.base_field.db_type(connection), size)
  75. def cast_db_type(self, connection):
  76. size = self.size or ''
  77. return '%s[%s]' % (self.base_field.cast_db_type(connection), size)
  78. def get_placeholder(self, value, compiler, connection):
  79. return '%s::{}'.format(self.db_type(connection))
  80. def get_db_prep_value(self, value, connection, prepared=False):
  81. if isinstance(value, (list, tuple)):
  82. return [self.base_field.get_db_prep_value(i, connection, prepared=False) for i in value]
  83. return value
  84. def deconstruct(self):
  85. name, path, args, kwargs = super().deconstruct()
  86. if path == 'django.contrib.postgres.fields.array.ArrayField':
  87. path = 'django.contrib.postgres.fields.ArrayField'
  88. kwargs.update({
  89. 'base_field': self.base_field.clone(),
  90. 'size': self.size,
  91. })
  92. return name, path, args, kwargs
  93. def to_python(self, value):
  94. if isinstance(value, str):
  95. # Assume we're deserializing
  96. vals = json.loads(value)
  97. value = [self.base_field.to_python(val) for val in vals]
  98. return value
  99. def _from_db_value(self, value, expression, connection):
  100. if value is None:
  101. return value
  102. return [
  103. self.base_field.from_db_value(item, expression, connection)
  104. for item in value
  105. ]
  106. def value_to_string(self, obj):
  107. values = []
  108. vals = self.value_from_object(obj)
  109. base_field = self.base_field
  110. for val in vals:
  111. if val is None:
  112. values.append(None)
  113. else:
  114. obj = AttributeSetter(base_field.attname, val)
  115. values.append(base_field.value_to_string(obj))
  116. return json.dumps(values)
  117. def get_transform(self, name):
  118. transform = super().get_transform(name)
  119. if transform:
  120. return transform
  121. if '_' not in name:
  122. try:
  123. index = int(name)
  124. except ValueError:
  125. pass
  126. else:
  127. index += 1 # postgres uses 1-indexing
  128. return IndexTransformFactory(index, self.base_field)
  129. try:
  130. start, end = name.split('_')
  131. start = int(start) + 1
  132. end = int(end) # don't add one here because postgres slices are weird
  133. except ValueError:
  134. pass
  135. else:
  136. return SliceTransformFactory(start, end)
  137. def validate(self, value, model_instance):
  138. super().validate(value, model_instance)
  139. for index, part in enumerate(value):
  140. try:
  141. self.base_field.validate(part, model_instance)
  142. except exceptions.ValidationError as error:
  143. raise prefix_validation_error(
  144. error,
  145. prefix=self.error_messages['item_invalid'],
  146. code='item_invalid',
  147. params={'nth': index + 1},
  148. )
  149. if isinstance(self.base_field, ArrayField):
  150. if len({len(i) for i in value}) > 1:
  151. raise exceptions.ValidationError(
  152. self.error_messages['nested_array_mismatch'],
  153. code='nested_array_mismatch',
  154. )
  155. def run_validators(self, value):
  156. super().run_validators(value)
  157. for index, part in enumerate(value):
  158. try:
  159. self.base_field.run_validators(part)
  160. except exceptions.ValidationError as error:
  161. raise prefix_validation_error(
  162. error,
  163. prefix=self.error_messages['item_invalid'],
  164. code='item_invalid',
  165. params={'nth': index + 1},
  166. )
  167. def formfield(self, **kwargs):
  168. return super().formfield(**{
  169. 'form_class': SimpleArrayField,
  170. 'base_field': self.base_field.formfield(),
  171. 'max_length': self.size,
  172. **kwargs,
  173. })
  174. class ArrayRHSMixin:
  175. def __init__(self, lhs, rhs):
  176. if isinstance(rhs, (tuple, list)):
  177. expressions = []
  178. for value in rhs:
  179. if not hasattr(value, 'resolve_expression'):
  180. field = lhs.output_field
  181. value = Value(field.base_field.get_prep_value(value))
  182. expressions.append(value)
  183. rhs = Func(
  184. *expressions,
  185. function='ARRAY',
  186. template='%(function)s[%(expressions)s]',
  187. )
  188. super().__init__(lhs, rhs)
  189. def process_rhs(self, compiler, connection):
  190. rhs, rhs_params = super().process_rhs(compiler, connection)
  191. cast_type = self.lhs.output_field.cast_db_type(connection)
  192. return '%s::%s' % (rhs, cast_type), rhs_params
  193. @ArrayField.register_lookup
  194. class ArrayContains(ArrayRHSMixin, lookups.DataContains):
  195. pass
  196. @ArrayField.register_lookup
  197. class ArrayContainedBy(ArrayRHSMixin, lookups.ContainedBy):
  198. pass
  199. @ArrayField.register_lookup
  200. class ArrayExact(ArrayRHSMixin, Exact):
  201. pass
  202. @ArrayField.register_lookup
  203. class ArrayOverlap(ArrayRHSMixin, lookups.Overlap):
  204. pass
  205. @ArrayField.register_lookup
  206. class ArrayLenTransform(Transform):
  207. lookup_name = 'len'
  208. output_field = IntegerField()
  209. def as_sql(self, compiler, connection):
  210. lhs, params = compiler.compile(self.lhs)
  211. # Distinguish NULL and empty arrays
  212. return (
  213. 'CASE WHEN %(lhs)s IS NULL THEN NULL ELSE '
  214. 'coalesce(array_length(%(lhs)s, 1), 0) END'
  215. ) % {'lhs': lhs}, params
  216. @ArrayField.register_lookup
  217. class ArrayInLookup(In):
  218. def get_prep_lookup(self):
  219. values = super().get_prep_lookup()
  220. if hasattr(values, 'resolve_expression'):
  221. return values
  222. # In.process_rhs() expects values to be hashable, so convert lists
  223. # to tuples.
  224. prepared_values = []
  225. for value in values:
  226. if hasattr(value, 'resolve_expression'):
  227. prepared_values.append(value)
  228. else:
  229. prepared_values.append(tuple(value))
  230. return prepared_values
  231. class IndexTransform(Transform):
  232. def __init__(self, index, base_field, *args, **kwargs):
  233. super().__init__(*args, **kwargs)
  234. self.index = index
  235. self.base_field = base_field
  236. def as_sql(self, compiler, connection):
  237. lhs, params = compiler.compile(self.lhs)
  238. return '%s[%%s]' % lhs, params + [self.index]
  239. @property
  240. def output_field(self):
  241. return self.base_field
  242. class IndexTransformFactory:
  243. def __init__(self, index, base_field):
  244. self.index = index
  245. self.base_field = base_field
  246. def __call__(self, *args, **kwargs):
  247. return IndexTransform(self.index, self.base_field, *args, **kwargs)
  248. class SliceTransform(Transform):
  249. def __init__(self, start, end, *args, **kwargs):
  250. super().__init__(*args, **kwargs)
  251. self.start = start
  252. self.end = end
  253. def as_sql(self, compiler, connection):
  254. lhs, params = compiler.compile(self.lhs)
  255. return '%s[%%s:%%s]' % lhs, params + [self.start, self.end]
  256. class SliceTransformFactory:
  257. def __init__(self, start, end):
  258. self.start = start
  259. self.end = end
  260. def __call__(self, *args, **kwargs):
  261. return SliceTransform(self.start, self.end, *args, **kwargs)