项目中使用Spring Security,但是还想通过统一认证平台实现单点登录

最近接到一个新的需求,就是将原本的项目集成到公司的统一认证平台,实现单点登录。相比于单系统登录,SSO需要一个独立的认证中心,只有认证中心能接受用户的用户名密码等安全信息,其他系统不提供登录入口,只接受认证中心的间接授权。间接授权通过令牌实现,SSO认证中心验证用户的用户名密码没问题,创建授权令牌,在接下来的跳转过程中,授权令牌作为参数发送给各个子系统,子系统拿到令牌,即得到了授权,可以借此创建局部会话,局部会话登录方式与单系统的登录方式相同。这个过程,也就是单点登录的原理。

上面我大致介绍了一下单点登录原理,但是我们系统再实现单点登录的前提下,原本的登录方式也要保留。按照常理来说也不麻烦,将原本的登录页面保留,然后在统一认证平台注册应用id、应用secret、应用OAuth回调地址等(我们公司是基于OAuth的方式完成应用的注册),我们将OAuth回调地址配置为我们的登录接口,在登录接口做判断,只要是SSO跳转过来的帮用户自动登录,登录成功以后再重定向我们系统的首页。

但因为我们系统是使用Spring Security做的权限认证,因此导致最初是现实走了点弯路,特此记录一下。

首先我先讲一下我在实现该功能遇到的阻塞

1、Spring Security登录方式只能通过POST方式,但OAuth回调地址都是GET请求。

2、如果我通过SSO获取到用户信息以后怎样登录授权。

3、通过SSO进入我们系统的用户需要保存用户信息(因为我们系统的登录页面也是支持账号密码登录)。

目前我的解决方案

1、新增一个OAuth回调地址,并且不需要鉴权

 @Override
    public void configure(WebSecurity web) throws Exception {
        // 该接口就是我配置的OAuth回调地址、不需要授权
        web.ignoring().antMatchers("/configplatform/admin/sso/login");
        web.ignoring().antMatchers("/swagger-ui.html");
        web.ignoring().antMatchers("/swagger-resources");
        web.ignoring().antMatchers("/swagger-resources/configuration/ui");
        web.ignoring().antMatchers("/favicon.ico");
    }

2、就是通过code获取accessToken

 /**
     * 通过sso返回的code获取token
     *
     * @param code code
     * @return accessToken
     */
    private String getTokenByCode(String code) {
        Map<String, String> params = new HashMap<>(8);
        params.put(CODE, code);
        params.put(CLIENT_ID, ssoConfig.getClientId());
        params.put(CLIENT_SECRET, ssoConfig.getClientSecret());
        params.put(REDIRECT_URI, ssoConfig.getRedirectUrlPrefix() + SSO_LOGIN);
        params.put(OAUTH_TIME_STAMP, String.valueOf(System.currentTimeMillis()));
        params.put(GRANT_TYPE, GRANT_TYPE_VALUE);
        log.info("通过code获取access_token 请求参数params{}", params);
        String url = SSO_PRE_URL + TOKEN;
        String accessToken = null;
        try {
            String response = HttpClientUtils.httpPostStr(url, 30000, params);
            JSONObject jsonObject = JSONObject.parseObject(response);
            accessToken = String.valueOf(jsonObject.get(ACCESS_TOKEN));
        } catch (Exception ex) {
            log.info("SSO获取access_token失败 ex:{}", ex);
            throw new BusinessException(ExceptionConstants.SYSTEM_ERROR, "SSO获取access_token失败");
        }
        if (StringUtils.isEmpty(accessToken)) {
            throw new BusinessException(ExceptionConstants.SYSTEM_ERROR, "SSO获取access_token失败");
        }
        return accessToken;
    }

3、通过accessToken获取用户信息

