fermat.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. package bigfft
  2. import (
  3. "math/big"
  4. )
  5. // Arithmetic modulo 2^n+1.
  6. // A fermat of length w+1 represents a number modulo 2^(w*_W) + 1. The last
  7. // word is zero or one. A number has at most two representatives satisfying the
  8. // 0-1 last word constraint.
  9. type fermat nat
  10. func (n fermat) String() string { return nat(n).String() }
  11. func (z fermat) norm() {
  12. n := len(z) - 1
  13. c := z[n]
  14. if c == 0 {
  15. return
  16. }
  17. if z[0] >= c {
  18. z[n] = 0
  19. z[0] -= c
  20. return
  21. }
  22. // z[0] < z[n].
  23. subVW(z, z, c) // Substract c
  24. if c > 1 {
  25. z[n] -= c - 1
  26. c = 1
  27. }
  28. // Add back c.
  29. if z[n] == 1 {
  30. z[n] = 0
  31. return
  32. } else {
  33. addVW(z, z, 1)
  34. }
  35. }
  36. // Shift computes (x << k) mod (2^n+1).
  37. func (z fermat) Shift(x fermat, k int) {
  38. if len(z) != len(x) {
  39. panic("len(z) != len(x) in Shift")
  40. }
  41. n := len(x) - 1
  42. // Shift by n*_W is taking the opposite.
  43. k %= 2 * n * _W
  44. if k < 0 {
  45. k += 2 * n * _W
  46. }
  47. neg := false
  48. if k >= n*_W {
  49. k -= n * _W
  50. neg = true
  51. }
  52. kw, kb := k/_W, k%_W
  53. z[n] = 1 // Add (-1)
  54. if !neg {
  55. for i := 0; i < kw; i++ {
  56. z[i] = 0
  57. }
  58. // Shift left by kw words.
  59. // x = a·2^(n-k) + b
  60. // x<<k = (b<<k) - a
  61. copy(z[kw:], x[:n-kw])
  62. b := subVV(z[:kw+1], z[:kw+1], x[n-kw:])
  63. if z[kw+1] > 0 {
  64. z[kw+1] -= b
  65. } else {
  66. subVW(z[kw+1:], z[kw+1:], b)
  67. }
  68. } else {
  69. for i := kw + 1; i < n; i++ {
  70. z[i] = 0
  71. }
  72. // Shift left and negate, by kw words.
  73. copy(z[:kw+1], x[n-kw:n+1]) // z_low = x_high
  74. b := subVV(z[kw:n], z[kw:n], x[:n-kw]) // z_high -= x_low
  75. z[n] -= b
  76. }
  77. // Add back 1.
  78. if z[n] > 0 {
  79. z[n]--
  80. } else if z[0] < ^big.Word(0) {
  81. z[0]++
  82. } else {
  83. addVW(z, z, 1)
  84. }
  85. // Shift left by kb bits
  86. shlVU(z, z, uint(kb))
  87. z.norm()
  88. }
  89. // ShiftHalf shifts x by k/2 bits the left. Shifting by 1/2 bit
  90. // is multiplication by sqrt(2) mod 2^n+1 which is 2^(3n/4) - 2^(n/4).
  91. // A temporary buffer must be provided in tmp.
  92. func (z fermat) ShiftHalf(x fermat, k int, tmp fermat) {
  93. n := len(z) - 1
  94. if k%2 == 0 {
  95. z.Shift(x, k/2)
  96. return
  97. }
  98. u := (k - 1) / 2
  99. a := u + (3*_W/4)*n
  100. b := u + (_W/4)*n
  101. z.Shift(x, a)
  102. tmp.Shift(x, b)
  103. z.Sub(z, tmp)
  104. }
  105. // Add computes addition mod 2^n+1.
  106. func (z fermat) Add(x, y fermat) fermat {
  107. if len(z) != len(x) {
  108. panic("Add: len(z) != len(x)")
  109. }
  110. addVV(z, x, y) // there cannot be a carry here.
  111. z.norm()
  112. return z
  113. }
  114. // Sub computes substraction mod 2^n+1.
  115. func (z fermat) Sub(x, y fermat) fermat {
  116. if len(z) != len(x) {
  117. panic("Add: len(z) != len(x)")
  118. }
  119. n := len(y) - 1
  120. b := subVV(z[:n], x[:n], y[:n])
  121. b += y[n]
  122. // If b > 0, we need to subtract b<<n, which is the same as adding b.
  123. z[n] = x[n]
  124. if z[0] <= ^big.Word(0)-b {
  125. z[0] += b
  126. } else {
  127. addVW(z, z, b)
  128. }
  129. z.norm()
  130. return z
  131. }
  132. func (z fermat) Mul(x, y fermat) fermat {
  133. if len(x) != len(y) {
  134. panic("Mul: len(x) != len(y)")
  135. }
  136. n := len(x) - 1
  137. if n < 30 {
  138. z = z[:2*n+2]
  139. basicMul(z, x, y)
  140. z = z[:2*n+1]
  141. } else {
  142. var xi, yi, zi big.Int
  143. xi.SetBits(x)
  144. yi.SetBits(y)
  145. zi.SetBits(z)
  146. zb := zi.Mul(&xi, &yi).Bits()
  147. if len(zb) <= n {
  148. // Short product.
  149. copy(z, zb)
  150. for i := len(zb); i < len(z); i++ {
  151. z[i] = 0
  152. }
  153. return z
  154. }
  155. z = zb
  156. }
  157. // len(z) is at most 2n+1.
  158. if len(z) > 2*n+1 {
  159. panic("len(z) > 2n+1")
  160. }
  161. // We now have
  162. // z = z[:n] + 1<<(n*W) * z[n:2n+1]
  163. // which normalizes to:
  164. // z = z[:n] - z[n:2n] + z[2n]
  165. c1 := big.Word(0)
  166. if len(z) > 2*n {
  167. c1 = addVW(z[:n], z[:n], z[2*n])
  168. }
  169. c2 := big.Word(0)
  170. if len(z) >= 2*n {
  171. c2 = subVV(z[:n], z[:n], z[n:2*n])
  172. } else {
  173. m := len(z) - n
  174. c2 = subVV(z[:m], z[:m], z[n:])
  175. c2 = subVW(z[m:n], z[m:n], c2)
  176. }
  177. // Restore carries.
  178. // Substracting z[n] -= c2 is the same
  179. // as z[0] += c2
  180. z = z[:n+1]
  181. z[n] = c1
  182. c := addVW(z, z, c2)
  183. if c != 0 {
  184. panic("impossible")
  185. }
  186. z.norm()
  187. return z
  188. }
  189. // copied from math/big
  190. //
  191. // basicMul multiplies x and y and leaves the result in z.
  192. // The (non-normalized) result is placed in z[0 : len(x) + len(y)].
  193. func basicMul(z, x, y fermat) {
  194. // initialize z
  195. for i := 0; i < len(z); i++ {
  196. z[i] = 0
  197. }
  198. for i, d := range y {
  199. if d != 0 {
  200. z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d)
  201. }
  202. }
  203. }