bytes_ops.c 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. // Bytes primitive operations
  2. //
  3. // These are registered in mypyc.primitives.bytes_ops.
  4. #include <Python.h>
  5. #include "CPy.h"
  6. // Returns -1 on error, 0 on inequality, 1 on equality.
  7. //
  8. // Falls back to PyObject_RichCompareBool.
  9. int CPyBytes_Compare(PyObject *left, PyObject *right) {
  10. if (PyBytes_CheckExact(left) && PyBytes_CheckExact(right)) {
  11. if (left == right) {
  12. return 1;
  13. }
  14. // Adapted from cpython internal implementation of bytes_compare.
  15. Py_ssize_t len = Py_SIZE(left);
  16. if (Py_SIZE(right) != len) {
  17. return 0;
  18. }
  19. PyBytesObject *left_b = (PyBytesObject *)left;
  20. PyBytesObject *right_b = (PyBytesObject *)right;
  21. if (left_b->ob_sval[0] != right_b->ob_sval[0]) {
  22. return 0;
  23. }
  24. return memcmp(left_b->ob_sval, right_b->ob_sval, len) == 0;
  25. }
  26. return PyObject_RichCompareBool(left, right, Py_EQ);
  27. }
  28. CPyTagged CPyBytes_GetItem(PyObject *o, CPyTagged index) {
  29. if (CPyTagged_CheckShort(index)) {
  30. Py_ssize_t n = CPyTagged_ShortAsSsize_t(index);
  31. Py_ssize_t size = ((PyVarObject *)o)->ob_size;
  32. if (n < 0)
  33. n += size;
  34. if (n < 0 || n >= size) {
  35. PyErr_SetString(PyExc_IndexError, "index out of range");
  36. return CPY_INT_TAG;
  37. }
  38. unsigned char num = PyBytes_Check(o) ? ((PyBytesObject *)o)->ob_sval[n]
  39. : ((PyByteArrayObject *)o)->ob_bytes[n];
  40. return num << 1;
  41. } else {
  42. PyErr_SetString(PyExc_OverflowError, CPYTHON_LARGE_INT_ERRMSG);
  43. return CPY_INT_TAG;
  44. }
  45. }
  46. PyObject *CPyBytes_Concat(PyObject *a, PyObject *b) {
  47. if (PyBytes_Check(a) && PyBytes_Check(b)) {
  48. Py_ssize_t a_len = ((PyVarObject *)a)->ob_size;
  49. Py_ssize_t b_len = ((PyVarObject *)b)->ob_size;
  50. PyBytesObject *ret = (PyBytesObject *)PyBytes_FromStringAndSize(NULL, a_len + b_len);
  51. if (ret != NULL) {
  52. memcpy(ret->ob_sval, ((PyBytesObject *)a)->ob_sval, a_len);
  53. memcpy(ret->ob_sval + a_len, ((PyBytesObject *)b)->ob_sval, b_len);
  54. }
  55. return (PyObject *)ret;
  56. } else if (PyByteArray_Check(a)) {
  57. return PyByteArray_Concat(a, b);
  58. } else {
  59. PyBytes_Concat(&a, b);
  60. return a;
  61. }
  62. }
  63. static inline Py_ssize_t Clamp(Py_ssize_t a, Py_ssize_t b, Py_ssize_t c) {
  64. return a < b ? b : (a >= c ? c : a);
  65. }
  66. PyObject *CPyBytes_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end) {
  67. if ((PyBytes_Check(obj) || PyByteArray_Check(obj))
  68. && CPyTagged_CheckShort(start) && CPyTagged_CheckShort(end)) {
  69. Py_ssize_t startn = CPyTagged_ShortAsSsize_t(start);
  70. Py_ssize_t endn = CPyTagged_ShortAsSsize_t(end);
  71. Py_ssize_t len = ((PyVarObject *)obj)->ob_size;
  72. if (startn < 0) {
  73. startn += len;
  74. }
  75. if (endn < 0) {
  76. endn += len;
  77. }
  78. startn = Clamp(startn, 0, len);
  79. endn = Clamp(endn, 0, len);
  80. Py_ssize_t slice_len = endn - startn;
  81. if (PyBytes_Check(obj)) {
  82. return PyBytes_FromStringAndSize(PyBytes_AS_STRING(obj) + startn, slice_len);
  83. } else {
  84. return PyByteArray_FromStringAndSize(PyByteArray_AS_STRING(obj) + startn, slice_len);
  85. }
  86. }
  87. return CPyObject_GetSlice(obj, start, end);
  88. }
  89. // Like _PyBytes_Join but fallback to dynamic call if 'sep' is not bytes
  90. // (mostly commonly, for bytearrays)
  91. PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter) {
  92. if (PyBytes_CheckExact(sep)) {
  93. return _PyBytes_Join(sep, iter);
  94. } else {
  95. _Py_IDENTIFIER(join);
  96. return _PyObject_CallMethodIdOneArg(sep, &PyId_join, iter);
  97. }
  98. }
  99. PyObject *CPyBytes_Build(Py_ssize_t len, ...) {
  100. Py_ssize_t i;
  101. Py_ssize_t sz = 0;
  102. va_list args;
  103. va_start(args, len);
  104. for (i = 0; i < len; i++) {
  105. PyObject *item = va_arg(args, PyObject *);
  106. size_t add_sz = ((PyVarObject *)item)->ob_size;
  107. // Using size_t to avoid overflow during arithmetic calculation
  108. if (add_sz > (size_t)(PY_SSIZE_T_MAX - sz)) {
  109. PyErr_SetString(PyExc_OverflowError,
  110. "join() result is too long for a Python bytes");
  111. return NULL;
  112. }
  113. sz += add_sz;
  114. }
  115. va_end(args);
  116. PyBytesObject *ret = (PyBytesObject *)PyBytes_FromStringAndSize(NULL, sz);
  117. if (ret != NULL) {
  118. char *res_data = ret->ob_sval;
  119. va_start(args, len);
  120. for (i = 0; i < len; i++) {
  121. PyObject *item = va_arg(args, PyObject *);
  122. Py_ssize_t item_sz = ((PyVarObject *)item)->ob_size;
  123. memcpy(res_data, ((PyBytesObject *)item)->ob_sval, item_sz);
  124. res_data += item_sz;
  125. }
  126. va_end(args);
  127. assert(res_data == ret->ob_sval + ((PyVarObject *)ret)->ob_size);
  128. }
  129. return (PyObject *)ret;
  130. }