java实现websocket server/client

田志
2023-12-01

最近在项目中有一个场景,在内网的应用需要接受外网应用的指令。有两种解决方案:

1.内网应用轮询外网应用,http请求指令

2.内网应用与外网应用之间建立websocket长连接

记录一下websocket server/client的java实现

一、websocket server

@Configuration
public class WebSocketConfig extends Configurator
{
    /**
     * 日志
     */
    private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketConfig.class);

    /**
     * 首先要注入ServerEndpointExporter,这个bean会自动注册使用了@ServerEndpoint注解声明的Websocket endpoint。
     * 要注意,如果使用独立的servlet容器,而不是直接使用springboot的内置容器,就不要注入ServerEndpointExporter,因为它将由容器自己提供和管理。
     *
     * @return ServerEndpointExporter
     */
    @Bean
    public ServerEndpointExporter serverEndpointExporter()
    {
        return new ServerEndpointExporter();
    }

    @Override
    public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request,
                                HandshakeResponse response)
    {
        /* 如果没有监听器,那么这里获取到的HttpSession是null */
        StandardSessionFacade ssf = (StandardSessionFacade)request.getHttpSession();
        if (ssf != null)
        {
            HttpSession httpSession = (HttpSession)request.getHttpSession();
            // 关键操作
            sec.getUserProperties().put("sessionId", httpSession.getId());
            LOGGER.debug("获取到的SessionID:" + httpSession.getId());
        }
    }

}
@Component
@ServerEndpoint(value = "/ws/{userId}")
public class WebSocketServer
{
    /**
     * 日志
     */
    private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketServer.class);

    // 静态变量,用来记录当前在线连接数。应该把它设计成线程安全的。
    private static int onlineCount = 0;

    // concurrent包的线程安全Set,用来存放每个客户端对应的MyWebSocket对象。
    private static ConcurrentHashMap<String, WebSocketServer> webSocketMap = new ConcurrentHashMap<>();

    // 保存允许建立连接的id
    private static List<String> idList = Lists.newArrayList();

    private String id = "";

    // 与某个客户端的连接会话,需要通过它来给客户端发送数据
    private Session session;

    /**
     * <关闭连接>
     *
     * @param userId userId
     * @throws
     */
    public void closeConn(String userId)
    {
        // 关闭连接
        try
        {
            WebSocketServer socket = webSocketMap.get(userId);
            if (null != socket)
            {
                if (socket.session.isOpen())
                {
                    socket.session.close();
                }
            }
        }
        catch (IOException e)
        {
            LOGGER.error("error, cause: ", e);
        }
        webSocketMap.remove(userId);
        idList.remove(userId);
    }

    /**
     * <连接/注册时去重>
     *
     * @param userId userId
     * @throws
     */
    public void conn(String userId)
    {
        // 去重
        if (!idList.contains(userId))
        {
            idList.add(userId);
        }
    }

    /**
     * <获取注册在websocket进行连接的id>
     *
     * @return 结果
     * @throws
     */
    public static List<String> getIdList()
    {
        return idList;
    }

    /**
     * <初始化方法>
     *
     * @throws
     */
    @PostConstruct
    public void init()
    {
        try
        {
            /**
             * TODO 这里的设计是在项目启动时从DB或者缓存中获取注册了允许建立连接的id
             */
            // TODO 初始化时将刚注入的对象进行静态保存

        }
        catch (Exception e)
        {
            // TODO 项目启动错误信息
            LOGGER.error("error, cause: ", e);
        }
    }

