You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
515 lines
15 KiB
515 lines
15 KiB
package gen |
|
|
|
import ( |
|
"encoding" |
|
"encoding/json" |
|
"fmt" |
|
"reflect" |
|
"strings" |
|
"unicode" |
|
|
|
"github.com/mailru/easyjson" |
|
) |
|
|
|
// Target this byte size for initial slice allocation to reduce garbage collection. |
|
const minSliceBytes = 64 |
|
|
|
func (g *Generator) getDecoderName(t reflect.Type) string { |
|
return g.functionName("decode", t) |
|
} |
|
|
|
var primitiveDecoders = map[reflect.Kind]string{ |
|
reflect.String: "in.String()", |
|
reflect.Bool: "in.Bool()", |
|
reflect.Int: "in.Int()", |
|
reflect.Int8: "in.Int8()", |
|
reflect.Int16: "in.Int16()", |
|
reflect.Int32: "in.Int32()", |
|
reflect.Int64: "in.Int64()", |
|
reflect.Uint: "in.Uint()", |
|
reflect.Uint8: "in.Uint8()", |
|
reflect.Uint16: "in.Uint16()", |
|
reflect.Uint32: "in.Uint32()", |
|
reflect.Uint64: "in.Uint64()", |
|
reflect.Float32: "in.Float32()", |
|
reflect.Float64: "in.Float64()", |
|
} |
|
|
|
var primitiveStringDecoders = map[reflect.Kind]string{ |
|
reflect.String: "in.String()", |
|
reflect.Int: "in.IntStr()", |
|
reflect.Int8: "in.Int8Str()", |
|
reflect.Int16: "in.Int16Str()", |
|
reflect.Int32: "in.Int32Str()", |
|
reflect.Int64: "in.Int64Str()", |
|
reflect.Uint: "in.UintStr()", |
|
reflect.Uint8: "in.Uint8Str()", |
|
reflect.Uint16: "in.Uint16Str()", |
|
reflect.Uint32: "in.Uint32Str()", |
|
reflect.Uint64: "in.Uint64Str()", |
|
reflect.Uintptr: "in.UintptrStr()", |
|
reflect.Float32: "in.Float32Str()", |
|
reflect.Float64: "in.Float64Str()", |
|
} |
|
|
|
var customDecoders = map[string]string{ |
|
"json.Number": "in.JsonNumber()", |
|
} |
|
|
|
// genTypeDecoder generates decoding code for the type t, but uses unmarshaler interface if implemented by t. |
|
func (g *Generator) genTypeDecoder(t reflect.Type, out string, tags fieldTags, indent int) error { |
|
ws := strings.Repeat(" ", indent) |
|
|
|
unmarshalerIface := reflect.TypeOf((*easyjson.Unmarshaler)(nil)).Elem() |
|
if reflect.PtrTo(t).Implements(unmarshalerIface) { |
|
fmt.Fprintln(g.out, ws+"("+out+").UnmarshalEasyJSON(in)") |
|
return nil |
|
} |
|
|
|
unmarshalerIface = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() |
|
if reflect.PtrTo(t).Implements(unmarshalerIface) { |
|
fmt.Fprintln(g.out, ws+"if data := in.Raw(); in.Ok() {") |
|
fmt.Fprintln(g.out, ws+" in.AddError( ("+out+").UnmarshalJSON(data) )") |
|
fmt.Fprintln(g.out, ws+"}") |
|
return nil |
|
} |
|
|
|
unmarshalerIface = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() |
|
if reflect.PtrTo(t).Implements(unmarshalerIface) { |
|
fmt.Fprintln(g.out, ws+"if data := in.UnsafeBytes(); in.Ok() {") |
|
fmt.Fprintln(g.out, ws+" in.AddError( ("+out+").UnmarshalText(data) )") |
|
fmt.Fprintln(g.out, ws+"}") |
|
return nil |
|
} |
|
|
|
err := g.genTypeDecoderNoCheck(t, out, tags, indent) |
|
return err |
|
} |
|
|
|
// returns true of the type t implements one of the custom unmarshaler interfaces |
|
func hasCustomUnmarshaler(t reflect.Type) bool { |
|
t = reflect.PtrTo(t) |
|
return t.Implements(reflect.TypeOf((*easyjson.Unmarshaler)(nil)).Elem()) || |
|
t.Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) || |
|
t.Implements(reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()) |
|
} |
|
|
|
// genTypeDecoderNoCheck generates decoding code for the type t. |
|
func (g *Generator) genTypeDecoderNoCheck(t reflect.Type, out string, tags fieldTags, indent int) error { |
|
ws := strings.Repeat(" ", indent) |
|
// Check whether type is primitive, needs to be done after interface check. |
|
if dec := customDecoders[t.String()]; dec != "" { |
|
fmt.Fprintln(g.out, ws+out+" = "+dec) |
|
return nil |
|
} else if dec := primitiveStringDecoders[t.Kind()]; dec != "" && tags.asString { |
|
fmt.Fprintln(g.out, ws+out+" = "+g.getType(t)+"("+dec+")") |
|
return nil |
|
} else if dec := primitiveDecoders[t.Kind()]; dec != "" { |
|
fmt.Fprintln(g.out, ws+out+" = "+g.getType(t)+"("+dec+")") |
|
return nil |
|
} |
|
|
|
switch t.Kind() { |
|
case reflect.Slice: |
|
tmpVar := g.uniqueVarName() |
|
elem := t.Elem() |
|
|
|
if elem.Kind() == reflect.Uint8 && elem.Name() == "uint8" { |
|
fmt.Fprintln(g.out, ws+"if in.IsNull() {") |
|
fmt.Fprintln(g.out, ws+" in.Skip()") |
|
fmt.Fprintln(g.out, ws+" "+out+" = nil") |
|
fmt.Fprintln(g.out, ws+"} else {") |
|
fmt.Fprintln(g.out, ws+" "+out+" = in.Bytes()") |
|
fmt.Fprintln(g.out, ws+"}") |
|
|
|
} else { |
|
|
|
capacity := minSliceBytes / elem.Size() |
|
if capacity == 0 { |
|
capacity = 1 |
|
} |
|
|
|
fmt.Fprintln(g.out, ws+"if in.IsNull() {") |
|
fmt.Fprintln(g.out, ws+" in.Skip()") |
|
fmt.Fprintln(g.out, ws+" "+out+" = nil") |
|
fmt.Fprintln(g.out, ws+"} else {") |
|
fmt.Fprintln(g.out, ws+" in.Delim('[')") |
|
fmt.Fprintln(g.out, ws+" if "+out+" == nil {") |
|
fmt.Fprintln(g.out, ws+" if !in.IsDelim(']') {") |
|
fmt.Fprintln(g.out, ws+" "+out+" = make("+g.getType(t)+", 0, "+fmt.Sprint(capacity)+")") |
|
fmt.Fprintln(g.out, ws+" } else {") |
|
fmt.Fprintln(g.out, ws+" "+out+" = "+g.getType(t)+"{}") |
|
fmt.Fprintln(g.out, ws+" }") |
|
fmt.Fprintln(g.out, ws+" } else { ") |
|
fmt.Fprintln(g.out, ws+" "+out+" = ("+out+")[:0]") |
|
fmt.Fprintln(g.out, ws+" }") |
|
fmt.Fprintln(g.out, ws+" for !in.IsDelim(']') {") |
|
fmt.Fprintln(g.out, ws+" var "+tmpVar+" "+g.getType(elem)) |
|
|
|
if err := g.genTypeDecoder(elem, tmpVar, tags, indent+2); err != nil { |
|
return err |
|
} |
|
|
|
fmt.Fprintln(g.out, ws+" "+out+" = append("+out+", "+tmpVar+")") |
|
fmt.Fprintln(g.out, ws+" in.WantComma()") |
|
fmt.Fprintln(g.out, ws+" }") |
|
fmt.Fprintln(g.out, ws+" in.Delim(']')") |
|
fmt.Fprintln(g.out, ws+"}") |
|
} |
|
|
|
case reflect.Array: |
|
iterVar := g.uniqueVarName() |
|
elem := t.Elem() |
|
|
|
if elem.Kind() == reflect.Uint8 && elem.Name() == "uint8" { |
|
fmt.Fprintln(g.out, ws+"if in.IsNull() {") |
|
fmt.Fprintln(g.out, ws+" in.Skip()") |
|
fmt.Fprintln(g.out, ws+"} else {") |
|
fmt.Fprintln(g.out, ws+" copy("+out+"[:], in.Bytes())") |
|
fmt.Fprintln(g.out, ws+"}") |
|
|
|
} else { |
|
|
|
length := t.Len() |
|
|
|
fmt.Fprintln(g.out, ws+"if in.IsNull() {") |
|
fmt.Fprintln(g.out, ws+" in.Skip()") |
|
fmt.Fprintln(g.out, ws+"} else {") |
|
fmt.Fprintln(g.out, ws+" in.Delim('[')") |
|
fmt.Fprintln(g.out, ws+" "+iterVar+" := 0") |
|
fmt.Fprintln(g.out, ws+" for !in.IsDelim(']') {") |
|
fmt.Fprintln(g.out, ws+" if "+iterVar+" < "+fmt.Sprint(length)+" {") |
|
|
|
if err := g.genTypeDecoder(elem, "("+out+")["+iterVar+"]", tags, indent+3); err != nil { |
|
return err |
|
} |
|
|
|
fmt.Fprintln(g.out, ws+" "+iterVar+"++") |
|
fmt.Fprintln(g.out, ws+" } else {") |
|
fmt.Fprintln(g.out, ws+" in.SkipRecursive()") |
|
fmt.Fprintln(g.out, ws+" }") |
|
fmt.Fprintln(g.out, ws+" in.WantComma()") |
|
fmt.Fprintln(g.out, ws+" }") |
|
fmt.Fprintln(g.out, ws+" in.Delim(']')") |
|
fmt.Fprintln(g.out, ws+"}") |
|
} |
|
|
|
case reflect.Struct: |
|
dec := g.getDecoderName(t) |
|
g.addType(t) |
|
|
|
fmt.Fprintln(g.out, ws+dec+"(in, &"+out+")") |
|
|
|
case reflect.Ptr: |
|
fmt.Fprintln(g.out, ws+"if in.IsNull() {") |
|
fmt.Fprintln(g.out, ws+" in.Skip()") |
|
fmt.Fprintln(g.out, ws+" "+out+" = nil") |
|
fmt.Fprintln(g.out, ws+"} else {") |
|
fmt.Fprintln(g.out, ws+" if "+out+" == nil {") |
|
fmt.Fprintln(g.out, ws+" "+out+" = new("+g.getType(t.Elem())+")") |
|
fmt.Fprintln(g.out, ws+" }") |
|
|
|
if err := g.genTypeDecoder(t.Elem(), "*"+out, tags, indent+1); err != nil { |
|
return err |
|
} |
|
|
|
fmt.Fprintln(g.out, ws+"}") |
|
|
|
case reflect.Map: |
|
key := t.Key() |
|
keyDec, ok := primitiveStringDecoders[key.Kind()] |
|
if !ok && !hasCustomUnmarshaler(key) { |
|
return fmt.Errorf("map type %v not supported: only string and integer keys and types implementing json.Unmarshaler are allowed", key) |
|
} // else assume the caller knows what they are doing and that the custom unmarshaler performs the translation from string or integer keys to the key type |
|
elem := t.Elem() |
|
tmpVar := g.uniqueVarName() |
|
|
|
fmt.Fprintln(g.out, ws+"if in.IsNull() {") |
|
fmt.Fprintln(g.out, ws+" in.Skip()") |
|
fmt.Fprintln(g.out, ws+"} else {") |
|
fmt.Fprintln(g.out, ws+" in.Delim('{')") |
|
fmt.Fprintln(g.out, ws+" if !in.IsDelim('}') {") |
|
fmt.Fprintln(g.out, ws+" "+out+" = make("+g.getType(t)+")") |
|
fmt.Fprintln(g.out, ws+" } else {") |
|
fmt.Fprintln(g.out, ws+" "+out+" = nil") |
|
fmt.Fprintln(g.out, ws+" }") |
|
|
|
fmt.Fprintln(g.out, ws+" for !in.IsDelim('}') {") |
|
if keyDec != "" { |
|
fmt.Fprintln(g.out, ws+" key := "+g.getType(key)+"("+keyDec+")") |
|
} else { |
|
fmt.Fprintln(g.out, ws+" var key "+g.getType(key)) |
|
if err := g.genTypeDecoder(key, "key", tags, indent+2); err != nil { |
|
return err |
|
} |
|
} |
|
|
|
fmt.Fprintln(g.out, ws+" in.WantColon()") |
|
fmt.Fprintln(g.out, ws+" var "+tmpVar+" "+g.getType(elem)) |
|
|
|
if err := g.genTypeDecoder(elem, tmpVar, tags, indent+2); err != nil { |
|
return err |
|
} |
|
|
|
fmt.Fprintln(g.out, ws+" ("+out+")[key] = "+tmpVar) |
|
fmt.Fprintln(g.out, ws+" in.WantComma()") |
|
fmt.Fprintln(g.out, ws+" }") |
|
fmt.Fprintln(g.out, ws+" in.Delim('}')") |
|
fmt.Fprintln(g.out, ws+"}") |
|
|
|
case reflect.Interface: |
|
if t.NumMethod() != 0 { |
|
return fmt.Errorf("interface type %v not supported: only interface{} is allowed", t) |
|
} |
|
fmt.Fprintln(g.out, ws+"if m, ok := "+out+".(easyjson.Unmarshaler); ok {") |
|
fmt.Fprintln(g.out, ws+"m.UnmarshalEasyJSON(in)") |
|
fmt.Fprintln(g.out, ws+"} else if m, ok := "+out+".(json.Unmarshaler); ok {") |
|
fmt.Fprintln(g.out, ws+"_ = m.UnmarshalJSON(in.Raw())") |
|
fmt.Fprintln(g.out, ws+"} else {") |
|
fmt.Fprintln(g.out, ws+" "+out+" = in.Interface()") |
|
fmt.Fprintln(g.out, ws+"}") |
|
default: |
|
return fmt.Errorf("don't know how to decode %v", t) |
|
} |
|
return nil |
|
|
|
} |
|
|
|
func (g *Generator) genStructFieldDecoder(t reflect.Type, f reflect.StructField) error { |
|
jsonName := g.fieldNamer.GetJSONFieldName(t, f) |
|
tags := parseFieldTags(f) |
|
|
|
if tags.omit { |
|
return nil |
|
} |
|
|
|
fmt.Fprintf(g.out, " case %q:\n", jsonName) |
|
if err := g.genTypeDecoder(f.Type, "out."+f.Name, tags, 3); err != nil { |
|
return err |
|
} |
|
|
|
if tags.required { |
|
fmt.Fprintf(g.out, "%sSet = true\n", f.Name) |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (g *Generator) genRequiredFieldSet(t reflect.Type, f reflect.StructField) { |
|
tags := parseFieldTags(f) |
|
|
|
if !tags.required { |
|
return |
|
} |
|
|
|
fmt.Fprintf(g.out, "var %sSet bool\n", f.Name) |
|
} |
|
|
|
func (g *Generator) genRequiredFieldCheck(t reflect.Type, f reflect.StructField) { |
|
jsonName := g.fieldNamer.GetJSONFieldName(t, f) |
|
tags := parseFieldTags(f) |
|
|
|
if !tags.required { |
|
return |
|
} |
|
|
|
g.imports["fmt"] = "fmt" |
|
|
|
fmt.Fprintf(g.out, "if !%sSet {\n", f.Name) |
|
fmt.Fprintf(g.out, " in.AddError(fmt.Errorf(\"key '%s' is required\"))\n", jsonName) |
|
fmt.Fprintf(g.out, "}\n") |
|
} |
|
|
|
func mergeStructFields(fields1, fields2 []reflect.StructField) (fields []reflect.StructField) { |
|
used := map[string]bool{} |
|
for _, f := range fields2 { |
|
used[f.Name] = true |
|
fields = append(fields, f) |
|
} |
|
|
|
for _, f := range fields1 { |
|
if !used[f.Name] { |
|
fields = append(fields, f) |
|
} |
|
} |
|
return |
|
} |
|
|
|
func getStructFields(t reflect.Type) ([]reflect.StructField, error) { |
|
if t.Kind() != reflect.Struct { |
|
return nil, fmt.Errorf("got %v; expected a struct", t) |
|
} |
|
|
|
var efields []reflect.StructField |
|
for i := 0; i < t.NumField(); i++ { |
|
f := t.Field(i) |
|
if !f.Anonymous { |
|
continue |
|
} |
|
|
|
t1 := f.Type |
|
if t1.Kind() == reflect.Ptr { |
|
t1 = t1.Elem() |
|
} |
|
|
|
fs, err := getStructFields(t1) |
|
if err != nil { |
|
return nil, fmt.Errorf("error processing embedded field: %v", err) |
|
} |
|
efields = mergeStructFields(efields, fs) |
|
} |
|
|
|
var fields []reflect.StructField |
|
for i := 0; i < t.NumField(); i++ { |
|
f := t.Field(i) |
|
if f.Anonymous { |
|
continue |
|
} |
|
|
|
c := []rune(f.Name)[0] |
|
if unicode.IsUpper(c) { |
|
fields = append(fields, f) |
|
} |
|
} |
|
return mergeStructFields(efields, fields), nil |
|
} |
|
|
|
func (g *Generator) genDecoder(t reflect.Type) error { |
|
switch t.Kind() { |
|
case reflect.Slice, reflect.Array, reflect.Map: |
|
return g.genSliceArrayDecoder(t) |
|
default: |
|
return g.genStructDecoder(t) |
|
} |
|
} |
|
|
|
func (g *Generator) genSliceArrayDecoder(t reflect.Type) error { |
|
switch t.Kind() { |
|
case reflect.Slice, reflect.Array, reflect.Map: |
|
default: |
|
return fmt.Errorf("cannot generate encoder/decoder for %v, not a slice/array/map type", t) |
|
} |
|
|
|
fname := g.getDecoderName(t) |
|
typ := g.getType(t) |
|
|
|
fmt.Fprintln(g.out, "func "+fname+"(in *jlexer.Lexer, out *"+typ+") {") |
|
fmt.Fprintln(g.out, " isTopLevel := in.IsStart()") |
|
err := g.genTypeDecoderNoCheck(t, "*out", fieldTags{}, 1) |
|
if err != nil { |
|
return err |
|
} |
|
fmt.Fprintln(g.out, " if isTopLevel {") |
|
fmt.Fprintln(g.out, " in.Consumed()") |
|
fmt.Fprintln(g.out, " }") |
|
fmt.Fprintln(g.out, "}") |
|
|
|
return nil |
|
} |
|
|
|
func (g *Generator) genStructDecoder(t reflect.Type) error { |
|
if t.Kind() != reflect.Struct { |
|
return fmt.Errorf("cannot generate encoder/decoder for %v, not a struct type", t) |
|
} |
|
|
|
fname := g.getDecoderName(t) |
|
typ := g.getType(t) |
|
|
|
fmt.Fprintln(g.out, "func "+fname+"(in *jlexer.Lexer, out *"+typ+") {") |
|
fmt.Fprintln(g.out, " isTopLevel := in.IsStart()") |
|
fmt.Fprintln(g.out, " if in.IsNull() {") |
|
fmt.Fprintln(g.out, " if isTopLevel {") |
|
fmt.Fprintln(g.out, " in.Consumed()") |
|
fmt.Fprintln(g.out, " }") |
|
fmt.Fprintln(g.out, " in.Skip()") |
|
fmt.Fprintln(g.out, " return") |
|
fmt.Fprintln(g.out, " }") |
|
|
|
// Init embedded pointer fields. |
|
for i := 0; i < t.NumField(); i++ { |
|
f := t.Field(i) |
|
if !f.Anonymous || f.Type.Kind() != reflect.Ptr { |
|
continue |
|
} |
|
fmt.Fprintln(g.out, " out."+f.Name+" = new("+g.getType(f.Type.Elem())+")") |
|
} |
|
|
|
fs, err := getStructFields(t) |
|
if err != nil { |
|
return fmt.Errorf("cannot generate decoder for %v: %v", t, err) |
|
} |
|
|
|
for _, f := range fs { |
|
g.genRequiredFieldSet(t, f) |
|
} |
|
|
|
fmt.Fprintln(g.out, " in.Delim('{')") |
|
fmt.Fprintln(g.out, " for !in.IsDelim('}') {") |
|
fmt.Fprintln(g.out, " key := in.UnsafeString()") |
|
fmt.Fprintln(g.out, " in.WantColon()") |
|
fmt.Fprintln(g.out, " if in.IsNull() {") |
|
fmt.Fprintln(g.out, " in.Skip()") |
|
fmt.Fprintln(g.out, " in.WantComma()") |
|
fmt.Fprintln(g.out, " continue") |
|
fmt.Fprintln(g.out, " }") |
|
|
|
fmt.Fprintln(g.out, " switch key {") |
|
for _, f := range fs { |
|
if err := g.genStructFieldDecoder(t, f); err != nil { |
|
return err |
|
} |
|
} |
|
|
|
fmt.Fprintln(g.out, " default:") |
|
if g.disallowUnknownFields { |
|
fmt.Fprintln(g.out, ` in.AddError(&jlexer.LexerError{ |
|
Offset: in.GetPos(), |
|
Reason: "unknown field", |
|
Data: key, |
|
})`) |
|
} else { |
|
fmt.Fprintln(g.out, " in.SkipRecursive()") |
|
} |
|
fmt.Fprintln(g.out, " }") |
|
fmt.Fprintln(g.out, " in.WantComma()") |
|
fmt.Fprintln(g.out, " }") |
|
fmt.Fprintln(g.out, " in.Delim('}')") |
|
fmt.Fprintln(g.out, " if isTopLevel {") |
|
fmt.Fprintln(g.out, " in.Consumed()") |
|
fmt.Fprintln(g.out, " }") |
|
|
|
for _, f := range fs { |
|
g.genRequiredFieldCheck(t, f) |
|
} |
|
|
|
fmt.Fprintln(g.out, "}") |
|
|
|
return nil |
|
} |
|
|
|
func (g *Generator) genStructUnmarshaler(t reflect.Type) error { |
|
switch t.Kind() { |
|
case reflect.Slice, reflect.Array, reflect.Map, reflect.Struct: |
|
default: |
|
return fmt.Errorf("cannot generate encoder/decoder for %v, not a struct/slice/array/map type", t) |
|
} |
|
|
|
fname := g.getDecoderName(t) |
|
typ := g.getType(t) |
|
|
|
if !g.noStdMarshalers { |
|
fmt.Fprintln(g.out, "// UnmarshalJSON supports json.Unmarshaler interface") |
|
fmt.Fprintln(g.out, "func (v *"+typ+") UnmarshalJSON(data []byte) error {") |
|
fmt.Fprintln(g.out, " r := jlexer.Lexer{Data: data}") |
|
fmt.Fprintln(g.out, " "+fname+"(&r, v)") |
|
fmt.Fprintln(g.out, " return r.Error()") |
|
fmt.Fprintln(g.out, "}") |
|
} |
|
|
|
fmt.Fprintln(g.out, "// UnmarshalEasyJSON supports easyjson.Unmarshaler interface") |
|
fmt.Fprintln(g.out, "func (v *"+typ+") UnmarshalEasyJSON(l *jlexer.Lexer) {") |
|
fmt.Fprintln(g.out, " "+fname+"(l, v)") |
|
fmt.Fprintln(g.out, "}") |
|
|
|
return nil |
|
}
|
|
|