当前位置: 首页 > 工具软件 > rpcx > 使用案例 >

rpcx源码解析之服务端

百里金林
2023-12-01

rpcx是一款非常优秀的golang编写的rpc框架,地址为https://github.com/smallnest/rpcx

下面主要分析一下server端的源码

1.1 初始化server

// Server is rpcx server that use TCP or UDP.
type Server struct {
	ln                 net.Listener //net监听器 
	readTimeout        time.Duration
	writeTimeout       time.Duration
	gatewayHTTPServer  *http.Server //http网关
	jsonrpcHTTPServer  *http.Server
	DisableHTTPGateway bool // should disable http invoke or not.
	DisableJSONRPC     bool // should disable json rpc or not.
	AsyncWrite         bool // set true if your server only serves few clients

	serviceMapMu sync.RWMutex          //serviceMap的锁 提供并发安全
	serviceMap   map[string]*service  //server端提供的service记录表

	router map[string]Handler    //路由表

	mu         sync.RWMutex
	activeConn map[net.Conn]struct{}  //server提供的活跃的connection
	doneChan   chan struct{}          //server完成通知channel
	seq        uint64        //server端序列号
        
	inShutdown int32
	onShutdown []func(s *Server)
	onRestart  []func(s *Server)

	// TLSConfig for creating tls tcp connection.
	tlsConfig *tls.Config
	// BlockCrypt for kcp.BlockCrypt
	options map[string]interface{}

	// CORS options
	corsOptions *CORSOptions

	Plugins PluginContainer

	// AuthFunc can be used to auth.
	AuthFunc func(ctx context.Context, req *protocol.Message, token string) error

	handlerMsgNum int32    //处理消息量

	HandleServiceError func(error)
}

// NewServer returns a server.
func NewServer(options ...OptionFn) *Server {
	s := &Server{
		Plugins:    &pluginContainer{},
		options:    make(map[string]interface{}),
		activeConn: make(map[net.Conn]struct{}),
		doneChan:   make(chan struct{}),
		serviceMap: make(map[string]*service),
		router:     make(map[string]Handler),
		AsyncWrite: false, // 除非你想做进一步的优化测试,否则建议你设置为false
	}

	for _, op := range options {
		op(s)
	}

	if s.options["TCPKeepAlivePeriod"] == nil {
		s.options["TCPKeepAlivePeriod"] = 3 * time.Minute
	}
	return s
}

1.2 启动server

// Serve starts and listens RPC requests.
// It is blocked until receiving connections from clients.
func (s *Server) Serve(network, address string) (err error) {
	var ln net.Listener
	ln, err = s.makeListener(network, address)
	if err != nil {
		return err
	}

	if network == "http" {
		s.serveByHTTP(ln, "")
		return nil
	}

	if network == "ws" || network == "wss" {
		s.serveByWS(ln, "")
		return nil
	}

	// try to start gateway
	ln = s.startGateway(network, ln)

	return s.serveListener(ln)
}

如果是非http请求,则走serveListener方法

// serveListener accepts incoming connections on the Listener ln,
// creating a new service goroutine for each.
// The service goroutines read requests and then call services to reply to them.
func (s *Server) serveListener(ln net.Listener) error {
	var tempDelay time.Duration

	s.mu.Lock()
	s.ln = ln
	s.mu.Unlock()
	for {
		conn, e := ln.Accept()
		if e != nil {
            //如果服务已经shutdown,直接返回错误
			if s.isShutdown() {
				<-s.doneChan
				return ErrServerClosed
			}

			if ne, ok := e.(net.Error); ok && ne.Temporary() {
				if tempDelay == 0 {
					tempDelay = 5 * time.Millisecond
				} else {
					tempDelay *= 2
				}

				if max := 1 * time.Second; tempDelay > max {
					tempDelay = max
				}

				log.Errorf("rpcx: Accept error: %v; retrying in %v", e, tempDelay)
				time.Sleep(tempDelay)
				continue
			}

			if errors.Is(e, cmux.ErrListenerClosed) {
				return ErrServerClosed
			}
			return e
		}
		tempDelay = 0

		if tc, ok := conn.(*net.TCPConn); ok {
			period := s.options["TCPKeepAlivePeriod"]
			if period != nil {
				tc.SetKeepAlive(true)
				tc.SetKeepAlivePeriod(period.(time.Duration))
				tc.SetLinger(10)
			}
		}
		conn, ok := s.Plugins.DoPostConnAccept(conn)
		
		if !ok {
			conn.Close()
			continue
		}

		s.mu.Lock()
		s.activeConn[conn] = struct{}{}  //记录当前server端的活跃连接数
		s.mu.Unlock()

		if share.Trace {
			log.Debugf("server accepted an conn: %v", conn.RemoteAddr().String())
		}

		go s.serveConn(conn)
	}
}

