search.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. import psycopg2
  2. from django.db.models import (
  3. CharField, Expression, Field, FloatField, Func, Lookup, TextField, Value,
  4. )
  5. from django.db.models.expressions import CombinedExpression
  6. from django.db.models.functions import Cast, Coalesce
  7. class SearchVectorExact(Lookup):
  8. lookup_name = 'exact'
  9. def process_rhs(self, qn, connection):
  10. if not isinstance(self.rhs, (SearchQuery, CombinedSearchQuery)):
  11. config = getattr(self.lhs, 'config', None)
  12. self.rhs = SearchQuery(self.rhs, config=config)
  13. rhs, rhs_params = super().process_rhs(qn, connection)
  14. return rhs, rhs_params
  15. def as_sql(self, qn, connection):
  16. lhs, lhs_params = self.process_lhs(qn, connection)
  17. rhs, rhs_params = self.process_rhs(qn, connection)
  18. params = lhs_params + rhs_params
  19. return '%s @@ %s' % (lhs, rhs), params
  20. class SearchVectorField(Field):
  21. def db_type(self, connection):
  22. return 'tsvector'
  23. class SearchQueryField(Field):
  24. def db_type(self, connection):
  25. return 'tsquery'
  26. class SearchConfig(Expression):
  27. def __init__(self, config):
  28. super().__init__()
  29. if not hasattr(config, 'resolve_expression'):
  30. config = Value(config)
  31. self.config = config
  32. @classmethod
  33. def from_parameter(cls, config):
  34. if config is None or isinstance(config, cls):
  35. return config
  36. return cls(config)
  37. def get_source_expressions(self):
  38. return [self.config]
  39. def set_source_expressions(self, exprs):
  40. self.config, = exprs
  41. def as_sql(self, compiler, connection):
  42. sql, params = compiler.compile(self.config)
  43. return '%s::regconfig' % sql, params
  44. class SearchVectorCombinable:
  45. ADD = '||'
  46. def _combine(self, other, connector, reversed):
  47. if not isinstance(other, SearchVectorCombinable):
  48. raise TypeError(
  49. 'SearchVector can only be combined with other SearchVector '
  50. 'instances, got %s.' % type(other).__name__
  51. )
  52. if reversed:
  53. return CombinedSearchVector(other, connector, self, self.config)
  54. return CombinedSearchVector(self, connector, other, self.config)
  55. class SearchVector(SearchVectorCombinable, Func):
  56. function = 'to_tsvector'
  57. arg_joiner = " || ' ' || "
  58. output_field = SearchVectorField()
  59. def __init__(self, *expressions, config=None, weight=None):
  60. super().__init__(*expressions)
  61. self.config = SearchConfig.from_parameter(config)
  62. if weight is not None and not hasattr(weight, 'resolve_expression'):
  63. weight = Value(weight)
  64. self.weight = weight
  65. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
  66. resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
  67. if self.config:
  68. resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
  69. return resolved
  70. def as_sql(self, compiler, connection, function=None, template=None):
  71. clone = self.copy()
  72. clone.set_source_expressions([
  73. Coalesce(
  74. expression
  75. if isinstance(expression.output_field, (CharField, TextField))
  76. else Cast(expression, TextField()),
  77. Value('')
  78. ) for expression in clone.get_source_expressions()
  79. ])
  80. config_sql = None
  81. config_params = []
  82. if template is None:
  83. if clone.config:
  84. config_sql, config_params = compiler.compile(clone.config)
  85. template = '%(function)s(%(config)s, %(expressions)s)'
  86. else:
  87. template = clone.template
  88. sql, params = super(SearchVector, clone).as_sql(
  89. compiler, connection, function=function, template=template,
  90. config=config_sql,
  91. )
  92. extra_params = []
  93. if clone.weight:
  94. weight_sql, extra_params = compiler.compile(clone.weight)
  95. sql = 'setweight({}, {})'.format(sql, weight_sql)
  96. return sql, config_params + params + extra_params
  97. class CombinedSearchVector(SearchVectorCombinable, CombinedExpression):
  98. def __init__(self, lhs, connector, rhs, config, output_field=None):
  99. self.config = config
  100. super().__init__(lhs, connector, rhs, output_field)
  101. class SearchQueryCombinable:
  102. BITAND = '&&'
  103. BITOR = '||'
  104. def _combine(self, other, connector, reversed):
  105. if not isinstance(other, SearchQueryCombinable):
  106. raise TypeError(
  107. 'SearchQuery can only be combined with other SearchQuery '
  108. 'instances, got %s.' % type(other).__name__
  109. )
  110. if reversed:
  111. return CombinedSearchQuery(other, connector, self, self.config)
  112. return CombinedSearchQuery(self, connector, other, self.config)
  113. # On Combinable, these are not implemented to reduce confusion with Q. In
  114. # this case we are actually (ab)using them to do logical combination so
  115. # it's consistent with other usage in Django.
  116. def __or__(self, other):
  117. return self._combine(other, self.BITOR, False)
  118. def __ror__(self, other):
  119. return self._combine(other, self.BITOR, True)
  120. def __and__(self, other):
  121. return self._combine(other, self.BITAND, False)
  122. def __rand__(self, other):
  123. return self._combine(other, self.BITAND, True)
  124. class SearchQuery(SearchQueryCombinable, Func):
  125. output_field = SearchQueryField()
  126. SEARCH_TYPES = {
  127. 'plain': 'plainto_tsquery',
  128. 'phrase': 'phraseto_tsquery',
  129. 'raw': 'to_tsquery',
  130. 'websearch': 'websearch_to_tsquery',
  131. }
  132. def __init__(self, value, output_field=None, *, config=None, invert=False, search_type='plain'):
  133. self.function = self.SEARCH_TYPES.get(search_type)
  134. if self.function is None:
  135. raise ValueError("Unknown search_type argument '%s'." % search_type)
  136. if not hasattr(value, 'resolve_expression'):
  137. value = Value(value)
  138. expressions = (value,)
  139. self.config = SearchConfig.from_parameter(config)
  140. if self.config is not None:
  141. expressions = (self.config,) + expressions
  142. self.invert = invert
  143. super().__init__(*expressions, output_field=output_field)
  144. def as_sql(self, compiler, connection, function=None, template=None):
  145. sql, params = super().as_sql(compiler, connection, function, template)
  146. if self.invert:
  147. sql = '!!(%s)' % sql
  148. return sql, params
  149. def __invert__(self):
  150. clone = self.copy()
  151. clone.invert = not self.invert
  152. return clone
  153. def __str__(self):
  154. result = super().__str__()
  155. return ('~%s' % result) if self.invert else result
  156. class CombinedSearchQuery(SearchQueryCombinable, CombinedExpression):
  157. def __init__(self, lhs, connector, rhs, config, output_field=None):
  158. self.config = config
  159. super().__init__(lhs, connector, rhs, output_field)
  160. def __str__(self):
  161. return '(%s)' % super().__str__()
  162. class SearchRank(Func):
  163. function = 'ts_rank'
  164. output_field = FloatField()
  165. def __init__(
  166. self, vector, query, weights=None, normalization=None,
  167. cover_density=False,
  168. ):
  169. if not hasattr(vector, 'resolve_expression'):
  170. vector = SearchVector(vector)
  171. if not hasattr(query, 'resolve_expression'):
  172. query = SearchQuery(query)
  173. expressions = (vector, query)
  174. if weights is not None:
  175. if not hasattr(weights, 'resolve_expression'):
  176. weights = Value(weights)
  177. expressions = (weights,) + expressions
  178. if normalization is not None:
  179. if not hasattr(normalization, 'resolve_expression'):
  180. normalization = Value(normalization)
  181. expressions += (normalization,)
  182. if cover_density:
  183. self.function = 'ts_rank_cd'
  184. super().__init__(*expressions)
  185. class SearchHeadline(Func):
  186. function = 'ts_headline'
  187. template = '%(function)s(%(expressions)s%(options)s)'
  188. output_field = TextField()
  189. def __init__(
  190. self, expression, query, *, config=None, start_sel=None, stop_sel=None,
  191. max_words=None, min_words=None, short_word=None, highlight_all=None,
  192. max_fragments=None, fragment_delimiter=None,
  193. ):
  194. if not hasattr(query, 'resolve_expression'):
  195. query = SearchQuery(query)
  196. options = {
  197. 'StartSel': start_sel,
  198. 'StopSel': stop_sel,
  199. 'MaxWords': max_words,
  200. 'MinWords': min_words,
  201. 'ShortWord': short_word,
  202. 'HighlightAll': highlight_all,
  203. 'MaxFragments': max_fragments,
  204. 'FragmentDelimiter': fragment_delimiter,
  205. }
  206. self.options = {
  207. option: value
  208. for option, value in options.items() if value is not None
  209. }
  210. expressions = (expression, query)
  211. if config is not None:
  212. config = SearchConfig.from_parameter(config)
  213. expressions = (config,) + expressions
  214. super().__init__(*expressions)
  215. def as_sql(self, compiler, connection, function=None, template=None):
  216. options_sql = ''
  217. options_params = []
  218. if self.options:
  219. # getquoted() returns a quoted bytestring of the adapted value.
  220. options_params.append(', '.join(
  221. '%s=%s' % (
  222. option,
  223. psycopg2.extensions.adapt(value).getquoted().decode(),
  224. ) for option, value in self.options.items()
  225. ))
  226. options_sql = ', %s'
  227. sql, params = super().as_sql(
  228. compiler, connection, function=function, template=template,
  229. options=options_sql,
  230. )
  231. return sql, params + options_params
  232. SearchVectorField.register_lookup(SearchVectorExact)
  233. class TrigramBase(Func):
  234. output_field = FloatField()
  235. def __init__(self, expression, string, **extra):
  236. if not hasattr(string, 'resolve_expression'):
  237. string = Value(string)
  238. super().__init__(expression, string, **extra)
  239. class TrigramSimilarity(TrigramBase):
  240. function = 'SIMILARITY'
  241. class TrigramDistance(TrigramBase):
  242. function = ''
  243. arg_joiner = ' <-> '