基于Python的SM4ECB及CBC加密


最近在学习SM4算法,通过代码实现可以更好地理解算法的具体运算过程。

1.算法简述

SM4分为加解密算法和密钥拓展算法,简单地说就是将128比特数据分为四组,使用轮函数对其进行运算,密钥拓展算法用于生成轮密钥,当使用密钥拓展算法时,轮函数输入为(四组输入数据,固定参数CK),当使用加密算法时轮函数输入为(四组输入数据,轮密钥rk),解密算法与加密算法相同只是使用的轮密钥顺序相反。

在本文中,明文和密钥及IV的输入都为16进制数,如果需要加密字符串类型的明文,需要先将其转为16进制,但在本文中仅针对明文为128bit时进行了实现,如果小于或大于128bit需要进行相应的填充。

(1)异或:

异或可以表示为xor,在python中可以通过^实现,a与b的异或运算可以理解为(假设a和b为16进制),需注意python默认值是字符串比如,a = “01234567”,需要进行转码后进行运算。

a = int('01234567',16)
b = int('30',16)
print(bin(a^b))
#a的二进制值为0b1001000110100010101100111
                   #b的二进制值为0b110000
#输出的结果为:0b1001000110100010101010111

其实就是,首先将进行异或的数转为2进制,从低位开始比较,两者相同则该位改为0,两者不同该位改为1,如果a>b,那么将比较至b的最高位,剩余的未比较的位数值保持不变

(2)轮函数:

轮函数结构GB-T 32907中对轮函数的描述很好理解,其中T为合成置换,包括一次S盒置换后将输出作为线性变换L的输入。

(3)S盒置换:

SBOX = ['d6', '90', 'e9', 'fe', 'cc', 'e1', '3d', 'b7', '16', 'b6', '14', 'c2', '28', 'fb', '2c', '05',
       '2b', '67', '9a', '76', '2a', 'be', '04', 'c3', 'aa', '44', '13', '26', '49', '86', '06', '99',
       '9c', '42', '50', 'f4', '91', 'ef', '98', '7a', '33', '54', '0b', '43', 'ed', 'cf', 'ac', '62',
       'e4', 'b3', '1c', 'a9', 'c9', '08', 'e8', '95', '80', 'df', '94', 'fa', '75', '8f', '3f', 'a6',
       '47', '07', 'a7', 'fc', 'f3', '73', '17', 'ba', '83', '59', '3c', '19', 'e6', '85', '4f', 'a8',
       '68', '6b', '81', 'b2', '71', '64', 'da', '8b', 'f8', 'eb', '0f', '4b', '70', '56', '9d', '35',
       '1e', '24', '0e', '5e', '63', '58', 'd1', 'a2', '25', '22', '7c', '3b', '01', '21', '78', '87',
       'd4', '00', '46', '57', '9f', 'd3', '27', '52', '4c', '36', '02', 'e7', 'a0', 'c4', 'c8', '9e',
       'ea', 'bf', '8a', 'd2', '40', 'c7', '38', 'b5', 'a3', 'f7', 'f2', 'ce', 'f9', '61', '15', 'a1',
       'e0', 'ae', '5d', 'a4', '9b', '34', '1a', '55', 'ad', '93', '32', '30', 'f5', '8c', 'b1', 'e3',
       '1d', 'f6', 'e2', '2e', '82', '66', 'ca', '60', 'c0', '29', '23', 'ab', '0d', '53', '4e', '6f',
       'd5', 'db', '37', '45', 'de', 'fd', '8e', '2f', '03', 'ff', '6a', '72', '6d', '6c', '5b', '51',
       '8d', '1b', 'af', '92', 'bb', 'dd', 'bc', '7f', '11', 'd9', '5c', '41', '1f', '10', '5a', 'd8',
       '0a', 'c1', '31', '88', 'a5', 'cd', '7b', 'bd', '2d', '74', 'd0', '12', 'b8', 'e5', 'b4', 'b0',
       '89', '69', '97', '4a', '0c', '96', '77', '7e', '65', 'b9', 'f1', '09', 'c5', '6e', 'c6', '84',
       '18', 'f0', '7d', 'ec', '3a', 'dc', '4d', '20', '79', 'ee', '5f', '3e', 'd7', 'cb', '39', '48',]

GB-T 32907中给出的S盒数据为16进制,不过我们将在进行S盒置换时将数据转为10进制进行运算,所以不需要使用0X形式描写数据。