private SSOUserInfoDO getUserInfoByToken(String token) {
        String url = SSO_PRE_URL + USER_INFO + "?" + ACCESS_TOKEN + "=" + token;
        log.info("通过access_token获取用户信息 请求参数params{}", url);
        SSOUserInfoDO ssoUserInfoDO = null;
        try {
            String response = HttpClientUtils.httpGet(url, 30000, null);
            ssoUserInfoDO = JSONObject.parseObject(response, SSOUserInfoDO.class);
        } catch (Exception ex) {
            log.info("SSO获取用户信息失败 ex:{}", ex);
            throw new BusinessException(ExceptionConstants.SYSTEM_ERROR, "SSO获取用户信息失败");
        }
        if (ssoUserInfoDO == null) {
            throw new BusinessException(ExceptionConstants.SYSTEM_ERROR, "SSO获取用户信息失败");
        }
        return ssoUserInfoDO;
    }

4、判断用户是否存在我们系统

private SysUserVO isUserExist(SSOUserInfoDO ssoUserInfoDO) {
        SysUserVO sysUserVO = sysUserService.getUserByAccount(ssoUserInfoDO.getSsoaccount());
           // 如果不存在新增用户
        if (sysUserVO == null) {
            sysUserVO = new SysUserVO();
            sysUserVO.setId(KeyUtil.getUUIDKey());
            sysUserVO.setAccountName(ssoUserInfoDO.getSsoaccount());
            sysUserVO.setUserName(ssoUserInfoDO.getYumLocalName());
            sysUserVO.setPassWord(ssoUserInfoDO.getSsoaccount());
            sysUserVO.setGender(" ");
            sysUserVO.setRoleId(" ");
            sysUserVO.setUserOnOffFlag(0);
             //新增用户
            sysUserService.addUser(sysUserVO);
        }
        return sysUserVO;
    }

5、通过用户名密码登陆

  @GetMapping("/sso/login")
    public void login(String code, HttpServletResponse response) throws IOException {
        // 获取code换token
        String accessToken = getTokenByCode(code);
        // 通过token获取 用户信息
        SSOUserInfoDO ssoUserInfoDO = getUserInfoByToken(accessToken);
        // 通过账号查询该用户是否存在我们系统
        SysUserVO sysUserVO = dealWithUserInfo(ssoUserInfoDO);
        String username = sysUserVO.getAccountName();
        String password = sysUserVO.getPassWord();

        Map<String, String> header = new HashMap<>(1);
        header.put(HTTP.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE);
        Map<String, String> params = new HashMap<>(3);
        params.put(USER_NAME, username);
        params.put(PASS_WORD, password);
        params.put(CODE, SSO);
        //该地址就是Spring Security配置的登陆地址,这样只要登陆成功,还是Spring Security帮我设置权限
        String urlLogin = ssoConfig.getRedirectUrlPrefix() + HttpConstants.FULL_LOGIN;
        log.info("SSO 登陆内部系统请求参数 params:{},url:{}", params, urlLogin);
        CloseableHttpResponse loginRes = HttpClientUtils.httpPost(urlLogin, 30000, params, header);
        //登陆成功以后设置cookies
        String jSessionId = getJSessionId(loginRes);
        log.info("sso认证完成,登陆系统获取的JSESSIONID{}", jSessionId);
        Cookie sessionId = getCookie(J_SESSION_ID, jSessionId);
        Cookie user = getCookie(USER_NAME, username);
        response.addCookie(sessionId);
        response.addCookie(user);
        //重定向地址 我们系统的首页
        response.sendRedirect(ssoConfig.getSsoSuccessPage());
    }

总结:上面代码都是伪代码,因为本人代码注释也是挺多的,大家应该可以看懂,如果还是存在疑问,可以把疑问留在评论中,我也会及时解答。

HttpClientUtils工具类代码

package com.yum.ec3.configplatform.management.util;

import java.io.IOException;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import javax.net.ssl.SSLContext;

import org.apache.http.HttpEntity;
import org.apache.http.NameValuePair;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.entity.UrlEncodedFormEntity;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.config.Registry;
import org.apache.http.config.RegistryBuilder;
import org.apache.http.conn.socket.ConnectionSocketFactory;
import org.apache.http.conn.socket.LayeredConnectionSocketFactory;
import org.apache.http.conn.socket.PlainConnectionSocketFactory;
import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.impl.conn.PoolingHttpClientConnectionManager;
import org.apache.http.message.BasicNameValuePair;
import org.apache.http.protocol.HTTP;
import org.apache.http.util.EntityUtils;
import org.springframework.http.MediaType;

