之前分析过rpcx的服务端实现,这篇主要分析一下rpcx的client端实现
rpcx的实现端的使用方式,举一个basic的例子
func main() {
flag.Parse()
d, _ := client.NewPeer2PeerDiscovery("tcp@"+*addr, "")
xclient := client.NewXClient("Arith", client.Failtry, client.RandomSelect, d, client.DefaultOption)
defer xclient.Close()
args := &example.Args{
A: 10,
B: 20,
}
for {
reply := &example.Reply{}
err := xclient.Call(context.Background(), "Mul", args, reply)
if err != nil {
log.Fatalf("failed to call: %v", err)
}
log.Printf("%d * %d = %d", args.A, args.B, reply.C)
time.Sleep(1e9)
}
}
主要的方法为NewXclient和相应的Call方法
1.1 NewXClient
// NewXClient creates a XClient that supports service discovery and service governance.
func NewXClient(servicePath string, failMode FailMode, selectMode SelectMode, discovery ServiceDiscovery, option Option) XClient {
client := &xClient{
failMode: failMode,
selectMode: selectMode,
discovery: discovery,
servicePath: servicePath,
cachedClient: make(map[string]RPCClient),
option: option,
}
//获取服务注册中心的service列表
pairs := discovery.GetServices()
sort.Slice(pairs, func(i, j int) bool {
return strings.Compare(pairs[i].Key, pairs[j].Key) <= 0
})
servers := make(map[string]string, len(pairs))
for _, p := range pairs {
servers[p.Key] = p.Value
}
filterByStateAndGroup(client.option.Group, servers)
client.servers = servers
if selectMode != Closest && selectMode != SelectByUser {
//client端的负载均衡方式
client.selector = newSelector(selectMode, servers)
}
client.Plugins = &pluginContainer{}
//返回一个初始化的channel
ch := client.discovery.WatchService()
if ch != nil {
client.ch = ch
//监听service变化内容
go client.watch(ch)
}
return client
}
XClient为一个接口,有很多丰富的方法可供调用,一个XCclient只能call一个服务,如果需要call多个服务则需要创建多个XClient端
// One XClient is used only for one service. You should create multiple XClient for multiple services.
type XClient interface {
SetPlugins(plugins PluginContainer)
GetPlugins() PluginContainer
SetSelector(s Selector)
ConfigGeoSelector(latitude, longitude float64)
Auth(auth string)
Go(ctx context.Context, serviceMethod string, args interface{}, reply interface{}, done chan *Call) (*Call, error)
Call(ctx context.Context, serviceMethod string, args interface{}, reply interface{}) error
Broadcast(ctx context.Context, serviceMethod string, args interface{}, reply interface{}) error
Fork(ctx context.Context, serviceMethod string, args interface{}, reply interface{}) error
Inform(ctx context.Context, serviceMethod string, args interface{}, reply interface{}) ([]Receipt, error)
SendRaw(ctx context.Context, r *protocol.Message) (map[string]string, []byte, error)
SendFile(ctx context.Context, fileName string, rateInBytesPerSecond int64, meta map[string]string) error
DownloadFile(ctx context.Context, requestFileName string, saveTo io.Writer, meta map[string]string) error
Stream(ctx context.Context, meta map[string]string) (net.Conn, error)
Close() error
}
1.2 调用服务端的函数:Call
// Call invokes the named function, waits for it to complete, and returns its error status.
// It handles errors base on FailMode.
func (c *xClient) Call(ctx context.Context, serviceMethod string, args interface{}, reply interface{}) error {
if c.isShutdown {
return ErrXClientShutdown
}
if c.auth != "" {
metadata := ctx.Value(share.ReqMetaDataKey)
if metadata == nil {
metadata = map[string]string{}
ctx = context.WithValue(ctx, share.ReqMetaDataKey, metadata)
}
m := metadata.(map[string]string)
m[share.AuthKey] = c.auth
}
ctx = setServerTimeout(ctx)
if share.Trace {
log.Debugf("select a client for %s.%s, failMode: %v, args: %+v in case of xclient Call", c.servicePath, serviceMethod, c.failMode, args)
}
var err error
//根据负载均衡算法,选取出可用的client
k, client, err := c.selectClient(ctx, c.servicePath, serviceMethod, args)
if err != nil {
if c.failMode == Failfast || contextCanceled(err) {
return err
}
}
if share.Trace {
if client != nil {
log.Debugf("selected a client %s for %s.%s, failMode: %v, args: %+v in case of xclient Call", client.RemoteAddr(), c.servicePath, serviceMethod, c.failMode, args)
} else {
log.Debugf("selected a client %s for %s.%s, failMode: %v, args: %+v in case of xclient Call", "nil", c.servicePath, serviceMethod, c.failMode, args)
}
}
var e error
switch c.failMode {
//根据失败模式的设定进行失败后重试
case Failtry:
retries := c.option.Retries
for retries >= 0 {
retries--
if client != nil {
err = c.wrapCall(ctx, client, serviceMethod, args, reply)
//正常情况下这中模式会在这个call后退出
if err == nil {
return nil
}
if contextCanceled(err) {
return err
}
if _, ok := err.(ServiceError); ok {
return err
}
}
if uncoverError(err) {
c.removeClient(k, c.servicePath, serviceMethod, client)
}
client, e = c.getCachedClient(k, c.servicePath, serviceMethod, args)
}
if err == nil {
err = e
}
return err
case Failover:
retries := c.option.Retries
for retries >= 0 {
retries--
if client != nil {
err = c.wrapCall(ctx, client, serviceMethod, args, reply)
if err == nil {
return nil
}
if contextCanceled(err) {
return err
}
if _, ok := err.(ServiceError); ok {
return err
}
}
if uncoverError(err) {
c.removeClient(k, c.servicePath, serviceMethod, client)
}
// select another server
k, client, e = c.selectClient(ctx, c.servicePath, serviceMethod, args)
}
if err == nil {
err = e
}
return err
case Failbackup:
ctx, cancelFn := context.WithCancel(ctx)
defer cancelFn()
call1 := make(chan *Call, 10)
call2 := make(chan *Call, 10)
var reply1, reply2 interface{}
if reply != nil {
reply1 = reflect.New(reflect.ValueOf(reply).Elem().Type()).Interface()
reply2 = reflect.New(reflect.ValueOf(reply).Elem().Type()).Interface()
}
//异步调用
_, err1 := c.Go(ctx, serviceMethod, args, reply1, call1)
t := time.NewTimer(c.option.BackupLatency)
select {
case <-ctx.Done(): // cancel by context
err = ctx.Err()
return err
case call := <-call1:
err = call.Error
if err == nil && reply != nil {
reflect.ValueOf(reply).Elem().Set(reflect.ValueOf(reply1).Elem())
}
return err
case <-t.C:
}
_, err2 := c.Go(ctx, serviceMethod, args, reply2, call2)
if err2 != nil {
if uncoverError(err2) {
c.removeClient(k, c.servicePath, serviceMethod, client)
}
err = err1
return err
}
select {
case <-ctx.Done(): // cancel by context
err = ctx.Err()
case call := <-call1:
err = call.Error
if err == nil && reply != nil && reply1 != nil {
reflect.ValueOf(reply).Elem().Set(reflect.ValueOf(reply1).Elem())
}
case call := <-call2:
err = call.Error
if err == nil && reply != nil && reply2 != nil {
reflect.ValueOf(reply).Elem().Set(reflect.ValueOf(reply2).Elem())
}
}
return err
default: // Failfast
err = c.wrapCall(ctx, client, serviceMethod, args, reply)
if err != nil {
if uncoverError(err) {
c.removeClient(k, c.servicePath, serviceMethod, client)
}
}
return err
}
}
接下来看一下c.wrapCall这个方法,主要的调用逻辑封装
func (c *xClient) wrapCall(ctx context.Context, client RPCClient, serviceMethod string, args interface{}, reply interface{}) error {
if client == nil {
return ErrServerUnavailable
}
if share.Trace {
log.Debugf("call a client for %s.%s, args: %+v in case of xclient wrapCall", c.servicePath, serviceMethod, args)
}
if _, ok := ctx.(*share.Context); !ok {
ctx = share.NewContext(ctx)
}
//call前插件前置钩子函数
c.Plugins.DoPreCall(ctx, c.servicePath, serviceMethod, args)
//call
err := client.Call(ctx, c.servicePath, serviceMethod, args, reply)
//call后插件后置钩子函数
c.Plugins.DoPostCall(ctx, c.servicePath, serviceMethod, args, reply, err)
if share.Trace {
log.Debugf("called a client for %s.%s, args: %+v, err: %v in case of xclient wrapCall", c.servicePath, serviceMethod, args, err)
}
return err
}
里面的client.Call就是核心方法了
// Call invokes the named function, waits for it to complete, and returns its error status.
func (client *Client) Call(ctx context.Context, servicePath, serviceMethod string, args interface{}, reply interface{}) error {
return client.call(ctx, servicePath, serviceMethod, args, reply)
}
func (client *Client) call(ctx context.Context, servicePath, serviceMethod string, args interface{}, reply interface{}) error {
seq := new(uint64)
ctx = context.WithValue(ctx, seqKey{}, seq)
if share.Trace {
log.Debugf("client.call for %s.%s, args: %+v in case of client call", servicePath, serviceMethod, args)
defer func() {
log.Debugf("client.call done for %s.%s, args: %+v in case of client call", servicePath, serviceMethod, args)
}()
}
//异步调用
Done := client.Go(ctx, servicePath, serviceMethod, args, reply, make(chan *Call, 1)).Done
var err error
select {
case <-ctx.Done(): // cancel by context
//被取消了
client.mutex.Lock()
call := client.pending[*seq]
delete(client.pending, *seq)
client.mutex.Unlock()
if call != nil {
call.Error = ctx.Err()
call.done()
}
return ctx.Err()
case call := <-Done:
//返回结果
err = call.Error
meta := ctx.Value(share.ResMetaDataKey)
if meta != nil && len(call.ResMetadata) > 0 {
resMeta := meta.(map[string]string)
locker, ok := ctx.Value(share.ContextTagsLock).(*sync.Mutex)
if ok {
locker.Lock()
for k, v := range call.ResMetadata {
resMeta[k] = v
}
resMeta[share.ServerAddress] = client.Conn.RemoteAddr().String()
locker.Unlock()
} else {
for k, v := range call.ResMetadata {
resMeta[k] = v
}
resMeta[share.ServerAddress] = client.Conn.RemoteAddr().String()
}
}
}
return err
}
在call中用到了client.Go,这是一个异步调用方法,结果会写在Done channel中
func (client *Client) Go(ctx context.Context, servicePath, serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call {
call := new(Call)
call.ServicePath = servicePath
call.ServiceMethod = serviceMethod
meta := ctx.Value(share.ReqMetaDataKey)
if meta != nil { // copy meta in context to meta in requests
call.Metadata = meta.(map[string]string)
}
if !share.IsShareContext(ctx) {
ctx = share.NewContext(ctx)
}
call.Args = args
call.Reply = reply
if done == nil {
done = make(chan *Call, 10) // buffered.
} else {
// If caller passes done != nil, it must arrange that
// done has enough buffer for the number of simultaneous
// RPCs that will be using that channel. If the channel
// is totally unbuffered, it's best not to run at all.
if cap(done) == 0 {
log.Panic("rpc: done channel is unbuffered")
}
}
call.Done = done
if share.Trace {
log.Debugf("client.Go send request for %s.%s, args: %+v in case of client call", servicePath, serviceMethod, args)
}
//开始发送请求
client.send(ctx, call)
return call
}
最后异步,send方法(总算是到了)
func (client *Client) send(ctx context.Context, call *Call) {
// Register this call.
client.mutex.Lock()
if client.shutdown || client.closing {
call.Error = ErrShutdown
client.mutex.Unlock()
call.done()
return
}
isHeartbeat := call.ServicePath == "" && call.ServiceMethod == ""
serializeType := client.option.SerializeType
if isHeartbeat {
serializeType = protocol.MsgPack
}
codec := share.Codecs[serializeType]
if codec == nil {
call.Error = ErrUnsupportedCodec
client.mutex.Unlock()
call.done()
return
}
if client.pending == nil {
client.pending = make(map[uint64]*Call)
}
seq := client.seq
client.seq++
client.pending[seq] = call
client.mutex.Unlock()
if cseq, ok := ctx.Value(seqKey{}).(*uint64); ok {
*cseq = seq
}
// req := protocol.NewMessage()
req := protocol.GetPooledMsg()
req.SetMessageType(protocol.Request)
req.SetSeq(seq)
if call.Reply == nil {
req.SetOneway(true)
}
// heartbeat, and use default SerializeType (msgpack)
if isHeartbeat {
req.SetHeartbeat(true)
req.SetSerializeType(protocol.MsgPack)
} else {
req.SetSerializeType(client.option.SerializeType)
}
if call.Metadata != nil {
req.Metadata = call.Metadata
}
req.ServicePath = call.ServicePath
req.ServiceMethod = call.ServiceMethod
//根据选用的codec对请求数据进行编码
data, err := codec.Encode(call.Args)
if err != nil {
client.mutex.Lock()
delete(client.pending, seq)
client.mutex.Unlock()
call.Error = err
call.done()
return
}
//如果数据超过1024并且设置了压缩算法的话
if len(data) > 1024 && client.option.CompressType != protocol.None {
req.SetCompressType(client.option.CompressType)
}
req.Payload = data
if client.Plugins != nil {
_ = client.Plugins.DoClientBeforeEncode(req)
}
if share.Trace {
log.Debugf("client.send for %s.%s, args: %+v in case of client call", call.ServicePath, call.ServiceMethod, call.Args)
}
//对原始的数据进行encode,转成byte数组
//这里的encode和上面的有所区别,这里是通用的数据传输协议,上面的是比如是用json还是proto进行数据编码
allData := req.EncodeSlicePointer()
//往连接里写数据,调用的是标准库中发送方法
_, err = client.Conn.Write(*allData)
//发送完了后将[]byte数组进行回收
protocol.PutData(allData)
if share.Trace {
log.Debugf("client.sent for %s.%s, args: %+v in case of client call", call.ServicePath, call.ServiceMethod, call.Args)
}
if err != nil {
//如果发送出错,则将pending中的这个call干掉,并将req信息回收
client.mutex.Lock()
call = client.pending[seq]
delete(client.pending, seq)
client.mutex.Unlock()
if call != nil {
call.Error = err
call.done()
}
protocol.FreeMsg(req)
return
}
isOneway := req.IsOneway()
//将req信息回收
protocol.FreeMsg(req)
if isOneway {
//如果不需要回复了的话,则将pending中的这个call干掉
client.mutex.Lock()
call = client.pending[seq]
delete(client.pending, seq)
client.mutex.Unlock()
if call != nil {
call.done()
}
}
if client.option.IdleTimeout != 0 {
_ = client.Conn.SetDeadline(time.Now().Add(client.option.IdleTimeout))
}
}
这里插入一下EncodeSlicePointer转换方法,整个rpcx底层的数据传输协议
// EncodeSlicePointer encodes messages as a byte slice pointer we can use pool to improve.
func (m Message) EncodeSlicePointer() *[]byte {
//从bytebufferpool先复用一个byte数据对象,提升性能
bb := bytebufferpool.Get()
//将元数据进行encode
encodeMetadata(m.Metadata, bb)
//bytes.Buffer 兼容性处理
meta := bb.Bytes()
spL := len(m.ServicePath)
smL := len(m.ServiceMethod)
var err error
payload := m.Payload
if m.CompressType() != None {
compressor := Compressors[m.CompressType()]
if compressor == nil {
m.SetCompressType(None)
} else {
payload, err = compressor.Zip(m.Payload)
if err != nil {
m.SetCompressType(None)
payload = m.Payload
}
}
}
totalL := (4 + spL) + (4 + smL) + (4 + len(meta)) + (4 + len(payload))
//metaStart起始位置
// header + dataLen + spLen + sp + smLen + sm + metaL + meta + payloadLen + payload
metaStart := 12 + 4 + (4 + spL) + (4 + smL)
payLoadStart := metaStart + (4 + len(meta))
l := 12 + 4 + totalL
data := bufferPool.Get(l)
copy(*data, m.Header[:])
// totalLen
//大端模式写入,根据协议的规定,total在整个byte数组中的起始和结束位置
binary.BigEndian.PutUint32((*data)[12:16], uint32(totalL))
binary.BigEndian.PutUint32((*data)[16:20], uint32(spL))
copy((*data)[20:20+spL], util.StringToSliceByte(m.ServicePath))
binary.BigEndian.PutUint32((*data)[20+spL:24+spL], uint32(smL))
copy((*data)[24+spL:metaStart], util.StringToSliceByte(m.ServiceMethod))
binary.BigEndian.PutUint32((*data)[metaStart:metaStart+4], uint32(len(meta)))
copy((*data)[metaStart+4:], meta)
//回收byte数组
bytebufferpool.Put(bb)
binary.BigEndian.PutUint32((*data)[payLoadStart:payLoadStart+4], uint32(len(payload)))
copy((*data)[payLoadStart+4:], payload)
return data
}
至此,一个不加其他插件的client端的实现模式到处差不多结束了,后续继续分析各种负载均衡模式和一些其他的功能