其中如果使用Oxd6形式描写数据,在转为10进制时直接使用int(0xd6)而不需添加进制参数,而使用“d6”描写数据时需要使用int(“d6”,16)。

(4)线性变换L:

线性变换
其中B<<<2 表示为对B进行32位左移2,可以理解位将高位的两位移动到低位,输入和输出都为32位。

2.密钥拓展算法

在一次针对128bit明文的加密过程中,密钥拓展算法将生成32个轮密钥。而密钥拓展算法的输入为128bit的密钥。

首先将16字节也就是128bit的密钥分为一个每4字节为一组的列表,这里可以使用如下方式进行分组,其中n为分组位数。

def group(list, n):
    for i in range(0, len(list), n):
        yield list[i:i + n]

使用方法为:(本文中明文和密钥的输入都是用标准中给出的示例,即’0123456789abcdeffedcba9876543210’)

MK = []
    for i in group('0123456789abcdeffedcba9876543210',8)
        MK.append(i)
>>>['01234567', '89abcdef', 'fedcba98', '76543210']

对MK和系统参数FK针对每个元素进行异或,这里使用了标准中给出的系统参数:

FK = ['a3b1bac6', '56aa3350', '677d9197', 'b27022dc']
def xor(a,b):#异或运算函数,返回8位16进制数
    a1 = int(a,16)
    b1 = int(b,16)
    if a == b:#这里出现的原因是,如果IV和明文完全相同时,异或结果为十进制0,但在运算中需要使用16字节数据,所以使用'{:032x}'.format()方法对数据进行转码
        A = '{:032x}'.format(int(a1^b1))
    else:
        A = '{:08x}'.format(int(a1^b1))
    return A
MK = []
    for i in group(key,8):
        MK.append(i)
    key0 = xor(MK[0],FK[0])
    key1 = xor(MK[1],FK[1])
    key2 = xor(MK[2],FK[2])
    key3 = xor(MK[3],FK[3])
    keylist = [key0,key1,key2,key3]

然后使用轮函数对密钥进行运算,迭代次数为32:

rk = []
    for i in range(32):
        a = round_function(keylist[i],keylist[i+1],keylist[i+2],keylist[i+3],CK[i],mod='extend')
        keylist.append(a)
        rk.append(a)

round_function为自定义的轮函数,将在下文中详细描述。

3.轮函数实现

文中的mod用于控制加密或解密及密钥拓展算法,加密模式为enc,解密模式为dec,密钥拓展算法的mod文中已进行了指定,使用时仅需调整enc或dec。

def round_function(k0,k1,k2,k3,rk,mod):
    k = xor(xor(xor(k1,k2),k3),rk)
    Tr = T(k,mod)
    rki = xor(k0,Tr)
    return rki

(1)T合成置换:

将S盒置换的结果作为线性变换的输入。

def T(A,mod):
    T = linear(S(A),mod)
    return T

(2)S盒置换实现:

S盒置换时需要注意,S盒中部分值高位为0,这里使用’{:02X}’.format()方法保留0,否则将在线性变换时由于位数不足报错,自己写的时候需要特别注意这一点,因为在python中如果十六进制数为03,某些转换进制的函数输出时将自动省略最高位的0,造成输出位数不足。

def S(A):
    A1 = []
    A2 = [0,0,0,0]
    for i in group(A,2):
        A1.append(i)
    for i in range(4):
        l = int(A1[i],16)
        A2[i] = '{:02x}'.format(int(SBOX[l],16)) 
    A2 = ''.join(A2)
    return A2

S盒置换的方法是将4字节输入数据按1字节分为4组,将每个元素转为十进制后提取其对应的S盒元素值,如1字节的30对应十进制值为48,则S盒置换结果为SBOX[48]。

(3)线性变换L实现:

线性变换的方法是将4字节输入分为8组,将每个元素都转为4位的二进制数,而后合并为一个32位的二进制数,然后进行32位循环左移,最后进行异或运算。
其中密钥拓展算法的线性变换方法和加解密算法的线性变换方法是不同的,如下:
加解密算法
密钥拓展算法
代码如下:

def left(list,n):#左移n位
  return list[n:] + list[:n]
