@ -7,19 +7,15 @@ import org.springframework.beans.factory.annotation.Autowired;
@@ -7,19 +7,15 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.cloud.gateway.filter.GatewayFilterChain ;
import org.springframework.cloud.gateway.filter.GlobalFilter ;
import org.springframework.core.Ordered ;
import org.springframework.core.io.buffer.DataBufferFactory ;
import org.springframework.data.redis.core.ValueOperations ;
import org.springframework.http.HttpStatus ;
import org.springframework.http.MediaType ;
import org.springframework.http.server.reactive.ServerHttpRequest ;
import org.springframework.http.server.reactive.ServerHttpResponse ;
import org.springframework.stereotype.Component ;
import org.springframework.web.server.ServerWebExchange ;
import com.alibaba.fastjson.JSON ;
import com.alibaba.fastjson.JSONObject ;
import com.ruoyi.common.core.constant.CacheConstants ;
import com.ruoyi.common.core.constant.Constants ;
import com.ruoyi.common.core.domain.R ;
import com.ruoyi.common.core.constant.HttpStatus ;
import com.ruoyi.common.core.utils.SecurityUtils ;
import com.ruoyi.common.core.utils.ServletUtils ;
import com.ruoyi.common.core.utils.StringUtils ;
import com.ruoyi.common.redis.service.RedisService ;
@ -35,7 +31,7 @@ import reactor.core.publisher.Mono;
@@ -35,7 +31,7 @@ import reactor.core.publisher.Mono;
public class AuthFilter implements GlobalFilter , Ordered
{
private static final Logger log = LoggerFactory . getLogger ( AuthFilter . class ) ;
private final static long EXPIRE_TIME = Constants . TOKEN_EXPIRE * 60 ;
// 排除过滤的 uri 地址,nacos自行添加
@ -44,61 +40,68 @@ public class AuthFilter implements GlobalFilter, Ordered
@@ -44,61 +40,68 @@ public class AuthFilter implements GlobalFilter, Ordered
@Resource ( name = "stringRedisTemplate" )
private ValueOperations < String , String > sops ;
@Autowired
private RedisService redisService ;
@Override
public Mono < Void > filter ( ServerWebExchange exchange , GatewayFilterChain chain )
{
String url = exchange . getRequest ( ) . getURI ( ) . getPath ( ) ;
ServerHttpRequest request = exchange . getRequest ( ) ;
ServerHttpRequest . Builder mutate = request . mutate ( ) ;
String url = request . getURI ( ) . getPath ( ) ;
// 跳过不需要验证的路径
if ( StringUtils . matches ( url , ignoreWhite . getWhites ( ) ) )
{
return chain . filter ( exchange ) ;
}
String token = getToken ( exchange . getRequest ( ) ) ;
if ( StringUtils . isBlank ( token ) )
String token = getToken ( request ) ;
if ( StringUtils . isEmpty ( token ) )
{
return setU nauthorizedResponse( exchange , "令牌不能为空" ) ;
return u nauthorizedResponse( exchange , "令牌不能为空" ) ;
}
String userStr = sops . get ( getTokenKey ( token ) ) ;
if ( StringUtils . isNull ( userStr ) )
if ( StringUtils . isEmpty ( userStr ) )
{
return setU nauthorizedResponse( exchange , "登录状态已过期" ) ;
return u nauthorizedResponse( exchange , "登录状态已过期" ) ;
}
JSONObject o bj = JSONObject . parseObject ( userStr ) ;
String userid = o bj. getString ( "userid" ) ;
String username = o bj. getString ( "username" ) ;
if ( StringUtils . isBlank ( userid ) | | StringUtils . isBlank ( username ) )
JSONObject cacheO bj = JSONObject . parseObject ( userStr ) ;
String userid = cacheO bj. getString ( "userid" ) ;
String username = cacheO bj. getString ( "username" ) ;
if ( StringUtils . isEmpty ( userid ) | | StringUtils . isEmpty ( username ) )
{
return setU nauthorizedResponse( exchange , "令牌验证失败" ) ;
return u nauthorizedResponse( exchange , "令牌验证失败" ) ;
}
// 设置过期时间
redisService . expire ( getTokenKey ( token ) , EXPIRE_TIME ) ;
// 设置用户信息到请求
ServerHttpRequest mutableReq = exchange . getRequest ( ) . mutate ( ) . header ( CacheConstants . DETAILS_USER_ID , userid )
. header ( CacheConstants . DETAILS_USERNAME , ServletUtils . urlEncode ( username ) ) . build ( ) ;
ServerWebExchange mutableExchange = exchange . mutate ( ) . request ( mutableReq ) . build ( ) ;
return chain . filter ( mutableExchange ) ;
addHeader ( mutate , CacheConstants . DETAILS_USER_ID , userid ) ;
addHeader ( mutate , CacheConstants . DETAILS_USERNAME , username ) ;
return chain . filter ( exchange . mutate ( ) . request ( mutate . build ( ) ) . build ( ) ) ;
}
private Mono < Void > setUnauthorizedResponse ( ServerWebExchange exchange , String msg )
private void addHeader ( ServerHttpRequest . Builder mutate , String name , Object value )
{
ServerHttpResponse response = exchange . getResponse ( ) ;
response . getHeaders ( ) . setContentType ( MediaType . APPLICATION_JSON ) ;
response . setStatusCode ( HttpStatus . OK ) ;
if ( value = = null )
{
return ;
}
String valueStr = value . toString ( ) ;
String valueEncode = ServletUtils . urlEncode ( valueStr ) ;
mutate . header ( name , valueEncode ) ;
}
private Mono < Void > unauthorizedResponse ( ServerWebExchange exchange , String msg )
{
log . error ( "[鉴权异常处理]请求路径:{}" , exchange . getRequest ( ) . getPath ( ) ) ;
return response . writeWith ( Mono . fromSupplier ( ( ) - > {
DataBufferFactory bufferFactory = response . bufferFactory ( ) ;
return bufferFactory . wrap ( JSON . toJSONBytes ( R . fail ( HttpStatus . UNAUTHORIZED . value ( ) , msg ) ) ) ;
} ) ) ;
return ServletUtils . webFluxResponseWriter ( exchange . getResponse ( ) , msg , HttpStatus . UNAUTHORIZED ) ;
}
/ * *
* 获取缓存key
* /
private String getTokenKey ( String token )
{
return CacheConstants . LOGIN_TOKEN_KEY + token ;
@ -109,12 +112,8 @@ public class AuthFilter implements GlobalFilter, Ordered
@@ -109,12 +112,8 @@ public class AuthFilter implements GlobalFilter, Ordered
* /
private String getToken ( ServerHttpRequest request )
{
String token = request . getHeaders ( ) . getFirst ( CacheConstants . HEADER ) ;
if ( StringUtils . isNotEmpty ( token ) & & token . startsWith ( CacheConstants . TOKEN_PREFIX ) )
{
token = token . replace ( CacheConstants . TOKEN_PREFIX , "" ) ;
}
return token ;
String token = request . getHeaders ( ) . getFirst ( CacheConstants . TOKEN_AUTHENTICATION ) ;
return SecurityUtils . replaceTokenPrefix ( token ) ;
}
@Override