flags.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. package z
  2. import (
  3. "fmt"
  4. "log"
  5. "os"
  6. "os/user"
  7. "path/filepath"
  8. "sort"
  9. "strconv"
  10. "strings"
  11. "time"
  12. "github.com/pkg/errors"
  13. )
  14. // SuperFlagHelp makes it really easy to generate command line `--help` output for a SuperFlag. For
  15. // example:
  16. //
  17. // const flagDefaults = `enabled=true; path=some/path;`
  18. //
  19. // var help string = z.NewSuperFlagHelp(flagDefaults).
  20. // Flag("enabled", "Turns on <something>.").
  21. // Flag("path", "The path to <something>.").
  22. // Flag("another", "Not present in defaults, but still included.").
  23. // String()
  24. //
  25. // The `help` string would then contain:
  26. //
  27. // enabled=true; Turns on <something>.
  28. // path=some/path; The path to <something>.
  29. // another=; Not present in defaults, but still included.
  30. //
  31. // All flags are sorted alphabetically for consistent `--help` output. Flags with default values are
  32. // placed at the top, and everything else goes under.
  33. type SuperFlagHelp struct {
  34. head string
  35. defaults *SuperFlag
  36. flags map[string]string
  37. }
  38. func NewSuperFlagHelp(defaults string) *SuperFlagHelp {
  39. return &SuperFlagHelp{
  40. defaults: NewSuperFlag(defaults),
  41. flags: make(map[string]string, 0),
  42. }
  43. }
  44. func (h *SuperFlagHelp) Head(head string) *SuperFlagHelp {
  45. h.head = head
  46. return h
  47. }
  48. func (h *SuperFlagHelp) Flag(name, description string) *SuperFlagHelp {
  49. h.flags[name] = description
  50. return h
  51. }
  52. func (h *SuperFlagHelp) String() string {
  53. defaultLines := make([]string, 0)
  54. otherLines := make([]string, 0)
  55. for name, help := range h.flags {
  56. val, found := h.defaults.m[name]
  57. line := fmt.Sprintf(" %s=%s; %s\n", name, val, help)
  58. if found {
  59. defaultLines = append(defaultLines, line)
  60. } else {
  61. otherLines = append(otherLines, line)
  62. }
  63. }
  64. sort.Strings(defaultLines)
  65. sort.Strings(otherLines)
  66. dls := strings.Join(defaultLines, "")
  67. ols := strings.Join(otherLines, "")
  68. if len(h.defaults.m) == 0 && len(ols) == 0 {
  69. // remove last newline
  70. dls = dls[:len(dls)-1]
  71. }
  72. // remove last newline
  73. if len(h.defaults.m) == 0 && len(ols) > 1 {
  74. ols = ols[:len(ols)-1]
  75. }
  76. return h.head + "\n" + dls + ols
  77. }
  78. func parseFlag(flag string) (map[string]string, error) {
  79. kvm := make(map[string]string)
  80. for _, kv := range strings.Split(flag, ";") {
  81. if strings.TrimSpace(kv) == "" {
  82. continue
  83. }
  84. // For a non-empty separator, 0 < len(splits) ≤ 2.
  85. splits := strings.SplitN(kv, "=", 2)
  86. k := strings.TrimSpace(splits[0])
  87. if len(splits) < 2 {
  88. return nil, fmt.Errorf("superflag: missing value for '%s' in flag: %s", k, flag)
  89. }
  90. k = strings.ToLower(k)
  91. k = strings.ReplaceAll(k, "_", "-")
  92. kvm[k] = strings.TrimSpace(splits[1])
  93. }
  94. return kvm, nil
  95. }
  96. type SuperFlag struct {
  97. m map[string]string
  98. }
  99. func NewSuperFlag(flag string) *SuperFlag {
  100. sf, err := newSuperFlagImpl(flag)
  101. if err != nil {
  102. log.Fatal(err)
  103. }
  104. return sf
  105. }
  106. func newSuperFlagImpl(flag string) (*SuperFlag, error) {
  107. m, err := parseFlag(flag)
  108. if err != nil {
  109. return nil, err
  110. }
  111. return &SuperFlag{m}, nil
  112. }
  113. func (sf *SuperFlag) String() string {
  114. if sf == nil {
  115. return ""
  116. }
  117. kvs := make([]string, 0, len(sf.m))
  118. for k, v := range sf.m {
  119. kvs = append(kvs, fmt.Sprintf("%s=%s", k, v))
  120. }
  121. return strings.Join(kvs, "; ")
  122. }
  123. func (sf *SuperFlag) MergeAndCheckDefault(flag string) *SuperFlag {
  124. sf, err := sf.mergeAndCheckDefaultImpl(flag)
  125. if err != nil {
  126. log.Fatal(err)
  127. }
  128. return sf
  129. }
  130. func (sf *SuperFlag) mergeAndCheckDefaultImpl(flag string) (*SuperFlag, error) {
  131. if sf == nil {
  132. m, err := parseFlag(flag)
  133. if err != nil {
  134. return nil, err
  135. }
  136. return &SuperFlag{m}, nil
  137. }
  138. src, err := parseFlag(flag)
  139. if err != nil {
  140. return nil, err
  141. }
  142. numKeys := len(sf.m)
  143. for k := range src {
  144. if _, ok := sf.m[k]; ok {
  145. numKeys--
  146. }
  147. }
  148. if numKeys != 0 {
  149. return nil, fmt.Errorf("superflag: found invalid options: %s.\nvalid options: %v", sf, flag)
  150. }
  151. for k, v := range src {
  152. if _, ok := sf.m[k]; !ok {
  153. sf.m[k] = v
  154. }
  155. }
  156. return sf, nil
  157. }
  158. func (sf *SuperFlag) Has(opt string) bool {
  159. val := sf.GetString(opt)
  160. return val != ""
  161. }
  162. func (sf *SuperFlag) GetDuration(opt string) time.Duration {
  163. val := sf.GetString(opt)
  164. if val == "" {
  165. return time.Duration(0)
  166. }
  167. if strings.Contains(val, "d") {
  168. val = strings.Replace(val, "d", "", 1)
  169. days, err := strconv.ParseInt(val, 0, 64)
  170. if err != nil {
  171. return time.Duration(0)
  172. }
  173. return time.Hour * 24 * time.Duration(days)
  174. }
  175. d, err := time.ParseDuration(val)
  176. if err != nil {
  177. return time.Duration(0)
  178. }
  179. return d
  180. }
  181. func (sf *SuperFlag) GetBool(opt string) bool {
  182. val := sf.GetString(opt)
  183. if val == "" {
  184. return false
  185. }
  186. b, err := strconv.ParseBool(val)
  187. if err != nil {
  188. err = errors.Wrapf(err,
  189. "Unable to parse %s as bool for key: %s. Options: %s\n",
  190. val, opt, sf)
  191. log.Fatalf("%+v", err)
  192. }
  193. return b
  194. }
  195. func (sf *SuperFlag) GetFloat64(opt string) float64 {
  196. val := sf.GetString(opt)
  197. if val == "" {
  198. return 0
  199. }
  200. f, err := strconv.ParseFloat(val, 64)
  201. if err != nil {
  202. err = errors.Wrapf(err,
  203. "Unable to parse %s as float64 for key: %s. Options: %s\n",
  204. val, opt, sf)
  205. log.Fatalf("%+v", err)
  206. }
  207. return f
  208. }
  209. func (sf *SuperFlag) GetInt64(opt string) int64 {
  210. val := sf.GetString(opt)
  211. if val == "" {
  212. return 0
  213. }
  214. i, err := strconv.ParseInt(val, 0, 64)
  215. if err != nil {
  216. err = errors.Wrapf(err,
  217. "Unable to parse %s as int64 for key: %s. Options: %s\n",
  218. val, opt, sf)
  219. log.Fatalf("%+v", err)
  220. }
  221. return i
  222. }
  223. func (sf *SuperFlag) GetUint64(opt string) uint64 {
  224. val := sf.GetString(opt)
  225. if val == "" {
  226. return 0
  227. }
  228. u, err := strconv.ParseUint(val, 0, 64)
  229. if err != nil {
  230. err = errors.Wrapf(err,
  231. "Unable to parse %s as uint64 for key: %s. Options: %s\n",
  232. val, opt, sf)
  233. log.Fatalf("%+v", err)
  234. }
  235. return u
  236. }
  237. func (sf *SuperFlag) GetUint32(opt string) uint32 {
  238. val := sf.GetString(opt)
  239. if val == "" {
  240. return 0
  241. }
  242. u, err := strconv.ParseUint(val, 0, 32)
  243. if err != nil {
  244. err = errors.Wrapf(err,
  245. "Unable to parse %s as uint32 for key: %s. Options: %s\n",
  246. val, opt, sf)
  247. log.Fatalf("%+v", err)
  248. }
  249. return uint32(u)
  250. }
  251. func (sf *SuperFlag) GetString(opt string) string {
  252. if sf == nil {
  253. return ""
  254. }
  255. return sf.m[opt]
  256. }
  257. func (sf *SuperFlag) GetPath(opt string) string {
  258. p := sf.GetString(opt)
  259. path, err := expandPath(p)
  260. if err != nil {
  261. log.Fatalf("Failed to get path: %+v", err)
  262. }
  263. return path
  264. }
  265. // expandPath expands the paths containing ~ to /home/user. It also computes the absolute path
  266. // from the relative paths. For example: ~/abc/../cef will be transformed to /home/user/cef.
  267. func expandPath(path string) (string, error) {
  268. if len(path) == 0 {
  269. return "", nil
  270. }
  271. if path[0] == '~' && (len(path) == 1 || os.IsPathSeparator(path[1])) {
  272. usr, err := user.Current()
  273. if err != nil {
  274. return "", errors.Wrap(err, "Failed to get the home directory of the user")
  275. }
  276. path = filepath.Join(usr.HomeDir, path[1:])
  277. }
  278. var err error
  279. path, err = filepath.Abs(path)
  280. if err != nil {
  281. return "", errors.Wrap(err, "Failed to generate absolute path")
  282. }
  283. return path, nil
  284. }