def linear(B,mod):
    B1 = list(B)
    for i in range(8):
        B1[i] = '{:04b}'.format(int(B1[i],16))
    B1 = ''.join(B1)
    B1_2= left(B1,2)
    B1_10 = left(B1,10)
    B1_18 = left(B1,18)
    B1_24 = left(B1,24)
    B1_13 = left(B1,13)
    B1_23 = left(B1,23)
    if mod == 'enc' or mod ==  'dec':#加解密算法
        BX = xor(xor(xor(xor(B1,B1_2),B1_10),B1_18),B1_24)
    elif mod == 'extend':#密钥拓展算法
        BX = xor(xor(B1,B1_13),B1_23)
    else:
        return "模式输入错误"
    #本文的异或函数中,当输入为32位二进制数时,输出也为32位二进制数,所以进行16进制转码
    BX = '%x'%int(BX, 2)
    return BX

4.ECB加解密算法实现:

加解密算法包括32次迭代的轮函数运算,最后进行一次反序变换。
首先将16字节输入按4字节分为4组,将结果和使用密钥拓展算法生成的轮密钥作为输入使用轮函数进行运算,因为前文中已经给出了轮函数的方法,此处直接调用即可。

def get_sm4_ecb(key,input_data,mod):
    data = []
    rk = get_key(key)
    for i in group(input_data,8):
        data.append(i)
    for i in range(32):
        if mod == 'enc':
            ldata = round_function(data[i],data[i+1],data[i+2],data[i+3],rk[i],mod)
        else:
            ldata = round_function(data[i],data[i+1],data[i+2],data[i+3],rk[31-i],mod)
        data.append(ldata)
    out_data = [data[35],data[34],data[33],data[32]]#反序变换
    out_data = ''.join(out_data)
    return out_data

5.CBC加解密算法实现:

CBC加密和ECB加密的区别在于引入了初始化向量IV,使用方法是再进行运算时先使用IV和明文进行异或运算,然后将结果作为新的明文输入,为了保证安全IV应为随机数,此处仅为示例故未使用随机数,如果需要使用随机数可以使用random模块生成,方法为:

import random
str = ''
a=str.join(random.choice("0123456789abcdef") for i in range(32))
print(a)

此处需要注意的是,IV应为16字节,也就是与明文长度一致,如果长度不足,应通过补0的方式补足16字节,原因是当IV长度大于明文时,异或的结果跟IV长度一致,将会导致后续的分组函数报错。

def get_sm4_cbc(key,input_data,iv,mod):
    rk = get_key(key)
    if mod == 'enc':
        input_data = xor(input_data,iv)
        data = []
        for i in group(input_data,8):
            data.append(i)
        for i in range(32):
            ldata = round_function(data[i],data[i+1],data[i+2],data[i+3],rk[i],mod)
            data.append(ldata)
        out_data = [data[35],data[34],data[33],data[32]]
        out_data = ''.join(out_data)
    else:
        data = []
        for i in group(input_data,8):
            data.append(i)
        for i in range(32):
            ldata = round_function(data[i],data[i+1],data[i+2],data[i+3],rk[31-i],mod)
            data.append(ldata)
        out_data = [data[35],data[34],data[33],data[32]]
        out_data = ''.join(out_data)
        out_data = xor(out_data,iv)
    return out_data

CBC解密的方法是将密文输入进行解密运算,将最终结果和IV进行异或运算,输出即为原始明文。

6.完整代码:

在这里增加了一些数据循环处理以及填充的方法以适应需要进行多组分组的数据和长度不足的密钥。

