4

前言

最近项目中有个私信功能,需要用到websocket,于是在网上找找资料并在实践中总结了一点经验分享给大家。

问题

在操作之前我先抛一个问题:

SpringBoot项目集成 webSocket,当客户端与服务器端建立连接的时候,发现 server对象并未注入而是为 null。
产生原因:spring管理的都是单例(singleton),和 websocket (多对象)相冲突。
详细解释:项目启动时初始化,会初始化 websocket (非用户连接的),spring 同时会为其注入 service,该对象的 service 不是 null,被成功注入。但是,由于 spring 默认管理的是单例,所以只会注入一次 service。当客户端与服务器端进行连接时,服务器端又会创建一个新的 websocket 对象,这时问题出现了:spring 管理的都是单例,不会给第二个 websocket 对象注入 service,所以导致只要是用户连接创建的 websocket 对象,都不能再注入了。
像 controller 里面有 service, service 里面有 dao。因为 controller,service ,dao 都有是单例,所以注入时不会报 null。但是 websocket 不是单例,所以使用spring注入一次后,后面的对象就不会再注入了,会报NullException。

下面会讲解决办法。

操作

1、引入websocket依赖包

<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-websocket</artifactId>
</dependency>

2、配置websocket

@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry webSocketHandlerRegistry) {
        // webSocket通道
        // 指定处理器和路径,如:http://www.baidu.com/service-name/websocket?uid=xxxx
        webSocketHandlerRegistry.addHandler(new WebSocketHandler(), "/websocket")
//                // 指定自定义拦截器
                .addInterceptors(new WebSocketInterceptor())
                // 允许跨域
                .setAllowedOrigins("*");
    }
}

3、添加获取websocket地址中的参数类

public class WebSocketInterceptor implements HandshakeInterceptor {

    /**
     * handler处理前调用,attributes属性最终在WebSocketSession里,可能通过webSocketSession.getAttributes().get(key值)获得
     */
    @Override
    public boolean beforeHandshake(org.springframework.http.server.ServerHttpRequest request, ServerHttpResponse serverHttpResponse, org.springframework.web.socket.WebSocketHandler webSocketHandler, Map<String, Object> map) throws Exception {
        if (request instanceof ServletServerHttpRequest) {
            ServletServerHttpRequest serverHttpRequest = (ServletServerHttpRequest) request;
            // 获取请求路径携带的参数
            String uid = serverHttpRequest.getServletRequest().getParameter("uid");
            map.put("uid", uid);
            return true;
        } else {
            return false;
        }
    }

    @Override
    public void afterHandshake(org.springframework.http.server.ServerHttpRequest serverHttpRequest, ServerHttpResponse serverHttpResponse, org.springframework.web.socket.WebSocketHandler webSocketHandler, Exception e) {

    }
}

4、添加解决server对象@Autowired注入为null的类

@Component
public class SpringContext implements ApplicationContextAware {

    /**
     * 打印日志
     */
    private Logger logger = LoggerFactory.getLogger(getClass());

    /**
     * 获取上下文对象
     */
    private static ApplicationContext applicationContext;


    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        SpringContext.applicationContext = applicationContext;
        logger.info("set applicationContext");
    }

    /**
     * 获取 applicationContext
     *
     * @return
     */
    public static ApplicationContext getApplicationContext() {
        return applicationContext;
    }

    /**
     * 通过 name 获取 bean 对象
     *
     * @param name
     * @return
     */
    public static Object getBean(String name) {

        return getApplicationContext().getBean(name);
    }

    /**
     * 通过 class 获取 bean 对象
     *
     * @param clazz
     * @param <T>
     * @return
     */
    public static <T> T getBean(Class<T> clazz) {
        return getApplicationContext().getBean(clazz);
    }

    /**
     * 通过 name,clazz  获取指定的 bean 对象
     *
     * @param name
     * @param clazz
     * @param <T>
     * @return
     */
    public static <T> T getBean(String name, Class<T> clazz) {
        return getApplicationContext().getBean(name, clazz);
    }

}

5、添加websocket接收发送消息类

@Component
public class WebSocketHandler extends AbstractWebSocketHandler {
    private static Logger log = LoggerFactory.getLogger(WebSocketHandler.class);

    public AccountFeignClient getAccountFeignClient() {
        return SpringContext.getBean(AccountFeignClient.class);
    }

    public NotifyMailboxService getNotifyMailboxService() {
        return SpringContext.getBean(NotifyMailboxService.class);
    }

    public NotifyMailboxMessageService getNotifyMailboxMessageService() {
        return SpringContext.getBean(NotifyMailboxMessageService.class);
    }

    /**
     * 存储sessionId和webSocketSession
     * 需要注意的是,webSocketSession没有提供无参构造,不能进行序列化,也就不能通过redis存储
     * 在分布式系统中,要想别的办法实现webSocketSession共享
     */
    private static Map<String, WebSocketSession> sessionMap = new ConcurrentHashMap<>();
    private static Map<String, String> userMap = new ConcurrentHashMap<>();

    /**
     * webSocket连接创建后调用
     */
    @Override
    public void afterConnectionEstablished(WebSocketSession session) {
        // 获取参数
        String uid = String.valueOf(session.getAttributes().get("uid"));
        String sessionId = session.getId();
        log.info("init websocket uid={},sessionId={}", uid, sessionId);
        userMap.put(uid, sessionId);
        sessionMap.put(sessionId, session);
    }

    /**
     * 前端发送消息到后台
     * 接收到消息会调用
     */
    @Override
    public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
        //A用户发送前端消息到后台,后台要保存A消息,并且向B用户推送消息
        if (message instanceof TextMessage) {
            log.info("message={}", message);
        } else if (message instanceof BinaryMessage) {

        } else if (message instanceof PongMessage) {

        } else {
            log.info("Unexpected WebSocket message type: " + message);
        }
        String uid = String.valueOf(session.getAttributes().get("uid"));
        String messages = (String) message.getPayload();
        ObjectMapper mapper = new ObjectMapper();
        HashMap<String, Object> map = mapper.readValue(messages, HashMap.class);
        String _uid = (String) map.get("uid");
//        String _dialogId = (String) map.get("dialogId");
        String _friendId = (String) map.get("friendId");
        String _message = (String) map.get("message");

        String sessionId = session.getId();
        log.info("sessionId={},uid={},_uid={},_friendId={},_message={}", sessionId, uid, _uid, _friendId, _message);

        if (!StringUtils.hasLength(sessionId) || !StringUtils.hasLength(_uid) || !StringUtils.hasLength(_friendId)) {
            log.info("sessionId&_uid&_friendId不能为空");
            session.sendMessage(new TextMessage("error:sessionId&_uid&_friendId不能为空"));
            return;
        }
        String dialogId = pushMessage(_uid, _friendId, _message);
        if (dialogId != null) {
            TextMessage textMessage = new TextMessage("dialogId:" + dialogId);
            // 向自己的ws推送消息
            session.sendMessage(textMessage);
            String sessionIdForFriend = userMap.get(_friendId);
            log.info("sessionIdForFriend={}", sessionIdForFriend);
            if (StringUtils.hasLength(sessionIdForFriend)) {
                WebSocketSession friendSession = sessionMap.get(sessionIdForFriend);
                if (friendSession != null && friendSession.isOpen())
                    // 向朋友推送消息
                    friendSession.sendMessage(textMessage);
            }
        }
    }

    /**
     * 连接出错会调用
     */
    @Override
    public void handleTransportError(WebSocketSession session, Throwable exception) {
        String uid = String.valueOf(session.getAttributes().get("uid"));
        String sessionId = session.getId();
        log.info("CLOSED uid= ={},sessionId={}", uid, sessionId);
        sessionMap.remove(sessionId);
    }

    /**
     * 连接关闭会调用
     */
    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) {
        String uid = String.valueOf(session.getAttributes().get("uid"));
        String sessionId = session.getId();
        log.info("CLOSED uid= ={},sessionId={}", uid, sessionId);
        sessionMap.remove(sessionId);
    }

    @Override
    public boolean supportsPartialMessages() {
        return false;
    }

    /**
     * 后台发送消息到前端
     * 封装方法发送消息到客户端
     */
    public static void sendMessage(String uid, String dialogId) {
        log.info("发送消息到:");
    }

    /**
     * @param uid
     * @param friendId
     * @param message
     * @return
     */
    private String pushMessage(String uid, String friendId, String message) {
        log.info("uid={},friendId={},message={}", uid, friendId, message);
        NotifyMailboxService notifyMailboxService = getNotifyMailboxService();
        NotifyMailboxMessageService notifyMailboxMessageService = getNotifyMailboxMessageService();
        try {
            NotifyMailbox notifyMailbox = notifyMailboxService.queryBy(uid, friendId);
            
        } catch (Exception e) {
            log.info("exception msg={}", e.getMessage());
            return null;
        }

    }
}

websocket前端状态码readyState

0        CONNECTING        连接尚未建立
1        OPEN            WebSocket的链接已经建立
2        CLOSING            连接正在关闭
3        CLOSED            连接已经关闭或不可用

image.png
image.png

总结

1、server对象并未注入而是为 null,所以要通过添加上面的SpringContext类,并通过下面这种方式引用

public AccountFeignClient getAccountFeignClient() {
    return SpringContext.getBean(AccountFeignClient.class);
}

public NotifyMailboxService getNotifyMailboxService() {
    return SpringContext.getBean(NotifyMailboxService.class);
}

public NotifyMailboxMessageService getNotifyMailboxMessageService() {
    return SpringContext.getBean(NotifyMailboxMessageService.class);
}

2、websocket的连接和关闭的session对应问题,使用上面代码就没问题,否则会出现连接上的问题。

3、WebSocketInterceptor会获取https://www.baidu.com/service-name/websocket?uid=xxx中的uid并注入到session中,因此WebSocketHandler类才能获取到session中的uid参数。

引用

WebSocket 教程
webSocket 中使用 @Autowired 注入对应为null


Awbeci
3.1k 声望215 粉丝

Awbeci