str_ops.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. """Primitive str ops."""
  2. from __future__ import annotations
  3. from mypyc.ir.ops import ERR_MAGIC, ERR_NEVER
  4. from mypyc.ir.rtypes import (
  5. RType,
  6. bit_rprimitive,
  7. bool_rprimitive,
  8. bytes_rprimitive,
  9. c_int_rprimitive,
  10. c_pyssize_t_rprimitive,
  11. int_rprimitive,
  12. list_rprimitive,
  13. object_rprimitive,
  14. pointer_rprimitive,
  15. str_rprimitive,
  16. )
  17. from mypyc.primitives.registry import (
  18. ERR_NEG_INT,
  19. binary_op,
  20. custom_op,
  21. function_op,
  22. load_address_op,
  23. method_op,
  24. )
  25. # Get the 'str' type object.
  26. load_address_op(name="builtins.str", type=object_rprimitive, src="PyUnicode_Type")
  27. # str(obj)
  28. str_op = function_op(
  29. name="builtins.str",
  30. arg_types=[object_rprimitive],
  31. return_type=str_rprimitive,
  32. c_function_name="PyObject_Str",
  33. error_kind=ERR_MAGIC,
  34. )
  35. # str1 + str2
  36. binary_op(
  37. name="+",
  38. arg_types=[str_rprimitive, str_rprimitive],
  39. return_type=str_rprimitive,
  40. c_function_name="PyUnicode_Concat",
  41. error_kind=ERR_MAGIC,
  42. )
  43. # str1 += str2
  44. #
  45. # PyUnicode_Append makes an effort to reuse the LHS when the refcount
  46. # is 1. This is super dodgy but oh well, the interpreter does it.
  47. binary_op(
  48. name="+=",
  49. arg_types=[str_rprimitive, str_rprimitive],
  50. return_type=str_rprimitive,
  51. c_function_name="CPyStr_Append",
  52. error_kind=ERR_MAGIC,
  53. steals=[True, False],
  54. )
  55. unicode_compare = custom_op(
  56. arg_types=[str_rprimitive, str_rprimitive],
  57. return_type=c_int_rprimitive,
  58. c_function_name="PyUnicode_Compare",
  59. error_kind=ERR_NEVER,
  60. )
  61. # str[index] (for an int index)
  62. method_op(
  63. name="__getitem__",
  64. arg_types=[str_rprimitive, int_rprimitive],
  65. return_type=str_rprimitive,
  66. c_function_name="CPyStr_GetItem",
  67. error_kind=ERR_MAGIC,
  68. )
  69. # str[begin:end]
  70. str_slice_op = custom_op(
  71. arg_types=[str_rprimitive, int_rprimitive, int_rprimitive],
  72. return_type=object_rprimitive,
  73. c_function_name="CPyStr_GetSlice",
  74. error_kind=ERR_MAGIC,
  75. )
  76. # str.join(obj)
  77. method_op(
  78. name="join",
  79. arg_types=[str_rprimitive, object_rprimitive],
  80. return_type=str_rprimitive,
  81. c_function_name="PyUnicode_Join",
  82. error_kind=ERR_MAGIC,
  83. )
  84. str_build_op = custom_op(
  85. arg_types=[c_pyssize_t_rprimitive],
  86. return_type=str_rprimitive,
  87. c_function_name="CPyStr_Build",
  88. error_kind=ERR_MAGIC,
  89. var_arg_type=str_rprimitive,
  90. )
  91. # str.startswith(str)
  92. method_op(
  93. name="startswith",
  94. arg_types=[str_rprimitive, str_rprimitive],
  95. return_type=bool_rprimitive,
  96. c_function_name="CPyStr_Startswith",
  97. error_kind=ERR_NEVER,
  98. )
  99. # str.endswith(str)
  100. method_op(
  101. name="endswith",
  102. arg_types=[str_rprimitive, str_rprimitive],
  103. return_type=bool_rprimitive,
  104. c_function_name="CPyStr_Endswith",
  105. error_kind=ERR_NEVER,
  106. )
  107. # str.split(...)
  108. str_split_types: list[RType] = [str_rprimitive, str_rprimitive, int_rprimitive]
  109. str_split_functions = ["PyUnicode_Split", "PyUnicode_Split", "CPyStr_Split"]
  110. str_split_constants: list[list[tuple[int, RType]]] = [
  111. [(0, pointer_rprimitive), (-1, c_int_rprimitive)],
  112. [(-1, c_int_rprimitive)],
  113. [],
  114. ]
  115. for i in range(len(str_split_types)):
  116. method_op(
  117. name="split",
  118. arg_types=str_split_types[0 : i + 1],
  119. return_type=list_rprimitive,
  120. c_function_name=str_split_functions[i],
  121. extra_int_constants=str_split_constants[i],
  122. error_kind=ERR_MAGIC,
  123. )
  124. # str.replace(old, new)
  125. method_op(
  126. name="replace",
  127. arg_types=[str_rprimitive, str_rprimitive, str_rprimitive],
  128. return_type=str_rprimitive,
  129. c_function_name="PyUnicode_Replace",
  130. error_kind=ERR_MAGIC,
  131. extra_int_constants=[(-1, c_int_rprimitive)],
  132. )
  133. # str.replace(old, new, count)
  134. method_op(
  135. name="replace",
  136. arg_types=[str_rprimitive, str_rprimitive, str_rprimitive, int_rprimitive],
  137. return_type=str_rprimitive,
  138. c_function_name="CPyStr_Replace",
  139. error_kind=ERR_MAGIC,
  140. )
  141. # check if a string is true (isn't an empty string)
  142. str_check_if_true = custom_op(
  143. arg_types=[str_rprimitive],
  144. return_type=bit_rprimitive,
  145. c_function_name="CPyStr_IsTrue",
  146. error_kind=ERR_NEVER,
  147. )
  148. str_ssize_t_size_op = custom_op(
  149. arg_types=[str_rprimitive],
  150. return_type=c_pyssize_t_rprimitive,
  151. c_function_name="CPyStr_Size_size_t",
  152. error_kind=ERR_NEG_INT,
  153. )
  154. # obj.decode()
  155. method_op(
  156. name="decode",
  157. arg_types=[bytes_rprimitive],
  158. return_type=str_rprimitive,
  159. c_function_name="CPy_Decode",
  160. error_kind=ERR_MAGIC,
  161. extra_int_constants=[(0, pointer_rprimitive), (0, pointer_rprimitive)],
  162. )
  163. # obj.decode(encoding)
  164. method_op(
  165. name="decode",
  166. arg_types=[bytes_rprimitive, str_rprimitive],
  167. return_type=str_rprimitive,
  168. c_function_name="CPy_Decode",
  169. error_kind=ERR_MAGIC,
  170. extra_int_constants=[(0, pointer_rprimitive)],
  171. )
  172. # obj.decode(encoding, errors)
  173. method_op(
  174. name="decode",
  175. arg_types=[bytes_rprimitive, str_rprimitive, str_rprimitive],
  176. return_type=str_rprimitive,
  177. c_function_name="CPy_Decode",
  178. error_kind=ERR_MAGIC,
  179. )
  180. # str.encode()
  181. method_op(
  182. name="encode",
  183. arg_types=[str_rprimitive],
  184. return_type=bytes_rprimitive,
  185. c_function_name="CPy_Encode",
  186. error_kind=ERR_MAGIC,
  187. extra_int_constants=[(0, pointer_rprimitive), (0, pointer_rprimitive)],
  188. )
  189. # str.encode(encoding)
  190. method_op(
  191. name="encode",
  192. arg_types=[str_rprimitive, str_rprimitive],
  193. return_type=bytes_rprimitive,
  194. c_function_name="CPy_Encode",
  195. error_kind=ERR_MAGIC,
  196. extra_int_constants=[(0, pointer_rprimitive)],
  197. )
  198. # str.encode(encoding, errors)
  199. method_op(
  200. name="encode",
  201. arg_types=[str_rprimitive, str_rprimitive, str_rprimitive],
  202. return_type=bytes_rprimitive,
  203. c_function_name="CPy_Encode",
  204. error_kind=ERR_MAGIC,
  205. )