SAAS 系统按租户分库实现

SAAS 按租户分库方案

 

saas系统服务数据按不同商户分库是比较简单安全的方案,不同商户数据分库隔离后不存在访问数据跨表跨库的问题,根据不同商户的单量灵活配置,单量少的可以公用一个库,单量大的可以独立集群。 

 

第一步

实现spring 的AbstractRoutingDataSource 抽象类:

import org.springframework.jdbc.datasource.lookup.AbstractRoutingDataSource;

/**
 * Created by chenwenshun on 2018/12/12.
 */

public class RoutingDataSource extends AbstractRoutingDataSource{

    @Override
    protected Object determineCurrentLookupKey() {
        return RoutingDataSourceContextHolder.get();
    }
}

写一个线程容器保存每个请求的数据源信息:

/**
 * Created by chenwenshun on 2018/12/12.
 */
public class RoutingDataSourceContextHolder {


    private static final ThreadLocal<DataSourceEnum> threadlocalDataSourceKey = new ThreadLocal<>();

    public static  void set(DataSourceEnum key){
        threadlocalDataSourceKey.set(key);
    }

    public static DataSourceEnum get(){
         return threadlocalDataSourceKey.get();
    }



    public static void clear()  {
        threadlocalDataSourceKey.remove();
    }




}

第二步

修改application.yml数据源配置:

datasource:
  druid:
    url: jdbc:mysql://{ip}:6630/dict?useUnicode=true&characterEncoding=utf8
    driver-class: com.mysql.jdbc.Driver
    username: xxxxxx
    password: xxxxxx
    initial-size: 1
    min-idle: 1
    max-active: 20
    test-on-borrow: true
    filters: stat

datasource2:
  druid:
    url: jdbc:mysql://{ip}:6630/dict2?useUnicode=true&characterEncoding=utf8
    driver-class: com.mysql.jdbc.Driver
    username: xxxxxx
    password: xxxxxx
    initial-size: 1
    min-idle: 1
    max-active: 20
    test-on-borrow: true
    filters: stat

对应的java 配置:

    

/**
 * Created by chenwenshun on 2018/8/31.
 */
@Configuration
public class DataSourceConfig {

    @Bean
    @ConfigurationProperties(prefix = "datasource.druid")
    public DataSource dataSource_0(){
        return DataSourceBuilder.create()
                .type(DruidDataSource.class)
                .build();
    }

    @Bean
    @ConfigurationProperties(prefix = "datasource2.druid")
    public DataSource dataSource_1(){
        return DataSourceBuilder.create()
                .type(DruidDataSource.class)
                .build();
    }


    @Bean
    @Primary
    public DataSource RoutingDataSource(
            @Autowired @Qualifier("dataSource_0") DataSource dataSource_0,
            @Autowired @Qualifier("dataSource_1") DataSource dataSource_1
    ){
        Map<Object, Object> map = new HashMap<>();
        map.put(DataSourceEnum.DS_0, dataSource_0);
        map.put(DataSourceEnum.DS_1, dataSource_1);

        RoutingDataSource routingDataSource = new RoutingDataSource();
        routingDataSource.setTargetDataSources(map);
        routingDataSource.setDefaultTargetDataSource(dataSource_0);
        return routingDataSource;
    }

第三步

实现一个商户与数据的映射逻辑,接口类似定义如下:


/**
 * Created by chenwenshun on 2018/12/14.
 * 商户与数据源的映射关系
 * 具体项目不同实现
 * 如:用数据库配置或者apollo,等方式
 */

public interface DataSourceMapping {


    /**
     * 商户ID 返回对应的数据源,可以采取两种方式
     *
     * 1、通过具体的配置
     *
     * 2、通过自己实现路由算法返回对应的数据源
     * @param availableDataSources 可用数据源
     * @param shareValue 业务ID如:商户ID
     * @return 数据源标示
     */
    String getDataSource(List<String> availableDataSources, String shareValue);
}

第四步

通过切面拦截所有controller 请求,获取商户ID,然后更具商户ID 调用DataSourceMapping.getDataSource 获取不同商户对应的数据源, 设置到RoutingDataSourceContextHolder 线程变量:

/**
 * Created by chenwenshun on 2018/12/12.
 */
@Aspect
@Component
public class RoutingDataSourceAspect {


    @Autowired
    private DataSourceMapping dataSourceMapping;

    @Pointcut("execution(public * com.freemud.springbootdemo.controller.*.*(..))")
    public void point() {
    }




    @Before("point()")
    public void doBefore(JoinPoint joinPoint) throws ClassNotFoundException, NotFoundException {
        String classType = joinPoint.getTarget().getClass().getName();
//        Class<?> clazz = Class.forName(classType);
//        String clazzName = clazz.getName();
        String methodName = joinPoint.getSignature().getName(); //获取方法名称
        Object[] args = joinPoint.getArgs();//参数
        Map<String,Object> nameAndArgs = this.getFieldsName(this.getClass(),classType,methodName,args);

        String companyId = null;

        if (nameAndArgs.containsKey("companyId")){
            companyId = (String)nameAndArgs.get("companyId");
        }else if (nameAndArgs.containsKey("requestBody")){

             BaseRequest request = (BaseRequest)nameAndArgs.get("requestBody");
             companyId = request.getCompanyId();

        }else {
            BaseRequest request = (BaseRequest)Lists.newArrayList(nameAndArgs.values()).get(0);
            companyId = request.getCompanyId();
        }

        if (StringUtils.isBlank(companyId)){
            throw new  UnsupportedOperationException("companyId can not be null!");
        }

        String ds = DataSourceEnum.DS_0.name();
        if (dataSourceMapping != null){
            List<String> dataSourceList = Lists.newArrayList();
            for ( DataSourceEnum dataSourceEnum :DataSourceEnum.values() ) {
                dataSourceList.add( dataSourceEnum.name() );
            }

            ds = dataSourceMapping.getDataSource(dataSourceList , companyId) ;
        }

        RoutingDataSourceContextHolder.set(DataSourceEnum.valueOf(ds));
    }


    @After("point()")
    public void doAfter(){
        RoutingDataSourceContextHolder.clear();
    }


    private Map<String,Object> getFieldsName(Class cls, String clazzName, String methodName, Object[] args) throws NotFoundException {
        Map<String,Object > map=new LinkedHashMap<>();
        ClassPool pool = ClassPool.getDefault();
        ClassClassPath classPath = new ClassClassPath(cls);
        pool.insertClassPath(classPath);

        CtClass cc = pool.get(clazzName);
        CtMethod cm = cc.getDeclaredMethod(methodName);
        MethodInfo methodInfo = cm.getMethodInfo();
        CodeAttribute codeAttribute = methodInfo.getCodeAttribute();
        LocalVariableAttribute attr = (LocalVariableAttribute) codeAttribute.getAttribute(LocalVariableAttribute.tag);
        int pos = Modifier.isStatic(cm.getModifiers()) ? 0 : 1;
        for (int i = 0; i < cm.getParameterTypes().length; i++){
            map.put( attr.variableName(i + pos),args[i]);//paramNames即参数名
        }
        return map;
    }


}

over ^0^