阅读protobuf Go代码
我想通过学习protobuf-go更深入了解go反射的应用,就看了下它的代码。
所看版本是
go 1.14
require (
github.com/golang/protobuf v1.4.3
google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.0.1 // indirect
google.golang.org/protobuf v1.23.0
)
粗略看一下
先瞄一眼它的生成文件*.pb.go
在proto中定义的一个message
message TestMsg {
string str = 1;
int32 vint32 = 2;
double vdouble = 3;
repeated string arrayStr = 4;
}
生成的.pb.go中有下面一个struct
type TestMsg struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Str string `protobuf:"bytes,1,opt,name=str,proto3" json:"str,omitempty"`
Vint32 int32 `protobuf:"varint,2,opt,name=vint32,proto3" json:"vint32,omitempty"`
Vdouble float64 `protobuf:"fixed64,3,opt,name=vdouble,proto3" json:"vdouble,omitempty"`
ArrayStr []string `protobuf:"bytes,4,rep,name=arrayStr,proto3" json:"arrayStr,omitempty"`
}
除了proto里对应字段和struct tags外,开始处有三个字段。
1、有个变量var file_xxxx_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
并且设置了里面的Exporter的值,file_msgid_proto_msgTypes[0].Exporter
2、TestMsg的ProtoReflect函数,
func (x *MsgIdManager) ProtoReflect() protoreflect.Message {
mi := &file_msgid_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
先是将x指针转成MessageSate指针,这个特性和C一样。
然后分别调用LoadMessageInfo和StoreMessageInfo。如果ms中的atomicMessageInfo指针为空,则将atomicMessageInfo指向mi。也就是所有的同一类型message实例共享一个MessageInfo实例。
3、TestMsg还有个Reset函数。
首先protobuf-go内部还有几个特别的类型,在pragma.go里
package pragma
import "sync"
// NoUnkeyedLiterals can be embedded in a struct to prevent unkeyed literals.
type NoUnkeyedLiterals struct{}
// DoNotImplement can be embedded in an interface to prevent trivial
// implementations of the interface.
//
// This is useful to prevent unauthorized implementations of an interface
// so that it can be extended in the future for any protobuf language changes.
type DoNotImplement interface{ ProtoInternal(DoNotImplement) }
// DoNotCompare can be embedded in a struct to prevent comparability.
type DoNotCompare [0]func()
// DoNotCopy can be embedded in a struct to help prevent shallow copies.
// This does not rely on a Go language feature, but rather a special case
// within the vet checker.
//
// See https://golang.org/issues/8005.
type DoNotCopy [0]sync.Mutex
MessageState在internal\impl\message_reflect.go
// MessageState is a data structure that is nested as the first field in a
// concrete message. It provides a way to implement the ProtoReflect method
// in an allocation-free way without needing to have a shadow Go type generated
// for every message type. This technique only works using unsafe.
//
//
// Example generated code:
//
// type M struct {
// state protoimpl.MessageState
//
// Field1 int32
// Field2 string
// Field3 *BarMessage
// ...
// }
//
// func (m *M) ProtoReflect() protoreflect.Message {
// ...
// }
//
// The MessageState type holds a *MessageInfo, which must be atomically set to
// the message info associated with a given message instance.
// By unsafely converting a *M into a *MessageState, the MessageState object
// has access to all the information needed to implement protobuf reflection.
// It has access to the message info as its first field, and a pointer to the
// MessageState is identical to a pointer to the concrete message value.
//
//
// Requirements:
// • The type M must implement protoreflect.ProtoMessage.
// • The address of m must not be nil.
// • The address of m and the address of m.state must be equal,
// even though they are different Go types.
type MessageState struct {
pragma.NoUnkeyedLiterals
pragma.DoNotCompare
pragma.DoNotCopy
atomicMessageInfo *MessageInfo
}
MessageState是作为具体消息中的第一个字段嵌套的数据结构。它提供了一种无需分配的方式来实现ProtoReflect方法,而不需要为每个消息类型生成一个shadow Go类型。这种技术只在不安全的情况下有效。
MessageState类型包含一个MessageInfo,它必须原子性地设置为与给定消息实例关联的消息信息。 通过不安全地将M转换为*MessageState,MessageState对象可以访问实现protobuf反射所需的所有信息。 它可以访问消息信息作为第一个字段,指向MessageState的指针与指向具体消息值的指针相同。
MessageInfo在internal\impl\message.go
// MessageInfo provides protobuf related functionality for a given Go type
// that represents a message. A given instance of MessageInfo is tied to
// exactly one Go type, which must be a pointer to a struct type.
//
// The exported fields must be populated before any methods are called
// and cannot be mutated after set.
type MessageInfo struct {
// GoReflectType is the underlying message Go type and must be populated.
GoReflectType reflect.Type // pointer to struct
// Desc is the underlying message descriptor type and must be populated.
Desc pref.MessageDescriptor
// Exporter must be provided in a purego environment in order to provide
// access to unexported fields.
Exporter exporter
// OneofWrappers is list of pointers to oneof wrapper struct types.
OneofWrappers []interface{}
initMu sync.Mutex // protects all unexported fields
initDone uint32
reflectMessageInfo // for reflection implementation
coderMessageInfo // for fast-path method implementations
}
Unmarshal的调用堆栈
下面是一些关键点的快照
x := &TestMsg{}
out, err := proto.Marshal(x)
if err != nil {
t.Error("Marshal", err)
}
x2 := &TestMsg{}
proto.Unmarshal(out, x2)
- Unmarshal(b []byte, m Message)
- _, err := UnmarshalOptions{}.unmarshal(b, m.ProtoReflect())
- methods := protoMethods(m); out, err = methods.Unmarshal(in)
- (mi *MessageInfo) unmarshal(in piface.UnmarshalInput); out,err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{ flags: in.Flags, resolver: in.Resolver})
- (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions); f = mi.denseCoderFields[num]; o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts) 在这里会buff进行decode,解出所有字段。这里p.Apply(f.offset)对指针偏移操作,到这里,可以想明白一点,pd.go里生成的message所有字段数据都是用指针保存,这样可以统一使用指针偏移来取到对应字段的指针。不过有点需要注意的是offset不是线性计算的,TestMsg的4个字段分别是40,56 ,64,72。有个奇怪的是string字段占用了16bits?有兴趣的可以看下offset的具体计算。unmarshalPointer里每次decode数据时,先取出tag,tag是一个uint62类型,包含了field number和wire type,低3位是wire type,采用protobuf整形数据的压缩算法。
- f.funcs.unmarshal 会根据变量的不同,会对应不同的函数。比如string的会是consumeStringValidateUTF8,在文件里internal\impl\codec_gen.go,可以搜索func consume开始的函数。如果字段类型是个message的话,对应的unmarshal函数是consumeMessageInfo。整个结构有点像棵树,所有的基本类型是叶子结点,由对应的consumeXX来decode,如果是message,则表现是个父母结点。p.SetPointer(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem()))) 这里应该是new这个message的对象,并取它的指针。
.pb.go里var file_xxx_proto_rawDesc是用来做什么。是FileDescriptorProto Marshal后的数据
descProto := proto.Clone(f.Proto).(*descriptorpb.FileDescriptorProto)
descProto.SourceCodeInfo = nil // drop source code information
b, err := proto.MarshalOptions{AllowPartial: true, Deterministic: true}.Marshal(descProto)
运行时用unmarshalSeed解释file_xxx_proto_rawDesc,里面包含文件名,package名,message,service等信息,也就是整个文件的描述信息。
里面大量用到指针,主要为了减少内存拷贝,提高效率。
反射结构的初始化
message结构反射信息初始化,Marshal和Unmarshal都要用到orderedCoderFields来遍历所有字段。
internal\impl\message.go里
func (mi *MessageInfo) initOnce() {
mi.initMu.Lock()
defer mi.initMu.Unlock()
if mi.initDone == 1 {
return
}
t := mi.GoReflectType
if t.Kind() != reflect.Ptr && t.Elem().Kind() != reflect.Struct {
panic(fmt.Sprintf("got %v, want *struct kind", t))
}
t = t.Elem()
si := mi.makeStructInfo(t)
mi.makeReflectFuncs(t, si) // 函数在internal\impl\message_reflect.go
mi.makeCoderMethods(t, si) // orderedCoderFields的初始化,记录字段该用consumeXXX,appendXX
atomic.StoreUint32(&mi.initDone, 1)
}
makeStructInfo会利用反射和struct tags(也就是.pb.go文件中,message struct字段末尾处的字符串)获得字段信息。下面是Vint32字段的struct tag。
Vint32 int32 `protobuf:"varint,1,opt,name=vint32,proto3" json:"vint32,omitempty"`
下面是makeStructInfo函数,在internal\impl\message.go里
func (mi *MessageInfo) makeStructInfo(t reflect.Type) structInfo {
si := structInfo{
sizecacheOffset: invalidOffset,
weakOffset: invalidOffset,
unknownOffset: invalidOffset,
extensionOffset: invalidOffset,
fieldsByNumber: map[pref.FieldNumber]reflect.StructField{},
oneofsByName: map[pref.Name]reflect.StructField{},
oneofWrappersByType: map[reflect.Type]pref.FieldNumber{},
oneofWrappersByNumber: map[pref.FieldNumber]reflect.Type{},
}
fieldLoop:
for i := 0; i < t.NumField(); i++ {
switch f := t.Field(i); f.Name {
case genname.SizeCache, genname.SizeCacheA:
if f.Type == sizecacheType {
si.sizecacheOffset = offsetOf(f, mi.Exporter)
}
case genname.WeakFields, genname.WeakFieldsA:
if f.Type == weakFieldsType {
si.weakOffset = offsetOf(f, mi.Exporter)
}
case genname.UnknownFields, genname.UnknownFieldsA:
if f.Type == unknownFieldsType {
si.unknownOffset = offsetOf(f, mi.Exporter)
}
case genname.ExtensionFields, genname.ExtensionFieldsA, genname.ExtensionFieldsB:
if f.Type == extensionFieldsType {
si.extensionOffset = offsetOf(f, mi.Exporter)
}
default: // 这里
for _, s := range strings.Split(f.Tag.Get("protobuf"), ",") {
if len(s) > 0 && strings.Trim(s, "0123456789") == "" {
n, _ := strconv.ParseUint(s, 10, 64)
si.fieldsByNumber[pref.FieldNumber(n)] = f
continue fieldLoop
}
}
if s := f.Tag.Get("protobuf_oneof"); len(s) > 0 {
si.oneofsByName[pref.Name(s)] = f
continue fieldLoop
}
}
}
// Derive a mapping of oneof wrappers to fields.
oneofWrappers := mi.OneofWrappers
for _, method := range []string{"XXX_OneofFuncs", "XXX_OneofWrappers"} {
if fn, ok := reflect.PtrTo(t).MethodByName(method); ok {
for _, v := range fn.Func.Call([]reflect.Value{reflect.Zero(fn.Type.In(0))}) {
if vs, ok := v.Interface().([]interface{}); ok {
oneofWrappers = vs
}
}
}
}
for _, v := range oneofWrappers {
...
}
return si
}
在internal\impl\codec_message.go里,makeCoderMethods函数中
fieldOffset = offsetOf(fs, mi.Exporter)
childMessage, funcs = fieldCoder(fd, ft) // *
}
cf := &preallocFields[i]
*cf = coderFieldInfo{
num: fd.Number(),
offset: fieldOffset,
wiretag: wiretag,
ft: ft,
tagsize: protowire.SizeVarint(wiretag),
funcs: funcs, // *
mi: childMessage,
validation: newFieldValidationInfo(mi, si, fd, ft),
isPointer: fd.Cardinality() == pref.Repeated || fd.HasPresence(),
isRequired: fd.Cardinality() == pref.Required,
}
mi.orderedCoderFields = append(mi.orderedCoderFields, cf) // *
mi.coderFields[cf.num] = cf
函数fieldCoder会根据字段类型选择不同的pointerCoderFuncs,下面是uint32字段的。
var coderUint32 = pointerCoderFuncs{
size: sizeUint32, //计算字段大小
marshal: appendUint32, //序列化字段
unmarshal: consumeUint32,//反序列号字段
merge: mergeUint32,//合并吧
}
根据一个proto message和field id得到它field的类型
m := TestMsg{}
md := m.ProtoReflect().Descriptor()
fd := md.Fields().ByNumber(protoreflect.FieldNumber(fieldid))
if fd == nil {
return nil
}
v := m.Mutable(fd)
lmi, ok := v.Interface().(interface{ LoadMessageInfo() *protoimpl.MessageInfo })
if !ok {
return nil
}
ms := lmi.LoadMessageInfo()
fmt.Println("GoReflectType", ms.GoReflectType)
加菜,getMessageInfo里判断类型和是否有某个接口的写法:
func getMessageInfo(mt reflect.Type) *MessageInfo {
m, ok := reflect.Zero(mt).Interface().(pref.ProtoMessage)
if !ok {
return nil
}
mr, ok := m.ProtoReflect().(interface{ ProtoMessageInfo() *MessageInfo })
if !ok {
return nil
}
return mr.ProtoMessageInfo()
}