Golang flag源码阅读及自己实现

看了一下flag的实现,其实挺简单的。首先我们从一个使用例子入手:

package main

import (
	"flag"
	"log"
)

var (
	useWorker = flag.Bool("useWorker", false, "blabla")
)

func main() {
	flag.Parse()

	log.Printf("useWorker: %t", *useWorker)
}

可以看到这么几点信息:

  • useWorker的类型是 *bool,而且是flag.Bool返回的
  • 必须要执行 flag.Parse() 才能解析命令行

看到这里其实大家就应该猜想一下,咋实现的?对我来说,从看到它返回的是指针,我的猜测就是,在 flag.Bool 这里会新建一个bool类型的变量,将默认值赋值给它,然后返回这个bool类型的指针。在 Parse函数里,对这个指针所指向的值进行更新。

我们来看看源码:

func Bool(name string, value bool, usage string) *bool {
	return CommandLine.Bool(name, value, usage)
}

func (f *FlagSet) Bool(name string, value bool, usage string) *bool {
	p := new(bool)
	f.BoolVar(p, name, value, usage)
	return p
}

Parse 的代码就比较复杂了:

func Parse() {
	// Ignore errors; CommandLine is set for ExitOnError.
	CommandLine.Parse(os.Args[1:])
}

func (f *FlagSet) Parse(arguments []string) error {
	f.parsed = true
	f.args = arguments
	for {
		seen, err := f.parseOne()
...

func (f *FlagSet) parseOne() (bool, error) {
	if len(f.args) == 0 {
		return false, nil
	}
...
	if fv, ok := flag.Value.(boolFlag); ok && fv.IsBoolFlag() { // special case: doesn't need an arg
		if hasValue {
			if err := fv.Set(value); err != nil {
				return false, f.failf("invalid boolean value %q for -%s: %v", value, name, err)
			}
		} else {
			if err := fv.Set("true"); err != nil {
				return false, f.failf("invalid boolean flag %s: %v", name, err)
			}
		}
    ...

可以看到基本上就是我所猜测的。所以,我自己写了一个简单的flag:

package main

import (
	"log"
	"os"
	"strings"
)

type MyFlagger interface {
	Set(v interface{})
}

type MyFlag struct {
	mapper map[string]MyFlagger
}

var myFlags = MyFlag{mapper: make(map[string]MyFlagger)}

type boolFlag struct {
	p *bool
}

func (b *boolFlag) Set(v interface{}) {
	*(b.p) = v.(bool)
}

func (m *MyFlag) Bool(name string, defaultValue bool) *bool {
	p := new(bool)
	*p = defaultValue
	m.mapper[name] = &boolFlag{p}
	return p
}

func (m *MyFlag) Parse() {
	if len(os.Args) == 1 {
		return
	}

	arg := os.Args[1]
	if !strings.HasPrefix(arg, "--") {
		log.Panicf("bad usage: ./test --blabla")
	}

	if len(arg) < 3 {
		log.Panicf("bad usage: ./test --blabla")
	}

	realArg := arg[2:]
	flag, ok := m.mapper[realArg]
	if !ok {
		log.Panicf("%s not found", realArg)
	}

	flag.Set(true)
}

func main() {
	useWorker := myFlags.Bool("useWorker", false)
	log.Printf("before parse: useWorker: %t", *useWorker)
	myFlags.Parse()
	log.Printf("after parse: useWorker: %t", *useWorker)
}

来,执行一下:

$ go run main.go
2020/04/23 22:39:15 before parse: useWorker: false
2020/04/23 22:39:15 after parse: useWorker: false
$ go run main.go --useWorker
2020/04/23 22:39:20 before parse: useWorker: false
2020/04/23 22:39:20 after parse: useWorker: true

更多文章
  • Greenlet和Stackless Python
  • 写一个简单的ORM
  • 从源码看Python的descriptor
  • Python字符串格式化
  • Gunicorn 简明教程
  • Raft 论文阅读笔记
  • 什么是 Golang Comparable Types
  • GFS 论文阅读
  • MapReduce 论文阅读
  • 一起来做贼:Goroutine原理和Work stealing
  • Go语言的defer, panic和recover
  • 再读vim help:vim小技巧
  • 再读 Python Language Reference
  • 设计模式(2)- 深入浅出设计模式 阅读笔记
  • 设计模式(1)- 深入浅出设计模式 阅读笔记