当前位置: 首页 > 工具软件 > rpcx > 使用案例 >

rpcx源码解析之client端

夏青青
2023-12-01

之前分析过rpcx的服务端实现,这篇主要分析一下rpcx的client端实现

rpcx的实现端的使用方式,举一个basic的例子

func main() {
	flag.Parse()
	d, _ := client.NewPeer2PeerDiscovery("tcp@"+*addr, "")
	xclient := client.NewXClient("Arith", client.Failtry, client.RandomSelect, d, client.DefaultOption)
	defer xclient.Close()

	args := &example.Args{
		A: 10,
		B: 20,
	}

	for {
		reply := &example.Reply{}
		err := xclient.Call(context.Background(), "Mul", args, reply)
		if err != nil {
			log.Fatalf("failed to call: %v", err)
		}

		log.Printf("%d * %d = %d", args.A, args.B, reply.C)
		time.Sleep(1e9)
	}

}

主要的方法为NewXclient和相应的Call方法

1.1 NewXClient

// NewXClient creates a XClient that supports service discovery and service governance.
func NewXClient(servicePath string, failMode FailMode, selectMode SelectMode, discovery ServiceDiscovery, option Option) XClient {
	client := &xClient{
		failMode:     failMode,
		selectMode:   selectMode,
		discovery:    discovery,
		servicePath:  servicePath,
		cachedClient: make(map[string]RPCClient),
		option:       option,
	}
    //获取服务注册中心的service列表
	pairs := discovery.GetServices()
	sort.Slice(pairs, func(i, j int) bool {
		return strings.Compare(pairs[i].Key, pairs[j].Key) <= 0
	})
	servers := make(map[string]string, len(pairs))
	for _, p := range pairs {
		servers[p.Key] = p.Value
	}
	filterByStateAndGroup(client.option.Group, servers)

	client.servers = servers
	if selectMode != Closest && selectMode != SelectByUser {
        //client端的负载均衡方式
		client.selector = newSelector(selectMode, servers)
	}

	client.Plugins = &pluginContainer{}
    //返回一个初始化的channel
	ch := client.discovery.WatchService()
	if ch != nil {
		client.ch = ch
        //监听service变化内容
		go client.watch(ch)
	}

	return client
}

XClient为一个接口,有很多丰富的方法可供调用,一个XCclient只能call一个服务,如果需要call多个服务则需要创建多个XClient端
// One XClient is used only for one service. You should create multiple XClient for multiple services.
type XClient interface {
	SetPlugins(plugins PluginContainer)
	GetPlugins() PluginContainer
	SetSelector(s Selector)
	ConfigGeoSelector(latitude, longitude float64)
	Auth(auth string)

	Go(ctx context.Context, serviceMethod string, args interface{}, reply interface{}, done chan *Call) (*Call, error)
	Call(ctx context.Context, serviceMethod string, args interface{}, reply interface{}) error
	Broadcast(ctx context.Context, serviceMethod string, args interface{}, reply interface{}) error
	Fork(ctx context.Context, serviceMethod string, args interface{}, reply interface{}) error
	Inform(ctx context.Context, serviceMethod string, args interface{}, reply interface{}) ([]Receipt, error)
	SendRaw(ctx context.Context, r *protocol.Message) (map[string]string, []byte, error)
	SendFile(ctx context.Context, fileName string, rateInBytesPerSecond int64, meta map[string]string) error
	DownloadFile(ctx context.Context, requestFileName string, saveTo io.Writer, meta map[string]string) error
	Stream(ctx context.Context, meta map[string]string) (net.Conn, error)
	Close() error
}

1.2 调用服务端的函数:Call

