rec.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. // Copyright 2023 The Rec Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package rec // import "modernc.org/rec/lib"
  5. import (
  6. "bytes"
  7. "flag"
  8. "fmt"
  9. "io"
  10. "os"
  11. "path/filepath"
  12. "strconv"
  13. "strings"
  14. "sync"
  15. "unicode"
  16. "modernc.org/gc/v3"
  17. )
  18. const pkg = "package x\n"
  19. type cfg struct {
  20. args []string
  21. nonFlagArgs []string
  22. oCaseInsensitive *bool
  23. oDbg *bool
  24. oImport *string
  25. oLex *string
  26. oLexString *string
  27. oLexStringUtf8 *string
  28. oLexUtf8 *string
  29. oMatch *string
  30. oMatchString *string
  31. oMatchStringUtf8 *string
  32. oMatchUtf8 *string
  33. oPkg *string
  34. oRx *string
  35. oTrc *bool
  36. post *post
  37. stderr io.Writer
  38. stdout io.Writer
  39. lcase string
  40. werr error
  41. closed bool
  42. }
  43. var (
  44. tables unicodeTables
  45. )
  46. // Main executes the CLI rec command. The 'args' argument should not include
  47. // the command name.
  48. func Main(args []string, stdout, stderr io.Writer) (rc int, err error) {
  49. c := &cfg{
  50. args: args,
  51. post: newPost(),
  52. stdout: stdout,
  53. stderr: stderr,
  54. }
  55. if c.stdout == nil {
  56. c.stdout = io.Discard
  57. }
  58. if c.stderr == nil {
  59. c.stderr = io.Discard
  60. }
  61. defer func() {
  62. if perr := c.post.post(); perr != nil && err == nil {
  63. rc = 1
  64. err = perr
  65. }
  66. if _, cerr := io.Copy(c.stdout, c.post.b); cerr != nil && err == nil {
  67. rc = 1
  68. err = cerr
  69. }
  70. }()
  71. set := flag.NewFlagSet("rec", flag.ContinueOnError)
  72. c.oCaseInsensitive = set.Bool("i", false, "case insensitive")
  73. c.oDbg = set.Bool("dbg", false, "enable debug output")
  74. c.oImport = set.String("import", "", "comma separated list if imports")
  75. c.oLex = set.String("lex", "", "name of the lex function")
  76. c.oLexString = set.String("lexstring", "", "name of the lexstring function")
  77. c.oLexStringUtf8 = set.String("lexstringutf8", "", "name of the lexstring utf8 function")
  78. c.oLexUtf8 = set.String("lexutf8", "", "name of the lexutf8 function")
  79. c.oMatch = set.String("match", "", "name of the match function")
  80. c.oMatchString = set.String("matchstring", "", "name of the matchstring function")
  81. c.oMatchStringUtf8 = set.String("matchstringutf8", "", "name of the matchstring utf8 function")
  82. c.oMatchUtf8 = set.String("matchutf8", "", "name of the matchutf8 function")
  83. c.oPkg = set.String("pkg", "", "wrap output in package <argument>")
  84. c.oRx = set.String("rx", "", "make the machine a method of <argument>")
  85. c.oTrc = set.Bool("trc", false, "inject tracing")
  86. set.SetOutput(stderr)
  87. set.Parse(args)
  88. c.nonFlagArgs = set.Args()
  89. if *c.oCaseInsensitive {
  90. for i, v := range c.nonFlagArgs {
  91. c.nonFlagArgs[i] = strings.ToLower(v)
  92. }
  93. c.lcase = "defer func() { r = unicode.ToLower(r) }();"
  94. }
  95. if *c.oPkg == "" {
  96. c.post.pkg = pkg
  97. c.post.b.WriteString(pkg)
  98. }
  99. switch {
  100. case *c.oLex != "":
  101. err = c.lex(*c.oLex, set.Args(), false, false)
  102. case *c.oLexUtf8 != "":
  103. err = c.lex(*c.oLexUtf8, set.Args(), true, false)
  104. case *c.oLexString != "":
  105. err = c.lex(*c.oLexString, set.Args(), false, true)
  106. case *c.oLexStringUtf8 != "":
  107. err = c.lex(*c.oLexStringUtf8, set.Args(), true, true)
  108. case *c.oMatch != "":
  109. err = c.match(*c.oMatch, set.Args(), false, false)
  110. case *c.oMatchUtf8 != "":
  111. err = c.match(*c.oMatchUtf8, set.Args(), true, false)
  112. case *c.oMatchString != "":
  113. err = c.match(*c.oMatchString, set.Args(), false, true)
  114. case *c.oMatchStringUtf8 != "":
  115. err = c.match(*c.oMatchStringUtf8, set.Args(), true, true)
  116. default:
  117. return 2, fmt.Errorf("no machine selected, use '-match=Match', for example")
  118. }
  119. if err != nil {
  120. return 1, fmt.Errorf("%s", err)
  121. }
  122. if c.werr != nil {
  123. return 1, c.werr
  124. }
  125. return 0, nil
  126. }
  127. func (c *cfg) rxString() (r string) {
  128. if s := *c.oRx; s != "" {
  129. r = fmt.Sprintf("(%s) ", s)
  130. }
  131. return r
  132. }
  133. func (c *cfg) prolog(hideNonFlagArgs bool) {
  134. s := *c.oPkg
  135. if s == "" {
  136. return
  137. }
  138. *c.oPkg = ""
  139. args := append([]string(nil), c.args...)
  140. if hideNonFlagArgs {
  141. m := map[string]struct{}{}
  142. for _, v := range c.nonFlagArgs {
  143. m[v] = struct{}{}
  144. }
  145. args = args[:0]
  146. for _, v := range c.args {
  147. if _, ok := m[v]; !ok {
  148. args = append(args, v)
  149. }
  150. }
  151. args = append(args, "<...>")
  152. }
  153. c.w(`// Code generated by %v, DO NOT EDIT.
  154. package %s
  155. `, append([]string{filepath.Base(os.Args[0])}, args...), s)
  156. if s = *c.oImport; s == "" {
  157. return
  158. }
  159. c.w(`
  160. import (
  161. `)
  162. for _, v := range strings.Split(s, ",") {
  163. c.w("\t%q\n", v)
  164. }
  165. c.w(")\n\n")
  166. }
  167. func (c *cfg) w(s string, args ...interface{}) {
  168. if c.closed {
  169. return
  170. }
  171. _, c.werr = fmt.Fprintf(c.post, s, args...)
  172. if c.werr != nil {
  173. c.closed = true
  174. }
  175. }
  176. func (c *cfg) trc(s, args string) {
  177. c.w("\tfmt.Fprintf(os.Stderr, %q, %s);", "TRC "+s, args)
  178. }
  179. func pretty(s string) string {
  180. for i := 0; i < len(s); i++ {
  181. if s[i] < ' ' || s[i] > 0x7e {
  182. return strconv.QuoteToASCII(s)
  183. }
  184. }
  185. return s
  186. }
  187. type unicodeTables struct {
  188. once sync.Once
  189. byFirstRune map[rune][]*unicode.RangeTable
  190. tab2Hash map[*unicode.RangeTable]string
  191. tab2Name map[*unicode.RangeTable]string
  192. }
  193. func (t *unicodeTables) boot() {
  194. t.once.Do(func() {
  195. t.tab2Name = map[*unicode.RangeTable]string{
  196. unicode.Cc: "Cc", // Cc is the set of Unicode characters in category Cc (Other, control).
  197. unicode.Cf: "Cf", // Cf is the set of Unicode characters in category Cf (Other, format).
  198. unicode.Co: "Co", // Co is the set of Unicode characters in category Co (Other, private use).
  199. unicode.Cs: "Cs", // Cs is the set of Unicode characters in category Cs (Other, surrogate).
  200. unicode.Digit: "Nd", // Digit is the set of Unicode characters with the "decimal digit" property.
  201. unicode.Nd: "Nd", // Nd is the set of Unicode characters in category Nd (Number, decimal digit).
  202. unicode.Letter: "L", // Letter/L is the set of Unicode letters, category L.
  203. unicode.L: "L",
  204. unicode.Lm: "Lm", // Lm is the set of Unicode characters in category Lm (Letter, modifier).
  205. unicode.Lo: "Lo", // Lo is the set of Unicode characters in category Lo (Letter, other).
  206. unicode.Lower: "Ll", // Lower is the set of Unicode lower case letters.
  207. unicode.Ll: "Ll", // Ll is the set of Unicode characters in category Ll (Letter, lowercase).
  208. unicode.Mark: "M", // Mark/M is the set of Unicode mark characters, category M.
  209. unicode.M: "M",
  210. unicode.Mc: "Mc", // Mc is the set of Unicode characters in category Mc (Mark, spacing combining).
  211. unicode.Me: "Me", // Me is the set of Unicode characters in category Me (Mark, enclosing).
  212. unicode.Mn: "Mn", // Mn is the set of Unicode characters in category Mn (Mark, nonspacing).
  213. unicode.Nl: "Nl", // Nl is the set of Unicode characters in category Nl (Number, letter).
  214. unicode.No: "No", // No is the set of Unicode characters in category No (Number, other).
  215. unicode.Number: "N", // Number/N is the set of Unicode number characters, category N.
  216. unicode.N: "N",
  217. unicode.Other: "C", // Other/C is the set of Unicode control and special characters, category C.
  218. unicode.C: "C",
  219. unicode.Pc: "Pc", // Pc is the set of Unicode characters in category Pc (Punctuation, connector).
  220. unicode.Pd: "Pd", // Pd is the set of Unicode characters in category Pd (Punctuation, dash).
  221. unicode.Pe: "Pe", // Pe is the set of Unicode characters in category Pe (Punctuation, close).
  222. unicode.Pf: "Pf", // Pf is the set of Unicode characters in category Pf (Punctuation, final quote).
  223. unicode.Pi: "Pi", // Pi is the set of Unicode characters in category Pi (Punctuation, initial quote).
  224. unicode.Po: "Po", // Po is the set of Unicode characters in category Po (Punctuation, other).
  225. unicode.Ps: "Ps", // Ps is the set of Unicode characters in category Ps (Punctuation, open).
  226. unicode.Punct: "P", // Punct/P is the set of Unicode punctuation characters, category P.
  227. unicode.P: "P",
  228. unicode.Sc: "Sc", // Sc is the set of Unicode characters in category Sc (Symbol, currency).
  229. unicode.Sk: "Sk", // Sk is the set of Unicode characters in category Sk (Symbol, modifier).
  230. unicode.Sm: "Sm", // Sm is the set of Unicode characters in category Sm (Symbol, math).
  231. unicode.So: "So", // So is the set of Unicode characters in category So (Symbol, other).
  232. unicode.Space: "Z", // Space/Z is the set of Unicode space characters, category Z.
  233. unicode.Z: "Z",
  234. unicode.Symbol: "S", // Symbol/S is the set of Unicode symbol characters, category S.
  235. unicode.S: "S",
  236. unicode.Title: "Lt", // Title is the set of Unicode title case letters.
  237. unicode.Lt: "Lt", // Lt is the set of Unicode characters in category Lt (Letter, titlecase).
  238. unicode.Upper: "Lu", // Upper is the set of Unicode upper case letters.
  239. unicode.Lu: "Lu", // Lu is the set of Unicode characters in category Lu (Letter, uppercase).
  240. unicode.Zl: "Zl", // Zl is the set of Unicode characters in category Zl (Separator, line).
  241. unicode.Zp: "Zp", // Zp is the set of Unicode characters in category Zp (Separator, paragraph).
  242. unicode.Zs: "Zs", // Zs is the set of Unicode characters in category Zs (Separator, space).
  243. }
  244. t.byFirstRune = map[rune][]*unicode.RangeTable{}
  245. for k := range t.tab2Name {
  246. r := firstRune(k)
  247. t.byFirstRune[r] = append(t.byFirstRune[r], k)
  248. }
  249. })
  250. }
  251. func (t *unicodeTables) find(s []rune) string {
  252. if len(s) == 0 {
  253. return ""
  254. }
  255. t.boot()
  256. sid := hashString(s)
  257. for _, rt := range t.byFirstRune[s[0]] {
  258. if t.hashString(rt) == sid {
  259. return t.tab2Name[rt]
  260. }
  261. }
  262. return ""
  263. }
  264. func (t *unicodeTables) hashString(rt *unicode.RangeTable) (s string) {
  265. if s, ok := t.tab2Hash[rt]; ok {
  266. return s
  267. }
  268. var ts []rune
  269. for _, v := range rt.R16 {
  270. switch {
  271. case v.Stride == 1:
  272. ts = append(ts, rune(v.Lo), rune(v.Hi))
  273. default:
  274. for r := v.Lo; r <= v.Hi; r += v.Stride {
  275. ts = append(ts, rune(r), rune(r))
  276. }
  277. }
  278. }
  279. for _, v := range rt.R32 {
  280. switch {
  281. case v.Stride == 1:
  282. ts = append(ts, rune(v.Lo), rune(v.Hi))
  283. default:
  284. for r := v.Lo; r <= v.Hi; r += v.Stride {
  285. ts = append(ts, rune(r), rune(r))
  286. }
  287. }
  288. }
  289. ts = normalizeRuneSlice(ts)
  290. s = hashString(ts)
  291. if t.tab2Hash == nil {
  292. t.tab2Hash = map[*unicode.RangeTable]string{}
  293. }
  294. t.tab2Hash[rt] = s
  295. return s
  296. }
  297. func firstRune(t *unicode.RangeTable) rune {
  298. if len(t.R16) != 0 {
  299. return rune(t.R16[0].Lo)
  300. }
  301. return rune(t.R32[0].Lo)
  302. }
  303. func hashString(s []rune) string {
  304. var b strings.Builder
  305. for _, v := range s {
  306. for i := 0; i < 4; i++ {
  307. b.WriteByte(byte(v))
  308. v >>= 8
  309. }
  310. }
  311. return b.String()
  312. }
  313. func normalizeRuneSlice(s []rune) (t []rune) {
  314. for i := 0; i < len(s); i += 2 {
  315. lo, hi := s[i], s[i+1]
  316. tl := len(t)
  317. if tl == 0 {
  318. t = append(t, lo, hi)
  319. continue
  320. }
  321. tlo, thi := t[tl-2], t[tl-1]
  322. if tlo == thi && lo == tlo+1 {
  323. t[tl-1] = hi
  324. continue
  325. }
  326. t = append(t, lo, hi)
  327. }
  328. return t
  329. }
  330. type fn struct {
  331. gotoStatements map[string]map[gc.Node]struct{}
  332. labeledStatements map[string]*gc.LabeledStmtNode
  333. }
  334. func newFn() *fn {
  335. return &fn{
  336. gotoStatements: map[string]map[gc.Node]struct{}{},
  337. labeledStatements: map[string]*gc.LabeledStmtNode{},
  338. }
  339. }
  340. type post struct {
  341. b *bytes.Buffer
  342. kill map[gc.Node]struct{}
  343. pkg string
  344. }
  345. func newPost() *post {
  346. return &post{
  347. b: bytes.NewBuffer(nil),
  348. kill: map[gc.Node]struct{}{},
  349. }
  350. }
  351. func (p *post) Write(b []byte) (int, error) { return p.b.Write(b) }
  352. func (p *post) source(ast *gc.AST) {
  353. s := gc.NodeSource(ast.SourceFile, true, p.kill)
  354. for {
  355. n := len(s)
  356. s = strings.ReplaceAll(s, "\n\n\tif", "\n\tif")
  357. s = strings.ReplaceAll(s, "\n\nl", "\nl")
  358. s = strings.ReplaceAll(s, "\n\n\treturn", "\n\treturn")
  359. if len(s) == n {
  360. break
  361. }
  362. }
  363. p.b.WriteString(s[len(p.pkg):])
  364. }
  365. func (p *post) post() (err error) {
  366. ast, err := gc.ParseFile("<output>", p.b.Bytes())
  367. if err != nil {
  368. return err
  369. }
  370. defer func() {
  371. p.b = bytes.NewBuffer(nil)
  372. p.source(ast)
  373. }()
  374. for n := ast.SourceFile.TopLevelDeclList; n != nil; n = n.List {
  375. switch x := n.TopLevelDecl.(type) {
  376. case *gc.FunctionDeclNode:
  377. p.functionBody(x.FunctionBody)
  378. }
  379. }
  380. return nil
  381. }
  382. func (p *post) functionBody(n *gc.FunctionBodyNode) {
  383. fn := newFn()
  384. p.block(fn, n.Block)
  385. for label, labelNode := range fn.labeledStatements {
  386. labelLine := labelNode.Position().Line
  387. gotos := fn.gotoStatements[label]
  388. for gotoNode := range gotos {
  389. if line := gotoNode.Position().Line; line == labelLine || line == labelLine-1 {
  390. p.kill[gotoNode] = struct{}{}
  391. delete(gotos, gotoNode)
  392. }
  393. }
  394. if len(gotos) == 0 {
  395. p.kill[labelNode.Label] = struct{}{}
  396. p.kill[labelNode.COLON] = struct{}{}
  397. }
  398. }
  399. }
  400. func (p *post) block(fn *fn, n *gc.BlockNode) {
  401. for l := n.StatementList; l != nil; l = l.List {
  402. p.statement(fn, l.Statement)
  403. }
  404. }
  405. func (p *post) statement(fn *fn, n gc.Node) {
  406. switch x := n.(type) {
  407. case *gc.LabeledStmtNode:
  408. fn.labeledStatements[x.Label.IDENT.Src()] = x
  409. p.statement(fn, x.Statement)
  410. case *gc.GotoStmtNode:
  411. m := fn.gotoStatements[x.Label.IDENT.Src()]
  412. if m == nil {
  413. m = map[gc.Node]struct{}{}
  414. fn.gotoStatements[x.Label.IDENT.Src()] = m
  415. }
  416. m[x] = struct{}{}
  417. case *gc.IfStmtNode:
  418. p.block(fn, x.Block)
  419. case *gc.IfElseStmtNode:
  420. p.block(fn, x.Block)
  421. switch y := x.ElseClause.(type) {
  422. case *gc.IfStmtNode, *gc.IfElseStmtNode:
  423. p.statement(fn, y)
  424. case *gc.BlockNode:
  425. p.block(fn, y)
  426. }
  427. }
  428. }