开启协程去处理请求

func (s *Server) serveConn(conn net.Conn) {
	if s.isShutdown() {
		s.closeConn(conn)
		return
	}

	defer func() {
		if err := recover(); err != nil {
			const size = 64 << 10
			buf := make([]byte, size)
			ss := runtime.Stack(buf, false)
			if ss > size {
				ss = size
			}
			buf = buf[:ss]
			log.Errorf("serving %s panic error: %s, stack:\n %s", conn.RemoteAddr(), err, buf)
		}

		if share.Trace {
			log.Debugf("server closed conn: %v", conn.RemoteAddr().String())
		}

		// make sure all inflight requests are handled and all drained
		if s.isShutdown() {
			<-s.doneChan
		}

		s.closeConn(conn)
	}()

	if tlsConn, ok := conn.(*tls.Conn); ok {
		if d := s.readTimeout; d != 0 {
			conn.SetReadDeadline(time.Now().Add(d))
		}
		if d := s.writeTimeout; d != 0 {
			conn.SetWriteDeadline(time.Now().Add(d))
		}
		if err := tlsConn.Handshake(); err != nil {
			log.Errorf("rpcx: TLS handshake error from %s: %v", conn.RemoteAddr(), err)
			return
		}
	}
    
    //设置conn读取buffer,默认是1kb
	r := bufio.NewReaderSize(conn, ReaderBuffsize)

	var writeCh chan *[]byte
	if s.AsyncWrite {
		writeCh = make(chan *[]byte, 1)
		defer close(writeCh)
		go s.serveAsyncWrite(conn, writeCh)
	}

	for {
		if s.isShutdown() {
			return
		}

		t0 := time.Now()
		if s.readTimeout != 0 {
			conn.SetReadDeadline(t0.Add(s.readTimeout))
		}

		ctx := share.WithValue(context.Background(), RemoteConnContextKey, conn)
        //读取client请求信息
		req, err := s.readRequest(ctx, r)
		if err != nil {
			if err == io.EOF {
				log.Infof("client has closed this connection: %s", conn.RemoteAddr().String())
			} else if errors.Is(err, net.ErrClosed) {
				log.Infof("rpcx: connection %s is closed", conn.RemoteAddr().String())
			} else if errors.Is(err, ErrReqReachLimit) {
				if !req.IsOneway() {
                    //如果需要返回
					res := req.Clone()
                    //遇到错误了,将req信息复制一份,原样返回给client
					res.SetMessageType(protocol.Response)

					handleError(res, err)
                     //返回response
					s.sendResponse(ctx, conn, writeCh, err, req, res)
                    //将res消息体放回池中,便于后边复用
					protocol.FreeMsg(res)
				} else {
                   //执行插件返回前钩子处理逻辑
					s.Plugins.DoPreWriteResponse(ctx, req, nil, err)
				}
                //将req消息体放回池中,便于复用
				protocol.FreeMsg(req)
				continue
			} else {
				log.Warnf("rpcx: failed to read request: %v", err)
			}
            //将req消息体放回池中,便于复用
			protocol.FreeMsg(req)

			return
		}

		//if share.Trace {
		log.Debugf("server received an request %+v from conn: %v", req, conn.RemoteAddr().String())
		//	}

		ctx = share.WithLocalValue(ctx, StartRequestContextKey, time.Now().UnixNano())
		closeConn := false

		if !req.IsHeartbeat() {
			err = s.auth(ctx, req)
			fmt.Println("auth", err)
			closeConn = err != nil
		}

		if err != nil {
			if !req.IsOneway() {
				res := req.Clone()
				res.SetMessageType(protocol.Response)
				handleError(res, err)
				s.sendResponse(ctx, conn, writeCh, err, req, res)
				protocol.FreeMsg(res)
			} else {
				s.Plugins.DoPreWriteResponse(ctx, req, nil, err)
			}
			protocol.FreeMsg(req)
			// auth failed, closed the connection
			if closeConn {
				log.Infof("auth failed for conn %s: %v", conn.RemoteAddr().String(), err)
				return
			}
			continue
		}
        //代表之前的都没有错误,开始正常的处理
		go func() {
			defer func() {
				if r := recover(); r != nil {
					// maybe panic because the writeCh is closed.
					log.Errorf("failed to handle request: %v", r)
				}
			}()
            //记录处理数+1,在返回后-1
			atomic.AddInt32(&s.handlerMsgNum, 1)
			defer atomic.AddInt32(&s.handlerMsgNum, -1)

			if req.IsHeartbeat() {
				s.Plugins.DoHeartbeatRequest(ctx, req)
				req.SetMessageType(protocol.Response)
				data := req.EncodeSlicePointer()
				if s.AsyncWrite {
					writeCh <- data
				} else {
					if s.writeTimeout != 0 {
						conn.SetWriteDeadline(time.Now().Add(s.writeTimeout))
					}
					conn.Write(*data)
					protocol.PutData(data)
				}
				protocol.FreeMsg(req)
				return
			}

			resMetadata := make(map[string]string)
			if req.Metadata == nil {
				req.Metadata = make(map[string]string)
			}
			ctx = share.WithLocalValue(share.WithLocalValue(ctx, share.ReqMetaDataKey, req.Metadata),
				share.ResMetaDataKey, resMetadata)

			cancelFunc := parseServerTimeout(ctx, req)
			if cancelFunc != nil {
				defer cancelFunc()
			}
            //server插件处理请求前的钩子逻辑
			s.Plugins.DoPreHandleRequest(ctx, req)

			if share.Trace {
				log.Debugf("server handle request %+v from conn: %v", req, conn.RemoteAddr().String())
			}

			// first use handler
			if handler, ok := s.router[req.ServicePath+"."+req.ServiceMethod]; ok {
				sctx := NewContext(ctx, conn, req, writeCh)
				err := handler(sctx)
				if err != nil {
					log.Errorf("[handler internal error]: servicepath: %s, servicemethod, err: %v", req.ServicePath, req.ServiceMethod, err)
				}

				protocol.FreeMsg(req)
				return
			}

			res, err := s.handleRequest(ctx, req)
			
			if err != nil {
				if s.HandleServiceError != nil {
					s.HandleServiceError(err)
				} else {
					log.Warnf("rpcx: failed to handle request: %v", err)
				}
			}

			if !req.IsOneway() {
                //需要返回的情况下
                //元数据的处理
				if len(resMetadata) > 0 { // copy meta in context to request
					meta := res.Metadata
					if meta == nil {
						res.Metadata = resMetadata
					} else {
						for k, v := range resMetadata {
							if meta[k] == "" {
								meta[k] = v
							}
						}
					}
				}   
                 //返回response
				s.sendResponse(ctx, conn, writeCh, err, req, res)
			}

			if share.Trace {
				log.Debugf("server write response %+v for an request %+v from conn: %v", res, req, conn.RemoteAddr().String())
			}
			//将req和res消息体返回池中,便于复用
			protocol.FreeMsg(req)
			protocol.FreeMsg(res)
		}()
	}
}