import com.yum.ec3.base.exception.BusinessException;
import com.yum.ec3.platform.common.util.LogUtil;

/**
 * <httpClient 工具类> <功能详细描述>
 *
 * @see [相关类/方法]
 * @since [产品/模块版本]
 */

public final class HttpClientUtils {

    private static LogUtil log = LogUtil.getLogger(HttpClientUtils.class);

    private static final String UTF_8 = "UTF-8";

    private static final String HTTPS = "https";

    private static final String HTTP_T = "http";

    /**
     * 连接池初始化
     */
    private static PoolingHttpClientConnectionManager cm = null;

    /**
     * 初始化httpclient 连接池
     */
    static {
        LayeredConnectionSocketFactory sslsf = null;
        try {
            sslsf = new SSLConnectionSocketFactory(SSLContext.getDefault());
        } catch (NoSuchAlgorithmException e) {
            log.info(e.getMessage());
        }
        Registry<ConnectionSocketFactory> socketFactoryRegistry = RegistryBuilder.<ConnectionSocketFactory> create()
                .register(HTTPS, sslsf).register(HTTP_T, new PlainConnectionSocketFactory()).build();
        cm = new PoolingHttpClientConnectionManager(socketFactoryRegistry);
        cm.setMaxTotal(200);
        cm.setDefaultMaxPerRoute(2);
    }

    private HttpClientUtils() {
    }

    /**
     * <配置httpclient> <功能详细描述>
     *
     * @param timeout 请求超时时间
     * @return CloseableHttpClient
     * @see [类、类#方法、类#成员]
     */

    private static CloseableHttpClient getHttpClient(Integer timeout) {
        RequestConfig requestConfig = RequestConfig.custom().
        // 设置连接超时时间
                setConnectionRequestTimeout(timeout).
                // 设置请求超时时间
                setConnectTimeout(timeout).
                // 设置响应超时时间
                setSocketTimeout(timeout).build();
        // 超时重试,服务器丢失连接重试
        HttpRequestRetryHandlerImpl retry = new HttpRequestRetryHandlerImpl();
        return HttpClients.custom().setDefaultRequestConfig(requestConfig).setRetryHandler(retry)
                .setConnectionManager(cm).setConnectionManagerShared(true).build();
    }

    /**
     * http post <功能详细描述>
     *
     * @param url 请求的url
     * @param timeout 请求响应的时间
     * @param param 请求的参数
     * @param header 请求头
     * @return string
     * @see [类、类#方法、类#成员]
     */
    public static CloseableHttpResponse httpPost(String url, Integer timeout, Map<String, String> param,
            Map<String, String> header) {
        CloseableHttpResponse response = null;
        // 获取客户端连接对象
        CloseableHttpClient httpClient = getHttpClient(timeout);
        List<NameValuePair> nameValuePairs = new ArrayList<>();
        HttpPost httpPost = new HttpPost(url);
        if (null != param) {
            for (Map.Entry<String, String> entry : param.entrySet()) {
                nameValuePairs.add(new BasicNameValuePair(entry.getKey(), entry.getValue()));
            }
        }
        try {
            httpPost.setEntity(new UrlEncodedFormEntity(nameValuePairs, UTF_8));
            for (Map.Entry<String, String> entry : header.entrySet()) {
                httpPost.addHeader(entry.getKey(), entry.getValue());
            }
            // 执行请求
            response = httpClient.execute(httpPost);
        } catch (Exception e) {
            log.info("user httpPost request error", e);
            throw new BusinessException(500, e.getMessage());
        } finally {
            try {
                doClose(httpClient, response);
            } catch (IOException e) {
                log.error("httpPost close error", e);
            }
        }
        return response;
    }

