You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
517 lines
14 KiB
517 lines
14 KiB
package main |
|
|
|
import ( |
|
"bytes" |
|
"flag" |
|
"go/ast" |
|
"io/ioutil" |
|
"log" |
|
"os" |
|
"path/filepath" |
|
"regexp" |
|
"strconv" |
|
"strings" |
|
"text/template" |
|
|
|
"go-common/app/tool/cache/common" |
|
) |
|
|
|
var ( |
|
encode = flag.String("encode", "", "encode type: json/pb/raw/gob/gzip") |
|
mcType = flag.String("type", "", "type: get/set/del/replace/only_add") |
|
key = flag.String("key", "", "key name method") |
|
expire = flag.String("expire", "", "expire time code") |
|
batchSize = flag.Int("batch", 0, "batch size") |
|
batchErr = flag.String("batch_err", "break", "batch err to contine or break") |
|
maxGroup = flag.Int("max_group", 0, "max group size") |
|
|
|
mcValidTypes = []string{"set", "replace", "del", "get", "only_add"} |
|
mcValidPrefix = []string{"set", "replace", "del", "get", "cache", "add"} |
|
optionNamesMap = map[string]bool{"batch": true, "max_group": true, "encode": true, "type": true, "key": true, "expire": true, "batch_err": true} |
|
simpleTypes = []string{"int", "int8", "int16", "int32", "int64", "float32", "float64", "uint", "uint8", "uint16", "uint32", "uint64", "bool", "string", "[]byte"} |
|
lenTypes = []string{"[]", "map"} |
|
) |
|
|
|
const ( |
|
_interfaceName = "_mc" |
|
_multiTpl = 1 |
|
_singleTpl = 2 |
|
_noneTpl = 3 |
|
_typeGet = "get" |
|
_typeSet = "set" |
|
_typeDel = "del" |
|
_typeReplace = "replace" |
|
_typeAdd = "only_add" |
|
) |
|
|
|
func resetFlag() { |
|
*encode = "" |
|
*mcType = "" |
|
*batchSize = 0 |
|
*maxGroup = 0 |
|
*batchErr = "break" |
|
} |
|
|
|
// options options |
|
type options struct { |
|
name string |
|
keyType string |
|
ValueType string |
|
template int |
|
SimpleValue bool |
|
// int float 类型 |
|
GetSimpleValue bool |
|
// string, []byte类型 |
|
GetDirectValue bool |
|
ConvertValue2Bytes string |
|
ConvertBytes2Value string |
|
GoValue bool |
|
ImportPackage string |
|
importPackages []string |
|
Args string |
|
PkgName string |
|
ExtraArgsType string |
|
ExtraArgs string |
|
MCType string |
|
KeyMethod string |
|
ExpireCode string |
|
Encode string |
|
UseMemcached bool |
|
InitValue bool |
|
OriginValueType string |
|
UseStrConv bool |
|
Comment string |
|
GroupSize int |
|
MaxGroup int |
|
EnableBatch bool |
|
BatchErrBreak bool |
|
LenType bool |
|
PointType bool |
|
} |
|
|
|
func parse(s *common.Source) (opts []*options) { |
|
f := s.F |
|
fset := s.Fset |
|
src := s.Src |
|
c := f.Scope.Lookup(_interfaceName) |
|
if (c == nil) || (c.Kind != ast.Typ) { |
|
log.Fatalln("无法找到缓存声明") |
|
} |
|
lines := strings.Split(src, "\n") |
|
lists := c.Decl.(*ast.TypeSpec).Type.(*ast.InterfaceType).Methods.List |
|
for _, list := range lists { |
|
opt := options{Args: s.GetDef(_interfaceName), UseMemcached: true, importPackages: s.Packages(list)} |
|
opt.name = list.Names[0].Name |
|
opt.KeyMethod = "key" + opt.name |
|
opt.ExpireCode = "d.mc" + opt.name + "Expire" |
|
// get comment |
|
line := fset.Position(list.Pos()).Line - 3 |
|
if len(lines)-1 >= line { |
|
comment := lines[line] |
|
opt.Comment = common.RegexpReplace(`\s+//(?P<name>.+)`, comment, "$name") |
|
opt.Comment = strings.TrimSpace(opt.Comment) |
|
} |
|
// get options |
|
line = fset.Position(list.Pos()).Line - 2 |
|
comment := lines[line] |
|
os.Args = []string{os.Args[0]} |
|
if regexp.MustCompile(`\s+//\s*mc:.+`).Match([]byte(comment)) { |
|
args := strings.Split(common.RegexpReplace(`//\s*mc:(?P<arg>.+)`, comment, "$arg"), " ") |
|
for _, arg := range args { |
|
arg = strings.TrimSpace(arg) |
|
if arg != "" { |
|
// validate option name |
|
argName := common.RegexpReplace(`-(?P<name>[\w_-]+)=.+`, arg, "$name") |
|
if !optionNamesMap[argName] { |
|
log.Fatalf("选项:%s 不存在 请检查拼写\n", argName) |
|
} |
|
os.Args = append(os.Args, arg) |
|
} |
|
} |
|
} |
|
resetFlag() |
|
flag.Parse() |
|
if *mcType != "" { |
|
opt.MCType = *mcType |
|
} |
|
if *key != "" { |
|
opt.KeyMethod = *key |
|
} |
|
if *expire != "" { |
|
opt.ExpireCode = *expire |
|
} |
|
opt.EnableBatch = (*batchSize != 0) && (*maxGroup != 0) |
|
opt.BatchErrBreak = *batchErr == "break" |
|
opt.GroupSize = *batchSize |
|
opt.MaxGroup = *maxGroup |
|
// get type from prefix |
|
if opt.MCType == "" { |
|
for _, t := range mcValidPrefix { |
|
if strings.HasPrefix(strings.ToLower(opt.name), t) { |
|
if t == "add" { |
|
t = _typeSet |
|
} |
|
opt.MCType = t |
|
break |
|
} |
|
} |
|
if opt.MCType == "" { |
|
log.Fatalln(opt.name + "请指定方法类型(type=get/set/del...)") |
|
} |
|
} |
|
if opt.MCType == "cache" { |
|
opt.MCType = _typeGet |
|
} |
|
params := list.Type.(*ast.FuncType).Params.List |
|
if len(params) == 0 { |
|
log.Fatalln(opt.name + "参数不足") |
|
} |
|
if s.ExprString(params[0].Type) != "context.Context" { |
|
log.Fatalln(opt.name + "第一个参数必须为context") |
|
} |
|
for _, param := range params { |
|
if len(param.Names) > 1 { |
|
log.Fatalln(opt.name + "不支持省略类型") |
|
} |
|
} |
|
// get template |
|
if len(params) == 1 { |
|
opt.template = _noneTpl |
|
} else if (len(params) == 2) && (opt.MCType == _typeSet || opt.MCType == _typeAdd || opt.MCType == _typeReplace) { |
|
if _, ok := params[1].Type.(*ast.MapType); ok { |
|
opt.template = _multiTpl |
|
} else { |
|
opt.template = _noneTpl |
|
} |
|
} else { |
|
if _, ok := params[1].Type.(*ast.ArrayType); ok { |
|
opt.template = _multiTpl |
|
} else { |
|
opt.template = _singleTpl |
|
} |
|
} |
|
// extra args |
|
if len(params) > 2 { |
|
args := []string{""} |
|
allArgs := []string{""} |
|
var pos = 2 |
|
if (opt.MCType == _typeAdd) || (opt.MCType == _typeSet) || (opt.MCType == _typeReplace) { |
|
pos = 3 |
|
} |
|
for _, pa := range params[pos:] { |
|
paType := s.ExprString(pa.Type) |
|
if len(pa.Names) == 0 { |
|
args = append(args, paType) |
|
allArgs = append(allArgs, paType) |
|
continue |
|
} |
|
var names []string |
|
for _, name := range pa.Names { |
|
names = append(names, name.Name) |
|
} |
|
allArgs = append(allArgs, strings.Join(names, ",")+" "+paType) |
|
args = append(args, strings.Join(names, ",")) |
|
} |
|
if len(args) > 1 { |
|
opt.ExtraArgs = strings.Join(args, ",") |
|
opt.ExtraArgsType = strings.Join(allArgs, ",") |
|
} |
|
} |
|
// get k v from results |
|
results := list.Type.(*ast.FuncType).Results.List |
|
if s.ExprString(results[len(results)-1].Type) != "error" { |
|
log.Fatalln("最后返回值参数需为error") |
|
} |
|
for _, res := range results { |
|
if len(res.Names) > 1 { |
|
log.Fatalln(opt.name + "返回值不支持省略类型") |
|
} |
|
} |
|
if opt.MCType == _typeGet { |
|
if len(results) != 2 { |
|
log.Fatalln("参数个数不对") |
|
} |
|
} |
|
// get key type and value type |
|
if (opt.MCType == _typeAdd) || (opt.MCType == _typeSet) || (opt.MCType == _typeReplace) { |
|
if opt.template == _multiTpl { |
|
p, ok := params[1].Type.(*ast.MapType) |
|
if !ok { |
|
log.Fatalf("%s: 参数类型错误 批量设置数据时类型需为map类型\n", opt.name) |
|
} |
|
opt.keyType = s.ExprString(p.Key) |
|
opt.ValueType = s.ExprString(p.Value) |
|
} else if opt.template == _singleTpl { |
|
opt.keyType = s.ExprString(params[1].Type) |
|
opt.ValueType = s.ExprString(params[2].Type) |
|
} else { |
|
opt.ValueType = s.ExprString(params[1].Type) |
|
} |
|
} |
|
if opt.MCType == _typeGet { |
|
if opt.template == _multiTpl { |
|
if p, ok := results[0].Type.(*ast.MapType); ok { |
|
opt.keyType = s.ExprString(p.Key) |
|
opt.ValueType = s.ExprString(p.Value) |
|
} else { |
|
log.Fatalf("%s: 返回值类型错误 批量获取数据时返回值需为map类型\n", opt.name) |
|
} |
|
} else if opt.template == _singleTpl { |
|
opt.keyType = s.ExprString(params[1].Type) |
|
opt.ValueType = s.ExprString(results[0].Type) |
|
} else { |
|
opt.ValueType = s.ExprString(results[0].Type) |
|
} |
|
} |
|
if opt.MCType == _typeDel { |
|
if opt.template == _multiTpl { |
|
p, ok := params[1].Type.(*ast.ArrayType) |
|
if !ok { |
|
log.Fatalf("%s: 类型错误 参数需为[]类型\n", opt.name) |
|
} |
|
opt.keyType = s.ExprString(p.Elt) |
|
} else if opt.template == _singleTpl { |
|
opt.keyType = s.ExprString(params[1].Type) |
|
} |
|
} |
|
for _, t := range simpleTypes { |
|
if t == opt.ValueType { |
|
opt.SimpleValue = true |
|
opt.GetSimpleValue = true |
|
opt.ConvertValue2Bytes = convertValue2Bytes(t) |
|
opt.ConvertBytes2Value = convertBytes2Value(t) |
|
break |
|
} |
|
} |
|
if opt.ValueType == "string" { |
|
opt.LenType = true |
|
} else { |
|
for _, t := range lenTypes { |
|
if strings.HasPrefix(opt.ValueType, t) { |
|
opt.LenType = true |
|
break |
|
} |
|
} |
|
} |
|
if opt.SimpleValue && (opt.ValueType == "[]byte" || opt.ValueType == "string") { |
|
opt.GetSimpleValue = false |
|
opt.GetDirectValue = true |
|
} |
|
if opt.MCType == _typeGet && opt.template == _multiTpl { |
|
opt.UseMemcached = false |
|
} |
|
if strings.HasPrefix(opt.ValueType, "*") { |
|
opt.InitValue = true |
|
opt.PointType = true |
|
opt.OriginValueType = strings.Replace(opt.ValueType, "*", "", 1) |
|
} else { |
|
opt.OriginValueType = opt.ValueType |
|
} |
|
if *encode != "" { |
|
var flags []string |
|
for _, f := range strings.Split(*encode, "|") { |
|
switch f { |
|
case "gob": |
|
flags = append(flags, "memcache.FlagGOB") |
|
case "json": |
|
flags = append(flags, "memcache.FlagJSON") |
|
case "raw": |
|
flags = append(flags, "memcache.FlagRAW") |
|
case "pb": |
|
flags = append(flags, "memcache.FlagProtobuf") |
|
case "gzip": |
|
flags = append(flags, "memcache.FlagGzip") |
|
default: |
|
log.Fatalf("%s: encode类型无效\n", opt.name) |
|
} |
|
} |
|
opt.Encode = strings.Join(flags, " | ") |
|
} else { |
|
if opt.SimpleValue { |
|
opt.Encode = "memcache.FlagRAW" |
|
} else { |
|
opt.Encode = "memcache.FlagJSON" |
|
} |
|
} |
|
opt.Check() |
|
opts = append(opts, &opt) |
|
} |
|
return |
|
} |
|
|
|
func (option *options) Check() { |
|
var valid bool |
|
for _, x := range mcValidTypes { |
|
if x == option.MCType { |
|
valid = true |
|
break |
|
} |
|
} |
|
if !valid { |
|
log.Fatalf("%s: 类型错误 不支持%s类型\n", option.name, option.MCType) |
|
} |
|
if (option.MCType != _typeDel) && !option.SimpleValue && !strings.Contains(option.ValueType, "*") && !strings.Contains(option.ValueType, "[]") && !strings.Contains(option.ValueType, "map") { |
|
log.Fatalf("%s: 值类型只能为基本类型/slice/map/指针类型\n", option.name) |
|
} |
|
} |
|
|
|
func genHeader(opts []*options) (src string) { |
|
option := options{PkgName: os.Getenv("GOPACKAGE"), UseMemcached: false} |
|
var packages []string |
|
packagesMap := map[string]bool{`"context"`: true} |
|
for _, opt := range opts { |
|
if len(opt.importPackages) > 0 { |
|
for _, pkg := range opt.importPackages { |
|
if !packagesMap[pkg] { |
|
packages = append(packages, pkg) |
|
packagesMap[pkg] = true |
|
} |
|
} |
|
} |
|
if opt.Args != "" { |
|
option.Args = opt.Args |
|
} |
|
if opt.UseMemcached { |
|
option.UseMemcached = true |
|
} |
|
if opt.SimpleValue && !opt.GetDirectValue { |
|
option.UseStrConv = true |
|
} |
|
if opt.EnableBatch { |
|
option.EnableBatch = true |
|
} |
|
} |
|
option.ImportPackage = strings.Join(packages, "\n") |
|
src = _headerTemplate |
|
t := template.Must(template.New("header").Parse(src)) |
|
var buffer bytes.Buffer |
|
err := t.Execute(&buffer, option) |
|
if err != nil { |
|
log.Fatalf("execute template: %s", err) |
|
} |
|
// Format the output. |
|
src = strings.Replace(buffer.String(), "\t", "", -1) |
|
src = regexp.MustCompile("\n+").ReplaceAllString(src, "\n") |
|
src = strings.Replace(src, "NEWLINE", "", -1) |
|
src = strings.Replace(src, "ARGS", option.Args, -1) |
|
return |
|
} |
|
|
|
func genBody(opts []*options) (res string) { |
|
for _, option := range opts { |
|
var src string |
|
if option.template == _multiTpl { |
|
switch option.MCType { |
|
case _typeGet: |
|
src = _multiGetTemplate |
|
case _typeSet: |
|
src = _multiSetTemplate |
|
case _typeReplace: |
|
src = _multiReplaceTemplate |
|
case _typeDel: |
|
src = _multiDelTemplate |
|
case _typeAdd: |
|
src = _multiAddTemplate |
|
} |
|
} else if option.template == _singleTpl { |
|
switch option.MCType { |
|
case _typeGet: |
|
src = _singleGetTemplate |
|
case _typeSet: |
|
src = _singleSetTemplate |
|
case _typeReplace: |
|
src = _singleReplaceTemplate |
|
case _typeDel: |
|
src = _singleDelTemplate |
|
case _typeAdd: |
|
src = _singleAddTemplate |
|
} |
|
} else { |
|
switch option.MCType { |
|
case _typeGet: |
|
src = _noneGetTemplate |
|
case _typeSet: |
|
src = _noneSetTemplate |
|
case _typeReplace: |
|
src = _noneReplaceTemplate |
|
case _typeDel: |
|
src = _noneDelTemplate |
|
case _typeAdd: |
|
src = _noneAddTemplate |
|
} |
|
} |
|
src = strings.Replace(src, "KEY", option.keyType, -1) |
|
src = strings.Replace(src, "NAME", option.name, -1) |
|
src = strings.Replace(src, "VALUE", option.ValueType, -1) |
|
src = strings.Replace(src, "GROUPSIZE", strconv.Itoa(option.GroupSize), -1) |
|
src = strings.Replace(src, "MAXGROUP", strconv.Itoa(option.MaxGroup), -1) |
|
t := template.Must(template.New("cache").Parse(src)) |
|
var buffer bytes.Buffer |
|
err := t.Execute(&buffer, option) |
|
if err != nil { |
|
log.Fatalf("execute template: %s", err) |
|
} |
|
// Format the output. |
|
src = strings.Replace(buffer.String(), "\t", "", -1) |
|
src = regexp.MustCompile("\n+").ReplaceAllString(src, "\n") |
|
res = res + "\n" + src |
|
} |
|
return |
|
} |
|
|
|
func main() { |
|
log.SetFlags(0) |
|
defer func() { |
|
if err := recover(); err != nil { |
|
log.Fatalf("程序解析失败, err: %+v 请企业微信联系 @wangxu01", err) |
|
} |
|
}() |
|
options := parse(common.NewSource(common.SourceText())) |
|
header := genHeader(options) |
|
body := genBody(options) |
|
code := common.FormatCode(header + "\n" + body) |
|
// Write to file. |
|
dir := filepath.Dir(".") |
|
outputName := filepath.Join(dir, "mc.cache.go") |
|
err := ioutil.WriteFile(outputName, []byte(code), 0644) |
|
if err != nil { |
|
log.Fatalf("写入文件失败: %s", err) |
|
} |
|
log.Println("mc.cache.go: 生成成功") |
|
} |
|
|
|
func convertValue2Bytes(t string) string { |
|
switch t { |
|
case "int", "int8", "int16", "int32", "int64": |
|
return "[]byte(strconv.FormatInt(int64(val), 10))" |
|
case "uint", "uint8", "uint16", "uint32", "uint64": |
|
return "[]byte(strconv.FormatUInt(val, 10))" |
|
case "bool": |
|
return "[]byte(strconv.FormatBool(val))" |
|
case "float32": |
|
return "[]byte(strconv.FormatFloat(val, 'E', -1, 32))" |
|
case "float64": |
|
return "[]byte(strconv.FormatFloat(val, 'E', -1, 64))" |
|
case "string": |
|
return "[]byte(val)" |
|
case "[]byte": |
|
return "val" |
|
} |
|
return "" |
|
} |
|
|
|
func convertBytes2Value(t string) string { |
|
switch t { |
|
case "int", "int8", "int16", "int32", "int64": |
|
return "strconv.ParseInt(v, 10, 64)" |
|
case "uint", "uint8", "uint16", "uint32", "uint64": |
|
return "strconv.ParseUInt(v, 10, 64)" |
|
case "bool": |
|
return "strconv.ParseBool(v)" |
|
case "float32": |
|
return "float32(strconv.ParseFloat(v, 32))" |
|
case "float64": |
|
return "strconv.ParseFloat(v, 64)" |
|
} |
|
return "" |
|
}
|
|
|