json.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. import json
  2. from django import forms
  3. from django.core import checks, exceptions
  4. from django.db import NotSupportedError, connections, router
  5. from django.db.models import lookups
  6. from django.db.models.lookups import PostgresOperatorLookup, Transform
  7. from django.utils.translation import gettext_lazy as _
  8. from . import Field
  9. from .mixins import CheckFieldDefaultMixin
  10. __all__ = ['JSONField']
  11. class JSONField(CheckFieldDefaultMixin, Field):
  12. empty_strings_allowed = False
  13. description = _('A JSON object')
  14. default_error_messages = {
  15. 'invalid': _('Value must be valid JSON.'),
  16. }
  17. _default_hint = ('dict', '{}')
  18. def __init__(
  19. self, verbose_name=None, name=None, encoder=None, decoder=None,
  20. **kwargs,
  21. ):
  22. if encoder and not callable(encoder):
  23. raise ValueError('The encoder parameter must be a callable object.')
  24. if decoder and not callable(decoder):
  25. raise ValueError('The decoder parameter must be a callable object.')
  26. self.encoder = encoder
  27. self.decoder = decoder
  28. super().__init__(verbose_name, name, **kwargs)
  29. def check(self, **kwargs):
  30. errors = super().check(**kwargs)
  31. databases = kwargs.get('databases') or []
  32. errors.extend(self._check_supported(databases))
  33. return errors
  34. def _check_supported(self, databases):
  35. errors = []
  36. for db in databases:
  37. if not router.allow_migrate_model(db, self.model):
  38. continue
  39. connection = connections[db]
  40. if (
  41. self.model._meta.required_db_vendor and
  42. self.model._meta.required_db_vendor != connection.vendor
  43. ):
  44. continue
  45. if not (
  46. 'supports_json_field' in self.model._meta.required_db_features or
  47. connection.features.supports_json_field
  48. ):
  49. errors.append(
  50. checks.Error(
  51. '%s does not support JSONFields.'
  52. % connection.display_name,
  53. obj=self.model,
  54. id='fields.E180',
  55. )
  56. )
  57. return errors
  58. def deconstruct(self):
  59. name, path, args, kwargs = super().deconstruct()
  60. if self.encoder is not None:
  61. kwargs['encoder'] = self.encoder
  62. if self.decoder is not None:
  63. kwargs['decoder'] = self.decoder
  64. return name, path, args, kwargs
  65. def from_db_value(self, value, expression, connection):
  66. if value is None:
  67. return value
  68. # Some backends (SQLite at least) extract non-string values in their
  69. # SQL datatypes.
  70. if isinstance(expression, KeyTransform) and not isinstance(value, str):
  71. return value
  72. try:
  73. return json.loads(value, cls=self.decoder)
  74. except json.JSONDecodeError:
  75. return value
  76. def get_internal_type(self):
  77. return 'JSONField'
  78. def get_prep_value(self, value):
  79. if value is None:
  80. return value
  81. return json.dumps(value, cls=self.encoder)
  82. def get_transform(self, name):
  83. transform = super().get_transform(name)
  84. if transform:
  85. return transform
  86. return KeyTransformFactory(name)
  87. def validate(self, value, model_instance):
  88. super().validate(value, model_instance)
  89. try:
  90. json.dumps(value, cls=self.encoder)
  91. except TypeError:
  92. raise exceptions.ValidationError(
  93. self.error_messages['invalid'],
  94. code='invalid',
  95. params={'value': value},
  96. )
  97. def value_to_string(self, obj):
  98. return self.value_from_object(obj)
  99. def formfield(self, **kwargs):
  100. return super().formfield(**{
  101. 'form_class': forms.JSONField,
  102. 'encoder': self.encoder,
  103. 'decoder': self.decoder,
  104. **kwargs,
  105. })
  106. def compile_json_path(key_transforms, include_root=True):
  107. path = ['$'] if include_root else []
  108. for key_transform in key_transforms:
  109. try:
  110. num = int(key_transform)
  111. except ValueError: # non-integer
  112. path.append('.')
  113. path.append(json.dumps(key_transform))
  114. else:
  115. path.append('[%s]' % num)
  116. return ''.join(path)
  117. class DataContains(PostgresOperatorLookup):
  118. lookup_name = 'contains'
  119. postgres_operator = '@>'
  120. def as_sql(self, compiler, connection):
  121. if not connection.features.supports_json_field_contains:
  122. raise NotSupportedError(
  123. 'contains lookup is not supported on this database backend.'
  124. )
  125. lhs, lhs_params = self.process_lhs(compiler, connection)
  126. rhs, rhs_params = self.process_rhs(compiler, connection)
  127. params = tuple(lhs_params) + tuple(rhs_params)
  128. return 'JSON_CONTAINS(%s, %s)' % (lhs, rhs), params
  129. class ContainedBy(PostgresOperatorLookup):
  130. lookup_name = 'contained_by'
  131. postgres_operator = '<@'
  132. def as_sql(self, compiler, connection):
  133. if not connection.features.supports_json_field_contains:
  134. raise NotSupportedError(
  135. 'contained_by lookup is not supported on this database backend.'
  136. )
  137. lhs, lhs_params = self.process_lhs(compiler, connection)
  138. rhs, rhs_params = self.process_rhs(compiler, connection)
  139. params = tuple(rhs_params) + tuple(lhs_params)
  140. return 'JSON_CONTAINS(%s, %s)' % (rhs, lhs), params
  141. class HasKeyLookup(PostgresOperatorLookup):
  142. logical_operator = None
  143. def as_sql(self, compiler, connection, template=None):
  144. # Process JSON path from the left-hand side.
  145. if isinstance(self.lhs, KeyTransform):
  146. lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(compiler, connection)
  147. lhs_json_path = compile_json_path(lhs_key_transforms)
  148. else:
  149. lhs, lhs_params = self.process_lhs(compiler, connection)
  150. lhs_json_path = '$'
  151. sql = template % lhs
  152. # Process JSON path from the right-hand side.
  153. rhs = self.rhs
  154. rhs_params = []
  155. if not isinstance(rhs, (list, tuple)):
  156. rhs = [rhs]
  157. for key in rhs:
  158. if isinstance(key, KeyTransform):
  159. *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
  160. else:
  161. rhs_key_transforms = [key]
  162. rhs_params.append('%s%s' % (
  163. lhs_json_path,
  164. compile_json_path(rhs_key_transforms, include_root=False),
  165. ))
  166. # Add condition for each key.
  167. if self.logical_operator:
  168. sql = '(%s)' % self.logical_operator.join([sql] * len(rhs_params))
  169. return sql, tuple(lhs_params) + tuple(rhs_params)
  170. def as_mysql(self, compiler, connection):
  171. return self.as_sql(compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)")
  172. def as_oracle(self, compiler, connection):
  173. sql, params = self.as_sql(compiler, connection, template="JSON_EXISTS(%s, '%%s')")
  174. # Add paths directly into SQL because path expressions cannot be passed
  175. # as bind variables on Oracle.
  176. return sql % tuple(params), []
  177. def as_postgresql(self, compiler, connection):
  178. if isinstance(self.rhs, KeyTransform):
  179. *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
  180. for key in rhs_key_transforms[:-1]:
  181. self.lhs = KeyTransform(key, self.lhs)
  182. self.rhs = rhs_key_transforms[-1]
  183. return super().as_postgresql(compiler, connection)
  184. def as_sqlite(self, compiler, connection):
  185. return self.as_sql(compiler, connection, template='JSON_TYPE(%s, %%s) IS NOT NULL')
  186. class HasKey(HasKeyLookup):
  187. lookup_name = 'has_key'
  188. postgres_operator = '?'
  189. prepare_rhs = False
  190. class HasKeys(HasKeyLookup):
  191. lookup_name = 'has_keys'
  192. postgres_operator = '?&'
  193. logical_operator = ' AND '
  194. def get_prep_lookup(self):
  195. return [str(item) for item in self.rhs]
  196. class HasAnyKeys(HasKeys):
  197. lookup_name = 'has_any_keys'
  198. postgres_operator = '?|'
  199. logical_operator = ' OR '
  200. class JSONExact(lookups.Exact):
  201. can_use_none_as_rhs = True
  202. def process_lhs(self, compiler, connection):
  203. lhs, lhs_params = super().process_lhs(compiler, connection)
  204. if connection.vendor == 'sqlite':
  205. rhs, rhs_params = super().process_rhs(compiler, connection)
  206. if rhs == '%s' and rhs_params == [None]:
  207. # Use JSON_TYPE instead of JSON_EXTRACT for NULLs.
  208. lhs = "JSON_TYPE(%s, '$')" % lhs
  209. return lhs, lhs_params
  210. def process_rhs(self, compiler, connection):
  211. rhs, rhs_params = super().process_rhs(compiler, connection)
  212. # Treat None lookup values as null.
  213. if rhs == '%s' and rhs_params == [None]:
  214. rhs_params = ['null']
  215. if connection.vendor == 'mysql':
  216. func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
  217. rhs = rhs % tuple(func)
  218. return rhs, rhs_params
  219. JSONField.register_lookup(DataContains)
  220. JSONField.register_lookup(ContainedBy)
  221. JSONField.register_lookup(HasKey)
  222. JSONField.register_lookup(HasKeys)
  223. JSONField.register_lookup(HasAnyKeys)
  224. JSONField.register_lookup(JSONExact)
  225. class KeyTransform(Transform):
  226. postgres_operator = '->'
  227. postgres_nested_operator = '#>'
  228. def __init__(self, key_name, *args, **kwargs):
  229. super().__init__(*args, **kwargs)
  230. self.key_name = str(key_name)
  231. def preprocess_lhs(self, compiler, connection):
  232. key_transforms = [self.key_name]
  233. previous = self.lhs
  234. while isinstance(previous, KeyTransform):
  235. key_transforms.insert(0, previous.key_name)
  236. previous = previous.lhs
  237. lhs, params = compiler.compile(previous)
  238. if connection.vendor == 'oracle':
  239. # Escape string-formatting.
  240. key_transforms = [key.replace('%', '%%') for key in key_transforms]
  241. return lhs, params, key_transforms
  242. def as_mysql(self, compiler, connection):
  243. lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
  244. json_path = compile_json_path(key_transforms)
  245. return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,)
  246. def as_oracle(self, compiler, connection):
  247. lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
  248. json_path = compile_json_path(key_transforms)
  249. return (
  250. "COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))" %
  251. ((lhs, json_path) * 2)
  252. ), tuple(params) * 2
  253. def as_postgresql(self, compiler, connection):
  254. lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
  255. if len(key_transforms) > 1:
  256. sql = '(%s %s %%s)' % (lhs, self.postgres_nested_operator)
  257. return sql, tuple(params) + (key_transforms,)
  258. try:
  259. lookup = int(self.key_name)
  260. except ValueError:
  261. lookup = self.key_name
  262. return '(%s %s %%s)' % (lhs, self.postgres_operator), tuple(params) + (lookup,)
  263. def as_sqlite(self, compiler, connection):
  264. lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
  265. json_path = compile_json_path(key_transforms)
  266. return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,)
  267. class KeyTextTransform(KeyTransform):
  268. postgres_operator = '->>'
  269. postgres_nested_operator = '#>>'
  270. class KeyTransformTextLookupMixin:
  271. """
  272. Mixin for combining with a lookup expecting a text lhs from a JSONField
  273. key lookup. On PostgreSQL, make use of the ->> operator instead of casting
  274. key values to text and performing the lookup on the resulting
  275. representation.
  276. """
  277. def __init__(self, key_transform, *args, **kwargs):
  278. if not isinstance(key_transform, KeyTransform):
  279. raise TypeError(
  280. 'Transform should be an instance of KeyTransform in order to '
  281. 'use this lookup.'
  282. )
  283. key_text_transform = KeyTextTransform(
  284. key_transform.key_name, *key_transform.source_expressions,
  285. **key_transform.extra,
  286. )
  287. super().__init__(key_text_transform, *args, **kwargs)
  288. class CaseInsensitiveMixin:
  289. """
  290. Mixin to allow case-insensitive comparison of JSON values on MySQL.
  291. MySQL handles strings used in JSON context using the utf8mb4_bin collation.
  292. Because utf8mb4_bin is a binary collation, comparison of JSON values is
  293. case-sensitive.
  294. """
  295. def process_lhs(self, compiler, connection):
  296. lhs, lhs_params = super().process_lhs(compiler, connection)
  297. if connection.vendor == 'mysql':
  298. return 'LOWER(%s)' % lhs, lhs_params
  299. return lhs, lhs_params
  300. def process_rhs(self, compiler, connection):
  301. rhs, rhs_params = super().process_rhs(compiler, connection)
  302. if connection.vendor == 'mysql':
  303. return 'LOWER(%s)' % rhs, rhs_params
  304. return rhs, rhs_params
  305. class KeyTransformIsNull(lookups.IsNull):
  306. # key__isnull=False is the same as has_key='key'
  307. def as_oracle(self, compiler, connection):
  308. sql, params = HasKey(
  309. self.lhs.lhs,
  310. self.lhs.key_name,
  311. ).as_oracle(compiler, connection)
  312. if not self.rhs:
  313. return sql, params
  314. # Column doesn't have a key or IS NULL.
  315. lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection)
  316. return '(NOT %s OR %s IS NULL)' % (sql, lhs), tuple(params) + tuple(lhs_params)
  317. def as_sqlite(self, compiler, connection):
  318. template = 'JSON_TYPE(%s, %%s) IS NULL'
  319. if not self.rhs:
  320. template = 'JSON_TYPE(%s, %%s) IS NOT NULL'
  321. return HasKey(self.lhs.lhs, self.lhs.key_name).as_sql(
  322. compiler,
  323. connection,
  324. template=template,
  325. )
  326. class KeyTransformIn(lookups.In):
  327. def resolve_expression_parameter(self, compiler, connection, sql, param):
  328. sql, params = super().resolve_expression_parameter(
  329. compiler, connection, sql, param,
  330. )
  331. if (
  332. not hasattr(param, 'as_sql') and
  333. not connection.features.has_native_json_field
  334. ):
  335. if connection.vendor == 'oracle':
  336. value = json.loads(param)
  337. sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
  338. if isinstance(value, (list, dict)):
  339. sql = sql % 'JSON_QUERY'
  340. else:
  341. sql = sql % 'JSON_VALUE'
  342. elif connection.vendor in {'sqlite', 'mysql'}:
  343. sql = "JSON_EXTRACT(%s, '$')"
  344. if connection.vendor == 'mysql' and connection.mysql_is_mariadb:
  345. sql = 'JSON_UNQUOTE(%s)' % sql
  346. return sql, params
  347. class KeyTransformExact(JSONExact):
  348. def process_lhs(self, compiler, connection):
  349. lhs, lhs_params = super().process_lhs(compiler, connection)
  350. if connection.vendor == 'sqlite':
  351. rhs, rhs_params = super().process_rhs(compiler, connection)
  352. if rhs == '%s' and rhs_params == ['null']:
  353. lhs, *_ = self.lhs.preprocess_lhs(compiler, connection)
  354. lhs = 'JSON_TYPE(%s, %%s)' % lhs
  355. return lhs, lhs_params
  356. def process_rhs(self, compiler, connection):
  357. if isinstance(self.rhs, KeyTransform):
  358. return super(lookups.Exact, self).process_rhs(compiler, connection)
  359. rhs, rhs_params = super().process_rhs(compiler, connection)
  360. if connection.vendor == 'oracle':
  361. func = []
  362. sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
  363. for value in rhs_params:
  364. value = json.loads(value)
  365. if isinstance(value, (list, dict)):
  366. func.append(sql % 'JSON_QUERY')
  367. else:
  368. func.append(sql % 'JSON_VALUE')
  369. rhs = rhs % tuple(func)
  370. elif connection.vendor == 'sqlite':
  371. func = ["JSON_EXTRACT(%s, '$')" if value != 'null' else '%s' for value in rhs_params]
  372. rhs = rhs % tuple(func)
  373. return rhs, rhs_params
  374. def as_oracle(self, compiler, connection):
  375. rhs, rhs_params = super().process_rhs(compiler, connection)
  376. if rhs_params == ['null']:
  377. # Field has key and it's NULL.
  378. has_key_expr = HasKey(self.lhs.lhs, self.lhs.key_name)
  379. has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
  380. is_null_expr = self.lhs.get_lookup('isnull')(self.lhs, True)
  381. is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
  382. return (
  383. '%s AND %s' % (has_key_sql, is_null_sql),
  384. tuple(has_key_params) + tuple(is_null_params),
  385. )
  386. return super().as_sql(compiler, connection)
  387. class KeyTransformIExact(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact):
  388. pass
  389. class KeyTransformIContains(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains):
  390. pass
  391. class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
  392. pass
  393. class KeyTransformIStartsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith):
  394. pass
  395. class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
  396. pass
  397. class KeyTransformIEndsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith):
  398. pass
  399. class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
  400. pass
  401. class KeyTransformIRegex(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex):
  402. pass
  403. class KeyTransformNumericLookupMixin:
  404. def process_rhs(self, compiler, connection):
  405. rhs, rhs_params = super().process_rhs(compiler, connection)
  406. if not connection.features.has_native_json_field:
  407. rhs_params = [json.loads(value) for value in rhs_params]
  408. return rhs, rhs_params
  409. class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
  410. pass
  411. class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
  412. pass
  413. class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
  414. pass
  415. class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
  416. pass
  417. KeyTransform.register_lookup(KeyTransformIn)
  418. KeyTransform.register_lookup(KeyTransformExact)
  419. KeyTransform.register_lookup(KeyTransformIExact)
  420. KeyTransform.register_lookup(KeyTransformIsNull)
  421. KeyTransform.register_lookup(KeyTransformIContains)
  422. KeyTransform.register_lookup(KeyTransformStartsWith)
  423. KeyTransform.register_lookup(KeyTransformIStartsWith)
  424. KeyTransform.register_lookup(KeyTransformEndsWith)
  425. KeyTransform.register_lookup(KeyTransformIEndsWith)
  426. KeyTransform.register_lookup(KeyTransformRegex)
  427. KeyTransform.register_lookup(KeyTransformIRegex)
  428. KeyTransform.register_lookup(KeyTransformLt)
  429. KeyTransform.register_lookup(KeyTransformLte)
  430. KeyTransform.register_lookup(KeyTransformGt)
  431. KeyTransform.register_lookup(KeyTransformGte)
  432. class KeyTransformFactory:
  433. def __init__(self, key_name):
  434. self.key_name = key_name
  435. def __call__(self, *args, **kwargs):
  436. return KeyTransform(self.key_name, *args, **kwargs)