123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396 |
- #include "sm4.hpp"
- #include <string>
- std::string sm4::base64_encode(const std::string& text)
- {
- const char* bytes_to_encode = text.c_str();
- int in_len = text.size();
- std::string ret;
- int i = 0;
- int j = 0;
- unsigned char char_array_3[3];
- unsigned char char_array_4[4];
- while (in_len--)
- {
- char_array_3[i++] = *(bytes_to_encode++);
- if (i == 3)
- {
- char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
- char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
- char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
- char_array_4[3] = char_array_3[2] & 0x3f;
- for (i = 0; (i < 4); i++)
- ret += base64_chars[char_array_4[i]];
- i = 0;
- }
- }
- if (i)
- {
- for (j = i; j < 3; j++)
- char_array_3[j] = '\0';
- char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
- char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
- char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
- char_array_4[3] = char_array_3[2] & 0x3f;
- for (j = 0; (j < i + 1); j++)
- ret += base64_chars[char_array_4[j]];
- while ((i++ < 3))
- ret += '=';
- }
- return ret;
- }
- std::string sm4::base64_decode(const std::string& encoded_string)
- {
- int in_len = encoded_string.size();
- int i = 0;
- int j = 0;
- int in_ = 0;
- unsigned char char_array_4[4], char_array_3[3];
- std::string ret;
- while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_]))
- {
- char_array_4[i++] = encoded_string[in_];
- in_++;
- if (i == 4)
- {
- for (i = 0; i < 4; i++)
- char_array_4[i] = base64_chars.find(char_array_4[i]);
- char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
- char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
- char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
- for (i = 0; (i < 3); i++)
- ret += char_array_3[i];
- i = 0;
- }
- }
- if (i)
- {
- for (j = i; j < 4; j++)
- char_array_4[j] = 0;
- for (j = 0; j < 4; j++)
- char_array_4[j] = base64_chars.find(char_array_4[j]);
- char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
- char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
- char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
- for (j = 0; (j < i - 1); j++)
- ret += char_array_3[j];
- }
- return ret;
- }
- std::string sm4::HexToStr(std::string str)
- {
- std::string hex = str;
- long len = hex.length();
- std::string newString;
- for (long i = 0; i < len; i += 2)
- {
- std::string byte = hex.substr(i, 2);
- char chr = (char)(int)strtol(byte.c_str(), NULL, 16);
- newString.push_back(chr);
- }
- return newString;
- }
- std::string sm4::StrToHex(std::string str)
- {
- unsigned char c;
- char buf[3];
- std::string result = "";
- std::stringstream ss;
- ss << str;
- while (ss.read((char*)(&c), sizeof(c)))
- {
- snprintf(buf,3, "%02X", c);
- result += buf;
- }
- return result;
- }
- std::string sm4::PKCS7(std::string str)
- {
- if (str.size() < 16)
- {
- char ch = 16 - str.size();
- str.append((size_t)ch, ch);
- }
- return str;
- }
- std::string sm4::BinToHex(std::string str)
- { // 二进制转换为十六进制的函数实现
- std::string hex = "";
- int temp = 0;
- while (str.size() % 4 != 0)
- {
- str = "0" + str;
- }
- for (int i = 0; i < str.size(); i += 4)
- {
- temp = (str[i] - '0') * 8 + (str[i + 1] - '0') * 4 + (str[i + 2] - '0') * 2 + (str[i + 3] - '0') * 1;
- if (temp < 10)
- {
- hex += std::to_string(temp);
- }
- else
- {
- hex += 'A' + (temp - 10);
- }
- }
- return hex;
- }
- std::string sm4::HexToBin(std::string str)
- { // 十六进制转换为二进制的函数实现
- std::string bin = "";
- std::string table[16] = { "0000", "0001", "0010", "0011", "0100", "0101", "0110", "0111", "1000", "1001", "1010", "1011", "1100", "1101", "1110", "1111" };
- for (int i = 0; i < str.size(); i++)
- {
- if (str[i] >= 'A' && str[i] <= 'F')
- {
- bin += table[str[i] - 'A' + 10];
- }
- else if (str[i] >= 'a' && str[i] <= 'f')
- {
- bin += table[str[i] - 'a' + 10];
- }
- else
- {
- bin += table[str[i] - '0'];
- }
- }
- return bin;
- }
- int sm4::HexToDec(char str)
- { // 十六进制转换为十进制的函数实现
- int dec = 0;
- if (str >= 'A' && str <= 'F')
- {
- dec += (str - 'A' + 10);
- }
- else if (str >= 'a' && str <= 'f')
- {
- dec += (str - 'a' + 10);
- }
- else
- {
- dec += (str - '0');
- }
- return dec;
- }
- std::string sm4::LeftShift(std::string str, int len)
- { // 循环左移len位函数实现
- std::string res = HexToBin(str);
- res = res.substr(len) + res.substr(0, len);
- return BinToHex(res);
- }
- std::string sm4::XOR(std::string str1, std::string str2)
- { // 异或函数实现
- std::string res1 = HexToBin(str1);
- std::string res2 = HexToBin(str2);
- std::string res = "";
- for (int i = 0; i < res1.size(); i++)
- {
- if (res1[i] == res2[i])
- {
- res += "0";
- }
- else
- {
- res += "1";
- }
- }
- return BinToHex(res);
- }
- std::string sm4::NLTransform(std::string str)
- { // 非线性变换t函数实现
- std::string Sbox[16][16] = { {"D6", "90", "E9", "FE", "CC", "E1", "3D", "B7", "16", "B6", "14", "C2", "28", "FB", "2C", "05"},
- {"2B", "67", "9A", "76", "2A", "BE", "04", "C3", "AA", "44", "13", "26", "49", "86", "06", "99"},
- {"9C", "42", "50", "F4", "91", "EF", "98", "7A", "33", "54", "0B", "43", "ED", "CF", "AC", "62"},
- {"E4", "B3", "1C", "A9", "C9", "08", "E8", "95", "80", "DF", "94", "FA", "75", "8F", "3F", "A6"},
- {"47", "07", "A7", "FC", "F3", "73", "17", "BA", "83", "59", "3C", "19", "E6", "85", "4F", "A8"},
- {"68", "6B", "81", "B2", "71", "64", "DA", "8B", "F8", "EB", "0F", "4B", "70", "56", "9D", "35"},
- {"1E", "24", "0E", "5E", "63", "58", "D1", "A2", "25", "22", "7C", "3B", "01", "21", "78", "87"},
- {"D4", "00", "46", "57", "9F", "D3", "27", "52", "4C", "36", "02", "E7", "A0", "C4", "C8", "9E"},
- {"EA", "BF", "8A", "D2", "40", "C7", "38", "B5", "A3", "F7", "F2", "CE", "F9", "61", "15", "A1"},
- {"E0", "AE", "5D", "A4", "9B", "34", "1A", "55", "AD", "93", "32", "30", "F5", "8C", "B1", "E3"},
- {"1D", "F6", "E2", "2E", "82", "66", "CA", "60", "C0", "29", "23", "AB", "0D", "53", "4E", "6F"},
- {"D5", "DB", "37", "45", "DE", "FD", "8E", "2F", "03", "FF", "6A", "72", "6D", "6C", "5B", "51"},
- {"8D", "1B", "AF", "92", "BB", "DD", "BC", "7F", "11", "D9", "5C", "41", "1F", "10", "5A", "D8"},
- {"0A", "C1", "31", "88", "A5", "CD", "7B", "BD", "2D", "74", "D0", "12", "B8", "E5", "B4", "B0"},
- {"89", "69", "97", "4A", "0C", "96", "77", "7E", "65", "B9", "F1", "09", "C5", "6E", "C6", "84"},
- {"18", "F0", "7D", "EC", "3A", "DC", "4D", "20", "79", "EE", "5F", "3E", "D7", "CB", "39", "48"} };
- std::string res = "";
- for (int i = 0; i < 4; i++)
- {
- res = res + Sbox[HexToDec(str[2 * i])][HexToDec(str[2 * i + 1])];
- }
- return res;
- }
- std::string sm4::LTransform(std::string str)
- { // 线性变换L函数实现
- return XOR(XOR(XOR(XOR(str, LeftShift(str, 2)), LeftShift(str, 10)), LeftShift(str, 18)), LeftShift(str, 24));
- }
- std::string sm4::L2Transform(std::string str)
- { // 线性变换L'函数实现
- return XOR(XOR(str, LeftShift(str, 13)), LeftShift(str, 23));
- }
- std::string sm4::T(std::string str)
- { // 用于加解密算法中的合成置换T函数实现
- return LTransform(NLTransform(str));
- }
- std::string sm4::T2(std::string str)
- { // 用于密钥扩展算法中的合成置换T函数实现
- return L2Transform(NLTransform(str));
- }
- std::string sm4::KeyExtension(std::string MK)
- { // 密钥扩展函数实现
- std::string FK[4] = { "A3B1BAC6", "56AA3350", "677D9197", "B27022DC" };
- std::string CK[32] = { "00070E15", "1C232A31", "383F464D", "545B6269",
- "70777E85", "8C939AA1", "A8AFB6BD", "C4CBD2D9",
- "E0E7EEF5", "FC030A11", "181F262D", "343B4249",
- "50575E65", "6C737A81", "888F969D", "A4ABB2B9",
- "C0C7CED5", "DCE3EAF1", "F8FF060D", "141B2229",
- "30373E45", "4C535A61", "686F767D", "848B9299",
- "A0A7AEB5", "BCC3CAD1", "D8DFE6ED", "F4FB0209",
- "10171E25", "2C333A41", "484F565D", "646B7279" };
- std::string K[36] = { XOR(MK.substr(0, 8), FK[0]), XOR(MK.substr(8, 8), FK[1]), XOR(MK.substr(16, 8), FK[2]), XOR(MK.substr(24), FK[3]) };
- std::string rks = "";
- for (int i = 0; i < 32; i++)
- {
- K[i + 4] = XOR(K[i], T2(XOR(XOR(XOR(K[i + 1], K[i + 2]), K[i + 3]), CK[i])));
- rks += K[i + 4];
- }
- return rks;
- }
- std::string sm4::sm4_encode(std::string& hex32, std::string& key)
- { // 加密函数实现
- // cout << "轮密钥与每轮输出状态:" << endl;
- // cout << endl;
- std::string cipher[36] = { hex32.substr(0, 8), hex32.substr(8, 8), hex32.substr(16, 8), hex32.substr(24) };
- std::string rks = KeyExtension(key);
- for (int i = 0; i < 32; i++)
- {
- cipher[i + 4] = XOR(cipher[i], T(XOR(XOR(XOR(cipher[i + 1], cipher[i + 2]), cipher[i + 3]), rks.substr(8 * i, 8))));
- // cout << "rk[" + to_string(i) + "] = " + rks.substr(8 * i, 8) + " X[" + to_string(i) + "] = " + cipher[i + 4] << endl;
- }
- // cout << endl;
- return cipher[35] + cipher[34] + cipher[33] + cipher[32];
- }
- std::string sm4::sm4_decode(std::string& hex32, std::string& key)
- { // 解密函数实现
- // cout << "轮密钥与每轮输出状态:" << endl;
- // cout << endl;
- std::string plain[36] = { hex32.substr(0, 8), hex32.substr(8, 8), hex32.substr(16, 8), hex32.substr(24, 8) };
- std::string rks = KeyExtension(key);
- for (int i = 0; i < 32; i++)
- {
- plain[i + 4] = XOR(plain[i], T(XOR(XOR(XOR(plain[i + 1], plain[i + 2]), plain[i + 3]), rks.substr(8 * (31 - i), 8))));
- // cout << "rk[" + to_string(i) + "] = " + rks.substr(8 * (31 - i), 8) + " X[" + to_string(i) + "] = " + plain[i + 4] << endl;
- }
- // cout << endl;
- return plain[35] + plain[34] + plain[33] + plain[32];
- }
- // 针对字符串加密成十六进制数据,使用ZeroPadding策略,不足16字节的补0
- std::string sm4::sm4encodestrhex(std::string text, std::string key)
- {
- int pos = 0;
- std::string all_cipher;
- while (pos < text.size())
- {
- int hasLen = text.size() - pos;
- std::string hex32 = hasLen > 16 ? text.substr(pos, 16) : text.substr(pos);
- if (hex32.size() < 16)
- {
- hex32 = PKCS7(hex32);
- }
- hex32 = StrToHex(hex32);
- all_cipher += sm4_encode(hex32, key);
- if (hasLen == 16)
- {
- // 正好16字节时,要补充16位
- hex32 = PKCS7("");
- hex32 = StrToHex(hex32);
- all_cipher += sm4_encode(hex32, key);
- }
- pos += 16;
- }
- return all_cipher;
- }
- // 针对字符串加密成十六进制后的解密,加密时使用了ZeroPadding策略,不足16字节的补0
- std::string sm4::sm4decodehexstr(std::string cipher, std::string key)
- {
- int pos = 0;
- std::string all_plain;
- while (pos < cipher.size())
- {
- int hasLen = cipher.size() - pos;
- if (hasLen % 32 != 0)
- {
- break; // 忽略无法解码的部分
- }
- std::string text32 = hasLen > 32 ? cipher.substr(pos, 32) : cipher.substr(pos);
- std::string hex32 = sm4_decode(text32, key);
- std::string one_plain = HexToStr(hex32);
- if (hasLen == 32)
- {
- // 去掉填充的数据(有可能整个数据都是填充的(16位整数倍的字符串加密时)
- int size = one_plain.at(one_plain.size() - 1);
- if (size >= 16)
- {
- one_plain.clear();
- }
- else
- {
- one_plain = one_plain.substr(0, 16 - size);
- }
- }
- all_plain.append(one_plain.c_str()); // 按字符串拼接,就可以去掉后面的0
- pos += 32;
- }
- return all_plain;
- }
- std::string sm4::sm4encodestrbase64(std::string text, std::string key)
- {
- std::string hex = sm4encodestrhex(text, key);
- std::string str = HexToStr(hex);
- return base64_encode(str);
- }
- std::string sm4::sm4decodestrbase64(std::string base64_text, std::string key)
- {
- std::string str = base64_decode(base64_text);
- std::string hex = StrToHex(str);
- return sm4decodehexstr(hex, key);
- }
|