Golang flag源码阅读及自己实现

看了一下flag的实现,其实挺简单的。首先我们从一个使用例子入手:

package main

import (
	"flag"
	"log"
)

var (
	useWorker = flag.Bool("useWorker", false, "blabla")
)

func main() {
	flag.Parse()

	log.Printf("useWorker: %t", *useWorker)
}

可以看到这么几点信息:

  • useWorker的类型是 *bool,而且是flag.Bool返回的
  • 必须要执行 flag.Parse() 才能解析命令行

看到这里其实大家就应该猜想一下,咋实现的?对我来说,从看到它返回的是指针,我的猜测就是,在 flag.Bool 这里会新建一个bool类型的变量,将默认值赋值给它,然后返回这个bool类型的指针。在 Parse函数里,对这个指针所指向的值进行更新。

我们来看看源码:

func Bool(name string, value bool, usage string) *bool {
	return CommandLine.Bool(name, value, usage)
}

func (f *FlagSet) Bool(name string, value bool, usage string) *bool {
	p := new(bool)
	f.BoolVar(p, name, value, usage)
	return p
}

Parse 的代码就比较复杂了:

func Parse() {
	// Ignore errors; CommandLine is set for ExitOnError.
	CommandLine.Parse(os.Args[1:])
}

func (f *FlagSet) Parse(arguments []string) error {
	f.parsed = true
	f.args = arguments
	for {
		seen, err := f.parseOne()
...

func (f *FlagSet) parseOne() (bool, error) {
	if len(f.args) == 0 {
		return false, nil
	}
...
	if fv, ok := flag.Value.(boolFlag); ok && fv.IsBoolFlag() { // special case: doesn't need an arg
		if hasValue {
			if err := fv.Set(value); err != nil {
				return false, f.failf("invalid boolean value %q for -%s: %v", value, name, err)
			}
		} else {
			if err := fv.Set("true"); err != nil {
				return false, f.failf("invalid boolean flag %s: %v", name, err)
			}
		}
    ...

可以看到基本上就是我所猜测的。所以,我自己写了一个简单的flag:

package main

import (
	"log"
	"os"
	"strings"
)

type MyFlagger interface {
	Set(v interface{})
}

type MyFlag struct {
	mapper map[string]MyFlagger
}

var myFlags = MyFlag{mapper: make(map[string]MyFlagger)}

type boolFlag struct {
	p *bool
}

func (b *boolFlag) Set(v interface{}) {
	*(b.p) = v.(bool)
}

func (m *MyFlag) Bool(name string, defaultValue bool) *bool {
	p := new(bool)
	*p = defaultValue
	m.mapper[name] = &boolFlag{p}
	return p
}

func (m *MyFlag) Parse() {
	if len(os.Args) == 1 {
		return
	}

	arg := os.Args[1]
	if !strings.HasPrefix(arg, "--") {
		log.Panicf("bad usage: ./test --blabla")
	}

	if len(arg) < 3 {
		log.Panicf("bad usage: ./test --blabla")
	}

	realArg := arg[2:]
	flag, ok := m.mapper[realArg]
	if !ok {
		log.Panicf("%s not found", realArg)
	}

	flag.Set(true)
}

func main() {
	useWorker := myFlags.Bool("useWorker", false)
	log.Printf("before parse: useWorker: %t", *useWorker)
	myFlags.Parse()
	log.Printf("after parse: useWorker: %t", *useWorker)
}

来,执行一下:

$ go run main.go
2020/04/23 22:39:15 before parse: useWorker: false
2020/04/23 22:39:15 after parse: useWorker: false
$ go run main.go --useWorker
2020/04/23 22:39:20 before parse: useWorker: false
2020/04/23 22:39:20 after parse: useWorker: true

更多文章
  • flutter webview加载时显示进度
  • SQLAlchemy快速更新或插入对象
  • 修复Linux下curl等无法使用letsencrypt证书
  • 欣赏一下K&R两位大神的代码
  • MySQL的ON DUPLICATE KEY UPDATE语句
  • 使用microk8s快速搭建k8s
  • Python中优雅的处理文件路径
  • Go语言MySQL时区问题
  • 我的技术栈选型
  • 为什么我要用Linux作为桌面?
  • disqus获取评论时忽略query string
  • MySQL性能优化指南
  • 网络编程所需要熟悉的那些函数
  • DNSCrypt简明教程
  • SQLAlchemy简明教程