jprotobuf是针对Java程序开发一套简易类库,目的是简化java语言对protobuf类库的使用
使用jprotobuf可以无需再去了解.proto文件操作与语法,直接使用java注解定义字段类型即可。
Jprotobuf官方github地址https://github.com/jhunters/jprotobuf
下面使用jprotobuf实现netty的编解码器
1.引入jprotobuf,这里使用gradle,使用maven类似配置不赘述
dependencies {
testCompile group: 'junit', name: 'junit', version: '4.12'
// compile group: 'com.google.protobuf', name: 'protobuf-java'
// https://mvnrepository.com/artifact/com.baidu/jprotobuf
compile group: 'com.baidu', name: 'jprotobuf', version: '2.2.8'
}
2.定义.proto文件
协议头文件
syntax = "proto2";
option java_package="com.iscas.protobuf";
option java_outer_classname = "Header";
message Head {
//消息体长度
required fixed32 length = 1;
//消息类型
required fixed32 messageType = 2;
}
消息体文件
syntax = "proto2";
option java_package="com.iscas.protobuf";
option java_outer_classname = "Message";
//import "Command.proto";
message MessageBase {
/**
* 客户端ID
*/
required string clientId = 1;
/**
* 消息类型
*/
required CommandType cmd = 2;
/**
* 数据(JSON)
*/
optional string data = 3;
/**
* 指令类型
*/
enum CommandType {
/**
* 验证
*/
AUTH = 1;
/**
* ping
*/
PING = 2;
/**
* pong
*/
PONG = 3;
/**
* 上传数据
*/
UPLOAD_DATA = 4;
/**
* 推送数据
*/
PUSH_DATA = 5;
/**
*模型吐出的数据
*/
MODEL_DATA=6;
/**
* 验证返回,如果有必要的可用于消息回执
*/
AUTH_BACK = 11;
}
}
3.使用代码生成器生成jprotobuf格式的Bean结构,不同于protoc生成的Java结构,这个结构更近似于JavaBean,使用起来更方便
public class ProtoToBeanGegerator {
/**
* proto文件名称
* */
private static String protoName = "Message.proto";
/**
* 生成的java目标目录路径
* */
private static String targetPath = "D:/test";
public static void main(String[] args) throws IOException {
@Cleanup InputStream fis = ProtoToBeanGegerator.class.getResourceAsStream("/" + protoName);
ProtobufIDLProxy.generateSource(fis, new File(targetPath));
}
}
@Data
public class HeadDTO {
@Protobuf(fieldType = FieldType.FIXED32, order = 1, required = true)
public Integer length;
@Protobuf(fieldType = FieldType.FIXED32, order = 2, required = true)
public Integer messageType;
}
@Data
public class MessageDTO {
@Protobuf(fieldType = FieldType.STRING, order = 1, required = true)
public String clientId;
@Protobuf(fieldType = FieldType.ENUM, order = 2, required = true)
public CommandType cmd;
@Protobuf(fieldType = FieldType.STRING, order = 3, required = false)
public String data;
public static enum CommandType implements EnumReadable {
AUTH(1), PING(2), PONG(3), UPLOAD_DATA(4), PUSH_DATA(5), MODEL_DATA(6), AUTH_BACK(11);
private final int value;
CommandType(int value) {
this.value = value;
}
public int value() {
return value;
}
}
}
compile group: 'io.netty', name: 'netty-all'
6.编码器(Codec不必每次编解码都create,这里只是个demo)
@ChannelHandler.Sharable
public class CustomProtobufEncoder extends MessageToByteEncoder<Object> {
@Override
protected void encode(
ChannelHandlerContext ctx, Object msg, ByteBuf out) throws Exception {
// byte[] body = msg.toByteArray();
DataUtils.encodeHeaderAndBody(msg,out);
// out.writeBytes(header);
// out.writeBytes(body);
return;
}
}
7.解码器
public class CustomProtobufDecoder extends ByteToMessageDecoder {
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
try {
while (in.readableBytes() > 10) { // 如果可读长度小于包头长度,退出。
in.markReaderIndex();
byte[] bytes = new byte[10];
in.readBytes(bytes);
//替换为Jprotobuf模式
Codec<HeadDTO> headDTOCodec = ProtobufProxy.create(HeadDTO.class);
HeadDTO headDTO = headDTOCodec.decode(bytes);
int length = headDTO.getLength();
int dataType = headDTO.getMessageType();
// Header.Head head = Header.Head.parseFrom(bytes);
// int length = head.getLength();
// byte dataType = (byte) head.getMessageType();
// 如果可读长度小于body长度,恢复读指针,退出。
if (in.readableBytes() < length) {
in.resetReaderIndex();
return;
}
// 读取body
ByteBuf bodyByteBuf = in.readBytes(length);
byte[] array;
int offset;
int readableLen = bodyByteBuf.readableBytes();
if (bodyByteBuf.hasArray()) {
array = bodyByteBuf.array();
offset = bodyByteBuf.arrayOffset() + bodyByteBuf.readerIndex();
} else {
array = new byte[readableLen];
bodyByteBuf.getBytes(bodyByteBuf.readerIndex(), array, 0, readableLen);
offset = 0;
}
//反序列化
Object result = decodeBody(dataType, array, offset, readableLen);
// ctx.fireChannelRead((Message.MessageBase)result);
out.add(result);
ReferenceCountUtil.release(bodyByteBuf);
}
} catch (Exception e) {
e.printStackTrace();
throw e;
}
}
// public MessageLite decodeBody(byte dataType, byte[] array, int offset, int length) throws Exception {
// if (dataType == 0x0) {
// return Message.MessageBase.getDefaultInstance().
// getParserForType().parseFrom(array, offset, length);
// }
// return null; // or throw exception
// }
public Object decodeBody(int dataType, byte[] array, int offset, int length) throws Exception {
Object result = null;
if (dataType == 0) {
//Message类型
byte[] bs = new byte[length];
System.arraycopy(array, offset, bs, 0, length);
Codec<MessageDTO> messageDTOCodec = ProtobufProxy.create(MessageDTO.class);
result = messageDTOCodec.decode(bs);
}
return result; // or throw exception
}
}
@Component
@Qualifier("authServerHandler")
@ChannelHandler.Sharable
@Slf4j
public class AuthServerHandler extends ChannelInboundHandlerAdapter {
private static final AtomicInteger integer = new AtomicInteger(0);
private final AttributeKey<String> clientInfo = AttributeKey.valueOf("clientInfo");
@Autowired
@Qualifier("channelRepository")
private ChannelRepository channelRepository;
@Value("${clientIds}")
private String clientIds;
@SuppressWarnings("deprecation")
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof MessageDTO) {
MessageDTO msgBase = (MessageDTO) msg;
String clientId = msgBase.getClientId();
/*认证处理*/
if (msgBase.getCmd().equals(MessageDTO.CommandType.AUTH)) {
log.info("----验证处理----");
boolean flag = false;
for (String id : clientIds.split(",")) {
if (id.equals(clientId)) {
flag = true;
break;
}
}
if (!flag) {
channelRepository.remove(clientId);
log.info("----验证失败即将断开连接----");
ChannelFuture channelFuture = ctx.writeAndFlush(DataUtils.createData(clientId, MessageDTO.CommandType.AUTH_BACK, "fail"));
channelFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
future.channel().close();
}
});
} else {
log.info("---验证成功---");
Attribute<String> attr = ctx.attr(clientInfo);
attr.set(clientId);
channelRepository.put(clientId, ctx.channel());
msgBase.setData("success");
msgBase.setCmd(MessageDTO.CommandType.AUTH_BACK);
ctx.writeAndFlush( msgBase);
}
} else if (msgBase.getCmd().equals(MessageDTO.CommandType.PING)) {
//处理ping消息
ctx.writeAndFlush(DataUtils.createData(clientId, MessageDTO.CommandType.PONG, "This is pong data"));
} else {
Channel ch = channelRepository.get(clientId);
if (null == ch || !ch.isOpen()) {
channelRepository.remove(clientId);
log.info("----没有权限,即将断开连接----");
ChannelFuture channelFuture = ctx.writeAndFlush(DataUtils.createData(clientId, MessageDTO.CommandType.AUTH_BACK, "fail"));
channelFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
future.channel().close();
}
});
} else {
//触发下一个handler
ctx.fireChannelRead(msg);
log.info("----将进入业务入处理逻辑-----");
}
}
ReferenceCountUtil.release(msg);
} else {
//触发下一个handler
ctx.fireChannelRead(msg);
log.info("----将进入业务入处理逻辑-----");
ReferenceCountUtil.release(msg);
}
}
}
@Slf4j
@Component
public class IdleServerHandler extends ChannelInboundHandlerAdapter {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof IdleStateEvent) {
IdleStateEvent event = (IdleStateEvent) evt;
String type = "";
if (event.state() == IdleState.READER_IDLE) {
type = "read idle";
} else if (event.state() == IdleState.WRITER_IDLE) {
type = "write idle";
} else if (event.state() == IdleState.ALL_IDLE) {
type = "all idle";
}
log.info(ctx.channel().remoteAddress()+"超时类型:" + type);
} else {
super.userEventTriggered(ctx, evt);
}
}
}
public class LogicServerHandler extends ChannelInboundHandlerAdapter {
private AtomicInteger integer = new AtomicInteger(0);
private final AttributeKey<String> clientInfo = AttributeKey.valueOf("clientInfo");
@Value("${modelId}")
private String modelId; //模型的ID
@Qualifier("channelRepository")
@Autowired
private ChannelRepository channelRepository;
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg != null) {
if (msg instanceof MessageDTO) {
log.info("-----:" + msg);
MessageDTO msgBase = (MessageDTO) msg;
MessageDTO.CommandType cmd = msgBase.getCmd();
if(cmd.equals(MessageDTO.CommandType.MODEL_DATA)){
//如果是MODEL_DATA,认为是C++模型推送过来的数据。
//除了自己的ClientId, 分别分发给其他客户端
String clientId = msgBase.getClientId();
if(clientId != null){
Set<String> keys = channelRepository.keys();
msgBase.setCmd(MessageDTO.CommandType.PUSH_DATA);
channelRepository.get("web-app").writeAndFlush(msgBase);
}
}
ReferenceCountUtil.release(msg);
}
}
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
}
@SuppressWarnings("deprecation")
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
Attribute<String> attr = ctx.attr(clientInfo);
String clientId = attr.get();
log.error("Connection closed, client is " + clientId);
cause.printStackTrace();
ctx.channel().close();
channelRepository.remove(clientId);
}
}
public class ChannelRepository {
private final static Map<String, Channel> channelCache = new ConcurrentHashMap<String, Channel>();
public void put(String key, Channel value) {
channelCache.put(key, value);
}
public Channel get(String key) {
return channelCache.get(key);
}
public void remove(String key) {
channelCache.remove(key);
}
public int size() {
return channelCache.size();
}
public Set<String> keys(){
return channelCache.keySet();
}
}
@Component
@Qualifier("serverChannelInitializer")
public class ServerChannelInitializer extends ChannelInitializer<SocketChannel> {
private final static int READER_IDLE_TIME_SECONDS = 30;//读操作空闲20秒
private final static int WRITER_IDLE_TIME_SECONDS = 30;//写操作空闲20秒
private final static int ALL_IDLE_TIME_SECONDS = 60;//读写全部空闲40秒
@Autowired
@Qualifier("authServerHandler")
private ChannelInboundHandlerAdapter authServerHandler;
@Autowired
@Qualifier("logicServerHandler")
private ChannelInboundHandlerAdapter logicServerHandler;
@Override
protected void initChannel(SocketChannel socketChannel) throws Exception {
ChannelPipeline p = socketChannel.pipeline();
p.addLast("idleStateHandler", new IdleStateHandler(READER_IDLE_TIME_SECONDS
, WRITER_IDLE_TIME_SECONDS, ALL_IDLE_TIME_SECONDS, TimeUnit.SECONDS));
p.addLast("decoder",new CustomProtobufDecoder());
p.addLast("encoder",new CustomProtobufEncoder());
// p.addLast("linebaseFrame",new LineBasedFrameDecoder(1024*1024*1024));
// p.addLast("msgDecorder",new MsgDecoder());
// p.addLast("msgEncorder",new MsgEncoder());
// p.addLast(new ProtobufVarint32FrameDecoder());
// p.addLast(new ProtobufDecoder(Message.MessageBase.getDefaultInstance()));
//
// p.addLast(new ProtobufVarint32LengthFieldPrepender());
// p.addLast(new ProtobufEncoder());
p.addLast("authServerHandler", authServerHandler);
p.addLast("hearableServerHandler", logicServerHandler);
p.addLast("idleTimeoutHandler", new IdleServerHandler());
}
}
@Component
public class TCPServer {
@Autowired
@Qualifier("serverBootstrap")
private ServerBootstrap serverBootstrap;
@Autowired
@Qualifier("tcpSocketAddress")
private InetSocketAddress tcpPort;
private Channel serverChannel;
public void start() throws Exception {
serverChannel = serverBootstrap.bind(tcpPort).sync().channel().closeFuture().sync().channel();
}
@PreDestroy
public void stop() throws Exception {
serverChannel.close();
serverChannel.parent().close();
}
public ServerBootstrap getServerBootstrap() {
return serverBootstrap;
}
public void setServerBootstrap(ServerBootstrap serverBootstrap) {
this.serverBootstrap = serverBootstrap;
}
public InetSocketAddress getTcpPort() {
return tcpPort;
}
public void setTcpPort(InetSocketAddress tcpPort) {
this.tcpPort = tcpPort;
}
}
@SpringBootApplication
@ComponentScan(value = {"com.iscas.distribution.server"})
@PropertySource(value= "classpath:/application.properties")
public class DistributeWebApp {
@Configuration
@Profile("production")
@PropertySource("classpath:/application.properties")
static class Production { }
@Configuration
@Profile("local")
@PropertySource({"classpath:/application.properties"})
static class Local { }
public static void main(String[] args) throws Exception {
ConfigurableApplicationContext context = SpringApplication.run(DistributeWebApp.class, args);
TCPServer tcpServer = context.getBean(TCPServer.class);
tcpServer.start();
}
@Value("${tcp.port}")
private int tcpPort;
@Value("${boss.thread.count}")
private int bossCount;
@Value("${worker.thread.count}")
private int workerCount;
@Value("${so.keepalive}")
private boolean keepAlive;
@Value("${so.backlog}")
private int backlog;
@SuppressWarnings("unchecked")
@Bean(name = "serverBootstrap")
public ServerBootstrap bootstrap() {
ServerBootstrap b = new ServerBootstrap();
b.group(bossGroup(), workerGroup())
.channel(NioServerSocketChannel.class)
.handler(new LoggingHandler(LogLevel.DEBUG))
.childHandler(serverChannelInitializer);
Map<ChannelOption<?>, Object> tcpChannelOptions = tcpChannelOptions();
Set<ChannelOption<?>> keySet = tcpChannelOptions.keySet();
for (@SuppressWarnings("rawtypes") ChannelOption option : keySet) {
b.option(option, tcpChannelOptions.get(option));
}
return b;
}
@Autowired
@Qualifier("serverChannelInitializer")
private ServerChannelInitializer serverChannelInitializer;
@Bean(name = "tcpChannelOptions")
public Map<ChannelOption<?>, Object> tcpChannelOptions() {
Map<ChannelOption<?>, Object> options = new HashMap<ChannelOption<?>, Object>();
options.put(ChannelOption.SO_KEEPALIVE, keepAlive);
options.put(ChannelOption.SO_BACKLOG, backlog);
options.put(ChannelOption.TCP_NODELAY, true);
// options.put(ChannelOption.RCVBUF_ALLOCATOR, AdaptiveRecvByteBufAllocator.DEFAULT);
options.put(ChannelOption.RCVBUF_ALLOCATOR, new AdaptiveRecvByteBufAllocator(64,1024,1024*1024));
return options;
}
@Bean(name = "bossGroup", destroyMethod = "shutdownGracefully")
public NioEventLoopGroup bossGroup() {
return new NioEventLoopGroup(bossCount);
}
@Bean(name = "workerGroup", destroyMethod = "shutdownGracefully")
public NioEventLoopGroup workerGroup() {
return new NioEventLoopGroup(workerCount);
}
@Bean(name = "tcpSocketAddress")
public InetSocketAddress tcpPort() {
return new InetSocketAddress(tcpPort);
}
@Bean(name = "channelRepository")
public ChannelRepository channelRepository() {
return new ChannelRepository();
}
}
9.netty客户端
@Slf4j
public class IdleClientHandler extends SimpleChannelInboundHandler<MessageDTO> {
private DistributeClient nettyClient;
private int heartbeatCount = 0;
private String clientId = "123456789";
/**
* @param nettyClient
*/
public IdleClientHandler(DistributeClient nettyClient, String clientId) {
this.nettyClient = nettyClient;
this.clientId = clientId;
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof IdleStateEvent) {
IdleStateEvent event = (IdleStateEvent) evt;
String type = "";
if (event.state() == IdleState.READER_IDLE) {
type = "read idle";
} else if (event.state() == IdleState.WRITER_IDLE) {
type = "write idle";
} else if (event.state() == IdleState.ALL_IDLE) {
type = "all idle";
sendPingMsg(ctx);
}
log.info(ctx.channel().remoteAddress() + "超时类型:" + type);
} else {
super.userEventTriggered(ctx, evt);
}
}
/**
* 发送ping消息
* @param context
*/
protected void sendPingMsg(ChannelHandlerContext context) {
MessageDTO messageDTO = new MessageDTO();
messageDTO.setCmd(MessageDTO.CommandType.PING);
messageDTO.setData("This is a ping msg");
messageDTO.setClientId("web-app");
context.writeAndFlush(messageDTO);
heartbeatCount++;
log.info("Client sent ping msg to " + context.channel().remoteAddress() + ", count: " + heartbeatCount);
}
/**
* 处理断开重连
*/
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
final EventLoop eventLoop = ctx.channel().eventLoop();
eventLoop.schedule(() -> nettyClient.doConnect(new Bootstrap(), eventLoop), 10L, TimeUnit.SECONDS);
super.channelInactive(ctx);
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, MessageDTO msg) throws Exception {
}
}
@Slf4j
@Component
public class LogicClientHandler extends SimpleChannelInboundHandler<Object> {
private AtomicInteger integer = new AtomicInteger(0);
@Autowired
private DistributeClientProps distributeClientProps;
// 连接成功后,向server发送消息
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
// Message.MessageBase.Builder authMsg = Message.MessageBase.newBuilder();
// authMsg.setClientId(distributeClientProps.getClientId());
// authMsg.setCmd(Command.CommandType.AUTH);
// authMsg.setData("This is auth data");
// ctx.writeAndFlush(authMsg.build());
MessageDTO authMsg = new MessageDTO();
authMsg.setClientId(distributeClientProps.getClientId());
authMsg.setCmd(MessageDTO.CommandType.AUTH);
authMsg.setData("This is auth data");
ctx.writeAndFlush(authMsg);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
log.debug("连接断开 ");
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, Object obj) throws Exception {
try {
if (obj instanceof MessageDTO) {
MessageDTO msg = (MessageDTO) obj;
if(msg.getCmd().equals(MessageDTO.CommandType.AUTH_BACK)){
String data = msg.getData();
if("success".equals(data)){
log.info("数据分发验证成功");
}else{
log.error("数据分发验证失败,连接即将断开");
}
} else if(msg.getCmd().equals(MessageDTO.CommandType.PING)){
//接收到server发送的ping指令
log.debug("PING Netty:" + msg.getData());
} else if(msg.getCmd().equals(MessageDTO.CommandType.PONG)){
//接收到server发送的pong指令
log.debug("Netty PONG:" + msg.getData());
} else if(msg.getCmd().equals(MessageDTO.CommandType.PUSH_DATA)){
log.info("Netty xinxi:" + msg.getData());
TimeUnit.SECONDS.sleep(1);
}else {
//将接收的数据根据不同的想定放入不同的队列
//TODO 将真实的ID与队列绑定
BlockingQueue<Object> blockingQueue = StaticData.QUEUE_MAP.get("这是一个想定或其他的标识");
if (blockingQueue != null) {
blockingQueue.put(msg);
}
}
}
} catch (Exception e) {
e.printStackTrace();
throw e;
}
}
}
@Slf4j
public class DistributeClient {
public static Channel clientChannel;
private LogicClientHandler logicClientHandler;
private DistributeClientProps distributeClientProps;
public DistributeClient(LogicClientHandler logicClientHandler,
DistributeClientProps distributeClientProps) {
this.logicClientHandler = logicClientHandler;
this.distributeClientProps = distributeClientProps;
}
// private final static String HOST = "127.0.0.1";
// private final static int PORT = 8090;
// private final static int READER_IDLE_TIME_SECONDS = 20;//读操作空闲20秒
// private final static int WRITER_IDLE_TIME_SECONDS = 20;//写操作空闲20秒
// private final static int ALL_IDLE_TIME_SECONDS = 40;//读写全部空闲40秒
private EventLoopGroup loop = new NioEventLoopGroup();
// public static void main(String[] args) throws Exception {
// DistributeClient client = new DistributeClient();
// client.run();
// }
public void run() throws Exception {
try {
doConnect(new Bootstrap(), loop);
}catch (Exception e) {
e.printStackTrace();
}
}
/**
* netty client 连接,连接失败10秒后重试连接
*/
public Bootstrap doConnect(Bootstrap bootstrap, EventLoopGroup eventLoopGroup) {
try {
if (bootstrap != null) {
bootstrap.group(eventLoopGroup);
bootstrap.channel(NioSocketChannel.class);
bootstrap.option(ChannelOption.SO_KEEPALIVE, true);
bootstrap.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
p.addLast("idleStateHandler", new IdleStateHandler(distributeClientProps.getReadIdleTimeSeconds()
, distributeClientProps.getWriteIdleTimeSeconds(), distributeClientProps.getAllIdleTimeSeconds(), TimeUnit.SECONDS));
// p.addLast(new ProtobufVarint32FrameDecoder());
// p.addLast(new ProtobufDecoder(Message.MessageBase.getDefaultInstance()));
//
// p.addLast(new ProtobufVarint32LengthFieldPrepender());
// p.addLast(new ProtobufEncoder());
p.addLast("decoder",new CustomProtobufDecoder());
p.addLast("encoder",new CustomProtobufEncoder());
p.addLast("clientHandler", logicClientHandler);
p.addLast("idleTimeoutHandler", new IdleClientHandler(DistributeClient.this,
distributeClientProps.getClientId()));
}
});
bootstrap.remoteAddress(distributeClientProps.getRemoteHost(), distributeClientProps.getRemotePort());
ChannelFuture f = bootstrap.connect().addListener((ChannelFuture futureListener)->{
final EventLoop eventLoop = futureListener.channel().eventLoop();
if (!futureListener.isSuccess()) {
log.warn("Failed to connect to server, try connect after 10s");
futureListener.channel().eventLoop().schedule(() -> doConnect(new Bootstrap(), eventLoop), 10, TimeUnit.SECONDS);
}
});
clientChannel = f.channel();
f.channel().closeFuture().sync();
eventLoopGroup.shutdownGracefully();
}
} catch (InterruptedException e) {
e.printStackTrace();
}
return bootstrap;
}
}
demo github地址:https://github.com/zhuquanwen/netty-jprotobuf