fix: SSO登录退出相关问题

This commit is contained in:
CaptainB 2022-10-21 12:01:40 +08:00 committed by 刘瑞斌
parent f5a2b5dd14
commit d8d9378c4f
9 changed files with 81 additions and 30 deletions

View File

@ -46,10 +46,12 @@ public class LoginController {
private ReactiveRedisSessionRepository reactiveRedisSessionRepository; private ReactiveRedisSessionRepository reactiveRedisSessionRepository;
@GetMapping(value = "/is-login") @GetMapping(value = "/is-login")
public Mono<ResultHolder> isLogin(@RequestHeader(name = SessionConstants.HEADER_TOKEN, required = false) String sessionId) throws Exception { public Mono<ResultHolder> isLogin(@RequestHeader(name = SessionConstants.HEADER_TOKEN, required = false) String sessionId,
@RequestHeader(name = SessionConstants.CSRF_TOKEN, required = false) String csrfToken) throws Exception {
RsaKey rsaKey = RsaUtil.getRsaKey(); RsaKey rsaKey = RsaUtil.getRsaKey();
if (StringUtils.isNotBlank(sessionId)) { if (StringUtils.isNotBlank(sessionId) && StringUtils.isNotBlank(csrfToken)) {
userLoginService.validateCsrfToken(sessionId, csrfToken);
return reactiveRedisSessionRepository.getSessionRedisOperations().opsForHash().get("spring:session:sessions:" + sessionId, "sessionAttr:user") return reactiveRedisSessionRepository.getSessionRedisOperations().opsForHash().get("spring:session:sessions:" + sessionId, "sessionAttr:user")
.switchIfEmpty(Mono.just(rsaKey)) .switchIfEmpty(Mono.just(rsaKey))
.map(r -> { .map(r -> {

View File

@ -2,6 +2,7 @@ package io.metersphere.gateway.controller;
import io.metersphere.commons.constants.OperLogConstants; import io.metersphere.commons.constants.OperLogConstants;
import io.metersphere.commons.constants.OperLogModule; import io.metersphere.commons.constants.OperLogModule;
import io.metersphere.commons.user.SessionUser;
import io.metersphere.commons.utils.CodingUtil; import io.metersphere.commons.utils.CodingUtil;
import io.metersphere.gateway.service.SSOService; import io.metersphere.gateway.service.SSOService;
import io.metersphere.log.annotation.MsAuditLog; import io.metersphere.log.annotation.MsAuditLog;
@ -13,6 +14,7 @@ import reactor.core.publisher.Mono;
import javax.annotation.Resource; import javax.annotation.Resource;
import java.util.Locale; import java.util.Locale;
import java.util.Optional;
@Controller @Controller
@RequestMapping("sso") @RequestMapping("sso")
@ -23,24 +25,24 @@ public class SSOController {
@GetMapping("callback/{authId}") @GetMapping("callback/{authId}")
@MsAuditLog(module = OperLogModule.AUTH_TITLE, type = OperLogConstants.LOGIN, title = "登录") @MsAuditLog(module = OperLogModule.AUTH_TITLE, type = OperLogConstants.LOGIN, title = "登录")
public Rendering callbackWithAuthId(@RequestParam("code") String code, @PathVariable("authId") String authId, WebSession session, Locale locale) throws Exception { public Rendering callbackWithAuthId(@RequestParam("code") String code, @PathVariable("authId") String authId, WebSession session, Locale locale) throws Exception {
ssoService.exchangeToken(code, authId, session, locale); Optional<SessionUser> sessionUser = ssoService.exchangeToken(code, authId, session, locale);
return Rendering.redirectTo("/#/?_token=" + CodingUtil.base64Encoding(session.getId())) return Rendering.redirectTo("/#/?_token=" + CodingUtil.base64Encoding(session.getId()) + "&_csrf=" + sessionUser.get().getCsrfToken())
.build(); .build();
} }
@GetMapping("callback") @GetMapping("callback")
@MsAuditLog(module = OperLogModule.AUTH_TITLE, type = OperLogConstants.LOGIN, title = "登录") @MsAuditLog(module = OperLogModule.AUTH_TITLE, type = OperLogConstants.LOGIN, title = "登录")
public Rendering callback(@RequestParam("code") String code, @RequestParam("state") String authId, WebSession session, Locale locale) throws Exception { public Rendering callback(@RequestParam("code") String code, @RequestParam("state") String authId, WebSession session, Locale locale) throws Exception {
ssoService.exchangeToken(code, authId, session, locale); Optional<SessionUser> sessionUser = ssoService.exchangeToken(code, authId, session, locale);
return Rendering.redirectTo("/#/?_token=" + CodingUtil.base64Encoding(session.getId())) return Rendering.redirectTo("/#/?_token=" + CodingUtil.base64Encoding(session.getId()) + "&_csrf=" + sessionUser.get().getCsrfToken())
.build(); .build();
} }
@GetMapping("/callback/cas/{authId}") @GetMapping("/callback/cas/{authId}")
@MsAuditLog(module = OperLogModule.AUTH_TITLE, type = OperLogConstants.LOGIN, title = "登录") @MsAuditLog(module = OperLogModule.AUTH_TITLE, type = OperLogConstants.LOGIN, title = "登录")
public Rendering casCallback(@RequestParam("ticket") String ticket, @PathVariable("authId") String authId, WebSession session, Locale locale) throws Exception { public Rendering casCallback(@RequestParam("ticket") String ticket, @PathVariable("authId") String authId, WebSession session, Locale locale) throws Exception {
ssoService.serviceValidate(ticket, authId, session, locale); Optional<SessionUser> sessionUser = ssoService.serviceValidate(ticket, authId, session, locale);
return Rendering.redirectTo("/#/?_token=" + CodingUtil.base64Encoding(session.getId())) return Rendering.redirectTo("/#/?_token=" + CodingUtil.base64Encoding(session.getId()) + "&_csrf=" + sessionUser.get().getCsrfToken())
.build(); .build();
} }

View File

@ -2,7 +2,6 @@ package io.metersphere.gateway.service;
import io.metersphere.base.domain.AuthSource; import io.metersphere.base.domain.AuthSource;
import io.metersphere.base.domain.User; import io.metersphere.base.domain.User;
import io.metersphere.commons.constants.UserSource;
import io.metersphere.commons.exception.MSException; import io.metersphere.commons.exception.MSException;
import io.metersphere.commons.user.SessionUser; import io.metersphere.commons.user.SessionUser;
import io.metersphere.commons.utils.CodingUtil; import io.metersphere.commons.utils.CodingUtil;
@ -61,7 +60,7 @@ public class SSOService {
@Resource @Resource
private UserLoginService userLoginService; private UserLoginService userLoginService;
public void exchangeToken(String code, String authId, WebSession session, Locale locale) throws Exception { public Optional<SessionUser> exchangeToken(String code, String authId, WebSession session, Locale locale) throws Exception {
AuthSource authSource = authSourceService.getAuthSource(authId); AuthSource authSource = authSourceService.getAuthSource(authId);
Map config = JSON.parseObject(authSource.getConfiguration(), Map.class); Map config = JSON.parseObject(authSource.getConfiguration(), Map.class);
String tokenUrl = (String) config.get("tokenUrl"); String tokenUrl = (String) config.get("tokenUrl");
@ -92,7 +91,7 @@ public class SSOService {
MSException.throwException(content); MSException.throwException(content);
} }
doOICDLogin(authSource, accessToken, session, locale); return doOICDLogin(authSource, accessToken, session, locale);
} }
private RestTemplate getRestTemplateIgnoreSSL() throws NoSuchAlgorithmException, KeyManagementException, KeyStoreException { private RestTemplate getRestTemplateIgnoreSSL() throws NoSuchAlgorithmException, KeyManagementException, KeyStoreException {
@ -118,7 +117,7 @@ public class SSOService {
return new RestTemplate(requestFactory); return new RestTemplate(requestFactory);
} }
private void doOICDLogin(AuthSource authSource, String accessToken, WebSession session, Locale locale) throws Exception { private Optional<SessionUser> doOICDLogin(AuthSource authSource, String accessToken, WebSession session, Locale locale) throws Exception {
Map config = JSON.parseObject(authSource.getConfiguration(), Map.class); Map config = JSON.parseObject(authSource.getConfiguration(), Map.class);
String userInfoUrl = (String) config.get("userInfoUrl"); String userInfoUrl = (String) config.get("userInfoUrl");
HttpHeaders headers = new HttpHeaders(); HttpHeaders headers = new HttpHeaders();
@ -157,12 +156,13 @@ public class SSOService {
session.getAttributes().put("authenticate", authSource.getType()); session.getAttributes().put("authenticate", authSource.getType());
session.getAttributes().put("authId", authSource.getId()); session.getAttributes().put("authId", authSource.getId());
session.getAttributes().put("user", userOptional.get()); session.getAttributes().put("user", userOptional.get());
return userOptional;
} }
/** /**
* cas callback * cas callback
*/ */
public void serviceValidate(String ticket, String authId, WebSession session, Locale locale) throws Exception { public Optional<SessionUser> serviceValidate(String ticket, String authId, WebSession session, Locale locale) throws Exception {
AuthSource authSource = authSourceService.getAuthSource(authId); AuthSource authSource = authSourceService.getAuthSource(authId);
Map config = JSON.parseObject(authSource.getConfiguration(), Map.class); Map config = JSON.parseObject(authSource.getConfiguration(), Map.class);
String redirectUrl = ((String) config.get("redirectUrl")).replace("${authId}", authId); String redirectUrl = ((String) config.get("redirectUrl")).replace("${authId}", authId);
@ -196,9 +196,12 @@ public class SSOService {
session.getAttributes().put("authenticate", authSource.getType()); session.getAttributes().put("authenticate", authSource.getType());
session.getAttributes().put("authId", authSource.getId()); session.getAttributes().put("authId", authSource.getId());
session.getAttributes().put("user", userOptional.get()); session.getAttributes().put("user", userOptional.get());
session.getAttributes().put("casTicket", ticket);
// 记录cas对应关系 // 记录cas对应关系
Long timeout = env.getProperty("spring.session.timeout", Long.class); Long timeout = env.getProperty("spring.session.timeout", Long.class);
stringRedisTemplate.opsForValue().set(ticket, name, timeout, TimeUnit.SECONDS); stringRedisTemplate.opsForValue().set(ticket, name, timeout, TimeUnit.SECONDS);
return userOptional;
} }
public void kickOutUser(String logoutToken) { public void kickOutUser(String logoutToken) {

View File

@ -8,7 +8,6 @@ import io.metersphere.commons.constants.UserStatus;
import io.metersphere.commons.exception.MSException; import io.metersphere.commons.exception.MSException;
import io.metersphere.commons.user.SessionUser; import io.metersphere.commons.user.SessionUser;
import io.metersphere.commons.utils.CodingUtil; import io.metersphere.commons.utils.CodingUtil;
import io.metersphere.commons.utils.SessionUtils;
import io.metersphere.dto.GroupResourceDTO; import io.metersphere.dto.GroupResourceDTO;
import io.metersphere.dto.UserDTO; import io.metersphere.dto.UserDTO;
import io.metersphere.dto.UserGroupPermissionDTO; import io.metersphere.dto.UserGroupPermissionDTO;
@ -18,6 +17,7 @@ import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.context.i18n.LocaleContextHolder; import org.springframework.context.i18n.LocaleContextHolder;
import org.springframework.session.FindByIndexNameSessionRepository;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.web.server.WebSession; import org.springframework.web.server.WebSession;
@ -55,7 +55,7 @@ public class UserLoginService {
userDTO = loginLocalMode(request.getUsername(), request.getPassword()); userDTO = loginLocalMode(request.getUsername(), request.getPassword());
break; break;
} }
autoSwitch(userDTO); autoSwitch(session, userDTO);
return Optional.of(SessionUser.fromUser(userDTO, session.getId())); return Optional.of(SessionUser.fromUser(userDTO, session.getId()));
} }
@ -84,7 +84,7 @@ public class UserLoginService {
return user; return user;
} }
public void autoSwitch(UserDTO user) { public void autoSwitch(WebSession session, UserDTO user) {
// 用户有 last_project_id 权限 // 用户有 last_project_id 权限
if (hasLastProjectPermission(user)) { if (hasLastProjectPermission(user)) {
return; return;
@ -94,7 +94,7 @@ public class UserLoginService {
return; return;
} }
// 判断其他权限 // 判断其他权限
checkNewWorkspaceAndProject(user); checkNewWorkspaceAndProject(session, user);
} }
private boolean hasLastProjectPermission(UserDTO user) { private boolean hasLastProjectPermission(UserDTO user) {
@ -162,7 +162,7 @@ public class UserLoginService {
return false; return false;
} }
private void checkNewWorkspaceAndProject(UserDTO user) { private void checkNewWorkspaceAndProject(WebSession session, UserDTO user) {
List<UserGroup> userGroups = user.getUserGroups(); List<UserGroup> userGroups = user.getUserGroups();
List<String> projectGroupIds = user.getGroups() List<String> projectGroupIds = user.getGroups()
.stream().filter(ug -> StringUtils.equals(ug.getType(), UserGroupType.PROJECT)) .stream().filter(ug -> StringUtils.equals(ug.getType(), UserGroupType.PROJECT))
@ -180,7 +180,7 @@ public class UserLoginService {
.collect(Collectors.toList()); .collect(Collectors.toList());
if (workspaces.size() > 0) { if (workspaces.size() > 0) {
String wsId = workspaces.get(0).getSourceId(); String wsId = workspaces.get(0).getSourceId();
switchUserResource("workspace", wsId, user); switchUserResource(session, "workspace", wsId, user);
} else { } else {
// 用户登录之后没有项目和工作空间的权限就把值清空 // 用户登录之后没有项目和工作空间的权限就把值清空
user.setLastWorkspaceId(""); user.setLastWorkspaceId("");
@ -200,7 +200,7 @@ public class UserLoginService {
} }
} }
public void switchUserResource(String sign, String sourceId, UserDTO sessionUser) { public void switchUserResource(WebSession session, String sign, String sourceId, UserDTO sessionUser) {
// 获取最新UserDTO // 获取最新UserDTO
UserDTO user = getUserDTO(sessionUser.getId()); UserDTO user = getUserDTO(sessionUser.getId());
User newUser = new User(); User newUser = new User();
@ -217,9 +217,11 @@ public class UserLoginService {
} }
BeanUtils.copyProperties(user, newUser); BeanUtils.copyProperties(user, newUser);
// 切换工作空间或组织之后更新 session 里的 user // 切换工作空间或组织之后更新 session 里的 user
SessionUtils.putUser(SessionUser.fromUser(user, SessionUtils.getSessionId())); session.getAttributes().put("user", SessionUser.fromUser(user, session.getId()));
session.getAttributes().put(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME, sessionUser.getId());
userMapper.updateByPrimaryKeySelective(newUser); userMapper.updateByPrimaryKeySelective(newUser);
} }
public UserDTO getLoginUser(String userId, List<String> list) { public UserDTO getLoginUser(String userId, List<String> list) {
UserExample example = new UserExample(); UserExample example = new UserExample();
example.createCriteria().andIdEqualTo(userId).andSourceIn(list); example.createCriteria().andIdEqualTo(userId).andSourceIn(list);
@ -381,9 +383,6 @@ public class UserLoginService {
} }
// 执行变更 // 执行变更
userMapper.updateByPrimaryKeySelective(user); userMapper.updateByPrimaryKeySelective(user);
if (StringUtils.equals(user.getStatus(), UserStatus.DISABLED)) {
SessionUtils.kickOutUser(user.getId());
}
} }
private List<Project> getProjectListByWsAndUserId(String userId, String workspaceId) { private List<Project> getProjectListByWsAndUserId(String userId, String workspaceId) {
@ -404,4 +403,20 @@ public class UserLoginService {
})); }));
return projectList; return projectList;
} }
public void validateCsrfToken(String sessionId, String csrfToken) {
if (StringUtils.isBlank(csrfToken)) {
throw new RuntimeException("csrf token is empty");
}
csrfToken = CodingUtil.aesDecrypt(csrfToken, SessionUser.secret, SessionUser.iv);
String[] signatureArray = StringUtils.split(StringUtils.trimToNull(csrfToken), "|");
if (signatureArray.length != 4) {
throw new RuntimeException("invalid token");
}
if (!StringUtils.equals(sessionId, signatureArray[2])) {
throw new RuntimeException("Please check csrf token.");
}
}
} }

View File

@ -3,7 +3,7 @@ import {$error} from "./message"
import {getCurrentProjectID, getCurrentWorkspaceId} from "../utils/token"; import {getCurrentProjectID, getCurrentWorkspaceId} from "../utils/token";
import {PROJECT_ID, TokenKey, WORKSPACE_ID} from "../utils/constants"; import {PROJECT_ID, TokenKey, WORKSPACE_ID} from "../utils/constants";
import packageJSON from '@/../package.json' import packageJSON from '@/../package.json'
import {getUUID} from "../utils"; import {getUrlParams, getUUID} from "../utils";
import {Base64} from "js-base64"; import {Base64} from "js-base64";
// baseURL 根据是否是独立运行修改 // baseURL 根据是否是独立运行修改
@ -21,6 +21,8 @@ if (window.location.pathname.startsWith('/' + packageJSON.name)) {
} }
} }
let urlParams = getUrlParams(window.location.href);
const instance = axios.create({ const instance = axios.create({
baseURL, // url = base url + request url baseURL, // url = base url + request url
withCredentials: true, withCredentials: true,
@ -39,12 +41,18 @@ instance.interceptors.request.use(
} }
// sso callback 过来的会有sessionId在url上 // sso callback 过来的会有sessionId在url上
if (!config.headers['X-AUTH-TOKEN']) { if (!config.headers['X-AUTH-TOKEN']) {
const paramsStr = window.location.href let sessionId = urlParams['_token']
let sessionId = paramsStr.split('_token=')[1]
if (sessionId) { if (sessionId) {
config.headers['X-AUTH-TOKEN'] = Base64.decode(sessionId); config.headers['X-AUTH-TOKEN'] = Base64.decode(sessionId);
} }
} }
// sso callback 过来的会有csrf在url上
if (!config.headers['CSRF-TOKEN']) {
let csrf = urlParams['_csrf']
if (csrf) {
config.headers['CSRF-TOKEN'] = csrf;
}
}
// 包含 工作空间 项目的标识 // 包含 工作空间 项目的标识
config.headers['WORKSPACE'] = getCurrentWorkspaceId(); config.headers['WORKSPACE'] = getCurrentWorkspaceId();
config.headers['PROJECT'] = getCurrentProjectID(); config.headers['PROJECT'] = getCurrentProjectID();

View File

@ -328,3 +328,14 @@ function _resizeTextarea(i, size, textareaList) {
export function checkMicroMode() { export function checkMicroMode() {
return sessionStorage.getItem("MICRO_MODE"); return sessionStorage.getItem("MICRO_MODE");
} }
export function getUrlParams(url) {
const arrSearch = url.split('?').pop().split('#').shift().split('&');
let obj = {};
arrSearch.forEach((item) => {
const [k, v] = item.split('=');
obj[k] = v;
return obj;
});
return obj;
}

View File

@ -59,7 +59,10 @@ public class SessionUtils {
} }
Map<String, ?> users = sessionRepository.findByPrincipalName(username); Map<String, ?> users = sessionRepository.findByPrincipalName(username);
if (MapUtils.isNotEmpty(users)) { if (MapUtils.isNotEmpty(users)) {
users.keySet().forEach(sessionRepository::deleteById); users.keySet().forEach(k -> {
sessionRepository.deleteById(k);
sessionRepository.getSessionRedisOperations().delete("spring:session:sessions:" + k);
});
} }
} }

View File

@ -87,7 +87,7 @@ public class LoginController {
@GetMapping(value = "/signout") @GetMapping(value = "/signout")
@MsAuditLog(module = OperLogModule.AUTH_TITLE, beforeEvent = "#msClass.getUserId(id)", type = OperLogConstants.LOGIN, title = "登出", msClass = SessionUtils.class) @MsAuditLog(module = OperLogModule.AUTH_TITLE, beforeEvent = "#msClass.getUserId(id)", type = OperLogConstants.LOGIN, title = "登出", msClass = SessionUtils.class)
public ResultHolder logout() throws Exception { public ResultHolder logout() throws Exception {
ssoLogoutService.logout(); ssoLogoutService.logout(SecurityUtils.getSubject().getSession());
SecurityUtils.getSubject().logout(); SecurityUtils.getSubject().logout();
return ResultHolder.success(StringUtils.EMPTY); return ResultHolder.success(StringUtils.EMPTY);
} }

View File

@ -8,6 +8,7 @@ import io.metersphere.commons.utils.JSON;
import io.metersphere.commons.utils.SessionUtils; import io.metersphere.commons.utils.SessionUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.shiro.SecurityUtils; import org.apache.shiro.SecurityUtils;
import org.apache.shiro.session.Session;
import org.springframework.data.redis.core.StringRedisTemplate; import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate; import org.springframework.web.client.RestTemplate;
@ -27,7 +28,7 @@ public class SSOLogoutService {
/** /**
* oidc logout * oidc logout
*/ */
public void logout() throws Exception { public void logout(Session session) throws Exception {
String authId = (String) SecurityUtils.getSubject().getSession().getAttribute("authId"); String authId = (String) SecurityUtils.getSubject().getSession().getAttribute("authId");
AuthSource authSource = authSourceMapper.selectByPrimaryKey(authId); AuthSource authSource = authSourceMapper.selectByPrimaryKey(authId);
if (authSource != null) { if (authSource != null) {
@ -38,6 +39,12 @@ public class SSOLogoutService {
restTemplate.getForEntity(logoutUrl + "?id_token_hint=" + idToken, String.class); restTemplate.getForEntity(logoutUrl + "?id_token_hint=" + idToken, String.class);
} }
if (StringUtils.equals(UserSource.CAS.name(), authSource.getType())) {
String casTicket = (String) session.getAttribute("casTicket");
if (StringUtils.isNotEmpty(casTicket)) {
stringRedisTemplate.delete(casTicket);
}
}
} }
} }