You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
454 lines
13 KiB
454 lines
13 KiB
/* |
|
* |
|
* Copyright 2016 gRPC authors. |
|
* |
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
* you may not use this file except in compliance with the License. |
|
* You may obtain a copy of the License at |
|
* |
|
* http://www.apache.org/licenses/LICENSE-2.0 |
|
* |
|
* Unless required by applicable law or agreed to in writing, software |
|
* distributed under the License is distributed on an "AS IS" BASIS, |
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
* See the License for the specific language governing permissions and |
|
* limitations under the License. |
|
* |
|
*/ |
|
|
|
//go:generate protoc --go_out=plugins=grpc:. grpc_reflection_v1alpha/reflection.proto |
|
|
|
/* |
|
Package reflection implements server reflection service. |
|
|
|
The service implemented is defined in: |
|
https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto. |
|
|
|
To register server reflection on a gRPC server: |
|
import "google.golang.org/grpc/reflection" |
|
|
|
s := grpc.NewServer() |
|
pb.RegisterYourOwnServer(s, &server{}) |
|
|
|
// Register reflection service on gRPC server. |
|
reflection.Register(s) |
|
|
|
s.Serve(lis) |
|
|
|
*/ |
|
package reflection // import "google.golang.org/grpc/reflection" |
|
|
|
import ( |
|
"bytes" |
|
"compress/gzip" |
|
"fmt" |
|
"io" |
|
"io/ioutil" |
|
"reflect" |
|
"sort" |
|
"sync" |
|
|
|
"github.com/golang/protobuf/proto" |
|
dpb "github.com/golang/protobuf/protoc-gen-go/descriptor" |
|
"google.golang.org/grpc" |
|
"google.golang.org/grpc/codes" |
|
rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" |
|
"google.golang.org/grpc/status" |
|
) |
|
|
|
type serverReflectionServer struct { |
|
s *grpc.Server |
|
|
|
initSymbols sync.Once |
|
serviceNames []string |
|
symbols map[string]*dpb.FileDescriptorProto // map of fully-qualified names to files |
|
} |
|
|
|
// Register registers the server reflection service on the given gRPC server. |
|
func Register(s *grpc.Server) { |
|
rpb.RegisterServerReflectionServer(s, &serverReflectionServer{ |
|
s: s, |
|
}) |
|
} |
|
|
|
// protoMessage is used for type assertion on proto messages. |
|
// Generated proto message implements function Descriptor(), but Descriptor() |
|
// is not part of interface proto.Message. This interface is needed to |
|
// call Descriptor(). |
|
type protoMessage interface { |
|
Descriptor() ([]byte, []int) |
|
} |
|
|
|
func (s *serverReflectionServer) getSymbols() (svcNames []string, symbolIndex map[string]*dpb.FileDescriptorProto) { |
|
s.initSymbols.Do(func() { |
|
serviceInfo := s.s.GetServiceInfo() |
|
|
|
s.symbols = map[string]*dpb.FileDescriptorProto{} |
|
s.serviceNames = make([]string, 0, len(serviceInfo)) |
|
processed := map[string]struct{}{} |
|
for svc, info := range serviceInfo { |
|
s.serviceNames = append(s.serviceNames, svc) |
|
fdenc, ok := parseMetadata(info.Metadata) |
|
if !ok { |
|
continue |
|
} |
|
fd, err := decodeFileDesc(fdenc) |
|
if err != nil { |
|
continue |
|
} |
|
s.processFile(fd, processed) |
|
} |
|
sort.Strings(s.serviceNames) |
|
}) |
|
|
|
return s.serviceNames, s.symbols |
|
} |
|
|
|
func (s *serverReflectionServer) processFile(fd *dpb.FileDescriptorProto, processed map[string]struct{}) { |
|
filename := fd.GetName() |
|
if _, ok := processed[filename]; ok { |
|
return |
|
} |
|
processed[filename] = struct{}{} |
|
|
|
prefix := fd.GetPackage() |
|
|
|
for _, msg := range fd.MessageType { |
|
s.processMessage(fd, prefix, msg) |
|
} |
|
for _, en := range fd.EnumType { |
|
s.processEnum(fd, prefix, en) |
|
} |
|
for _, ext := range fd.Extension { |
|
s.processField(fd, prefix, ext) |
|
} |
|
for _, svc := range fd.Service { |
|
svcName := fqn(prefix, svc.GetName()) |
|
s.symbols[svcName] = fd |
|
for _, meth := range svc.Method { |
|
name := fqn(svcName, meth.GetName()) |
|
s.symbols[name] = fd |
|
} |
|
} |
|
|
|
for _, dep := range fd.Dependency { |
|
fdenc := proto.FileDescriptor(dep) |
|
fdDep, err := decodeFileDesc(fdenc) |
|
if err != nil { |
|
continue |
|
} |
|
s.processFile(fdDep, processed) |
|
} |
|
} |
|
|
|
func (s *serverReflectionServer) processMessage(fd *dpb.FileDescriptorProto, prefix string, msg *dpb.DescriptorProto) { |
|
msgName := fqn(prefix, msg.GetName()) |
|
s.symbols[msgName] = fd |
|
|
|
for _, nested := range msg.NestedType { |
|
s.processMessage(fd, msgName, nested) |
|
} |
|
for _, en := range msg.EnumType { |
|
s.processEnum(fd, msgName, en) |
|
} |
|
for _, ext := range msg.Extension { |
|
s.processField(fd, msgName, ext) |
|
} |
|
for _, fld := range msg.Field { |
|
s.processField(fd, msgName, fld) |
|
} |
|
for _, oneof := range msg.OneofDecl { |
|
oneofName := fqn(msgName, oneof.GetName()) |
|
s.symbols[oneofName] = fd |
|
} |
|
} |
|
|
|
func (s *serverReflectionServer) processEnum(fd *dpb.FileDescriptorProto, prefix string, en *dpb.EnumDescriptorProto) { |
|
enName := fqn(prefix, en.GetName()) |
|
s.symbols[enName] = fd |
|
|
|
for _, val := range en.Value { |
|
valName := fqn(enName, val.GetName()) |
|
s.symbols[valName] = fd |
|
} |
|
} |
|
|
|
func (s *serverReflectionServer) processField(fd *dpb.FileDescriptorProto, prefix string, fld *dpb.FieldDescriptorProto) { |
|
fldName := fqn(prefix, fld.GetName()) |
|
s.symbols[fldName] = fd |
|
} |
|
|
|
func fqn(prefix, name string) string { |
|
if prefix == "" { |
|
return name |
|
} |
|
return prefix + "." + name |
|
} |
|
|
|
// fileDescForType gets the file descriptor for the given type. |
|
// The given type should be a proto message. |
|
func (s *serverReflectionServer) fileDescForType(st reflect.Type) (*dpb.FileDescriptorProto, error) { |
|
m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(protoMessage) |
|
if !ok { |
|
return nil, fmt.Errorf("failed to create message from type: %v", st) |
|
} |
|
enc, _ := m.Descriptor() |
|
|
|
return decodeFileDesc(enc) |
|
} |
|
|
|
// decodeFileDesc does decompression and unmarshalling on the given |
|
// file descriptor byte slice. |
|
func decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) { |
|
raw, err := decompress(enc) |
|
if err != nil { |
|
return nil, fmt.Errorf("failed to decompress enc: %v", err) |
|
} |
|
|
|
fd := new(dpb.FileDescriptorProto) |
|
if err := proto.Unmarshal(raw, fd); err != nil { |
|
return nil, fmt.Errorf("bad descriptor: %v", err) |
|
} |
|
return fd, nil |
|
} |
|
|
|
// decompress does gzip decompression. |
|
func decompress(b []byte) ([]byte, error) { |
|
r, err := gzip.NewReader(bytes.NewReader(b)) |
|
if err != nil { |
|
return nil, fmt.Errorf("bad gzipped descriptor: %v", err) |
|
} |
|
out, err := ioutil.ReadAll(r) |
|
if err != nil { |
|
return nil, fmt.Errorf("bad gzipped descriptor: %v", err) |
|
} |
|
return out, nil |
|
} |
|
|
|
func typeForName(name string) (reflect.Type, error) { |
|
pt := proto.MessageType(name) |
|
if pt == nil { |
|
return nil, fmt.Errorf("unknown type: %q", name) |
|
} |
|
st := pt.Elem() |
|
|
|
return st, nil |
|
} |
|
|
|
func fileDescContainingExtension(st reflect.Type, ext int32) (*dpb.FileDescriptorProto, error) { |
|
m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message) |
|
if !ok { |
|
return nil, fmt.Errorf("failed to create message from type: %v", st) |
|
} |
|
|
|
var extDesc *proto.ExtensionDesc |
|
for id, desc := range proto.RegisteredExtensions(m) { |
|
if id == ext { |
|
extDesc = desc |
|
break |
|
} |
|
} |
|
|
|
if extDesc == nil { |
|
return nil, fmt.Errorf("failed to find registered extension for extension number %v", ext) |
|
} |
|
|
|
return decodeFileDesc(proto.FileDescriptor(extDesc.Filename)) |
|
} |
|
|
|
func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) { |
|
m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message) |
|
if !ok { |
|
return nil, fmt.Errorf("failed to create message from type: %v", st) |
|
} |
|
|
|
exts := proto.RegisteredExtensions(m) |
|
out := make([]int32, 0, len(exts)) |
|
for id := range exts { |
|
out = append(out, id) |
|
} |
|
return out, nil |
|
} |
|
|
|
// fileDescEncodingByFilename finds the file descriptor for given filename, |
|
// does marshalling on it and returns the marshalled result. |
|
func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte, error) { |
|
enc := proto.FileDescriptor(name) |
|
if enc == nil { |
|
return nil, fmt.Errorf("unknown file: %v", name) |
|
} |
|
fd, err := decodeFileDesc(enc) |
|
if err != nil { |
|
return nil, err |
|
} |
|
return proto.Marshal(fd) |
|
} |
|
|
|
// parseMetadata finds the file descriptor bytes specified meta. |
|
// For SupportPackageIsVersion4, m is the name of the proto file, we |
|
// call proto.FileDescriptor to get the byte slice. |
|
// For SupportPackageIsVersion3, m is a byte slice itself. |
|
func parseMetadata(meta interface{}) ([]byte, bool) { |
|
// Check if meta is the file name. |
|
if fileNameForMeta, ok := meta.(string); ok { |
|
return proto.FileDescriptor(fileNameForMeta), true |
|
} |
|
|
|
// Check if meta is the byte slice. |
|
if enc, ok := meta.([]byte); ok { |
|
return enc, true |
|
} |
|
|
|
return nil, false |
|
} |
|
|
|
// fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol, |
|
// does marshalling on it and returns the marshalled result. |
|
// The given symbol can be a type, a service or a method. |
|
func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) ([]byte, error) { |
|
_, symbols := s.getSymbols() |
|
fd := symbols[name] |
|
if fd == nil { |
|
// Check if it's a type name that was not present in the |
|
// transitive dependencies of the registered services. |
|
if st, err := typeForName(name); err == nil { |
|
fd, err = s.fileDescForType(st) |
|
if err != nil { |
|
return nil, err |
|
} |
|
} |
|
} |
|
|
|
if fd == nil { |
|
return nil, fmt.Errorf("unknown symbol: %v", name) |
|
} |
|
|
|
return proto.Marshal(fd) |
|
} |
|
|
|
// fileDescEncodingContainingExtension finds the file descriptor containing given extension, |
|
// does marshalling on it and returns the marshalled result. |
|
func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32) ([]byte, error) { |
|
st, err := typeForName(typeName) |
|
if err != nil { |
|
return nil, err |
|
} |
|
fd, err := fileDescContainingExtension(st, extNum) |
|
if err != nil { |
|
return nil, err |
|
} |
|
return proto.Marshal(fd) |
|
} |
|
|
|
// allExtensionNumbersForTypeName returns all extension numbers for the given type. |
|
func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]int32, error) { |
|
st, err := typeForName(name) |
|
if err != nil { |
|
return nil, err |
|
} |
|
extNums, err := s.allExtensionNumbersForType(st) |
|
if err != nil { |
|
return nil, err |
|
} |
|
return extNums, nil |
|
} |
|
|
|
// ServerReflectionInfo is the reflection service handler. |
|
func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error { |
|
for { |
|
in, err := stream.Recv() |
|
if err == io.EOF { |
|
return nil |
|
} |
|
if err != nil { |
|
return err |
|
} |
|
|
|
out := &rpb.ServerReflectionResponse{ |
|
ValidHost: in.Host, |
|
OriginalRequest: in, |
|
} |
|
switch req := in.MessageRequest.(type) { |
|
case *rpb.ServerReflectionRequest_FileByFilename: |
|
b, err := s.fileDescEncodingByFilename(req.FileByFilename) |
|
if err != nil { |
|
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ |
|
ErrorResponse: &rpb.ErrorResponse{ |
|
ErrorCode: int32(codes.NotFound), |
|
ErrorMessage: err.Error(), |
|
}, |
|
} |
|
} else { |
|
out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ |
|
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}}, |
|
} |
|
} |
|
case *rpb.ServerReflectionRequest_FileContainingSymbol: |
|
b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol) |
|
if err != nil { |
|
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ |
|
ErrorResponse: &rpb.ErrorResponse{ |
|
ErrorCode: int32(codes.NotFound), |
|
ErrorMessage: err.Error(), |
|
}, |
|
} |
|
} else { |
|
out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ |
|
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}}, |
|
} |
|
} |
|
case *rpb.ServerReflectionRequest_FileContainingExtension: |
|
typeName := req.FileContainingExtension.ContainingType |
|
extNum := req.FileContainingExtension.ExtensionNumber |
|
b, err := s.fileDescEncodingContainingExtension(typeName, extNum) |
|
if err != nil { |
|
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ |
|
ErrorResponse: &rpb.ErrorResponse{ |
|
ErrorCode: int32(codes.NotFound), |
|
ErrorMessage: err.Error(), |
|
}, |
|
} |
|
} else { |
|
out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ |
|
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}}, |
|
} |
|
} |
|
case *rpb.ServerReflectionRequest_AllExtensionNumbersOfType: |
|
extNums, err := s.allExtensionNumbersForTypeName(req.AllExtensionNumbersOfType) |
|
if err != nil { |
|
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ |
|
ErrorResponse: &rpb.ErrorResponse{ |
|
ErrorCode: int32(codes.NotFound), |
|
ErrorMessage: err.Error(), |
|
}, |
|
} |
|
} else { |
|
out.MessageResponse = &rpb.ServerReflectionResponse_AllExtensionNumbersResponse{ |
|
AllExtensionNumbersResponse: &rpb.ExtensionNumberResponse{ |
|
BaseTypeName: req.AllExtensionNumbersOfType, |
|
ExtensionNumber: extNums, |
|
}, |
|
} |
|
} |
|
case *rpb.ServerReflectionRequest_ListServices: |
|
svcNames, _ := s.getSymbols() |
|
serviceResponses := make([]*rpb.ServiceResponse, len(svcNames)) |
|
for i, n := range svcNames { |
|
serviceResponses[i] = &rpb.ServiceResponse{ |
|
Name: n, |
|
} |
|
} |
|
out.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{ |
|
ListServicesResponse: &rpb.ListServiceResponse{ |
|
Service: serviceResponses, |
|
}, |
|
} |
|
default: |
|
return status.Errorf(codes.InvalidArgument, "invalid MessageRequest: %v", in.MessageRequest) |
|
} |
|
|
|
if err := stream.Send(out); err != nil { |
|
return err |
|
} |
|
} |
|
}
|
|
|