// Call invokes the named function, waits for it to complete, and returns its error status.
// It handles errors base on FailMode.
func (c *xClient) Call(ctx context.Context, serviceMethod string, args interface{}, reply interface{}) error {
	if c.isShutdown {
		return ErrXClientShutdown
	}

	if c.auth != "" {
		metadata := ctx.Value(share.ReqMetaDataKey)
		if metadata == nil {
			metadata = map[string]string{}
			ctx = context.WithValue(ctx, share.ReqMetaDataKey, metadata)
		}
		m := metadata.(map[string]string)
		m[share.AuthKey] = c.auth
	}
	ctx = setServerTimeout(ctx)

	if share.Trace {
		log.Debugf("select a client for %s.%s, failMode: %v, args: %+v in case of xclient Call", c.servicePath, serviceMethod, c.failMode, args)
	}

	var err error
	//根据负载均衡算法,选取出可用的client
	k, client, err := c.selectClient(ctx, c.servicePath, serviceMethod, args)
	if err != nil {
		if c.failMode == Failfast || contextCanceled(err) {
			return err
		}
	}

	if share.Trace {
		if client != nil {
			log.Debugf("selected a client %s for %s.%s, failMode: %v, args: %+v in case of xclient Call", client.RemoteAddr(), c.servicePath, serviceMethod, c.failMode, args)
		} else {
			log.Debugf("selected a client %s for %s.%s, failMode: %v, args: %+v in case of xclient Call", "nil", c.servicePath, serviceMethod, c.failMode, args)
		}
	}

	var e error
	switch c.failMode {
	//根据失败模式的设定进行失败后重试
	case Failtry:
		retries := c.option.Retries
		for retries >= 0 {
			retries--

			if client != nil {
				err = c.wrapCall(ctx, client, serviceMethod, args, reply)
                //正常情况下这中模式会在这个call后退出
				if err == nil {
					return nil
				}
				if contextCanceled(err) {
					return err
				}
				if _, ok := err.(ServiceError); ok {
					return err
				}
			}

			if uncoverError(err) {
				c.removeClient(k, c.servicePath, serviceMethod, client)
			}
			client, e = c.getCachedClient(k, c.servicePath, serviceMethod, args)
		}
		if err == nil {
			err = e
		}
		return err
	case Failover:
		retries := c.option.Retries
		for retries >= 0 {
			retries--

			if client != nil {
				err = c.wrapCall(ctx, client, serviceMethod, args, reply)
				if err == nil {
					return nil
				}
				if contextCanceled(err) {
					return err
				}
				if _, ok := err.(ServiceError); ok {
					return err
				}
			}

			if uncoverError(err) {
				c.removeClient(k, c.servicePath, serviceMethod, client)
			}
			// select another server
			k, client, e = c.selectClient(ctx, c.servicePath, serviceMethod, args)
		}

		if err == nil {
			err = e
		}
		return err
	case Failbackup:
		ctx, cancelFn := context.WithCancel(ctx)
		defer cancelFn()
		call1 := make(chan *Call, 10)
		call2 := make(chan *Call, 10)

		var reply1, reply2 interface{}

		if reply != nil {
			reply1 = reflect.New(reflect.ValueOf(reply).Elem().Type()).Interface()
			reply2 = reflect.New(reflect.ValueOf(reply).Elem().Type()).Interface()
		}
		//异步调用
		_, err1 := c.Go(ctx, serviceMethod, args, reply1, call1)

		t := time.NewTimer(c.option.BackupLatency)
		select {
		case <-ctx.Done(): // cancel by context
			err = ctx.Err()
			return err
		case call := <-call1:
			err = call.Error
			if err == nil && reply != nil {
				reflect.ValueOf(reply).Elem().Set(reflect.ValueOf(reply1).Elem())
			}
			return err
		case <-t.C:

		}
		_, err2 := c.Go(ctx, serviceMethod, args, reply2, call2)
		if err2 != nil {
			if uncoverError(err2) {
				c.removeClient(k, c.servicePath, serviceMethod, client)
			}
			err = err1
			return err
		}

		select {
		case <-ctx.Done(): // cancel by context
			err = ctx.Err()
		case call := <-call1:
			err = call.Error
			if err == nil && reply != nil && reply1 != nil {
				reflect.ValueOf(reply).Elem().Set(reflect.ValueOf(reply1).Elem())
			}
		case call := <-call2:
			err = call.Error
			if err == nil && reply != nil && reply2 != nil {
				reflect.ValueOf(reply).Elem().Set(reflect.ValueOf(reply2).Elem())
			}
		}

		return err
	default: // Failfast
		err = c.wrapCall(ctx, client, serviceMethod, args, reply)
		if err != nil {
			if uncoverError(err) {
				c.removeClient(k, c.servicePath, serviceMethod, client)
			}
		}

		return err
	}
}

接下来看一下c.wrapCall这个方法,主要的调用逻辑封装

func (c *xClient) wrapCall(ctx context.Context, client RPCClient, serviceMethod string, args interface{}, reply interface{}) error {
	if client == nil {
		return ErrServerUnavailable
	}

	if share.Trace {
		log.Debugf("call a client for %s.%s, args: %+v in case of xclient wrapCall", c.servicePath, serviceMethod, args)
	}

	if _, ok := ctx.(*share.Context); !ok {
		ctx = share.NewContext(ctx)
	}
    //call前插件前置钩子函数
	c.Plugins.DoPreCall(ctx, c.servicePath, serviceMethod, args)
     //call
	err := client.Call(ctx, c.servicePath, serviceMethod, args, reply)
    //call后插件后置钩子函数
	c.Plugins.DoPostCall(ctx, c.servicePath, serviceMethod, args, reply, err)

	if share.Trace {
		log.Debugf("called a client for %s.%s, args: %+v, err: %v in case of xclient wrapCall", c.servicePath, serviceMethod, args, err)
	}

	return err
}