SBOX = ['d6', '90', 'e9', 'fe', 'cc', 'e1', '3d', 'b7', '16', 'b6', '14', 'c2', '28', 'fb', '2c', '05',
        '2b', '67', '9a', '76', '2a', 'be', '04', 'c3', 'aa', '44', '13', '26', '49', '86', '06', '99',
        '9c', '42', '50', 'f4', '91', 'ef', '98', '7a', '33', '54', '0b', '43', 'ed', 'cf', 'ac', '62',
        'e4', 'b3', '1c', 'a9', 'c9', '08', 'e8', '95', '80', 'df', '94', 'fa', '75', '8f', '3f', 'a6',
        '47', '07', 'a7', 'fc', 'f3', '73', '17', 'ba', '83', '59', '3c', '19', 'e6', '85', '4f', 'a8',
        '68', '6b', '81', 'b2', '71', '64', 'da', '8b', 'f8', 'eb', '0f', '4b', '70', '56', '9d', '35',
        '1e', '24', '0e', '5e', '63', '58', 'd1', 'a2', '25', '22', '7c', '3b', '01', '21', '78', '87',
        'd4', '00', '46', '57', '9f', 'd3', '27', '52', '4c', '36', '02', 'e7', 'a0', 'c4', 'c8', '9e',
        'ea', 'bf', '8a', 'd2', '40', 'c7', '38', 'b5', 'a3', 'f7', 'f2', 'ce', 'f9', '61', '15', 'a1',
        'e0', 'ae', '5d', 'a4', '9b', '34', '1a', '55', 'ad', '93', '32', '30', 'f5', '8c', 'b1', 'e3',
        '1d', 'f6', 'e2', '2e', '82', '66', 'ca', '60', 'c0', '29', '23', 'ab', '0d', '53', '4e', '6f',
        'd5', 'db', '37', '45', 'de', 'fd', '8e', '2f', '03', 'ff', '6a', '72', '6d', '6c', '5b', '51',
        '8d', '1b', 'af', '92', 'bb', 'dd', 'bc', '7f', '11', 'd9', '5c', '41', '1f', '10', '5a', 'd8',
        '0a', 'c1', '31', '88', 'a5', 'cd', '7b', 'bd', '2d', '74', 'd0', '12', 'b8', 'e5', 'b4', 'b0',
        '89', '69', '97', '4a', '0c', '96', '77', '7e', '65', 'b9', 'f1', '09', 'c5', '6e', 'c6', '84',
        '18', 'f0', '7d', 'ec', '3a', 'dc', '4d', '20', '79', 'ee', '5f', '3e', 'd7', 'cb', '39', '48',]
FK = ['a3b1bac6', '56aa3350', '677d9197', 'b27022dc']
CK = ['00070e15', '1c232a31', '383f464d', '545b6269',
      '70777e85', '8c939aa1', 'a8afb6bd', 'c4cbd2d9',
      'e0e7eef5', 'fc030a11', '181f262d', '343b4249',
      '50575e65', '6c737a81', '888f969d', 'a4abb2b9',
      'c0c7ced5', 'dce3eaf1', 'f8ff060d', '141b2229',
      '30373e45', '4c535a61', '686f767d', '848b9299',
      'a0a7aeb5', 'bcc3cad1', 'd8dfe6ed', 'f4fb0209',
      '10171e25', '2c333a41', '484f565d', '646b7279']

class sm4(object):
    def __init__(self,mod,key,input_data,iv=None,padding=None):
        self.mod = mod
        if len(key) < 32:
            self.key = str(key).ljust(32,'0')
        else:self.key = key
        self.iv = iv
        self.input_data = input_data
        self.padding = padding
        self.input_data = self.p7p(self.input_data) if self.padding else input_data
    def left(self,list,n):
        return list[n:] + list[:n]
    def group(self,list, n):
        for i in range(0, len(list), n):
            yield list[i:i + n]
    def p7p(self,text):
        l = []
        [l.append(i) for i in self.group(text,2)]
        return text+"{:02x}".format(16-len(l)%16)*(16-len(l)%16)
    def xor(self,a,b):
        a1 = int(a,16)
        b1 = int(b,16)
        if  a == b:#len(a)>8 or len(b)>8: #a == b:
            A = '{:032x}'.format(int(a1^b1))
        else:
            A = '{:08x}'.format(int(a1^b1))
        return A
    def round_function(self,k0,k1,k2,k3,rk,mod):
        k = self.xor(self.xor(self.xor(k1,k2),k3),rk)
        Tr = self.T(k,mod)
        rki = self.xor(k0,Tr)
        return rki
    def T(self,A,mod):
        T = self.linear(self.S(A),mod)
        return T
    def S(self,A):
        A1 = []
        A2 = [0,0,0,0]
        for i in self.group(A,2):
            A1.append(i)
        for i in range(4):
            l = int(A1[i],16)
            A2[i] = '{:02x}'.format(int(SBOX[l],16))
        A2 = ''.join(A2)
        return A2
    def linear(self,B,mod):
        B1 = list(B)
        for i in range(8):
            B1[i] = '{:04b}'.format(int(B1[i],16))
        B1 = ''.join(B1)
        B1_2= self.left(B1,2)
        B1_10 = self.left(B1,10)
        B1_18 = self.left(B1,18)
        B1_24 = self.left(B1,24)
        B1_13 = self.left(B1,13)
        B1_23 = self.left(B1,23)
        if mod == 'enc' or mod ==  'dec':
            BX = self.xor(self.xor(self.xor(self.xor(B1,B1_2),B1_10),B1_18),B1_24)
        elif mod == 'extend':
            BX = self.xor(self.xor(B1,B1_13),B1_23)
        else:
            return "模式输入错误"
        BX = '%08x'%int(BX, 2)
        return BX
    def get_key(self,key):
        MK = []
        for i in self.group(key,8):
            MK.append(i)
        key0 = self.xor(MK[0],FK[0])
        key1 = self.xor(MK[1],FK[1])
        key2 = self.xor(MK[2],FK[2])
        key3 = self.xor(MK[3],FK[3])
        keylist = [key0,key1,key2,key3]
        rk = []
        for i in range(32):
            a = self.round_function(keylist[i],keylist[i+1],keylist[i+2],keylist[i+3],CK[i],mod='extend')
            keylist.append(a)
            rk.append(a)
        return rk
    def get_sm4_ecb(self):
        rk = self.get_key(self.key)
        l = []
        [l.append(i) for i in self.group(self.input_data, 32)]
        def fillecb():
            for i in range(len(l)):
                l[i] = ecb(l[i])
            out_data = ''.join(l)
            return out_data
        def ecb(input_data):
            data = []
            for i in self.group(input_data,8):
                data.append(i)
            for i in range(32):
                if self.mod == 'enc':
                    ldata = self.round_function(data[i],data[i+1],data[i+2],data[i+3],rk[i],self.mod)
                else:
                    ldata = self.round_function(data[i],data[i+1],data[i+2],data[i+3],rk[31-i],self.mod)
                data.append(ldata)
            out_data = [data[35],data[34],data[33],data[32]]
            out_data = ''.join(out_data)
            return out_data
        return fillecb()
    def get_sm4_cbc(self):
        rk = self.get_key(self.key)
        l = []
        l1 = []
        for i in self.group(self.input_data, 32):
            l.append(i)
            l1.append(i)
        def fillcbc():
            if self.mod == 'enc':
                for i in range(len(l)):
                    if i == 0:
                        l[i] = '{:032x}'.format(int(self.xor(l[i], self.iv), 16))
                    l[i] = cbc(l[i])
                    if len(l)-i > 1:
                        l[i+1] = '{:032x}'.format(int(self.xor(l[i],l[i+1]),16))
                out_data = ''.join(l)
            else:
                for i in range(len(l)):
                    if len(l) - i > 1:
                        l[i+1] = cbc(l[i+1])
                        l[i+1] = '{:032x}'.format(int(self.xor(l1[i],l[i+1]),16))
                    if i == 0:
                        l[i] = cbc(l[i])
                        l[i] = '{:032x}'.format(int(self.xor(l[i], self.iv),16))
                out_data = ''.join(l)
            return out_data
        def cbc(input_data):
            if self.mod == 'enc':
                data = []
                for i in self.group(input_data,8):
                    data.append(i)
                for i in range(32):
                    ldata = self.round_function(data[i],data[i+1],data[i+2],data[i+3],rk[i],self.mod)
                    data.append(ldata)
                out_data = [data[35],data[34],data[33],data[32]]
                out_data = ''.join(out_data)
            else:
                data = []
                for i in self.group(input_data,8):
                    data.append(i)
                for i in range(32):
                    ldata = self.round_function(data[i],data[i+1],data[i+2],data[i+3],rk[31-i],self.mod)
                    data.append(ldata)
                out_data = [data[35],data[34],data[33],data[32]]
                out_data = ''.join(out_data)
                out_data = '{:032x}'.format(int(out_data, 16))
            return out_data
        return fillcbc()

iv='0123456789abcdeffedcba9876543210'
plain = '681edf34d206965e86b3e94f536e4246'
key='0123456789abcdeffedcba9876543210'
en_data = sm4(mod='enc',key=key,input_data=plain,iv=iv,padding=None).get_sm4_cbc()
plain1 = sm4(mod='dec',key=key,input_data=en_data,iv=iv,padding=None).get_sm4_cbc()
print('en_data:%s'%en_data)
print('plain1:%s'%plain1)





版权声明:本文为weixin_46491183原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。