//    /**
//     * 连接启动时查询是否有滞留的新邮件提醒
//     *
//     * @param id
//     * @throws IOException
//     * @author caoting
//     * @date 2019年2月28日
//     */
//    private void selectOfflineMail(String id)
//        throws IOException
//    {
//        // 查询缓存中是否存在离线邮件消息
//        RedisProperties.Jedis jedis = redis.getConnection();
//        try
//        {
//            List<String> mails = jedis.lrange(Constant.MAIL_OFFLINE + id, 0, -1);
//            if (!StringUtils.isEmpty(mails))
//            {
//                for (String mailuuid : mails)
//                {
//                    String mail = jedis.get(mailuuid);
//                    if (!StringUtils.isEmpty(mail))
//                    {
//                        sendToUser(Constant.MESSAGE_MAIL + mail, id);
//                    }
//                    Thread.sleep(1000);
//                }
//                // 发送完成从缓存中移除
//                jedis.del(Constant.MAIL_OFFLINE + id);
//            }
//        }
//        catch (InterruptedException e)
//        {
//            e.printStackTrace();
//        }
//        finally
//        {
//            jedis.close();
//        }
//    }

    /**
     * <连接建立成功调用的方法>
     *
     * @param userId  userId
     * @param session session
     * @throws
     */
    @OnOpen
    public void onOpen(@PathParam(value = "userId") String userId, Session session)
    {
        try
        {
            // 注:admin是管理员内部使用通道  不受监控  谨慎使用
//            if (!id.contains("admin"))
//            {
//                this.session = session;
//                this.id = id;//接收到发送消息的人员编号
//                // 验证id是否在允许
//                if (idList.contains(id))
//                {
//                    // 判断是否已存在相同id
//                    WebSocketServer socket = webSocketSet.get(id);
//                    if (socket == null)
//                    {
//                        webSocketSet.put(id, this);     //加入set中
//                        addOnlineCount(); // 在线数加1
//
//                        this.sendMessage("Hello:::" + id);
//                        System.out.println("用户" + id + "加入!当前在线人数为" + getOnlineCount());
//
//                        // 检查是否存在离线推送消息
//                        selectOfflineMail(id);
//                    }
//                    else
//                    {
//                        this.sendMessage("连接id重复--连接即将关闭");
//                        this.session.close();
//                    }
//                }
//                else
//                {
//                    // 查询数据库中是否存在数据
//                    WsIds wsIds = wsIdsService.selectByAppId(id);
//                    if (null != wsIds)
//                    {
//                        idList.add(id);
//                        webSocketSet.put(id, this);     //加入set中
//
//                        addOnlineCount(); // 在线数加1
//                        this.sendMessage("Hello:::" + id);
//                        LOGGER.debug("用户" + id + "加入!当前在线人数为" + getOnlineCount());
//
//                        // 检查是否存在离线推送消息
//                        selectOfflineMail(id);
//
//                    }
//                    else
//                    {
//                        // 关闭
//                        this.sendMessage("暂无连接权限,连接即将关闭,请确认连接申请是否过期!");
//                        this.session.close();
//                        LOGGER.warn("有异常应用尝试与服务器进行长连接  使用id为:" + id);
//                    }
//                }
//            }
//            else
//            {
            this.session = session;
            this.id = userId;//接收到发送消息的人员编号

            webSocketMap.put(userId, this);     //加入set中
            addOnlineCount(); // 在线数加1

            this.sendMessage("Hello:::" + userId);
            LOGGER.debug("用户" + userId + "加入!当前在线人数为" + getOnlineCount());
//            }
        }
        catch (IOException e)
        {
            LOGGER.error("error, cause: ", e);
        }
    }

    /**
     * 连接关闭调用的方法
     */
    @OnClose
    public void onClose()
    {
        webSocketMap.remove(this.id); // 从set中删除
        idList.remove(this.id);
        subOnlineCount(); // 在线数减1
        LOGGER.debug("有一连接关闭!当前在线人数为" + getOnlineCount());
    }

    /**
     * 收到客户端消息后调用的方法
     *
     * @param message 客户端发送过来的消息
     */
    @OnMessage
    public void onMessage(String message, Session session)
    {
        LOGGER.debug("来自客户端的消息:" + message);
        // TODO 收到客户端消息后的操作
    }

    /**
     * 发生错误时调用
     */
    @OnError
    public void onError(Session session, Throwable error)
    {
        LOGGER.debug("发生错误");
        LOGGER.error("error, cause: ", error);
    }

    /**
     * <发送message>
     *
     * @param message message
     * @throws
     */
    public synchronized void sendMessage(String message)
        throws IOException
    {
        this.session.getBasicRemote().sendText(message);
    }

    /**
     * 发送信息给指定ID用户,如果用户不在线则返回不在线信息给自己
     *
     * @param message
     * @param sendUserId
     * @throws IOException
     */
    public Boolean sendToUser(String message, String sendUserId)
        throws IOException
    {
        Boolean flag = true;
        WebSocketServer socket = webSocketMap.get(sendUserId);
        if (socket != null)
        {
            try
            {
                if (socket.session.isOpen())
                {
                    socket.sendMessage(message);
                }
                else
                {
                    flag = false;
                }
            }
            catch (Exception e)
            {
                flag = false;
                LOGGER.error("error, cause: ", e);
            }
        }
        else
        {
            flag = false;
            LOGGER.warn("【" + sendUserId + "】 该用户不在线");
        }
        return flag;
    }

    /**
     * <群发自定义消息>
     *
     * @param message message
     * @throws
     */
    public void sendToAll(String message)
    {
        List<String> userIdList = new ArrayList<>(webSocketMap.keySet());
        for (int i = 0; i < userIdList.size(); i++)
        {
            try
            {
                WebSocketServer socket = webSocketMap.get(userIdList.get(i));
                if (socket.session.isOpen())
                {
                    socket.sendMessage(message);
                }
            }
            catch (Exception e)
            {
                LOGGER.error("sendToAll error. cause: ", e);
                //异常重传
                i--;
                continue;
            }
        }
    }

    /**
     * <获取在线人数>
     *
     * @return 在线人数
     * @throws
     */
    public static synchronized int getOnlineCount()
    {
        return onlineCount;
    }

    /**
     * <增加在线人数>
     *
     * @throws
     */
    public static synchronized void addOnlineCount()
    {
        WebSocketServer.onlineCount++;
    }

    /**
     * <减少在线人数>
     *
     * @throws
     */
    public static synchronized void subOnlineCount()
    {
        if (WebSocketServer.onlineCount > 0)
        {
            WebSocketServer.onlineCount--;
        }
    }
}
@Component
public class RequestListener implements ServletRequestListener
{
    /**
     * 日志
     */
    private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketConfig.class);

    @Override
    public void requestInitialized(ServletRequestEvent sre)
    {
        // 将所有request请求都携带上httpSession
        HttpSession httpSession = ((HttpServletRequest)sre.getServletRequest()).getSession();
        LOGGER.debug("将所有request请求都携带上httpSession " + httpSession.getId());
    }

    public RequestListener()
    {
    }

    @Override
    public void requestDestroyed(ServletRequestEvent arg0)
    {
    }

}
@Component
public class WebSocketTask
{
    private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketTask.class);

    @Autowired
    private WebSocketServer webSocketServer;

    /**
     * 服务器主动向客户端推送消息
     */
    @Scheduled(cron = "0/30 * * * * ?")
    public void initiativeSendMsg()
    {
        LOGGER.info("服务器主动向客户端推送消息.");
        webSocketServer.sendToAll("test");
    }
}