里面的client.Call就是核心方法了

// Call invokes the named function, waits for it to complete, and returns its error status.
func (client *Client) Call(ctx context.Context, servicePath, serviceMethod string, args interface{}, reply interface{}) error {
	return client.call(ctx, servicePath, serviceMethod, args, reply)
}

func (client *Client) call(ctx context.Context, servicePath, serviceMethod string, args interface{}, reply interface{}) error {
	seq := new(uint64)
	ctx = context.WithValue(ctx, seqKey{}, seq)

	if share.Trace {
		log.Debugf("client.call for %s.%s, args: %+v in case of client call", servicePath, serviceMethod, args)
		defer func() {
			log.Debugf("client.call done for %s.%s, args: %+v in case of client call", servicePath, serviceMethod, args)
		}()
	}
     //异步调用
	Done := client.Go(ctx, servicePath, serviceMethod, args, reply, make(chan *Call, 1)).Done

	var err error
	select {
	case <-ctx.Done(): // cancel by context
        //被取消了
		client.mutex.Lock()
		call := client.pending[*seq]
		delete(client.pending, *seq)
		client.mutex.Unlock()
		if call != nil {
			call.Error = ctx.Err()
			call.done()
		}

		return ctx.Err()
	case call := <-Done:
         //返回结果
		err = call.Error
		meta := ctx.Value(share.ResMetaDataKey)
		if meta != nil && len(call.ResMetadata) > 0 {
			resMeta := meta.(map[string]string)
			locker, ok := ctx.Value(share.ContextTagsLock).(*sync.Mutex)
			if ok {

				locker.Lock()
				for k, v := range call.ResMetadata {
					resMeta[k] = v
				}
				resMeta[share.ServerAddress] = client.Conn.RemoteAddr().String()
				locker.Unlock()

			} else {
				for k, v := range call.ResMetadata {
					resMeta[k] = v
				}
				resMeta[share.ServerAddress] = client.Conn.RemoteAddr().String()
			}
		}
	}

	return err
}

在call中用到了client.Go,这是一个异步调用方法,结果会写在Done channel中

