| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477 |
- // Copyright 2023 The Rec Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package rec // import "modernc.org/rec/lib"
- import (
- "bytes"
- "flag"
- "fmt"
- "io"
- "os"
- "path/filepath"
- "strconv"
- "strings"
- "sync"
- "unicode"
- "modernc.org/gc/v3"
- )
- const pkg = "package x\n"
- type cfg struct {
- args []string
- nonFlagArgs []string
- oCaseInsensitive *bool
- oDbg *bool
- oImport *string
- oLex *string
- oLexString *string
- oLexStringUtf8 *string
- oLexUtf8 *string
- oMatch *string
- oMatchString *string
- oMatchStringUtf8 *string
- oMatchUtf8 *string
- oPkg *string
- oRx *string
- oTrc *bool
- post *post
- stderr io.Writer
- stdout io.Writer
- lcase string
- werr error
- closed bool
- }
- var (
- tables unicodeTables
- )
- // Main executes the CLI rec command. The 'args' argument should not include
- // the command name.
- func Main(args []string, stdout, stderr io.Writer) (rc int, err error) {
- c := &cfg{
- args: args,
- post: newPost(),
- stdout: stdout,
- stderr: stderr,
- }
- if c.stdout == nil {
- c.stdout = io.Discard
- }
- if c.stderr == nil {
- c.stderr = io.Discard
- }
- defer func() {
- if perr := c.post.post(); perr != nil && err == nil {
- rc = 1
- err = perr
- }
- if _, cerr := io.Copy(c.stdout, c.post.b); cerr != nil && err == nil {
- rc = 1
- err = cerr
- }
- }()
- set := flag.NewFlagSet("rec", flag.ContinueOnError)
- c.oCaseInsensitive = set.Bool("i", false, "case insensitive")
- c.oDbg = set.Bool("dbg", false, "enable debug output")
- c.oImport = set.String("import", "", "comma separated list if imports")
- c.oLex = set.String("lex", "", "name of the lex function")
- c.oLexString = set.String("lexstring", "", "name of the lexstring function")
- c.oLexStringUtf8 = set.String("lexstringutf8", "", "name of the lexstring utf8 function")
- c.oLexUtf8 = set.String("lexutf8", "", "name of the lexutf8 function")
- c.oMatch = set.String("match", "", "name of the match function")
- c.oMatchString = set.String("matchstring", "", "name of the matchstring function")
- c.oMatchStringUtf8 = set.String("matchstringutf8", "", "name of the matchstring utf8 function")
- c.oMatchUtf8 = set.String("matchutf8", "", "name of the matchutf8 function")
- c.oPkg = set.String("pkg", "", "wrap output in package <argument>")
- c.oRx = set.String("rx", "", "make the machine a method of <argument>")
- c.oTrc = set.Bool("trc", false, "inject tracing")
- set.SetOutput(stderr)
- set.Parse(args)
- c.nonFlagArgs = set.Args()
- if *c.oCaseInsensitive {
- for i, v := range c.nonFlagArgs {
- c.nonFlagArgs[i] = strings.ToLower(v)
- }
- c.lcase = "defer func() { r = unicode.ToLower(r) }();"
- }
- if *c.oPkg == "" {
- c.post.pkg = pkg
- c.post.b.WriteString(pkg)
- }
- switch {
- case *c.oLex != "":
- err = c.lex(*c.oLex, set.Args(), false, false)
- case *c.oLexUtf8 != "":
- err = c.lex(*c.oLexUtf8, set.Args(), true, false)
- case *c.oLexString != "":
- err = c.lex(*c.oLexString, set.Args(), false, true)
- case *c.oLexStringUtf8 != "":
- err = c.lex(*c.oLexStringUtf8, set.Args(), true, true)
- case *c.oMatch != "":
- err = c.match(*c.oMatch, set.Args(), false, false)
- case *c.oMatchUtf8 != "":
- err = c.match(*c.oMatchUtf8, set.Args(), true, false)
- case *c.oMatchString != "":
- err = c.match(*c.oMatchString, set.Args(), false, true)
- case *c.oMatchStringUtf8 != "":
- err = c.match(*c.oMatchStringUtf8, set.Args(), true, true)
- default:
- return 2, fmt.Errorf("no machine selected, use '-match=Match', for example")
- }
- if err != nil {
- return 1, fmt.Errorf("%s", err)
- }
- if c.werr != nil {
- return 1, c.werr
- }
- return 0, nil
- }
- func (c *cfg) rxString() (r string) {
- if s := *c.oRx; s != "" {
- r = fmt.Sprintf("(%s) ", s)
- }
- return r
- }
- func (c *cfg) prolog(hideNonFlagArgs bool) {
- s := *c.oPkg
- if s == "" {
- return
- }
- *c.oPkg = ""
- args := append([]string(nil), c.args...)
- if hideNonFlagArgs {
- m := map[string]struct{}{}
- for _, v := range c.nonFlagArgs {
- m[v] = struct{}{}
- }
- args = args[:0]
- for _, v := range c.args {
- if _, ok := m[v]; !ok {
- args = append(args, v)
- }
- }
- args = append(args, "<...>")
- }
- c.w(`// Code generated by %v, DO NOT EDIT.
- package %s
- `, append([]string{filepath.Base(os.Args[0])}, args...), s)
- if s = *c.oImport; s == "" {
- return
- }
- c.w(`
- import (
- `)
- for _, v := range strings.Split(s, ",") {
- c.w("\t%q\n", v)
- }
- c.w(")\n\n")
- }
- func (c *cfg) w(s string, args ...interface{}) {
- if c.closed {
- return
- }
- _, c.werr = fmt.Fprintf(c.post, s, args...)
- if c.werr != nil {
- c.closed = true
- }
- }
- func (c *cfg) trc(s, args string) {
- c.w("\tfmt.Fprintf(os.Stderr, %q, %s);", "TRC "+s, args)
- }
- func pretty(s string) string {
- for i := 0; i < len(s); i++ {
- if s[i] < ' ' || s[i] > 0x7e {
- return strconv.QuoteToASCII(s)
- }
- }
- return s
- }
- type unicodeTables struct {
- once sync.Once
- byFirstRune map[rune][]*unicode.RangeTable
- tab2Hash map[*unicode.RangeTable]string
- tab2Name map[*unicode.RangeTable]string
- }
- func (t *unicodeTables) boot() {
- t.once.Do(func() {
- t.tab2Name = map[*unicode.RangeTable]string{
- unicode.Cc: "Cc", // Cc is the set of Unicode characters in category Cc (Other, control).
- unicode.Cf: "Cf", // Cf is the set of Unicode characters in category Cf (Other, format).
- unicode.Co: "Co", // Co is the set of Unicode characters in category Co (Other, private use).
- unicode.Cs: "Cs", // Cs is the set of Unicode characters in category Cs (Other, surrogate).
- unicode.Digit: "Nd", // Digit is the set of Unicode characters with the "decimal digit" property.
- unicode.Nd: "Nd", // Nd is the set of Unicode characters in category Nd (Number, decimal digit).
- unicode.Letter: "L", // Letter/L is the set of Unicode letters, category L.
- unicode.L: "L",
- unicode.Lm: "Lm", // Lm is the set of Unicode characters in category Lm (Letter, modifier).
- unicode.Lo: "Lo", // Lo is the set of Unicode characters in category Lo (Letter, other).
- unicode.Lower: "Ll", // Lower is the set of Unicode lower case letters.
- unicode.Ll: "Ll", // Ll is the set of Unicode characters in category Ll (Letter, lowercase).
- unicode.Mark: "M", // Mark/M is the set of Unicode mark characters, category M.
- unicode.M: "M",
- unicode.Mc: "Mc", // Mc is the set of Unicode characters in category Mc (Mark, spacing combining).
- unicode.Me: "Me", // Me is the set of Unicode characters in category Me (Mark, enclosing).
- unicode.Mn: "Mn", // Mn is the set of Unicode characters in category Mn (Mark, nonspacing).
- unicode.Nl: "Nl", // Nl is the set of Unicode characters in category Nl (Number, letter).
- unicode.No: "No", // No is the set of Unicode characters in category No (Number, other).
- unicode.Number: "N", // Number/N is the set of Unicode number characters, category N.
- unicode.N: "N",
- unicode.Other: "C", // Other/C is the set of Unicode control and special characters, category C.
- unicode.C: "C",
- unicode.Pc: "Pc", // Pc is the set of Unicode characters in category Pc (Punctuation, connector).
- unicode.Pd: "Pd", // Pd is the set of Unicode characters in category Pd (Punctuation, dash).
- unicode.Pe: "Pe", // Pe is the set of Unicode characters in category Pe (Punctuation, close).
- unicode.Pf: "Pf", // Pf is the set of Unicode characters in category Pf (Punctuation, final quote).
- unicode.Pi: "Pi", // Pi is the set of Unicode characters in category Pi (Punctuation, initial quote).
- unicode.Po: "Po", // Po is the set of Unicode characters in category Po (Punctuation, other).
- unicode.Ps: "Ps", // Ps is the set of Unicode characters in category Ps (Punctuation, open).
- unicode.Punct: "P", // Punct/P is the set of Unicode punctuation characters, category P.
- unicode.P: "P",
- unicode.Sc: "Sc", // Sc is the set of Unicode characters in category Sc (Symbol, currency).
- unicode.Sk: "Sk", // Sk is the set of Unicode characters in category Sk (Symbol, modifier).
- unicode.Sm: "Sm", // Sm is the set of Unicode characters in category Sm (Symbol, math).
- unicode.So: "So", // So is the set of Unicode characters in category So (Symbol, other).
- unicode.Space: "Z", // Space/Z is the set of Unicode space characters, category Z.
- unicode.Z: "Z",
- unicode.Symbol: "S", // Symbol/S is the set of Unicode symbol characters, category S.
- unicode.S: "S",
- unicode.Title: "Lt", // Title is the set of Unicode title case letters.
- unicode.Lt: "Lt", // Lt is the set of Unicode characters in category Lt (Letter, titlecase).
- unicode.Upper: "Lu", // Upper is the set of Unicode upper case letters.
- unicode.Lu: "Lu", // Lu is the set of Unicode characters in category Lu (Letter, uppercase).
- unicode.Zl: "Zl", // Zl is the set of Unicode characters in category Zl (Separator, line).
- unicode.Zp: "Zp", // Zp is the set of Unicode characters in category Zp (Separator, paragraph).
- unicode.Zs: "Zs", // Zs is the set of Unicode characters in category Zs (Separator, space).
- }
- t.byFirstRune = map[rune][]*unicode.RangeTable{}
- for k := range t.tab2Name {
- r := firstRune(k)
- t.byFirstRune[r] = append(t.byFirstRune[r], k)
- }
- })
- }
- func (t *unicodeTables) find(s []rune) string {
- if len(s) == 0 {
- return ""
- }
- t.boot()
- sid := hashString(s)
- for _, rt := range t.byFirstRune[s[0]] {
- if t.hashString(rt) == sid {
- return t.tab2Name[rt]
- }
- }
- return ""
- }
- func (t *unicodeTables) hashString(rt *unicode.RangeTable) (s string) {
- if s, ok := t.tab2Hash[rt]; ok {
- return s
- }
- var ts []rune
- for _, v := range rt.R16 {
- switch {
- case v.Stride == 1:
- ts = append(ts, rune(v.Lo), rune(v.Hi))
- default:
- for r := v.Lo; r <= v.Hi; r += v.Stride {
- ts = append(ts, rune(r), rune(r))
- }
- }
- }
- for _, v := range rt.R32 {
- switch {
- case v.Stride == 1:
- ts = append(ts, rune(v.Lo), rune(v.Hi))
- default:
- for r := v.Lo; r <= v.Hi; r += v.Stride {
- ts = append(ts, rune(r), rune(r))
- }
- }
- }
- ts = normalizeRuneSlice(ts)
- s = hashString(ts)
- if t.tab2Hash == nil {
- t.tab2Hash = map[*unicode.RangeTable]string{}
- }
- t.tab2Hash[rt] = s
- return s
- }
- func firstRune(t *unicode.RangeTable) rune {
- if len(t.R16) != 0 {
- return rune(t.R16[0].Lo)
- }
- return rune(t.R32[0].Lo)
- }
- func hashString(s []rune) string {
- var b strings.Builder
- for _, v := range s {
- for i := 0; i < 4; i++ {
- b.WriteByte(byte(v))
- v >>= 8
- }
- }
- return b.String()
- }
- func normalizeRuneSlice(s []rune) (t []rune) {
- for i := 0; i < len(s); i += 2 {
- lo, hi := s[i], s[i+1]
- tl := len(t)
- if tl == 0 {
- t = append(t, lo, hi)
- continue
- }
- tlo, thi := t[tl-2], t[tl-1]
- if tlo == thi && lo == tlo+1 {
- t[tl-1] = hi
- continue
- }
- t = append(t, lo, hi)
- }
- return t
- }
- type fn struct {
- gotoStatements map[string]map[gc.Node]struct{}
- labeledStatements map[string]*gc.LabeledStmtNode
- }
- func newFn() *fn {
- return &fn{
- gotoStatements: map[string]map[gc.Node]struct{}{},
- labeledStatements: map[string]*gc.LabeledStmtNode{},
- }
- }
- type post struct {
- b *bytes.Buffer
- kill map[gc.Node]struct{}
- pkg string
- }
- func newPost() *post {
- return &post{
- b: bytes.NewBuffer(nil),
- kill: map[gc.Node]struct{}{},
- }
- }
- func (p *post) Write(b []byte) (int, error) { return p.b.Write(b) }
- func (p *post) source(ast *gc.AST) {
- s := gc.NodeSource(ast.SourceFile, true, p.kill)
- for {
- n := len(s)
- s = strings.ReplaceAll(s, "\n\n\tif", "\n\tif")
- s = strings.ReplaceAll(s, "\n\nl", "\nl")
- s = strings.ReplaceAll(s, "\n\n\treturn", "\n\treturn")
- if len(s) == n {
- break
- }
- }
- p.b.WriteString(s[len(p.pkg):])
- }
- func (p *post) post() (err error) {
- ast, err := gc.ParseFile("<output>", p.b.Bytes())
- if err != nil {
- return err
- }
- defer func() {
- p.b = bytes.NewBuffer(nil)
- p.source(ast)
- }()
- for n := ast.SourceFile.TopLevelDeclList; n != nil; n = n.List {
- switch x := n.TopLevelDecl.(type) {
- case *gc.FunctionDeclNode:
- p.functionBody(x.FunctionBody)
- }
- }
- return nil
- }
- func (p *post) functionBody(n *gc.FunctionBodyNode) {
- fn := newFn()
- p.block(fn, n.Block)
- for label, labelNode := range fn.labeledStatements {
- labelLine := labelNode.Position().Line
- gotos := fn.gotoStatements[label]
- for gotoNode := range gotos {
- if line := gotoNode.Position().Line; line == labelLine || line == labelLine-1 {
- p.kill[gotoNode] = struct{}{}
- delete(gotos, gotoNode)
- }
- }
- if len(gotos) == 0 {
- p.kill[labelNode.Label] = struct{}{}
- p.kill[labelNode.COLON] = struct{}{}
- }
- }
- }
- func (p *post) block(fn *fn, n *gc.BlockNode) {
- for l := n.StatementList; l != nil; l = l.List {
- p.statement(fn, l.Statement)
- }
- }
- func (p *post) statement(fn *fn, n gc.Node) {
- switch x := n.(type) {
- case *gc.LabeledStmtNode:
- fn.labeledStatements[x.Label.IDENT.Src()] = x
- p.statement(fn, x.Statement)
- case *gc.GotoStmtNode:
- m := fn.gotoStatements[x.Label.IDENT.Src()]
- if m == nil {
- m = map[gc.Node]struct{}{}
- fn.gotoStatements[x.Label.IDENT.Src()] = m
- }
- m[x] = struct{}{}
- case *gc.IfStmtNode:
- p.block(fn, x.Block)
- case *gc.IfElseStmtNode:
- p.block(fn, x.Block)
- switch y := x.ElseClause.(type) {
- case *gc.IfStmtNode, *gc.IfElseStmtNode:
- p.statement(fn, y)
- case *gc.BlockNode:
- p.block(fn, y)
- }
- }
- }
|