二、websocket client

@Configuration
@Order(1)
public class WebSocketConfig implements ApplicationRunner
{
    /**
     * 日志
     */
    private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketConfig.class);

    private static Boolean isOk;

    private static WebSocketContainer container = ContainerProvider.getWebSocketContainer();

    private WebSocketClient client;

    /**
     * 定义定时任务线程
     */
    private ScheduledExecutorService scheduledExecutorService = Executors.newScheduledThreadPool(
        1);

    @Value("${springboot.websocket.uri}")
    private String uri;

    @Value("${springboot.websocket.ip}")
    private String ip;

    @Autowired
    public void setMessageService(RedisUtils redisUtils)
    {
        WebSocketClient.redisUtils = redisUtils;
    }

    /**
     * <run>
     *
     * @param args args
     * @throws
     */
    @Override
    public void run(ApplicationArguments args)
        throws Exception
    {
        LOGGER.info("[WebSocketConfig] web socket init start.");

        // websocket客户端初始化
        wsClientInit();
    }

    /**
     * <websocket客户端初始化>
     *
     * @throws
     */
    public void wsClientInit()
    {
        LOGGER.info("[WebSocketConfig] start to wsClientInit");
        try
        {
            client = new WebSocketClient();
            WebSocketClient.beforeInit();
            container.connectToServer(client,
                new URI(uri + UUID.randomUUID().toString().replaceAll("-", "")));

            isOk = true;
        }
        catch (Exception e)
        {
            isOk = false;
            LOGGER.error("error, cause: ", e);
        }

        // 参数:1、任务体 2、首次执行的延时时间
        //      3、任务执行间隔 4、间隔时间单位
        scheduledExecutorService.scheduleAtFixedRate(new Runnable()
        {
            @Override
            public void run()
            {
                //心跳检测 断线重连
                heartbeatCheck();
            }
        }, 1, 30, TimeUnit.SECONDS);

        LOGGER.info("[WebSocketConfig] end to wsClientInit");
    }

    /**
     * <心跳检测 断线重连>
     *
     * @throws
     */
    private void heartbeatCheck()
    {
        LOGGER.info("[WebSocketConfig] start to heartbeatCheck");
        if (isOk != null && isOk)
        {
            try
            {
                client.send("ping " + ip);
            }
            catch (Exception e)
            {
                isOk = false;
            }
        }
        else
        {
            // 系统连接失败进行重试
            LOGGER.warn("系统连接失败,正在重连...");
            try
            {
                client.send("ping " + ip);
                LOGGER.warn("系统重连成功!");
                isOk = true;
            }
            catch (Exception e)
            {
                try
                {
                    client = new WebSocketClient();
                    container.connectToServer(client,
                        new URI(uri + UUID.randomUUID().toString().replaceAll("-", "")));

                    isOk = true;
                }
                catch (Exception e1)
                {
                    isOk = false;
                }

                if (isOk != null && isOk)
                {
                    LOGGER.warn("系统重连成功!");
                }
            }
        }
    }
}
@ClientEndpoint
public class WebSocketClient
{
    /**
     * 日志
     */
    private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketClient.class);

    /**
     * 这里使用静态,让 RedisUtils 属于类
     */
    public static RedisUtils redisUtils;

    /**
     * session
     */
    private Session session;

    /**
     * <beforeInit>
     *
     * @throws
     */
    public static void beforeInit()
    {
        // 在socket配置类中调用此方法可以完成一些需要初始化注入的操作
    }

    @OnOpen
    public void onOpen(Session session)
    {
        LOGGER.info("连接开启...");
        this.session = session;
    }

    @OnMessage
    public void onMessage(String message, Session session)
    {
        redisUtils.set(UUID.randomUUID().toString().replaceAll("-", ""), message);
        LOGGER.info(message);
    }

    @OnClose
    public void onClose()
    {
        LOGGER.info("长连接关闭...");
    }

    @OnError
    public void onError(Session session, Throwable t)
    {
        LOGGER.error("error, cause: ", t);
    }

    /**
     * <异步发送message>
     *
     * @param message message
     * @throws
     */
    public void send(String message)
    {
        this.session.getAsyncRemote().sendText(message);
    }

    /**
     * <发送message>
     *
     * @param message message
     * @throws
     */
    public void sendMessage(String message)
    {
        try
        {
            session.getBasicRemote().sendText(message);
        }
        catch (IOException ex)
        {
            LOGGER.error("error, cause: ", ex);
        }
    }

    /**
     * <关闭连接>
     *
     * @throws
     */
    public void close()
        throws IOException
    {
        if (this.session.isOpen())
        {
            this.session.close();
        }
    }
}
#websocket对应地址
springboot.websocket.uri=ws://localhost:81/ws/
springboot.websocket.ip=localhost:81

三、对应依赖

<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-websocket</artifactId>
		</dependency>

 

 类似资料: