Golang flag源码阅读及自己实现

看了一下flag的实现,其实挺简单的。首先我们从一个使用例子入手:

package main

import (
	"flag"
	"log"
)

var (
	useWorker = flag.Bool("useWorker", false, "blabla")
)

func main() {
	flag.Parse()

	log.Printf("useWorker: %t", *useWorker)
}

可以看到这么几点信息:

  • useWorker的类型是 *bool,而且是flag.Bool返回的
  • 必须要执行 flag.Parse() 才能解析命令行

看到这里其实大家就应该猜想一下,咋实现的?对我来说,从看到它返回的是指针,我的猜测就是,在 flag.Bool 这里会新建一个bool类型的变量,将默认值赋值给它,然后返回这个bool类型的指针。在 Parse函数里,对这个指针所指向的值进行更新。

我们来看看源码:

func Bool(name string, value bool, usage string) *bool {
	return CommandLine.Bool(name, value, usage)
}

func (f *FlagSet) Bool(name string, value bool, usage string) *bool {
	p := new(bool)
	f.BoolVar(p, name, value, usage)
	return p
}

Parse 的代码就比较复杂了:

func Parse() {
	// Ignore errors; CommandLine is set for ExitOnError.
	CommandLine.Parse(os.Args[1:])
}

func (f *FlagSet) Parse(arguments []string) error {
	f.parsed = true
	f.args = arguments
	for {
		seen, err := f.parseOne()
...

func (f *FlagSet) parseOne() (bool, error) {
	if len(f.args) == 0 {
		return false, nil
	}
...
	if fv, ok := flag.Value.(boolFlag); ok && fv.IsBoolFlag() { // special case: doesn't need an arg
		if hasValue {
			if err := fv.Set(value); err != nil {
				return false, f.failf("invalid boolean value %q for -%s: %v", value, name, err)
			}
		} else {
			if err := fv.Set("true"); err != nil {
				return false, f.failf("invalid boolean flag %s: %v", name, err)
			}
		}
    ...

可以看到基本上就是我所猜测的。所以,我自己写了一个简单的flag:

package main

import (
	"log"
	"os"
	"strings"
)

type MyFlagger interface {
	Set(v interface{})
}

type MyFlag struct {
	mapper map[string]MyFlagger
}

var myFlags = MyFlag{mapper: make(map[string]MyFlagger)}

type boolFlag struct {
	p *bool
}

func (b *boolFlag) Set(v interface{}) {
	*(b.p) = v.(bool)
}

func (m *MyFlag) Bool(name string, defaultValue bool) *bool {
	p := new(bool)
	*p = defaultValue
	m.mapper[name] = &boolFlag{p}
	return p
}

func (m *MyFlag) Parse() {
	if len(os.Args) == 1 {
		return
	}

	arg := os.Args[1]
	if !strings.HasPrefix(arg, "--") {
		log.Panicf("bad usage: ./test --blabla")
	}

	if len(arg) < 3 {
		log.Panicf("bad usage: ./test --blabla")
	}

	realArg := arg[2:]
	flag, ok := m.mapper[realArg]
	if !ok {
		log.Panicf("%s not found", realArg)
	}

	flag.Set(true)
}

func main() {
	useWorker := myFlags.Bool("useWorker", false)
	log.Printf("before parse: useWorker: %t", *useWorker)
	myFlags.Parse()
	log.Printf("after parse: useWorker: %t", *useWorker)
}

来,执行一下:

$ go run main.go
2020/04/23 22:39:15 before parse: useWorker: false
2020/04/23 22:39:15 after parse: useWorker: false
$ go run main.go --useWorker
2020/04/23 22:39:20 before parse: useWorker: false
2020/04/23 22:39:20 after parse: useWorker: true

更多文章
  • 面试的一些技巧
  • HTTP/2 简介
  • 独立运营博客一年的一些数据分享
  • To B(usiness) 和 To C(ustomer)
  • 常见的软件架构套路
  • Cookie 中的secure和httponly属性
  • Google Ads使用体验
  • Go的custom import path
  • 如何挖掘二级子域名?
  • Go Module 简明教程
  • 写了一个Telegram Bot:自动化分享高质量内容
  • Vim打开很慢,怎么找出最慢的插件?怎么解决?
  • 为什么我选择放弃运营微信公众号?
  • ArchLinux 怎么降级 package ?
  • Web后端工程师进阶指南(2018)