首先是s.readRequest,读取请求信息

func (s *Server) readRequest(ctx context.Context, r io.Reader) (req *protocol.Message, err error) {
	err = s.Plugins.DoPreReadRequest(ctx)
	if err != nil {
		return nil, err
	}
	// pool req?
    //sync.pool中获取Message信息
	req = protocol.GetPooledMsg()
    //解码信息
	err = req.Decode(r)
	if err == io.EOF {
		return req, err
	}
	perr := s.Plugins.DoPostReadRequest(ctx, req, err)
	if err == nil {
		err = perr
	}
	return req, err
}

var msgPool = sync.Pool{
	New: func() interface{} {
		header := Header([12]byte{})
		header[0] = magicNumber

		return &Message{
			Header: &header,
		}
	},
}

// Message is the generic type of Request and Response.
type Message struct {
	*Header
	ServicePath   string
	ServiceMethod string
	Metadata      map[string]string
	Payload       []byte
	data          []byte
}
// GetPooledMsg gets a pooled message.
func GetPooledMsg() *Message {
	return msgPool.Get().(*Message)
}

比较经典的是decode编码信息这段

关于数据协议的详细描述,见链接协议详解 · Go RPC编程指南

// Decode decodes a message from reader.
func (m *Message) Decode(r io.Reader) error {
	// validate rest length for each step?

	// parse header
    //解析头部
	_, err := io.ReadFull(r, m.Header[:1])
	if err != nil {
		return err
	}
    //固定的magicnumber
	if !m.Header.CheckMagicNumber() {
		return fmt.Errorf("wrong magic number: %v", m.Header[0])
	}

	_, err = io.ReadFull(r, m.Header[1:])
	if err != nil {
		return err
	}

	// total
	lenData := poolUint32Data.Get().(*[]byte)
	_, err = io.ReadFull(r, *lenData)
	if err != nil {
		poolUint32Data.Put(lenData)
		return err
	}
	l := binary.BigEndian.Uint32(*lenData)
	poolUint32Data.Put(lenData)
     //MaxMessageLength默认为0,不限制信息长度
     //如果解析出来的数据总长度大于设定的信息长度,则返回错误
	if MaxMessageLength > 0 && int(l) > MaxMessageLength {
		return ErrMessageTooLong
	}

	totalL := int(l)
	if cap(m.data) >= totalL { // reuse data
		m.data = m.data[:totalL]
	} else {
		m.data = make([]byte, totalL)
	}
	data := m.data
	_, err = io.ReadFull(r, data)
	if err != nil {
		return err
	}

	n := 0
	// parse servicePath
	l = binary.BigEndian.Uint32(data[n:4])
	n = n + 4
	nEnd := n + int(l)
	m.ServicePath = util.SliceByteToString(data[n:nEnd])
	n = nEnd

	// parse serviceMethod
	l = binary.BigEndian.Uint32(data[n : n+4])
	n = n + 4
	nEnd = n + int(l)
	m.ServiceMethod = util.SliceByteToString(data[n:nEnd])
	n = nEnd

	// parse meta
	l = binary.BigEndian.Uint32(data[n : n+4])
	n = n + 4
	nEnd = n + int(l)

	if l > 0 {
		m.Metadata, err = decodeMetadata(l, data[n:nEnd])
		if err != nil {
			return err
		}
	}
	n = nEnd

	// parse payload
	l = binary.BigEndian.Uint32(data[n : n+4])
	_ = l
	n = n + 4
	m.Payload = data[n:]

	if m.CompressType() != None {
		compressor := Compressors[m.CompressType()]
		if compressor == nil {
			return ErrUnsupportedCompressor
		}
		m.Payload, err = compressor.Unzip(m.Payload)
		if err != nil {
			return err
		}
	}

	return err
}

读取到请求信息后,针对请求信息中的方法名,调用对用的方法,具体的函数是handleRequest

func (s *Server) handleRequest(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
	serviceName := req.ServicePath
	methodName := req.ServiceMethod

	res = req.Clone()

	res.SetMessageType(protocol.Response)
	s.serviceMapMu.RLock()
	service := s.serviceMap[serviceName]

	if share.Trace {
		log.Debugf("server get service %+v for an request %+v", service, req)
	}

	s.serviceMapMu.RUnlock()
	if service == nil {
		err = errors.New("rpcx: can't find service " + serviceName)
		return handleError(res, err)
	}
	mtype := service.method[methodName]
	if mtype == nil {
		if service.function[methodName] != nil { // check raw functions
			return s.handleRequestForFunction(ctx, req)
		}
		err = errors.New("rpcx: can't find method " + methodName)
		return handleError(res, err)
	}

	// get a argv object from object pool
	argv := reflectTypePools.Get(mtype.ArgType)

	codec := share.Codecs[req.SerializeType()]
	if codec == nil {
		err = fmt.Errorf("can not find codec for %d", req.SerializeType())
		return handleError(res, err)
	}
    //将请求信息根据不同的编码接口实现方进行解码
	err = codec.Decode(req.Payload, argv)

	if err != nil {
		return handleError(res, err)
	}

	// and get a reply object from object pool
	replyv := reflectTypePools.Get(mtype.ReplyType)
    //插件的前置call钩子处理
	argv, err = s.Plugins.DoPreCall(ctx, serviceName, methodName, argv)
	if err != nil {
		// return reply to object pool
		reflectTypePools.Put(mtype.ReplyType, replyv)
		return handleError(res, err)
	}

	if mtype.ArgType.Kind() != reflect.Ptr {
		err = service.call(ctx, mtype, reflect.ValueOf(argv).Elem(), reflect.ValueOf(replyv))
	} else {
        //调用service注册的方法,实现机制为反射
 
	//returnValues := function.Call([]reflect.Value{s.rcvr, reflect.ValueOf(ctx), argv, replyv})

         
		err = service.call(ctx, mtype, reflect.ValueOf(argv), reflect.ValueOf(replyv))
	}

	if err == nil {
		replyv, err = s.Plugins.DoPostCall(ctx, serviceName, methodName, argv, replyv)
	}

	// return argc to object pool
	reflectTypePools.Put(mtype.ArgType, argv)

	if err != nil {
		if replyv != nil {
			data, err := codec.Encode(replyv)
			// return reply to object pool
			reflectTypePools.Put(mtype.ReplyType, replyv)
			if err != nil {
				return handleError(res, err)
			}
			res.Payload = data
		}
		return handleError(res, err)
	}

	if !req.IsOneway() {
		data, err := codec.Encode(replyv)
		fmt.Println("req.Payload", string(data))
		// return reply to object pool
		reflectTypePools.Put(mtype.ReplyType, replyv)
		if err != nil {
			return handleError(res, err)
		}
		res.Payload = data
	} else if replyv != nil {
		reflectTypePools.Put(mtype.ReplyType, replyv)
	}

	if share.Trace {
		log.Debugf("server called service %+v for an request %+v", service, req)
	}

	return res, nil
}

拿到返回信息后,sendResponse.整个调用过程就结束了

func (s *Server) sendResponse(ctx *share.Context, conn net.Conn, writeCh chan *[]byte, err error, req, res *protocol.Message) {
    //如果Payload中的数据长度大于1024,则根据请求client中设置的压缩方法进行压缩
	if len(res.Payload) > 1024 && req.CompressType() != protocol.None {
		res.SetCompressType(req.CompressType())
	}
     //根据协议的内容对数据进行编码后发送出去
	data := res.EncodeSlicePointer()

	s.Plugins.DoPreWriteResponse(ctx, req, res, err)
    //如果是异步返回的话,将结果写入结果channel中
	if s.AsyncWrite {
		writeCh <- data
	} else {
		if s.writeTimeout != 0 {
			conn.SetWriteDeadline(time.Now().Add(s.writeTimeout))
		}
        //往返回连接中写入返回数据
		conn.Write(*data)
        //将byte slice放回池中进行复用
		protocol.PutData(data)
	}
    //server端插件的返回后置钩子函数处理
	s.Plugins.DoPostWriteResponse(ctx, req, res, err)
}

 类似资料: