在学习了Web-Socket的课程后:

  • WebSocket入门与案例实战
    我在想如果是在多实例的场景下,WebSocketSession该如何在不同的jvm中共享呢?在看了一些文章和查询ai后,我做了一个实现。

1. 目录结构

本次的测试实在我原有项目上进行的,因此代码中可能包含其他的业务逻辑,下面的结构仅包含与WebSocket相关的文件。

D:.  
└─site  
    └─lazyking  
        ├─filebox  
           │  FileBoxApplication.java // 启动类  
           │  
           ├─config  
           │      WebSocketConfig.java // web-socket 配置类  
           │  
           ├─controller  
           │      WsController.java // 处理WebSocket连接后的事件  
           │  
           ├─filter  
           │      WebSocketInterceptor.java // 对WebSocket"握手"前后进行处理  
           │  
           ├─listener  
           │      UserMessageListener.java // 订阅redis消息  
           │  
           ├─service  
           │  │  WsMessageService.java   
           │  │  
           │  └─impl  
           │          WsMessageServiceImpl.java // 封装了消息发送的业务  
           │  
           └─util  
                   WebSocketManager.java // 存储和管理WebSocketSession

2. 实现

实现思路:

  1. 建立WebSocket连接

    • 在客户端发送请求之后,会先被WebSocketInterceptor拦截,在其beforeHandshake方法中进行身份认证,并取出用户id放入到session中。
    • 在连接建立成功后,会被WsController中的afterConnectionEstablished方法进行处理,将连接的Session保存到redis(以用户id为键)和WebSocketManager中。
    • redis中存放session的目的是为了判断用户是否在线,如果存在则将消息发布到redis中,否则将消息存储到redis中。
    • 此外WsController还有
      • handleTextMessage:对客户端发送来的消息进行处理
      • handleTransportError:对连接抛出的异常进行处理
      • afterConnectionClosed:连接关闭后的处理
  2. 消息发送

    • WsMessageService对服务端向客户端发送消息进行了封装,提供sendMessage方法
    • 在sendMessage方法的实现中,首先会判断需要接收消息的用户是否在线,如果在线将消息发布到redis上,否则将消息保存到redis上
    • UserMessageListener对订阅的消息进行处理,在听到广播后,就去查询自己所在WebSocketManager是否有收信用户的Session,如果存在就发送消息

3. 代码

3.1 WebSocketConfig

@Configuration  
@EnableWebSocket  
@RequiredArgsConstructor  
public class WebSocketConfig implements WebSocketConfigurer {  
  
    private final WsController wsController;  
    private final WebSocketInterceptor webSocketInterceptor;  
  
    @Override  
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {  
        registry.addHandler(wsController, "/message")  
                .addInterceptors(webSocketInterceptor)  
                .setAllowedOrigins("*");  
    }  
    
	// 这里偷懒没有创建redis的配置文件,因此将订阅的配置放到了这里
    @Bean  
    public RedisMessageListenerContainer container(  
            RedisConnectionFactory connectionFactory,  
            UserMessageListener listener) {  
  
        RedisMessageListenerContainer container = new RedisMessageListenerContainer();  
        container.setConnectionFactory(connectionFactory);  
        container.addMessageListener(listener, new PatternTopic("ws:push:*"));  
        return container;  
    }  
}	

3.2 WsController

@Slf4j  
@Component  
@RequiredArgsConstructor  
public class WsController extends AbstractWebSocketHandler {  
  
    private final StringRedisTemplate redisTemplate;  
  
    @Override  
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {  
        super.afterConnectionEstablished(session);  
        log.info("经过:afterConnectionEstablished");  
        // 获取userId  
        String userId = (String) session.getAttributes().get("userId");  
        // 以userId为键,将sessionId存入redis中  
        redisTemplate.opsForValue().set("ws:user:" + userId, session.getId(), Duration.ofMinutes(30));  
        // 将session保存到本地  
        WebSocketManager.addSession(userId, session);  
  
        log.info("WebSocket连接建立,userId={}, sessionId={}, IP={}", userId, session.getId(),  
                session.getRemoteAddress());  
    }  
  
    @Override  
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {  
        super.afterConnectionClosed(session, status);  
        // 连接关闭,从redis中删除session  
        String userId = (String) session.getAttributes().get("userId");  
        if (userId != null) {  
            // 从redis中删除session  
            redisTemplate.delete("ws:user:" + userId);  
            // 删除本地保存的session  
            WebSocketManager.removeSession(userId);  
  
            log.info("WebSocket连接关闭,userId={}, sessionId={}", userId, session.getId());  
        }  
        log.info("连接关闭了{}", status.getCode());  
    }  
  
    @Override  
    protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {  
        super.handleTextMessage(session, message);  
        log.info("{}: {}", session.getId(), message.getPayload());  
    }  
  
    @Override  
    public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {  
        String userId = (String) session.getAttributes().get("userId");  
        log.error("WebSocket传输错误,userId={}, sessionId={}, error={}", userId, session.getId(), exception.getMessage(), exception);  
  
        if (userId != null) {  
            redisTemplate.delete("ws:user:" + userId);  
            WebSocketManager.removeSession(userId);  
        }  
  
        session.close(CloseStatus.SERVER_ERROR);  
    }  
}

3.3 WebSocketInterceptor

@Slf4j  
@Component  
@RequiredArgsConstructor  
public class WebSocketInterceptor extends HttpSessionHandshakeInterceptor {  
  
    private final JWTUtil jwtUtil;  
  
    @Override  
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {  
  
        log.info("经过了beforeHandshake");  
  
        // 从request中获取cookie  
        List<String> cookies = request.getHeaders().get("Cookie");  
        if (cookies == null) {  
            log.info("cookie为空");  
            return false;        }  
        String token = Arrays.stream(cookies.get(0).split("; "))  
                .filter(cookie -> cookie.startsWith("Authorization"))  
                .findFirst()  
                .orElse(null);  
        if (token == null || token.isEmpty()) {  
            log.info("token不存在");  
            return false;        }  
  
        token = token.split("=")[1];  
  
        log.info("token: {}", token);  
  
        // 判断token是否有效  
        if (!jwtUtil.verifyToken(token)) {  
            log.info("token无效: {}", token);  
            return false;        }  
  
        // token有效,获取用户ID  
        String userId = jwtUtil.getUserId(token);  
  
        // 将用户id存入attributes  
        attributes.put("userId", userId);  
  
        log.info("用户id:{}", userId);  
  
        // 放行请求  
        return super.beforeHandshake(request, response, wsHandler, attributes);  
    }  
  
    @Override  
    public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception ex) {  
        super.afterHandshake(request, response, wsHandler, ex);  
    }  
}

3.4 UserMessageListener

@Component  
public class UserMessageListener implements MessageListener {  
  
    @Override  
    public void onMessage(Message message, byte[] pattern) {  
        String channel = new String(message.getChannel());  
        String userId = channel.substring("ws:push:".length());  
        String content = new String(message.getBody());  
  
        WebSocketSession session = WebSocketManager.getSession(userId);  
        if (session != null && session.isOpen()) {  
            try {  
                session.sendMessage(new TextMessage(content));  
            } catch (IOException e) {  
                e.printStackTrace();  
            }  
        }  
    }  
}

3.5 WsMessageService

public interface WsMessageService {  
    void sendMessage(String recipientId, String message);  
}

3.6 WsMessageServiceImpl

@Service  
@RequiredArgsConstructor  
public class WsMessageServiceImpl implements WsMessageService {  
  
    private final StringRedisTemplate redisTemplate;  
  
    @Override  
    public void sendMessage(String recipientId, String message) {  
        // 查询用户是否在线  
        String status = redisTemplate.opsForValue().get("ws:user:" + recipientId);  
        if (status != null) {  
            // 用户在线,发布消息  
            redisTemplate.convertAndSend("ws:push:" + recipientId, message);  
        } else {  
            // 用户不在线,将消息存入到redis中  
            String hashKey = UUID.randomUUID().toString().replace("_", "");  
            redisTemplate.opsForHash().put("message:user:" + recipientId, hashKey, message);  
        }  
    }  
}

3.7 WebSocketManager

public class WebSocketManager {  
    private static final ConcurrentHashMap<String, WebSocketSession> sessionMap  
            = new ConcurrentHashMap<>();  
  
    public static void addSession(String userId, WebSocketSession session) {  
        sessionMap.put(userId, session);  
    }  
  
    public static WebSocketSession getSession(String userId) {  
        return sessionMap.get(userId);  
    }  
  
    public static void removeSession(String userId) {  
        sessionMap.remove(userId);  
    }  
}

4. 测试

为了测试,我编写了一个api接口进行测试,在其中调用WsMessageService.sendMessage()方法。

MessageController

@RestController  
@RequiredArgsConstructor  
@RequestMapping("/send-message")  
public class MessageController {  
  
    private final WsMessageService wsMessageService;  
  
    @GetMapping("/send")  
    public R<Object> send(@RequestParam("userId") String userId,  
                          @RequestParam("message") String message) {  
        wsMessageService.sendMessage(userId, message);  
        return R.success();  
    }  
  
}

为了进行多实例测试,我启动两个实例,端口分别为8081、8082通过nginx进行负载均衡,nginx的配置文件如下:

nginx.conf

# nginx.conf

worker_processes 1;

events {
    worker_connections 1024;
}

http {
    include       mime.types;
    default_type  application/octet-stream;

    sendfile        on;
    keepalive_timeout  65;

    # 定义负载均衡集群
    upstream springboot_cluster {
        server 127.0.0.1:8081;
        server 127.0.0.1:8082;
    }

    server {
        listen 8080;
        server_name localhost;

        location / {
            proxy_pass http://springboot_cluster;
            proxy_http_version 1.1;
            proxy_set_header Upgrade $http_upgrade;
            proxy_set_header Connection "upgrade";
            proxy_set_header Host $host;
            proxy_set_header X-Real-IP $remote_addr;
            proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
            proxy_set_header X-Forwarded-Proto $scheme;
        }

        error_page 500 502 503 504 /50x.html;
        location = /50x.html {
            root html;
        }
    }
}