rpcx是一款非常优秀的golang编写的rpc框架,地址为https://github.com/smallnest/rpcx
下面主要分析一下server端的源码
1.1 初始化server
// Server is rpcx server that use TCP or UDP.
type Server struct {
ln net.Listener //net监听器
readTimeout time.Duration
writeTimeout time.Duration
gatewayHTTPServer *http.Server //http网关
jsonrpcHTTPServer *http.Server
DisableHTTPGateway bool // should disable http invoke or not.
DisableJSONRPC bool // should disable json rpc or not.
AsyncWrite bool // set true if your server only serves few clients
serviceMapMu sync.RWMutex //serviceMap的锁 提供并发安全
serviceMap map[string]*service //server端提供的service记录表
router map[string]Handler //路由表
mu sync.RWMutex
activeConn map[net.Conn]struct{} //server提供的活跃的connection
doneChan chan struct{} //server完成通知channel
seq uint64 //server端序列号
inShutdown int32
onShutdown []func(s *Server)
onRestart []func(s *Server)
// TLSConfig for creating tls tcp connection.
tlsConfig *tls.Config
// BlockCrypt for kcp.BlockCrypt
options map[string]interface{}
// CORS options
corsOptions *CORSOptions
Plugins PluginContainer
// AuthFunc can be used to auth.
AuthFunc func(ctx context.Context, req *protocol.Message, token string) error
handlerMsgNum int32 //处理消息量
HandleServiceError func(error)
}
// NewServer returns a server.
func NewServer(options ...OptionFn) *Server {
s := &Server{
Plugins: &pluginContainer{},
options: make(map[string]interface{}),
activeConn: make(map[net.Conn]struct{}),
doneChan: make(chan struct{}),
serviceMap: make(map[string]*service),
router: make(map[string]Handler),
AsyncWrite: false, // 除非你想做进一步的优化测试,否则建议你设置为false
}
for _, op := range options {
op(s)
}
if s.options["TCPKeepAlivePeriod"] == nil {
s.options["TCPKeepAlivePeriod"] = 3 * time.Minute
}
return s
}
1.2 启动server
// Serve starts and listens RPC requests.
// It is blocked until receiving connections from clients.
func (s *Server) Serve(network, address string) (err error) {
var ln net.Listener
ln, err = s.makeListener(network, address)
if err != nil {
return err
}
if network == "http" {
s.serveByHTTP(ln, "")
return nil
}
if network == "ws" || network == "wss" {
s.serveByWS(ln, "")
return nil
}
// try to start gateway
ln = s.startGateway(network, ln)
return s.serveListener(ln)
}
如果是非http请求,则走serveListener方法
// serveListener accepts incoming connections on the Listener ln,
// creating a new service goroutine for each.
// The service goroutines read requests and then call services to reply to them.
func (s *Server) serveListener(ln net.Listener) error {
var tempDelay time.Duration
s.mu.Lock()
s.ln = ln
s.mu.Unlock()
for {
conn, e := ln.Accept()
if e != nil {
//如果服务已经shutdown,直接返回错误
if s.isShutdown() {
<-s.doneChan
return ErrServerClosed
}
if ne, ok := e.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
} else {
tempDelay *= 2
}
if max := 1 * time.Second; tempDelay > max {
tempDelay = max
}
log.Errorf("rpcx: Accept error: %v; retrying in %v", e, tempDelay)
time.Sleep(tempDelay)
continue
}
if errors.Is(e, cmux.ErrListenerClosed) {
return ErrServerClosed
}
return e
}
tempDelay = 0
if tc, ok := conn.(*net.TCPConn); ok {
period := s.options["TCPKeepAlivePeriod"]
if period != nil {
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(period.(time.Duration))
tc.SetLinger(10)
}
}
conn, ok := s.Plugins.DoPostConnAccept(conn)
if !ok {
conn.Close()
continue
}
s.mu.Lock()
s.activeConn[conn] = struct{}{} //记录当前server端的活跃连接数
s.mu.Unlock()
if share.Trace {
log.Debugf("server accepted an conn: %v", conn.RemoteAddr().String())
}
go s.serveConn(conn)
}
}
开启协程去处理请求
func (s *Server) serveConn(conn net.Conn) {
if s.isShutdown() {
s.closeConn(conn)
return
}
defer func() {
if err := recover(); err != nil {
const size = 64 << 10
buf := make([]byte, size)
ss := runtime.Stack(buf, false)
if ss > size {
ss = size
}
buf = buf[:ss]
log.Errorf("serving %s panic error: %s, stack:\n %s", conn.RemoteAddr(), err, buf)
}
if share.Trace {
log.Debugf("server closed conn: %v", conn.RemoteAddr().String())
}
// make sure all inflight requests are handled and all drained
if s.isShutdown() {
<-s.doneChan
}
s.closeConn(conn)
}()
if tlsConn, ok := conn.(*tls.Conn); ok {
if d := s.readTimeout; d != 0 {
conn.SetReadDeadline(time.Now().Add(d))
}
if d := s.writeTimeout; d != 0 {
conn.SetWriteDeadline(time.Now().Add(d))
}
if err := tlsConn.Handshake(); err != nil {
log.Errorf("rpcx: TLS handshake error from %s: %v", conn.RemoteAddr(), err)
return
}
}
//设置conn读取buffer,默认是1kb
r := bufio.NewReaderSize(conn, ReaderBuffsize)
var writeCh chan *[]byte
if s.AsyncWrite {
writeCh = make(chan *[]byte, 1)
defer close(writeCh)
go s.serveAsyncWrite(conn, writeCh)
}
for {
if s.isShutdown() {
return
}
t0 := time.Now()
if s.readTimeout != 0 {
conn.SetReadDeadline(t0.Add(s.readTimeout))
}
ctx := share.WithValue(context.Background(), RemoteConnContextKey, conn)
//读取client请求信息
req, err := s.readRequest(ctx, r)
if err != nil {
if err == io.EOF {
log.Infof("client has closed this connection: %s", conn.RemoteAddr().String())
} else if errors.Is(err, net.ErrClosed) {
log.Infof("rpcx: connection %s is closed", conn.RemoteAddr().String())
} else if errors.Is(err, ErrReqReachLimit) {
if !req.IsOneway() {
//如果需要返回
res := req.Clone()
//遇到错误了,将req信息复制一份,原样返回给client
res.SetMessageType(protocol.Response)
handleError(res, err)
//返回response
s.sendResponse(ctx, conn, writeCh, err, req, res)
//将res消息体放回池中,便于后边复用
protocol.FreeMsg(res)
} else {
//执行插件返回前钩子处理逻辑
s.Plugins.DoPreWriteResponse(ctx, req, nil, err)
}
//将req消息体放回池中,便于复用
protocol.FreeMsg(req)
continue
} else {
log.Warnf("rpcx: failed to read request: %v", err)
}
//将req消息体放回池中,便于复用
protocol.FreeMsg(req)
return
}
//if share.Trace {
log.Debugf("server received an request %+v from conn: %v", req, conn.RemoteAddr().String())
// }
ctx = share.WithLocalValue(ctx, StartRequestContextKey, time.Now().UnixNano())
closeConn := false
if !req.IsHeartbeat() {
err = s.auth(ctx, req)
fmt.Println("auth", err)
closeConn = err != nil
}
if err != nil {
if !req.IsOneway() {
res := req.Clone()
res.SetMessageType(protocol.Response)
handleError(res, err)
s.sendResponse(ctx, conn, writeCh, err, req, res)
protocol.FreeMsg(res)
} else {
s.Plugins.DoPreWriteResponse(ctx, req, nil, err)
}
protocol.FreeMsg(req)
// auth failed, closed the connection
if closeConn {
log.Infof("auth failed for conn %s: %v", conn.RemoteAddr().String(), err)
return
}
continue
}
//代表之前的都没有错误,开始正常的处理
go func() {
defer func() {
if r := recover(); r != nil {
// maybe panic because the writeCh is closed.
log.Errorf("failed to handle request: %v", r)
}
}()
//记录处理数+1,在返回后-1
atomic.AddInt32(&s.handlerMsgNum, 1)
defer atomic.AddInt32(&s.handlerMsgNum, -1)
if req.IsHeartbeat() {
s.Plugins.DoHeartbeatRequest(ctx, req)
req.SetMessageType(protocol.Response)
data := req.EncodeSlicePointer()
if s.AsyncWrite {
writeCh <- data
} else {
if s.writeTimeout != 0 {
conn.SetWriteDeadline(time.Now().Add(s.writeTimeout))
}
conn.Write(*data)
protocol.PutData(data)
}
protocol.FreeMsg(req)
return
}
resMetadata := make(map[string]string)
if req.Metadata == nil {
req.Metadata = make(map[string]string)
}
ctx = share.WithLocalValue(share.WithLocalValue(ctx, share.ReqMetaDataKey, req.Metadata),
share.ResMetaDataKey, resMetadata)
cancelFunc := parseServerTimeout(ctx, req)
if cancelFunc != nil {
defer cancelFunc()
}
//server插件处理请求前的钩子逻辑
s.Plugins.DoPreHandleRequest(ctx, req)
if share.Trace {
log.Debugf("server handle request %+v from conn: %v", req, conn.RemoteAddr().String())
}
// first use handler
if handler, ok := s.router[req.ServicePath+"."+req.ServiceMethod]; ok {
sctx := NewContext(ctx, conn, req, writeCh)
err := handler(sctx)
if err != nil {
log.Errorf("[handler internal error]: servicepath: %s, servicemethod, err: %v", req.ServicePath, req.ServiceMethod, err)
}
protocol.FreeMsg(req)
return
}
res, err := s.handleRequest(ctx, req)
if err != nil {
if s.HandleServiceError != nil {
s.HandleServiceError(err)
} else {
log.Warnf("rpcx: failed to handle request: %v", err)
}
}
if !req.IsOneway() {
//需要返回的情况下
//元数据的处理
if len(resMetadata) > 0 { // copy meta in context to request
meta := res.Metadata
if meta == nil {
res.Metadata = resMetadata
} else {
for k, v := range resMetadata {
if meta[k] == "" {
meta[k] = v
}
}
}
}
//返回response
s.sendResponse(ctx, conn, writeCh, err, req, res)
}
if share.Trace {
log.Debugf("server write response %+v for an request %+v from conn: %v", res, req, conn.RemoteAddr().String())
}
//将req和res消息体返回池中,便于复用
protocol.FreeMsg(req)
protocol.FreeMsg(res)
}()
}
}
首先是s.readRequest,读取请求信息
func (s *Server) readRequest(ctx context.Context, r io.Reader) (req *protocol.Message, err error) {
err = s.Plugins.DoPreReadRequest(ctx)
if err != nil {
return nil, err
}
// pool req?
//sync.pool中获取Message信息
req = protocol.GetPooledMsg()
//解码信息
err = req.Decode(r)
if err == io.EOF {
return req, err
}
perr := s.Plugins.DoPostReadRequest(ctx, req, err)
if err == nil {
err = perr
}
return req, err
}
var msgPool = sync.Pool{
New: func() interface{} {
header := Header([12]byte{})
header[0] = magicNumber
return &Message{
Header: &header,
}
},
}
// Message is the generic type of Request and Response.
type Message struct {
*Header
ServicePath string
ServiceMethod string
Metadata map[string]string
Payload []byte
data []byte
}
// GetPooledMsg gets a pooled message.
func GetPooledMsg() *Message {
return msgPool.Get().(*Message)
}
比较经典的是decode编码信息这段
关于数据协议的详细描述,见链接协议详解 · Go RPC编程指南
// Decode decodes a message from reader.
func (m *Message) Decode(r io.Reader) error {
// validate rest length for each step?
// parse header
//解析头部
_, err := io.ReadFull(r, m.Header[:1])
if err != nil {
return err
}
//固定的magicnumber
if !m.Header.CheckMagicNumber() {
return fmt.Errorf("wrong magic number: %v", m.Header[0])
}
_, err = io.ReadFull(r, m.Header[1:])
if err != nil {
return err
}
// total
lenData := poolUint32Data.Get().(*[]byte)
_, err = io.ReadFull(r, *lenData)
if err != nil {
poolUint32Data.Put(lenData)
return err
}
l := binary.BigEndian.Uint32(*lenData)
poolUint32Data.Put(lenData)
//MaxMessageLength默认为0,不限制信息长度
//如果解析出来的数据总长度大于设定的信息长度,则返回错误
if MaxMessageLength > 0 && int(l) > MaxMessageLength {
return ErrMessageTooLong
}
totalL := int(l)
if cap(m.data) >= totalL { // reuse data
m.data = m.data[:totalL]
} else {
m.data = make([]byte, totalL)
}
data := m.data
_, err = io.ReadFull(r, data)
if err != nil {
return err
}
n := 0
// parse servicePath
l = binary.BigEndian.Uint32(data[n:4])
n = n + 4
nEnd := n + int(l)
m.ServicePath = util.SliceByteToString(data[n:nEnd])
n = nEnd
// parse serviceMethod
l = binary.BigEndian.Uint32(data[n : n+4])
n = n + 4
nEnd = n + int(l)
m.ServiceMethod = util.SliceByteToString(data[n:nEnd])
n = nEnd
// parse meta
l = binary.BigEndian.Uint32(data[n : n+4])
n = n + 4
nEnd = n + int(l)
if l > 0 {
m.Metadata, err = decodeMetadata(l, data[n:nEnd])
if err != nil {
return err
}
}
n = nEnd
// parse payload
l = binary.BigEndian.Uint32(data[n : n+4])
_ = l
n = n + 4
m.Payload = data[n:]
if m.CompressType() != None {
compressor := Compressors[m.CompressType()]
if compressor == nil {
return ErrUnsupportedCompressor
}
m.Payload, err = compressor.Unzip(m.Payload)
if err != nil {
return err
}
}
return err
}
读取到请求信息后,针对请求信息中的方法名,调用对用的方法,具体的函数是handleRequest
func (s *Server) handleRequest(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
serviceName := req.ServicePath
methodName := req.ServiceMethod
res = req.Clone()
res.SetMessageType(protocol.Response)
s.serviceMapMu.RLock()
service := s.serviceMap[serviceName]
if share.Trace {
log.Debugf("server get service %+v for an request %+v", service, req)
}
s.serviceMapMu.RUnlock()
if service == nil {
err = errors.New("rpcx: can't find service " + serviceName)
return handleError(res, err)
}
mtype := service.method[methodName]
if mtype == nil {
if service.function[methodName] != nil { // check raw functions
return s.handleRequestForFunction(ctx, req)
}
err = errors.New("rpcx: can't find method " + methodName)
return handleError(res, err)
}
// get a argv object from object pool
argv := reflectTypePools.Get(mtype.ArgType)
codec := share.Codecs[req.SerializeType()]
if codec == nil {
err = fmt.Errorf("can not find codec for %d", req.SerializeType())
return handleError(res, err)
}
//将请求信息根据不同的编码接口实现方进行解码
err = codec.Decode(req.Payload, argv)
if err != nil {
return handleError(res, err)
}
// and get a reply object from object pool
replyv := reflectTypePools.Get(mtype.ReplyType)
//插件的前置call钩子处理
argv, err = s.Plugins.DoPreCall(ctx, serviceName, methodName, argv)
if err != nil {
// return reply to object pool
reflectTypePools.Put(mtype.ReplyType, replyv)
return handleError(res, err)
}
if mtype.ArgType.Kind() != reflect.Ptr {
err = service.call(ctx, mtype, reflect.ValueOf(argv).Elem(), reflect.ValueOf(replyv))
} else {
//调用service注册的方法,实现机制为反射
//returnValues := function.Call([]reflect.Value{s.rcvr, reflect.ValueOf(ctx), argv, replyv})
err = service.call(ctx, mtype, reflect.ValueOf(argv), reflect.ValueOf(replyv))
}
if err == nil {
replyv, err = s.Plugins.DoPostCall(ctx, serviceName, methodName, argv, replyv)
}
// return argc to object pool
reflectTypePools.Put(mtype.ArgType, argv)
if err != nil {
if replyv != nil {
data, err := codec.Encode(replyv)
// return reply to object pool
reflectTypePools.Put(mtype.ReplyType, replyv)
if err != nil {
return handleError(res, err)
}
res.Payload = data
}
return handleError(res, err)
}
if !req.IsOneway() {
data, err := codec.Encode(replyv)
fmt.Println("req.Payload", string(data))
// return reply to object pool
reflectTypePools.Put(mtype.ReplyType, replyv)
if err != nil {
return handleError(res, err)
}
res.Payload = data
} else if replyv != nil {
reflectTypePools.Put(mtype.ReplyType, replyv)
}
if share.Trace {
log.Debugf("server called service %+v for an request %+v", service, req)
}
return res, nil
}
拿到返回信息后,sendResponse.整个调用过程就结束了
func (s *Server) sendResponse(ctx *share.Context, conn net.Conn, writeCh chan *[]byte, err error, req, res *protocol.Message) {
//如果Payload中的数据长度大于1024,则根据请求client中设置的压缩方法进行压缩
if len(res.Payload) > 1024 && req.CompressType() != protocol.None {
res.SetCompressType(req.CompressType())
}
//根据协议的内容对数据进行编码后发送出去
data := res.EncodeSlicePointer()
s.Plugins.DoPreWriteResponse(ctx, req, res, err)
//如果是异步返回的话,将结果写入结果channel中
if s.AsyncWrite {
writeCh <- data
} else {
if s.writeTimeout != 0 {
conn.SetWriteDeadline(time.Now().Add(s.writeTimeout))
}
//往返回连接中写入返回数据
conn.Write(*data)
//将byte slice放回池中进行复用
protocol.PutData(data)
}
//server端插件的返回后置钩子函数处理
s.Plugins.DoPostWriteResponse(ctx, req, res, err)
}