// Copyright 2023 The Rec Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package rec // import "modernc.org/rec/lib" import ( "bytes" "flag" "fmt" "io" "os" "path/filepath" "strconv" "strings" "sync" "unicode" "modernc.org/gc/v3" ) const pkg = "package x\n" type cfg struct { args []string nonFlagArgs []string oCaseInsensitive *bool oDbg *bool oImport *string oLex *string oLexString *string oLexStringUtf8 *string oLexUtf8 *string oMatch *string oMatchString *string oMatchStringUtf8 *string oMatchUtf8 *string oPkg *string oRx *string oTrc *bool post *post stderr io.Writer stdout io.Writer lcase string werr error closed bool } var ( tables unicodeTables ) // Main executes the CLI rec command. The 'args' argument should not include // the command name. func Main(args []string, stdout, stderr io.Writer) (rc int, err error) { c := &cfg{ args: args, post: newPost(), stdout: stdout, stderr: stderr, } if c.stdout == nil { c.stdout = io.Discard } if c.stderr == nil { c.stderr = io.Discard } defer func() { if perr := c.post.post(); perr != nil && err == nil { rc = 1 err = perr } if _, cerr := io.Copy(c.stdout, c.post.b); cerr != nil && err == nil { rc = 1 err = cerr } }() set := flag.NewFlagSet("rec", flag.ContinueOnError) c.oCaseInsensitive = set.Bool("i", false, "case insensitive") c.oDbg = set.Bool("dbg", false, "enable debug output") c.oImport = set.String("import", "", "comma separated list if imports") c.oLex = set.String("lex", "", "name of the lex function") c.oLexString = set.String("lexstring", "", "name of the lexstring function") c.oLexStringUtf8 = set.String("lexstringutf8", "", "name of the lexstring utf8 function") c.oLexUtf8 = set.String("lexutf8", "", "name of the lexutf8 function") c.oMatch = set.String("match", "", "name of the match function") c.oMatchString = set.String("matchstring", "", "name of the matchstring function") c.oMatchStringUtf8 = set.String("matchstringutf8", "", "name of the matchstring utf8 function") c.oMatchUtf8 = set.String("matchutf8", "", "name of the matchutf8 function") c.oPkg = set.String("pkg", "", "wrap output in package ") c.oRx = set.String("rx", "", "make the machine a method of ") c.oTrc = set.Bool("trc", false, "inject tracing") set.SetOutput(stderr) set.Parse(args) c.nonFlagArgs = set.Args() if *c.oCaseInsensitive { for i, v := range c.nonFlagArgs { c.nonFlagArgs[i] = strings.ToLower(v) } c.lcase = "defer func() { r = unicode.ToLower(r) }();" } if *c.oPkg == "" { c.post.pkg = pkg c.post.b.WriteString(pkg) } switch { case *c.oLex != "": err = c.lex(*c.oLex, set.Args(), false, false) case *c.oLexUtf8 != "": err = c.lex(*c.oLexUtf8, set.Args(), true, false) case *c.oLexString != "": err = c.lex(*c.oLexString, set.Args(), false, true) case *c.oLexStringUtf8 != "": err = c.lex(*c.oLexStringUtf8, set.Args(), true, true) case *c.oMatch != "": err = c.match(*c.oMatch, set.Args(), false, false) case *c.oMatchUtf8 != "": err = c.match(*c.oMatchUtf8, set.Args(), true, false) case *c.oMatchString != "": err = c.match(*c.oMatchString, set.Args(), false, true) case *c.oMatchStringUtf8 != "": err = c.match(*c.oMatchStringUtf8, set.Args(), true, true) default: return 2, fmt.Errorf("no machine selected, use '-match=Match', for example") } if err != nil { return 1, fmt.Errorf("%s", err) } if c.werr != nil { return 1, c.werr } return 0, nil } func (c *cfg) rxString() (r string) { if s := *c.oRx; s != "" { r = fmt.Sprintf("(%s) ", s) } return r } func (c *cfg) prolog(hideNonFlagArgs bool) { s := *c.oPkg if s == "" { return } *c.oPkg = "" args := append([]string(nil), c.args...) if hideNonFlagArgs { m := map[string]struct{}{} for _, v := range c.nonFlagArgs { m[v] = struct{}{} } args = args[:0] for _, v := range c.args { if _, ok := m[v]; !ok { args = append(args, v) } } args = append(args, "<...>") } c.w(`// Code generated by %v, DO NOT EDIT. package %s `, append([]string{filepath.Base(os.Args[0])}, args...), s) if s = *c.oImport; s == "" { return } c.w(` import ( `) for _, v := range strings.Split(s, ",") { c.w("\t%q\n", v) } c.w(")\n\n") } func (c *cfg) w(s string, args ...interface{}) { if c.closed { return } _, c.werr = fmt.Fprintf(c.post, s, args...) if c.werr != nil { c.closed = true } } func (c *cfg) trc(s, args string) { c.w("\tfmt.Fprintf(os.Stderr, %q, %s);", "TRC "+s, args) } func pretty(s string) string { for i := 0; i < len(s); i++ { if s[i] < ' ' || s[i] > 0x7e { return strconv.QuoteToASCII(s) } } return s } type unicodeTables struct { once sync.Once byFirstRune map[rune][]*unicode.RangeTable tab2Hash map[*unicode.RangeTable]string tab2Name map[*unicode.RangeTable]string } func (t *unicodeTables) boot() { t.once.Do(func() { t.tab2Name = map[*unicode.RangeTable]string{ unicode.Cc: "Cc", // Cc is the set of Unicode characters in category Cc (Other, control). unicode.Cf: "Cf", // Cf is the set of Unicode characters in category Cf (Other, format). unicode.Co: "Co", // Co is the set of Unicode characters in category Co (Other, private use). unicode.Cs: "Cs", // Cs is the set of Unicode characters in category Cs (Other, surrogate). unicode.Digit: "Nd", // Digit is the set of Unicode characters with the "decimal digit" property. unicode.Nd: "Nd", // Nd is the set of Unicode characters in category Nd (Number, decimal digit). unicode.Letter: "L", // Letter/L is the set of Unicode letters, category L. unicode.L: "L", unicode.Lm: "Lm", // Lm is the set of Unicode characters in category Lm (Letter, modifier). unicode.Lo: "Lo", // Lo is the set of Unicode characters in category Lo (Letter, other). unicode.Lower: "Ll", // Lower is the set of Unicode lower case letters. unicode.Ll: "Ll", // Ll is the set of Unicode characters in category Ll (Letter, lowercase). unicode.Mark: "M", // Mark/M is the set of Unicode mark characters, category M. unicode.M: "M", unicode.Mc: "Mc", // Mc is the set of Unicode characters in category Mc (Mark, spacing combining). unicode.Me: "Me", // Me is the set of Unicode characters in category Me (Mark, enclosing). unicode.Mn: "Mn", // Mn is the set of Unicode characters in category Mn (Mark, nonspacing). unicode.Nl: "Nl", // Nl is the set of Unicode characters in category Nl (Number, letter). unicode.No: "No", // No is the set of Unicode characters in category No (Number, other). unicode.Number: "N", // Number/N is the set of Unicode number characters, category N. unicode.N: "N", unicode.Other: "C", // Other/C is the set of Unicode control and special characters, category C. unicode.C: "C", unicode.Pc: "Pc", // Pc is the set of Unicode characters in category Pc (Punctuation, connector). unicode.Pd: "Pd", // Pd is the set of Unicode characters in category Pd (Punctuation, dash). unicode.Pe: "Pe", // Pe is the set of Unicode characters in category Pe (Punctuation, close). unicode.Pf: "Pf", // Pf is the set of Unicode characters in category Pf (Punctuation, final quote). unicode.Pi: "Pi", // Pi is the set of Unicode characters in category Pi (Punctuation, initial quote). unicode.Po: "Po", // Po is the set of Unicode characters in category Po (Punctuation, other). unicode.Ps: "Ps", // Ps is the set of Unicode characters in category Ps (Punctuation, open). unicode.Punct: "P", // Punct/P is the set of Unicode punctuation characters, category P. unicode.P: "P", unicode.Sc: "Sc", // Sc is the set of Unicode characters in category Sc (Symbol, currency). unicode.Sk: "Sk", // Sk is the set of Unicode characters in category Sk (Symbol, modifier). unicode.Sm: "Sm", // Sm is the set of Unicode characters in category Sm (Symbol, math). unicode.So: "So", // So is the set of Unicode characters in category So (Symbol, other). unicode.Space: "Z", // Space/Z is the set of Unicode space characters, category Z. unicode.Z: "Z", unicode.Symbol: "S", // Symbol/S is the set of Unicode symbol characters, category S. unicode.S: "S", unicode.Title: "Lt", // Title is the set of Unicode title case letters. unicode.Lt: "Lt", // Lt is the set of Unicode characters in category Lt (Letter, titlecase). unicode.Upper: "Lu", // Upper is the set of Unicode upper case letters. unicode.Lu: "Lu", // Lu is the set of Unicode characters in category Lu (Letter, uppercase). unicode.Zl: "Zl", // Zl is the set of Unicode characters in category Zl (Separator, line). unicode.Zp: "Zp", // Zp is the set of Unicode characters in category Zp (Separator, paragraph). unicode.Zs: "Zs", // Zs is the set of Unicode characters in category Zs (Separator, space). } t.byFirstRune = map[rune][]*unicode.RangeTable{} for k := range t.tab2Name { r := firstRune(k) t.byFirstRune[r] = append(t.byFirstRune[r], k) } }) } func (t *unicodeTables) find(s []rune) string { if len(s) == 0 { return "" } t.boot() sid := hashString(s) for _, rt := range t.byFirstRune[s[0]] { if t.hashString(rt) == sid { return t.tab2Name[rt] } } return "" } func (t *unicodeTables) hashString(rt *unicode.RangeTable) (s string) { if s, ok := t.tab2Hash[rt]; ok { return s } var ts []rune for _, v := range rt.R16 { switch { case v.Stride == 1: ts = append(ts, rune(v.Lo), rune(v.Hi)) default: for r := v.Lo; r <= v.Hi; r += v.Stride { ts = append(ts, rune(r), rune(r)) } } } for _, v := range rt.R32 { switch { case v.Stride == 1: ts = append(ts, rune(v.Lo), rune(v.Hi)) default: for r := v.Lo; r <= v.Hi; r += v.Stride { ts = append(ts, rune(r), rune(r)) } } } ts = normalizeRuneSlice(ts) s = hashString(ts) if t.tab2Hash == nil { t.tab2Hash = map[*unicode.RangeTable]string{} } t.tab2Hash[rt] = s return s } func firstRune(t *unicode.RangeTable) rune { if len(t.R16) != 0 { return rune(t.R16[0].Lo) } return rune(t.R32[0].Lo) } func hashString(s []rune) string { var b strings.Builder for _, v := range s { for i := 0; i < 4; i++ { b.WriteByte(byte(v)) v >>= 8 } } return b.String() } func normalizeRuneSlice(s []rune) (t []rune) { for i := 0; i < len(s); i += 2 { lo, hi := s[i], s[i+1] tl := len(t) if tl == 0 { t = append(t, lo, hi) continue } tlo, thi := t[tl-2], t[tl-1] if tlo == thi && lo == tlo+1 { t[tl-1] = hi continue } t = append(t, lo, hi) } return t } type fn struct { gotoStatements map[string]map[gc.Node]struct{} labeledStatements map[string]*gc.LabeledStmtNode } func newFn() *fn { return &fn{ gotoStatements: map[string]map[gc.Node]struct{}{}, labeledStatements: map[string]*gc.LabeledStmtNode{}, } } type post struct { b *bytes.Buffer kill map[gc.Node]struct{} pkg string } func newPost() *post { return &post{ b: bytes.NewBuffer(nil), kill: map[gc.Node]struct{}{}, } } func (p *post) Write(b []byte) (int, error) { return p.b.Write(b) } func (p *post) source(ast *gc.AST) { s := gc.NodeSource(ast.SourceFile, true, p.kill) for { n := len(s) s = strings.ReplaceAll(s, "\n\n\tif", "\n\tif") s = strings.ReplaceAll(s, "\n\nl", "\nl") s = strings.ReplaceAll(s, "\n\n\treturn", "\n\treturn") if len(s) == n { break } } p.b.WriteString(s[len(p.pkg):]) } func (p *post) post() (err error) { ast, err := gc.ParseFile("", p.b.Bytes()) if err != nil { return err } defer func() { p.b = bytes.NewBuffer(nil) p.source(ast) }() for n := ast.SourceFile.TopLevelDeclList; n != nil; n = n.List { switch x := n.TopLevelDecl.(type) { case *gc.FunctionDeclNode: p.functionBody(x.FunctionBody) } } return nil } func (p *post) functionBody(n *gc.FunctionBodyNode) { fn := newFn() p.block(fn, n.Block) for label, labelNode := range fn.labeledStatements { labelLine := labelNode.Position().Line gotos := fn.gotoStatements[label] for gotoNode := range gotos { if line := gotoNode.Position().Line; line == labelLine || line == labelLine-1 { p.kill[gotoNode] = struct{}{} delete(gotos, gotoNode) } } if len(gotos) == 0 { p.kill[labelNode.Label] = struct{}{} p.kill[labelNode.COLON] = struct{}{} } } } func (p *post) block(fn *fn, n *gc.BlockNode) { for l := n.StatementList; l != nil; l = l.List { p.statement(fn, l.Statement) } } func (p *post) statement(fn *fn, n gc.Node) { switch x := n.(type) { case *gc.LabeledStmtNode: fn.labeledStatements[x.Label.IDENT.Src()] = x p.statement(fn, x.Statement) case *gc.GotoStmtNode: m := fn.gotoStatements[x.Label.IDENT.Src()] if m == nil { m = map[gc.Node]struct{}{} fn.gotoStatements[x.Label.IDENT.Src()] = m } m[x] = struct{}{} case *gc.IfStmtNode: p.block(fn, x.Block) case *gc.IfElseStmtNode: p.block(fn, x.Block) switch y := x.ElseClause.(type) { case *gc.IfStmtNode, *gc.IfElseStmtNode: p.statement(fn, y) case *gc.BlockNode: p.block(fn, y) } } }