    public static String httpPostStr(String url, Integer timeout, Map<String, String> param) {
        String msg = "";
        // 获取客户端连接对象
        CloseableHttpClient httpClient = getHttpClient(timeout);
        HttpPost httpPost = new HttpPost(url);
        // 设置提交方式
        httpPost.addHeader(HTTP.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE);
        List<NameValuePair> nameValuePairs = new ArrayList<>();
        if (null != param) {
            for (Map.Entry<String, String> entry : param.entrySet()) {
                nameValuePairs.add(new BasicNameValuePair(entry.getKey(), entry.getValue()));
            }
        }
        CloseableHttpResponse response = null;
        try {
            httpPost.setEntity(new UrlEncodedFormEntity(nameValuePairs, UTF_8));
            // 执行请求
            response = httpClient.execute(httpPost);
            // 获得响应的实体对象
            HttpEntity entity = response.getEntity();
            msg = EntityUtils.toString(entity, UTF_8);

        } catch (Exception e) {
            log.info("user httpPost request error", e);
            throw new BusinessException(500, e.getMessage());
        } finally {
            try {
                doClose(httpClient, response);
            } catch (IOException e) {
                log.error("httpPost close error", e);
            }
        }
        return msg;
    }

    /**
     * get请求
     *
     * @param url 请求地址
     * @param timeout 超时时间 单位秒
     * @param map map
     * @return string
     */
    public static String httpGet(String url, Integer timeout, Map<String, String> map) {
        String msg = "";
        CloseableHttpClient httpClient = getHttpClient(timeout);
        CloseableHttpResponse response = null;
        try {
            // 声明URIBuilder
            URIBuilder uriBuilder = new URIBuilder(url);
            // 判断参数map是否为非空
            if (map != null) {
                // 遍历参数
                for (Map.Entry<String, String> entry : map.entrySet()) {
                    // 设置参数
                    uriBuilder.setParameter(entry.getKey(), entry.getValue());
                }
            }
            // 2 创建httpGet对象,相当于设置url请求地址
            HttpGet httpGet = new HttpGet(uriBuilder.build());

            // 3 使用HttpClient执行httpGet,相当于按回车,发起请求
            response = httpClient.execute(httpGet);
            // 4 解析结果,封装返回对象httpResult,相当于显示相应的结果
            HttpEntity entity = response.getEntity();
            msg = EntityUtils.toString(entity, UTF_8);
        } catch (Exception e) {
            log.error("user httpGet request error", e);
            throw new BusinessException(500, e.getMessage());
        } finally {
            try {
                doClose(httpClient, response);
            } catch (IOException e) {
                log.error("httpGet close error", e);
            }
        }
        // 返回
        return msg;
    }

    /**
     * 对资源进行关闭
     *
     * @param httpClient httpClient
     * @param response response
     */
    private static void doClose(CloseableHttpClient httpClient, CloseableHttpResponse response) throws IOException {
        if (response != null) {
            response.close();
        }
        if (httpClient != null) {
            httpClient.close();
        }
    }
}
package com.yum.ec3.configplatform.management.util;

import java.io.IOException;
import java.net.SocketTimeoutException;

import com.yum.ec3.platform.common.util.LogUtil;
import org.apache.http.NoHttpResponseException;
import org.apache.http.client.HttpRequestRetryHandler;
import org.apache.http.conn.ConnectTimeoutException;
import org.apache.http.protocol.HttpContext;

/**
 * httpclient请求重试策略 httpclient请求重试策略配置类
 *
 * @see [相关类/方法]
 * @since [产品/模块版本]
 */

public class HttpRequestRetryHandlerImpl implements HttpRequestRetryHandler {

    private LogUtil log = LogUtil.getLogger(HttpRequestRetryHandlerImpl.class);

    private static final Integer COUNT = 2;

    /**
     * 实现HttpRequestRetryHandler {@inheritDoc}
     */
    @Override
    public boolean retryRequest(IOException exception, int i, HttpContext httpContext) {
        // 重试次数最大为3次
        if (i >= COUNT) {
            return false;
        }
        // 没有响应,重试
        if (exception instanceof NoHttpResponseException || exception instanceof ConnectTimeoutException
                || exception instanceof SocketTimeoutException) {
            return true;
            // 连接超时,重试
        }  else {
            // 自己新增log
            log.error("HttpRequestRetryHandlerImpl", exception);
            return false;
        }
    }

}


版权声明:本文为qqqqqqhhhhhh原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。