sm4.cpp 11 KB


  1. #include "sm4.hpp"
  2. #include <string>
  3. std::string sm4::base64_encode(const std::string& text)
  4. {
  5. const char* bytes_to_encode = text.c_str();
  6. int in_len = text.size();
  7. std::string ret;
  8. int i = 0;
  9. int j = 0;
  10. unsigned char char_array_3[3];
  11. unsigned char char_array_4[4];
  12. while (in_len--)
  13. {
  14. char_array_3[i++] = *(bytes_to_encode++);
  15. if (i == 3)
  16. {
  17. char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
  18. char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
  19. char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
  20. char_array_4[3] = char_array_3[2] & 0x3f;
  21. for (i = 0; (i < 4); i++)
  22. ret += base64_chars[char_array_4[i]];
  23. i = 0;
  24. }
  25. }
  26. if (i)
  27. {
  28. for (j = i; j < 3; j++)
  29. char_array_3[j] = '\0';
  30. char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
  31. char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
  32. char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
  33. char_array_4[3] = char_array_3[2] & 0x3f;
  34. for (j = 0; (j < i + 1); j++)
  35. ret += base64_chars[char_array_4[j]];
  36. while ((i++ < 3))
  37. ret += '=';
  38. }
  39. return ret;
  40. }
  41. std::string sm4::base64_decode(const std::string& encoded_string)
  42. {
  43. int in_len = encoded_string.size();
  44. int i = 0;
  45. int j = 0;
  46. int in_ = 0;
  47. unsigned char char_array_4[4], char_array_3[3];
  48. std::string ret;
  49. while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_]))
  50. {
  51. char_array_4[i++] = encoded_string[in_];
  52. in_++;
  53. if (i == 4)
  54. {
  55. for (i = 0; i < 4; i++)
  56. char_array_4[i] = base64_chars.find(char_array_4[i]);
  57. char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
  58. char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
  59. char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
  60. for (i = 0; (i < 3); i++)
  61. ret += char_array_3[i];
  62. i = 0;
  63. }
  64. }
  65. if (i)
  66. {
  67. for (j = i; j < 4; j++)
  68. char_array_4[j] = 0;
  69. for (j = 0; j < 4; j++)
  70. char_array_4[j] = base64_chars.find(char_array_4[j]);
  71. char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
  72. char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
  73. char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
  74. for (j = 0; (j < i - 1); j++)
  75. ret += char_array_3[j];
  76. }
  77. return ret;
  78. }
  79. std::string sm4::HexToStr(std::string str)
  80. {
  81. std::string hex = str;
  82. long len = hex.length();
  83. std::string newString;
  84. for (long i = 0; i < len; i += 2)
  85. {
  86. std::string byte = hex.substr(i, 2);
  87. char chr = (char)(int)strtol(byte.c_str(), NULL, 16);
  88. newString.push_back(chr);
  89. }
  90. return newString;
  91. }
  92. std::string sm4::StrToHex(std::string str)
  93. {
  94. unsigned char c;
  95. char buf[3];
  96. std::string result = "";
  97. std::stringstream ss;
  98. ss << str;
  99. while (ss.read((char*)(&c), sizeof(c)))
  100. {
  101. snprintf(buf,3, "%02X", c);
  102. result += buf;
  103. }
  104. return result;
  105. }
  106. std::string sm4::PKCS7(std::string str)
  107. {
  108. if (str.size() < 16)
  109. {
  110. char ch = 16 - str.size();
  111. str.append((size_t)ch, ch);
  112. }
  113. return str;
  114. }
  115. std::string sm4::BinToHex(std::string str)
  116. { // 二进制转换为十六进制的函数实现
  117. std::string hex = "";
  118. int temp = 0;
  119. while (str.size() % 4 != 0)
  120. {
  121. str = "0" + str;
  122. }
  123. for (int i = 0; i < str.size(); i += 4)
  124. {
  125. temp = (str[i] - '0') * 8 + (str[i + 1] - '0') * 4 + (str[i + 2] - '0') * 2 + (str[i + 3] - '0') * 1;
  126. if (temp < 10)
  127. {
  128. hex += std::to_string(temp);
  129. }
  130. else
  131. {
  132. hex += 'A' + (temp - 10);
  133. }
  134. }
  135. return hex;
  136. }
  137. std::string sm4::HexToBin(std::string str)
  138. { // 十六进制转换为二进制的函数实现
  139. std::string bin = "";
  140. std::string table[16] = { "0000", "0001", "0010", "0011", "0100", "0101", "0110", "0111", "1000", "1001", "1010", "1011", "1100", "1101", "1110", "1111" };
  141. for (int i = 0; i < str.size(); i++)
  142. {
  143. if (str[i] >= 'A' && str[i] <= 'F')
  144. {
  145. bin += table[str[i] - 'A' + 10];
  146. }
  147. else if (str[i] >= 'a' && str[i] <= 'f')
  148. {
  149. bin += table[str[i] - 'a' + 10];
  150. }
  151. else
  152. {
  153. bin += table[str[i] - '0'];
  154. }
  155. }
  156. return bin;
  157. }
  158. int sm4::HexToDec(char str)
  159. { // 十六进制转换为十进制的函数实现
  160. int dec = 0;
  161. if (str >= 'A' && str <= 'F')
  162. {
  163. dec += (str - 'A' + 10);
  164. }
  165. else if (str >= 'a' && str <= 'f')
  166. {
  167. dec += (str - 'a' + 10);
  168. }
  169. else
  170. {
  171. dec += (str - '0');
  172. }
  173. return dec;
  174. }
  175. std::string sm4::LeftShift(std::string str, int len)
  176. { // 循环左移len位函数实现
  177. std::string res = HexToBin(str);
  178. res = res.substr(len) + res.substr(0, len);
  179. return BinToHex(res);
  180. }
  181. std::string sm4::XOR(std::string str1, std::string str2)
  182. { // 异或函数实现
  183. std::string res1 = HexToBin(str1);
  184. std::string res2 = HexToBin(str2);
  185. std::string res = "";
  186. for (int i = 0; i < res1.size(); i++)
  187. {
  188. if (res1[i] == res2[i])
  189. {
  190. res += "0";
  191. }
  192. else
  193. {
  194. res += "1";
  195. }
  196. }
  197. return BinToHex(res);
  198. }
  199. std::string sm4::NLTransform(std::string str)
  200. { // 非线性变换t函数实现
  201. std::string Sbox[16][16] = { {"D6", "90", "E9", "FE", "CC", "E1", "3D", "B7", "16", "B6", "14", "C2", "28", "FB", "2C", "05"},
  202. {"2B", "67", "9A", "76", "2A", "BE", "04", "C3", "AA", "44", "13", "26", "49", "86", "06", "99"},
  203. {"9C", "42", "50", "F4", "91", "EF", "98", "7A", "33", "54", "0B", "43", "ED", "CF", "AC", "62"},
  204. {"E4", "B3", "1C", "A9", "C9", "08", "E8", "95", "80", "DF", "94", "FA", "75", "8F", "3F", "A6"},
  205. {"47", "07", "A7", "FC", "F3", "73", "17", "BA", "83", "59", "3C", "19", "E6", "85", "4F", "A8"},
  206. {"68", "6B", "81", "B2", "71", "64", "DA", "8B", "F8", "EB", "0F", "4B", "70", "56", "9D", "35"},
  207. {"1E", "24", "0E", "5E", "63", "58", "D1", "A2", "25", "22", "7C", "3B", "01", "21", "78", "87"},
  208. {"D4", "00", "46", "57", "9F", "D3", "27", "52", "4C", "36", "02", "E7", "A0", "C4", "C8", "9E"},
  209. {"EA", "BF", "8A", "D2", "40", "C7", "38", "B5", "A3", "F7", "F2", "CE", "F9", "61", "15", "A1"},
  210. {"E0", "AE", "5D", "A4", "9B", "34", "1A", "55", "AD", "93", "32", "30", "F5", "8C", "B1", "E3"},
  211. {"1D", "F6", "E2", "2E", "82", "66", "CA", "60", "C0", "29", "23", "AB", "0D", "53", "4E", "6F"},
  212. {"D5", "DB", "37", "45", "DE", "FD", "8E", "2F", "03", "FF", "6A", "72", "6D", "6C", "5B", "51"},
  213. {"8D", "1B", "AF", "92", "BB", "DD", "BC", "7F", "11", "D9", "5C", "41", "1F", "10", "5A", "D8"},
  214. {"0A", "C1", "31", "88", "A5", "CD", "7B", "BD", "2D", "74", "D0", "12", "B8", "E5", "B4", "B0"},
  215. {"89", "69", "97", "4A", "0C", "96", "77", "7E", "65", "B9", "F1", "09", "C5", "6E", "C6", "84"},
  216. {"18", "F0", "7D", "EC", "3A", "DC", "4D", "20", "79", "EE", "5F", "3E", "D7", "CB", "39", "48"} };
  217. std::string res = "";
  218. for (int i = 0; i < 4; i++)
  219. {
  220. res = res + Sbox[HexToDec(str[2 * i])][HexToDec(str[2 * i + 1])];
  221. }
  222. return res;
  223. }
  224. std::string sm4::LTransform(std::string str)
  225. { // 线性变换L函数实现
  226. return XOR(XOR(XOR(XOR(str, LeftShift(str, 2)), LeftShift(str, 10)), LeftShift(str, 18)), LeftShift(str, 24));
  227. }
  228. std::string sm4::L2Transform(std::string str)
  229. { // 线性变换L'函数实现
  230. return XOR(XOR(str, LeftShift(str, 13)), LeftShift(str, 23));
  231. }
  232. std::string sm4::T(std::string str)
  233. { // 用于加解密算法中的合成置换T函数实现
  234. return LTransform(NLTransform(str));
  235. }
  236. std::string sm4::T2(std::string str)
  237. { // 用于密钥扩展算法中的合成置换T函数实现
  238. return L2Transform(NLTransform(str));
  239. }
  240. std::string sm4::KeyExtension(std::string MK)
  241. { // 密钥扩展函数实现
  242. std::string FK[4] = { "A3B1BAC6", "56AA3350", "677D9197", "B27022DC" };
  243. std::string CK[32] = { "00070E15", "1C232A31", "383F464D", "545B6269",
  244. "70777E85", "8C939AA1", "A8AFB6BD", "C4CBD2D9",
  245. "E0E7EEF5", "FC030A11", "181F262D", "343B4249",
  246. "50575E65", "6C737A81", "888F969D", "A4ABB2B9",
  247. "C0C7CED5", "DCE3EAF1", "F8FF060D", "141B2229",
  248. "30373E45", "4C535A61", "686F767D", "848B9299",
  249. "A0A7AEB5", "BCC3CAD1", "D8DFE6ED", "F4FB0209",
  250. "10171E25", "2C333A41", "484F565D", "646B7279" };
  251. std::string K[36] = { XOR(MK.substr(0, 8), FK[0]), XOR(MK.substr(8, 8), FK[1]), XOR(MK.substr(16, 8), FK[2]), XOR(MK.substr(24), FK[3]) };
  252. std::string rks = "";
  253. for (int i = 0; i < 32; i++)
  254. {
  255. K[i + 4] = XOR(K[i], T2(XOR(XOR(XOR(K[i + 1], K[i + 2]), K[i + 3]), CK[i])));
  256. rks += K[i + 4];
  257. }
  258. return rks;
  259. }
  260. std::string sm4::sm4_encode(std::string& hex32, std::string& key)
  261. { // 加密函数实现
  262. // cout << "轮密钥与每轮输出状态:" << endl;
  263. // cout << endl;
  264. std::string cipher[36] = { hex32.substr(0, 8), hex32.substr(8, 8), hex32.substr(16, 8), hex32.substr(24) };
  265. std::string rks = KeyExtension(key);
  266. for (int i = 0; i < 32; i++)
  267. {
  268. cipher[i + 4] = XOR(cipher[i], T(XOR(XOR(XOR(cipher[i + 1], cipher[i + 2]), cipher[i + 3]), rks.substr(8 * i, 8))));
  269. // cout << "rk[" + to_string(i) + "] = " + rks.substr(8 * i, 8) + " X[" + to_string(i) + "] = " + cipher[i + 4] << endl;
  270. }
  271. // cout << endl;
  272. return cipher[35] + cipher[34] + cipher[33] + cipher[32];
  273. }
  274. std::string sm4::sm4_decode(std::string& hex32, std::string& key)
  275. { // 解密函数实现
  276. // cout << "轮密钥与每轮输出状态:" << endl;
  277. // cout << endl;
  278. std::string plain[36] = { hex32.substr(0, 8), hex32.substr(8, 8), hex32.substr(16, 8), hex32.substr(24, 8) };
  279. std::string rks = KeyExtension(key);
  280. for (int i = 0; i < 32; i++)
  281. {
  282. plain[i + 4] = XOR(plain[i], T(XOR(XOR(XOR(plain[i + 1], plain[i + 2]), plain[i + 3]), rks.substr(8 * (31 - i), 8))));
  283. // cout << "rk[" + to_string(i) + "] = " + rks.substr(8 * (31 - i), 8) + " X[" + to_string(i) + "] = " + plain[i + 4] << endl;
  284. }
  285. // cout << endl;
  286. return plain[35] + plain[34] + plain[33] + plain[32];
  287. }
  288. // 针对字符串加密成十六进制数据,使用ZeroPadding策略,不足16字节的补0
  289. std::string sm4::sm4encodestrhex(std::string text, std::string key)
  290. {
  291. int pos = 0;
  292. std::string all_cipher;
  293. while (pos < text.size())
  294. {
  295. int hasLen = text.size() - pos;
  296. std::string hex32 = hasLen > 16 ? text.substr(pos, 16) : text.substr(pos);
  297. if (hex32.size() < 16)
  298. {
  299. hex32 = PKCS7(hex32);
  300. }
  301. hex32 = StrToHex(hex32);
  302. all_cipher += sm4_encode(hex32, key);
  303. if (hasLen == 16)
  304. {
  305. // 正好16字节时,要补充16位
  306. hex32 = PKCS7("");
  307. hex32 = StrToHex(hex32);
  308. all_cipher += sm4_encode(hex32, key);
  309. }
  310. pos += 16;
  311. }
  312. return all_cipher;
  313. }
  314. // 针对字符串加密成十六进制后的解密,加密时使用了ZeroPadding策略,不足16字节的补0
  315. std::string sm4::sm4decodehexstr(std::string cipher, std::string key)
  316. {
  317. int pos = 0;
  318. std::string all_plain;
  319. while (pos < cipher.size())
  320. {
  321. int hasLen = cipher.size() - pos;
  322. if (hasLen % 32 != 0)
  323. {
  324. break; // 忽略无法解码的部分
  325. }
  326. std::string text32 = hasLen > 32 ? cipher.substr(pos, 32) : cipher.substr(pos);
  327. std::string hex32 = sm4_decode(text32, key);
  328. std::string one_plain = HexToStr(hex32);
  329. if (hasLen == 32)
  330. {
  331. // 去掉填充的数据(有可能整个数据都是填充的(16位整数倍的字符串加密时)
  332. int size = one_plain.at(one_plain.size() - 1);
  333. if (size >= 16)
  334. {
  335. one_plain.clear();
  336. }
  337. else
  338. {
  339. one_plain = one_plain.substr(0, 16 - size);
  340. }
  341. }
  342. all_plain.append(one_plain.c_str()); // 按字符串拼接,就可以去掉后面的0
  343. pos += 32;
  344. }
  345. return all_plain;
  346. }
  347. std::string sm4::sm4encodestrbase64(std::string text, std::string key)
  348. {
  349. std::string hex = sm4encodestrhex(text, key);
  350. std::string str = HexToStr(hex);
  351. return base64_encode(str);
  352. }
  353. std::string sm4::sm4decodestrbase64(std::string base64_text, std::string key)
  354. {
  355. std::string str = base64_decode(base64_text);
  356. std::string hex = StrToHex(str);
  357. return sm4decodehexstr(hex, key);
  358. }