func (client *Client) Go(ctx context.Context, servicePath, serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call {
	call := new(Call)
	call.ServicePath = servicePath
	call.ServiceMethod = serviceMethod
	meta := ctx.Value(share.ReqMetaDataKey)
	if meta != nil { // copy meta in context to meta in requests
		call.Metadata = meta.(map[string]string)
	}

	if !share.IsShareContext(ctx) {
		ctx = share.NewContext(ctx)
	}

	call.Args = args
	call.Reply = reply
	if done == nil {
		done = make(chan *Call, 10) // buffered.
	} else {
		// If caller passes done != nil, it must arrange that
		// done has enough buffer for the number of simultaneous
		// RPCs that will be using that channel. If the channel
		// is totally unbuffered, it's best not to run at all.
		if cap(done) == 0 {
			log.Panic("rpc: done channel is unbuffered")
		}
	}
	call.Done = done

	if share.Trace {
		log.Debugf("client.Go send request for %s.%s, args: %+v in case of client call", servicePath, serviceMethod, args)
	}
	//开始发送请求
	client.send(ctx, call)
	return call
}

最后异步,send方法(总算是到了)

func (client *Client) send(ctx context.Context, call *Call) {
	// Register this call.
	client.mutex.Lock()
	if client.shutdown || client.closing {
		call.Error = ErrShutdown
		client.mutex.Unlock()
		call.done()
		return
	}

	isHeartbeat := call.ServicePath == "" && call.ServiceMethod == ""
	serializeType := client.option.SerializeType
	if isHeartbeat {
		serializeType = protocol.MsgPack
	}
	codec := share.Codecs[serializeType]
	if codec == nil {
		call.Error = ErrUnsupportedCodec
		client.mutex.Unlock()
		call.done()
		return
	}

	if client.pending == nil {
		client.pending = make(map[uint64]*Call)
	}

	seq := client.seq
	client.seq++
	client.pending[seq] = call
	client.mutex.Unlock()

	if cseq, ok := ctx.Value(seqKey{}).(*uint64); ok {
		*cseq = seq
	}

	// req := protocol.NewMessage()
	req := protocol.GetPooledMsg()
	req.SetMessageType(protocol.Request)
	req.SetSeq(seq)
	if call.Reply == nil {
		req.SetOneway(true)
	}

	// heartbeat, and use default SerializeType (msgpack)
	if isHeartbeat {
		req.SetHeartbeat(true)
		req.SetSerializeType(protocol.MsgPack)
	} else {
		req.SetSerializeType(client.option.SerializeType)
	}

	if call.Metadata != nil {
		req.Metadata = call.Metadata
	}

	req.ServicePath = call.ServicePath
	req.ServiceMethod = call.ServiceMethod
	//根据选用的codec对请求数据进行编码
	data, err := codec.Encode(call.Args)
	if err != nil {
		client.mutex.Lock()
		delete(client.pending, seq)
		client.mutex.Unlock()
		call.Error = err
		call.done()
		return
	}
	//如果数据超过1024并且设置了压缩算法的话
	if len(data) > 1024 && client.option.CompressType != protocol.None {
		req.SetCompressType(client.option.CompressType)
	}

	req.Payload = data

	if client.Plugins != nil {
		_ = client.Plugins.DoClientBeforeEncode(req)
	}

	if share.Trace {
		log.Debugf("client.send for %s.%s, args: %+v in case of client call", call.ServicePath, call.ServiceMethod, call.Args)
	}
	//对原始的数据进行encode,转成byte数组
	//这里的encode和上面的有所区别,这里是通用的数据传输协议,上面的是比如是用json还是proto进行数据编码
	allData := req.EncodeSlicePointer()
	//往连接里写数据,调用的是标准库中发送方法
	_, err = client.Conn.Write(*allData)
	//发送完了后将[]byte数组进行回收
	protocol.PutData(allData)
	if share.Trace {
		log.Debugf("client.sent for %s.%s, args: %+v in case of client call", call.ServicePath, call.ServiceMethod, call.Args)
	}

	if err != nil {
		//如果发送出错,则将pending中的这个call干掉,并将req信息回收
		client.mutex.Lock()
		call = client.pending[seq]
		delete(client.pending, seq)
		client.mutex.Unlock()
		if call != nil {
			call.Error = err
			call.done()
		}
		protocol.FreeMsg(req)
		return
	}
	
	
	isOneway := req.IsOneway()
	//将req信息回收
	protocol.FreeMsg(req)

	if isOneway {
		//如果不需要回复了的话,则将pending中的这个call干掉
		client.mutex.Lock()
		call = client.pending[seq]
		delete(client.pending, seq)
		client.mutex.Unlock()
		if call != nil {
			call.done()
		}
	}

	if client.option.IdleTimeout != 0 {
		_ = client.Conn.SetDeadline(time.Now().Add(client.option.IdleTimeout))
	}
}

这里插入一下EncodeSlicePointer转换方法,整个rpcx底层的数据传输协议

// EncodeSlicePointer encodes messages as a byte slice pointer we can use pool to improve.
func (m Message) EncodeSlicePointer() *[]byte {
	//从bytebufferpool先复用一个byte数据对象,提升性能
	bb := bytebufferpool.Get()
	//将元数据进行encode
	encodeMetadata(m.Metadata, bb)
	//bytes.Buffer 兼容性处理
	meta := bb.Bytes()

	spL := len(m.ServicePath)
	smL := len(m.ServiceMethod)

	var err error
	payload := m.Payload
	if m.CompressType() != None {
		compressor := Compressors[m.CompressType()]
		if compressor == nil {
			m.SetCompressType(None)
		} else {
			payload, err = compressor.Zip(m.Payload)
			if err != nil {
				m.SetCompressType(None)
				payload = m.Payload
			}
		}
	}
	
	totalL := (4 + spL) + (4 + smL) + (4 + len(meta)) + (4 + len(payload))
	//metaStart起始位置
	// header + dataLen + spLen + sp + smLen + sm + metaL + meta + payloadLen + payload
	metaStart := 12 + 4 + (4 + spL) + (4 + smL)

	payLoadStart := metaStart + (4 + len(meta))
	l := 12 + 4 + totalL

	data := bufferPool.Get(l)
	copy(*data, m.Header[:])

	// totalLen
	//大端模式写入,根据协议的规定,total在整个byte数组中的起始和结束位置
	binary.BigEndian.PutUint32((*data)[12:16], uint32(totalL))

	binary.BigEndian.PutUint32((*data)[16:20], uint32(spL))
	copy((*data)[20:20+spL], util.StringToSliceByte(m.ServicePath))

	binary.BigEndian.PutUint32((*data)[20+spL:24+spL], uint32(smL))
	copy((*data)[24+spL:metaStart], util.StringToSliceByte(m.ServiceMethod))

	binary.BigEndian.PutUint32((*data)[metaStart:metaStart+4], uint32(len(meta)))
	copy((*data)[metaStart+4:], meta)
    //回收byte数组
	bytebufferpool.Put(bb)

	binary.BigEndian.PutUint32((*data)[payLoadStart:payLoadStart+4], uint32(len(payload)))
	copy((*data)[payLoadStart+4:], payload)

	return data
}

至此,一个不加其他插件的client端的实现模式到处差不多结束了,后续继续分析各种负载均衡模式和一些其他的功能

 类似资料: