自己想法和实现,如果有说错的或者有更好的简单的实现方式可以私信交流一下(主要是实现握手时鉴权)
- 握手鉴权是基于前台请求头 Sec-WebSocket-Protocol的
- 本身socket并没有提供自定义请求头,只能自定义 Sec-WebSocket-Protocol的自协议
socket握手请求是基于http的,握手成功后会升级为ws
前台传输了 token作为Sec-WebSocket-Protocol的值,后台接收到后总是断开连接,后来网上看了很多博客说的都是大同小异,然后就看了他的源码一步步走的(倔脾气哈哈),终于我看到了端倪,这个问题是因为前后台的Sec-WebSocket-Protocol值不一致,所以会断开,但是我记得websocket好像是不用自己设置请求头的,但是netty我看了源码,好像没有预留设置websocket的response的响应头(这只是我的个人理解)
解释: 自定义替换WebSocketProtocolHandler,复制WebSocketProtocolHandler的内容即可,因为主要是WebSocketServerProtocolHandler自定义会用到
abstract class CustomWebSocketProtocolHandler extends MessageToMessageDecoder {@Overrideprotected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List
解释: 自定义WebSocketServerProtocolHandler,实现上面自定义的WebSocketProtocolHandler,具体内容和WebSocketServerProtocolHandler保持一致,只需要将handlerAdded中的类ProtocolHandler改为自己定义的即可
注意:后面监听读写的自定义业务的handler需要实现相应的方法:异常或者事件监听,因为比如异常,如果抛出异常了,是不会有控制器去管的,因为当前的业务控制器就是最后一层,因为上面已经把默认实现改成了自己的实现(其他的控制器都是基于默认handler实现的,如果改了后,去初始化自己改后的handler那便是最后一层),所以要手动去关闭
ublic class CustomWebSocketServerProtocolHandler extends CustomWebSocketProtocolHandler {/*** Events that are fired to notify about handshake status*/public enum ServerHandshakeStateEvent {/*** The Handshake was completed successfully and the channel was upgraded to websockets.** @deprecated in favor of {@link WebSocketServerProtocolHandler.HandshakeComplete} class,* it provides extra information about the handshake*/@DeprecatedHANDSHAKE_COMPLETE}/*** The Handshake was completed successfully and the channel was upgraded to websockets.*/public static final class HandshakeComplete {private final String requestUri;private final HttpHeaders requestHeaders;private final String selectedSubprotocol;public HandshakeComplete(String requestUri, HttpHeaders requestHeaders, String selectedSubprotocol) {this.requestUri = requestUri;this.requestHeaders = requestHeaders;this.selectedSubprotocol = selectedSubprotocol;}public String requestUri() {return requestUri;}public HttpHeaders requestHeaders() {return requestHeaders;}public String selectedSubprotocol() {return selectedSubprotocol;}}private static final AttributeKey HANDSHAKER_ATTR_KEY =AttributeKey.valueOf(WebSocketServerHandshaker.class, "HANDSHAKER");private final String websocketPath;private final String subprotocols;private final boolean allowExtensions;private final int maxFramePayloadLength;private final boolean allowMaskMismatch;private final boolean checkStartsWith;public CustomWebSocketServerProtocolHandler(String websocketPath) {this(websocketPath, null, false);}public CustomWebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith) {this(websocketPath, null, false, 65536, false, checkStartsWith);}public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols) {this(websocketPath, subprotocols, false);}public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions) {this(websocketPath, subprotocols, allowExtensions, 65536);}public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols,boolean allowExtensions, int maxFrameSize) {this(websocketPath, subprotocols, allowExtensions, maxFrameSize, false);}public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols,boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) {this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false);}public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols,boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) {this.websocketPath = websocketPath;this.subprotocols = subprotocols;this.allowExtensions = allowExtensions;maxFramePayloadLength = maxFrameSize;this.allowMaskMismatch = allowMaskMismatch;this.checkStartsWith = checkStartsWith;}@Overridepublic void handlerAdded(ChannelHandlerContext ctx) {ChannelPipeline cp = ctx.pipeline();if (cp.get(CustomWebSocketServerProtocolHandler.class) == null) {// Add the WebSocketHandshakeHandler before this one.ctx.pipeline().addBefore(ctx.name(), CustomWebSocketServerProtocolHandler.class.getName(),new CustomWebSocketServerProtocolHandler(websocketPath, subprotocols,allowExtensions, maxFramePayloadLength, allowMaskMismatch, checkStartsWith));}if (cp.get(Utf8FrameValidator.class) == null) {// Add the UFT8 checking before this one.ctx.pipeline().addBefore(ctx.name(), Utf8FrameValidator.class.getName(),new Utf8FrameValidator());}}@Overrideprotected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List
用SecurityServerHandler自定义的入站控制器替换原有默认的控制器WebSocketServerProtocolHandshakeHandler
这一步最关键了,因为在这一步就要将头设置进去,前面两步只是为这一步做铺垫,因为netty包中的类不能外部引用也没有提供修改方法,所以才有了上面的自定义类,此类中需要调整握手逻辑,添加握手响应头,然后将WebSocketServerProtocolHandler改为CustomWebSocketServerProtocolHandler,其他的实现类也是一样的去改
public class SecurityServerHandler extends ChannelInboundHandlerAdapter {private final String websocketPath;private final String subprotocols;private final boolean allowExtensions;private final int maxFramePayloadSize;private final boolean allowMaskMismatch;private final boolean checkStartsWith;/*** 自定义属性 token头key*/private final String tokenHeader;/*** 自定义属性 token*/private final boolean hasToken;public SecurityServerHandler(String websocketPath, String subprotocols,boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, String tokenHeader, boolean hasToken) {this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false,tokenHeader,hasToken);}SecurityServerHandler(String websocketPath, String subprotocols,boolean allowExtensions, int maxFrameSize,boolean allowMaskMismatch,boolean checkStartsWith,String tokenHeader,boolean hasToken) {this.websocketPath = websocketPath;this.subprotocols = subprotocols;this.allowExtensions = allowExtensions;maxFramePayloadSize = maxFrameSize;this.allowMaskMismatch = allowMaskMismatch;this.checkStartsWith = checkStartsWith;this.tokenHeader = tokenHeader;this.hasToken = hasToken;}@Overridepublic void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception {final FullHttpRequest req = (FullHttpRequest) msg;if (isNotWebSocketPath(req)) {ctx.fireChannelRead(msg);return;}try {// 具体的鉴权逻辑HttpHeaders headers = req.headers();String token = Objects.requireNonNull(headers.get(tokenHeader));if(hasToken){// 开启鉴权 认证//extracts device information headersLoginUser loginUser = SecurityUtils.getLoginUser(token);if(null == loginUser){refuseChannel(ctx);return;}Long userId = loginUser.getUserId();//check ......SecurityCheckComplete complete = new SecurityCheckComplete(String.valueOf(userId),tokenHeader,hasToken);ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).set(complete);ctx.fireUserEventTriggered(complete);}else {// 不开启鉴权 / 认证SecurityCheckComplete complete = new SecurityCheckComplete(null,tokenHeader,hasToken);ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).set(complete);}if (req.method() != GET) {sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN));return;}final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(getWebSocketLocation(ctx.pipeline(), req, websocketPath), subprotocols,allowExtensions, maxFramePayloadSize, allowExtensions);final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);if (handshaker == null) {WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());} else {// 此处将具体的头加入http中,因为这个头会传递个netty底层设置响应头的方法中,默认实现是传的nullHttpHeaders httpHeaders = new DefaultHttpHeaders().add(tokenHeader,token);// 此处便是构造握手相应头的关键步骤final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req,httpHeaders,ctx.channel().newPromise());handshakeFuture.addListener((ChannelFutureListener) future -> {if (!future.isSuccess()) {ctx.fireExceptionCaught(future.cause());} else {// Kept for compatibilityctx.fireUserEventTriggered(CustomWebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE);ctx.fireUserEventTriggered(new CustomWebSocketServerProtocolHandler.HandshakeComplete(req.uri(), req.headers(), handshaker.selectedSubprotocol()));}});CustomWebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker);ctx.pipeline().replace(this, "WS403Responder",CustomWebSocketServerProtocolHandler.forbiddenHttpRequestResponder());}}catch (Exception e){e.printStackTrace();}finally {req.release();}}public static final class HandshakeComplete {private final String requestUri;private final HttpHeaders requestHeaders;private final String selectedSubprotocol;HandshakeComplete(String requestUri, HttpHeaders requestHeaders, String selectedSubprotocol) {this.requestUri = requestUri;this.requestHeaders = requestHeaders;this.selectedSubprotocol = selectedSubprotocol;}public String requestUri() {return requestUri;}public HttpHeaders requestHeaders() {return requestHeaders;}public String selectedSubprotocol() {return selectedSubprotocol;}}private boolean isNotWebSocketPath(FullHttpRequest req) {return checkStartsWith ? !req.uri().startsWith(websocketPath) : !req.uri().equals(websocketPath);}private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) {ChannelFuture f = ctx.channel().writeAndFlush(res);if (!isKeepAlive(req) || res.status().code() != 200) {f.addListener(ChannelFutureListener.CLOSE);}}private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) {String protocol = "ws";if (cp.get(SslHandler.class) != null) {// SSL in use so use Secure WebSocketsprotocol = "wss";}String host = req.headers().get(HttpHeaderNames.HOST);return protocol + "://" + host + path;}private void refuseChannel(ChannelHandlerContext ctx) {ctx.channel().writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED));ctx.channel().close();}private static void send100Continue(ChannelHandlerContext ctx,String tokenHeader,String token) {DefaultFullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE);response.headers().set(tokenHeader,token);ctx.writeAndFlush(response);}@Overridepublic void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {System.out.println("channel 捕获到异常了,关闭了");super.exceptionCaught(ctx, cause);}@Getter@AllArgsConstructorpublic static final class SecurityCheckComplete {private String userId;private String tokenHeader;private Boolean hasToken;}
}
其他的类需要自己实现或者引用,其他的就是无关紧要的,不用去处理的类
@Overrideprotected void initChannel(SocketChannel ch){log.info("有新的连接");//获取工人所要做的工程(管道器==管道器对应的便是管道channel)ChannelPipeline pipeline = ch.pipeline();//为工人的工程按顺序添加工序/材料 (为管道器设置对应的handler也就是控制器)//1.设置心跳机制pipeline.addLast("idle-state",new IdleStateHandler(nettyWebSocketProperties.getReaderIdleTime(),0,0,TimeUnit.SECONDS));//2.出入站时的控制器,大部分用于针对心跳机制pipeline.addLast("change-duple",new WsChannelDupleHandler(nettyWebSocketProperties.getReaderIdleTime()));//3.加解码pipeline.addLast("http-codec",new HttpServerCodec());//3.打印控制器,为工人提供明显可见的操作结果的样式pipeline.addLast("logging", new LoggingHandler(LogLevel.INFO));pipeline.addLast("aggregator",new HttpObjectAggregator(8192));// 将自己的授权handler替换原有的handlerpipeline.addLast("auth",new SecurityServerHandler(// 此处我是用的yaml配置的,换成自己的即可nettyWebSocketProperties.getWebsocketPath(),nettyWebSocketProperties.getSubProtocols(),nettyWebSocketProperties.getAllowExtensions(),nettyWebSocketProperties.getMaxFrameSize(),//todofalse,nettyWebSocketProperties.getTokenHeader(),nettyWebSocketProperties.getHasToken()));pipeline.addLast("http-chunked",new ChunkedWriteHandler());// 将自己的协议控制器替换原有的协议控制器pipeline.addLast("websocket",new CustomWebSocketServerProtocolHandler(nettyWebSocketProperties.getWebsocketPath(),nettyWebSocketProperties.getSubProtocols(),nettyWebSocketProperties.getAllowExtensions(),nettyWebSocketProperties.getMaxFrameSize()));//7.自定义的handler针对业务pipeline.addLast("chat-handler",new ChatHandler());}
调整为自定义请求头解析,但不去替换其他handler
package com.edu.message.handler.security;import com.edu.common.utils.SecurityUtils;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.FullHttpMessage;
import io.netty.handler.codec.http.HttpHeaders;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;import java.util.Objects;import static com.edu.message.handler.attributeKey.AttributeKeyUtils.SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY;/*** @author Administrator*/
@Slf4j
public class SecurityServerHandler extends ChannelInboundHandlerAdapter {private String tokenHeader;private Boolean hasToken;public SecurityServerHandler(String tokenHeader,Boolean hasToken){this.tokenHeader = tokenHeader;this.hasToken = hasToken;}private SecurityServerHandler(){}@Overridepublic void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {if(msg instanceof FullHttpMessage){FullHttpMessage httpMessage = (FullHttpMessage) msg;HttpHeaders headers = httpMessage.headers();String token = Objects.requireNonNull(headers.get(tokenHeader));if(hasToken){// 开启鉴权 认证//extracts device information headersLong userId = 12345L;//SecurityUtils.getLoginUser(token).getUserId();//check ......SecurityCheckComplete complete = new SecurityCheckComplete(String.valueOf(userId),tokenHeader,hasToken);ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).set(complete);ctx.fireUserEventTriggered(complete);}else {// 不开启鉴权 / 认证SecurityCheckComplete complete = new SecurityCheckComplete(null,tokenHeader,hasToken);ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).set(complete);}}//other protocolssuper.channelRead(ctx, msg);}@Overridepublic void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {System.out.println("channel 捕获到异常了,关闭了");super.exceptionCaught(ctx, cause);}@Getter@AllArgsConstructorpublic static final class SecurityCheckComplete {private String userId;private String tokenHeader;private Boolean hasToken;}
}
改为使用默认实现
@Overrideprotected void initChannel(SocketChannel ch){log.info("有新的连接");//获取工人所要做的工程(管道器==管道器对应的便是管道channel)ChannelPipeline pipeline = ch.pipeline();//为工人的工程按顺序添加工序/材料 (为管道器设置对应的handler也就是控制器)//1.设置心跳机制pipeline.addLast("idle-state",new IdleStateHandler(nettyWebSocketProperties.getReaderIdleTime(),0,0,TimeUnit.SECONDS));//2.出入站时的控制器,大部分用于针对心跳机制pipeline.addLast("change-duple",new WsChannelDupleHandler(nettyWebSocketProperties.getReaderIdleTime()));//3.加解码pipeline.addLast("http-codec",new HttpServerCodec());//3.打印控制器,为工人提供明显可见的操作结果的样式pipeline.addLast("logging", new LoggingHandler(LogLevel.INFO));pipeline.addLast("aggregator",new HttpObjectAggregator(8192));pipeline.addLast("auth",new SecurityServerHandler(nettyWebSocketProperties.getTokenHeader(),nettyWebSocketProperties.getHasToken()));pipeline.addLast("http-chunked",new ChunkedWriteHandler());
// pipeline.addLast("websocket",
// new CustomWebSocketServerProtocolHandler(
// nettyWebSocketProperties.getWebsocketPath(),
// nettyWebSocketProperties.getSubProtocols(),
// nettyWebSocketProperties.getAllowExtensions(),
// nettyWebSocketProperties.getMaxFrameSize())
// );pipeline.addLast("websocket",new WebSocketServerProtocolHandler(nettyWebSocketProperties.getWebsocketPath(),nettyWebSocketProperties.getSubProtocols(),nettyWebSocketProperties.getAllowExtensions(),nettyWebSocketProperties.getMaxFrameSize()));//7.自定义的handler针对业务pipeline.addLast("chat-handler",new ChatHandler());}
第一步走到了自己定义的鉴权控制器(入站控制器),执行channelRead方法
自定义业务handler中的事件方法
此处便是走到了默认协议控制器的channelRead方法,需要注意handshaker.handshake(ctx.channel(), req) 这个方法,这是处理握手的方法,打个断点进去
可以看到handshake 方法传的 HttpHeaders是null,这里就是核心的握手逻辑可以看到并没有提供相应的头处理器
newHandshakeResponse(req, responseHeaders) 就是构建响应结果,可以看到头是null
可以看到有回到了自定义handler的业务控制器 中的时间监听方法
此时只要放行这一步便会在控制台打印出响应头,可以看出并没有设置我们自己的响应头,还是null
最后统一返回,连接中断,自协议头不一致所导致