diff --git a/backend/src/main/java/io/metersphere/commons/utils/ShiroUtils.java b/backend/src/main/java/io/metersphere/commons/utils/ShiroUtils.java index cd5b511398..7c21289b99 100644 --- a/backend/src/main/java/io/metersphere/commons/utils/ShiroUtils.java +++ b/backend/src/main/java/io/metersphere/commons/utils/ShiroUtils.java @@ -1,7 +1,10 @@ package io.metersphere.commons.utils; +import io.metersphere.security.CustomSessionIdGenerator; +import io.metersphere.security.CustomSessionManager; import org.apache.shiro.cache.CacheManager; import org.apache.shiro.session.mgt.SessionManager; +import org.apache.shiro.session.mgt.eis.AbstractSessionDAO; import org.apache.shiro.web.servlet.Cookie; import org.apache.shiro.web.servlet.SimpleCookie; import org.apache.shiro.web.session.mgt.DefaultWebSessionManager; @@ -78,13 +81,15 @@ public class ShiroUtils { } public static SessionManager getSessionManager(Long sessionTimeout, CacheManager cacheManager){ - DefaultWebSessionManager sessionManager = new DefaultWebSessionManager(); + DefaultWebSessionManager sessionManager = new CustomSessionManager(); sessionManager.setSessionIdUrlRewritingEnabled(false); sessionManager.setDeleteInvalidSessions(true); sessionManager.setSessionValidationSchedulerEnabled(true); sessionManager.setSessionIdCookie(ShiroUtils.getSessionIdCookie()); sessionManager.setGlobalSessionTimeout(sessionTimeout * 1000);// 超时时间ms sessionManager.setCacheManager(cacheManager); + AbstractSessionDAO sessionDAO = (AbstractSessionDAO) sessionManager.getSessionDAO(); + sessionDAO.setSessionIdGenerator(new CustomSessionIdGenerator()); //sessionManager.setSessionIdCookieEnabled(true); return sessionManager; diff --git a/backend/src/main/java/io/metersphere/security/ApiKeyFilter.java b/backend/src/main/java/io/metersphere/security/ApiKeyFilter.java index 8cadbac193..c510f4aab7 100644 --- a/backend/src/main/java/io/metersphere/security/ApiKeyFilter.java +++ b/backend/src/main/java/io/metersphere/security/ApiKeyFilter.java @@ -28,6 +28,12 @@ public class ApiKeyFilter extends AnonymousFilter { if (ApiKeyHandler.isApiKeyCall(WebUtils.toHttp(request))) { String userId = ApiKeyHandler.getUser(WebUtils.toHttp(request)); SecurityUtils.getSubject().login(new MsUserToken(userId, ApiKeySessionHandler.random, "LOCAL")); + } else { + String id = (String) SecurityUtils.getSubject().getSession().getId(); + // 防止调用时使用 ak 作为 cookie 跳过登入逻辑 + if (id.length() < 20) { + SecurityUtils.getSubject().logout(); + } } } @@ -43,4 +49,4 @@ public class ApiKeyFilter extends AnonymousFilter { return true; } -} \ No newline at end of file +} diff --git a/backend/src/main/java/io/metersphere/security/CustomSessionIdGenerator.java b/backend/src/main/java/io/metersphere/security/CustomSessionIdGenerator.java new file mode 100644 index 0000000000..769dd8ab42 --- /dev/null +++ b/backend/src/main/java/io/metersphere/security/CustomSessionIdGenerator.java @@ -0,0 +1,19 @@ +package io.metersphere.security; + +import org.apache.commons.lang3.StringUtils; +import org.apache.shiro.session.Session; +import org.apache.shiro.session.mgt.eis.SessionIdGenerator; + +import java.io.Serializable; +import java.util.UUID; + +public class CustomSessionIdGenerator implements SessionIdGenerator { + @Override + public Serializable generateId(Session session) { + String threadSessionId = CustomSessionManager.threadSessionId.get(); + if (StringUtils.isNotBlank(threadSessionId)) { + return threadSessionId; + } + return UUID.randomUUID().toString(); + } +} diff --git a/backend/src/main/java/io/metersphere/security/CustomSessionManager.java b/backend/src/main/java/io/metersphere/security/CustomSessionManager.java new file mode 100644 index 0000000000..e9d9b3be3b --- /dev/null +++ b/backend/src/main/java/io/metersphere/security/CustomSessionManager.java @@ -0,0 +1,31 @@ +package io.metersphere.security; + +import org.apache.shiro.web.session.mgt.DefaultWebSessionManager; +import org.apache.shiro.web.util.WebUtils; + +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import java.io.Serializable; + +public class CustomSessionManager extends DefaultWebSessionManager { + + static final ThreadLocal threadSessionId = new ThreadLocal<>(); + + @Override + protected Serializable getSessionId(ServletRequest request, ServletResponse response) { + String id = null; + HttpServletRequest httpRequest = WebUtils.toHttp(request); + if (ApiKeyHandler.isApiKeyCall(httpRequest)) { + // API调用同一个ak使用同一个session,避免调用频繁,导致session过多,内存泄漏 + id = httpRequest.getHeader(ApiKeyHandler.API_ACCESS_KEY); + setSessionIdCookieEnabled(false); + threadSessionId.set(id); + return id; + } + // 线程池中线程可能会复用,非api删除 + threadSessionId.remove(); + setSessionIdCookieEnabled(true); + return super.getSessionId(request, response); + } +}