AES.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. // Copyright 2017 Dolphin Emulator Project
  2. // SPDX-License-Identifier: GPL-2.0-or-later
  3. #include "Common/Crypto/AES.h"
  4. #include <array>
  5. #include <bit>
  6. #include <memory>
  7. #include <mbedtls/aes.h>
  8. #include "Common/Assert.h"
  9. #include "Common/CPUDetect.h"
  10. #ifdef _MSC_VER
  11. #include <intrin.h>
  12. #else
  13. #if defined(_M_X86_64)
  14. #include <x86intrin.h>
  15. #elif defined(_M_ARM_64)
  16. #include <arm_acle.h>
  17. #include <arm_neon.h>
  18. #endif
  19. #endif
  20. #ifdef _MSC_VER
  21. #define ATTRIBUTE_TARGET(x)
  22. #else
  23. #define ATTRIBUTE_TARGET(x) [[gnu::target(x)]]
  24. #endif
  25. namespace Common::AES
  26. {
  27. // For x64 and arm64, it's very unlikely a user's cpu does not support the accelerated version,
  28. // fallback is just in case.
  29. template <Mode AesMode>
  30. class ContextGeneric final : public Context
  31. {
  32. public:
  33. ContextGeneric(const u8* key)
  34. {
  35. mbedtls_aes_init(&ctx);
  36. if constexpr (AesMode == Mode::Encrypt)
  37. ASSERT(!mbedtls_aes_setkey_enc(&ctx, key, 128));
  38. else
  39. ASSERT(!mbedtls_aes_setkey_dec(&ctx, key, 128));
  40. }
  41. virtual bool Crypt(const u8* iv, u8* iv_out, const u8* buf_in, u8* buf_out,
  42. size_t len) const override
  43. {
  44. std::array<u8, BLOCK_SIZE> iv_tmp{};
  45. if (iv)
  46. std::memcpy(&iv_tmp[0], iv, BLOCK_SIZE);
  47. constexpr int mode = (AesMode == Mode::Encrypt) ? MBEDTLS_AES_ENCRYPT : MBEDTLS_AES_DECRYPT;
  48. if (mbedtls_aes_crypt_cbc(const_cast<mbedtls_aes_context*>(&ctx), mode, len, &iv_tmp[0], buf_in,
  49. buf_out))
  50. return false;
  51. if (iv_out)
  52. std::memcpy(iv_out, &iv_tmp[0], BLOCK_SIZE);
  53. return true;
  54. }
  55. private:
  56. mbedtls_aes_context ctx{};
  57. };
  58. #if defined(_M_X86_64)
  59. // Note that (for instructions with same data width) the actual instructions emitted vary depending
  60. // on compiler and flags. The naming is somewhat confusing, because VAES cpuid flag was added after
  61. // VAES(VEX.128):
  62. // clang-format off
  63. // instructions | cpuid flag | #define
  64. // AES(128) | AES | -
  65. // VAES(VEX.128) | AES & AVX | __AVX__
  66. // VAES(VEX.256) | VAES | -
  67. // VAES(EVEX.128) | VAES & AVX512VL | __AVX512VL__
  68. // VAES(EVEX.256) | VAES & AVX512VL | __AVX512VL__
  69. // VAES(EVEX.512) | VAES & AVX512F | __AVX512F__
  70. // clang-format on
  71. template <Mode AesMode>
  72. class ContextAESNI final : public Context
  73. {
  74. static inline __m128i Aes128KeygenAssistFinish(__m128i key, __m128i kga)
  75. {
  76. __m128i tmp = _mm_shuffle_epi32(kga, _MM_SHUFFLE(3, 3, 3, 3));
  77. tmp = _mm_xor_si128(tmp, key);
  78. key = _mm_slli_si128(key, 4);
  79. tmp = _mm_xor_si128(tmp, key);
  80. key = _mm_slli_si128(key, 4);
  81. tmp = _mm_xor_si128(tmp, key);
  82. key = _mm_slli_si128(key, 4);
  83. tmp = _mm_xor_si128(tmp, key);
  84. return tmp;
  85. }
  86. template <size_t RoundIdx>
  87. ATTRIBUTE_TARGET("aes")
  88. inline constexpr void StoreRoundKey(__m128i rk)
  89. {
  90. if constexpr (AesMode == Mode::Encrypt)
  91. round_keys[RoundIdx] = rk;
  92. else
  93. {
  94. constexpr size_t idx = NUM_ROUND_KEYS - RoundIdx - 1;
  95. if constexpr (idx == 0 || idx == NUM_ROUND_KEYS - 1)
  96. round_keys[idx] = rk;
  97. else
  98. round_keys[idx] = _mm_aesimc_si128(rk);
  99. }
  100. }
  101. template <size_t RoundIdx, int Rcon>
  102. ATTRIBUTE_TARGET("aes")
  103. inline constexpr __m128i Aes128Keygen(__m128i rk)
  104. {
  105. rk = Aes128KeygenAssistFinish(rk, _mm_aeskeygenassist_si128(rk, Rcon));
  106. StoreRoundKey<RoundIdx>(rk);
  107. return rk;
  108. }
  109. public:
  110. ContextAESNI(const u8* key)
  111. {
  112. __m128i rk = _mm_loadu_si128((const __m128i*)key);
  113. StoreRoundKey<0>(rk);
  114. rk = Aes128Keygen<1, 0x01>(rk);
  115. rk = Aes128Keygen<2, 0x02>(rk);
  116. rk = Aes128Keygen<3, 0x04>(rk);
  117. rk = Aes128Keygen<4, 0x08>(rk);
  118. rk = Aes128Keygen<5, 0x10>(rk);
  119. rk = Aes128Keygen<6, 0x20>(rk);
  120. rk = Aes128Keygen<7, 0x40>(rk);
  121. rk = Aes128Keygen<8, 0x80>(rk);
  122. rk = Aes128Keygen<9, 0x1b>(rk);
  123. Aes128Keygen<10, 0x36>(rk);
  124. }
  125. ATTRIBUTE_TARGET("aes")
  126. inline void CryptBlock(__m128i* iv, const u8* buf_in, u8* buf_out) const
  127. {
  128. __m128i block = _mm_loadu_si128((const __m128i*)buf_in);
  129. if constexpr (AesMode == Mode::Encrypt)
  130. {
  131. block = _mm_xor_si128(_mm_xor_si128(block, *iv), round_keys[0]);
  132. for (size_t i = 1; i < Nr; ++i)
  133. block = _mm_aesenc_si128(block, round_keys[i]);
  134. block = _mm_aesenclast_si128(block, round_keys[Nr]);
  135. *iv = block;
  136. }
  137. else
  138. {
  139. __m128i iv_next = block;
  140. block = _mm_xor_si128(block, round_keys[0]);
  141. for (size_t i = 1; i < Nr; ++i)
  142. block = _mm_aesdec_si128(block, round_keys[i]);
  143. block = _mm_aesdeclast_si128(block, round_keys[Nr]);
  144. block = _mm_xor_si128(block, *iv);
  145. *iv = iv_next;
  146. }
  147. _mm_storeu_si128((__m128i*)buf_out, block);
  148. }
  149. // Takes advantage of instruction pipelining to parallelize.
  150. template <size_t NumBlocks>
  151. ATTRIBUTE_TARGET("aes")
  152. inline void DecryptPipelined(__m128i* iv, const u8* buf_in, u8* buf_out) const
  153. {
  154. constexpr size_t Depth = NumBlocks;
  155. __m128i block[Depth];
  156. for (size_t d = 0; d < Depth; d++)
  157. block[d] = _mm_loadu_si128(&((const __m128i*)buf_in)[d]);
  158. __m128i iv_next[1 + Depth];
  159. iv_next[0] = *iv;
  160. for (size_t d = 0; d < Depth; d++)
  161. iv_next[1 + d] = block[d];
  162. for (size_t d = 0; d < Depth; d++)
  163. block[d] = _mm_xor_si128(block[d], round_keys[0]);
  164. // The main speedup is here
  165. for (size_t i = 1; i < Nr; ++i)
  166. for (size_t d = 0; d < Depth; d++)
  167. block[d] = _mm_aesdec_si128(block[d], round_keys[i]);
  168. for (size_t d = 0; d < Depth; d++)
  169. block[d] = _mm_aesdeclast_si128(block[d], round_keys[Nr]);
  170. for (size_t d = 0; d < Depth; d++)
  171. block[d] = _mm_xor_si128(block[d], iv_next[d]);
  172. *iv = iv_next[1 + Depth - 1];
  173. for (size_t d = 0; d < Depth; d++)
  174. _mm_storeu_si128(&((__m128i*)buf_out)[d], block[d]);
  175. }
  176. virtual bool Crypt(const u8* iv, u8* iv_out, const u8* buf_in, u8* buf_out,
  177. size_t len) const override
  178. {
  179. if (len % BLOCK_SIZE)
  180. return false;
  181. __m128i iv_block = iv ? _mm_loadu_si128((const __m128i*)iv) : _mm_setzero_si128();
  182. if constexpr (AesMode == Mode::Decrypt)
  183. {
  184. // On amd zen2...(benchmark, not real-world):
  185. // With AES(128) instructions, BLOCK_DEPTH results in following speedup vs. non-pipelined: 4:
  186. // 18%, 8: 22%, 9: 26%, 10-15: 31%. 16: 8% (register exhaustion). With VAES(VEX.128), 10 gives
  187. // 36% speedup vs. its corresponding baseline. VAES(VEX.128) is ~4% faster than AES(128). The
  188. // result is similar on zen3.
  189. // Zen3 in general is 20% faster than zen2 in aes, and VAES(VEX.256) is 35% faster than
  190. // zen3/VAES(VEX.128).
  191. // It seems like VAES(VEX.256) should be faster?
  192. // TODO Choose value at runtime based on some criteria?
  193. constexpr size_t BLOCK_DEPTH = 10;
  194. constexpr size_t CHUNK_LEN = BLOCK_DEPTH * BLOCK_SIZE;
  195. while (len >= CHUNK_LEN)
  196. {
  197. DecryptPipelined<BLOCK_DEPTH>(&iv_block, buf_in, buf_out);
  198. buf_in += CHUNK_LEN;
  199. buf_out += CHUNK_LEN;
  200. len -= CHUNK_LEN;
  201. }
  202. }
  203. len /= BLOCK_SIZE;
  204. while (len--)
  205. {
  206. CryptBlock(&iv_block, buf_in, buf_out);
  207. buf_in += BLOCK_SIZE;
  208. buf_out += BLOCK_SIZE;
  209. }
  210. if (iv_out)
  211. _mm_storeu_si128((__m128i*)iv_out, iv_block);
  212. return true;
  213. }
  214. private:
  215. // Ensures alignment specifiers are respected.
  216. struct XmmReg
  217. {
  218. __m128i data;
  219. XmmReg& operator=(const __m128i& m)
  220. {
  221. data = m;
  222. return *this;
  223. }
  224. operator __m128i() const { return data; }
  225. };
  226. std::array<XmmReg, NUM_ROUND_KEYS> round_keys;
  227. };
  228. #endif
  229. #if defined(_M_ARM_64)
  230. template <Mode AesMode>
  231. class ContextNeon final : public Context
  232. {
  233. public:
  234. template <size_t RoundIdx>
  235. inline constexpr void StoreRoundKey(const u32* rk)
  236. {
  237. const uint8x16_t rk_block = vreinterpretq_u8_u32(vld1q_u32(rk));
  238. if constexpr (AesMode == Mode::Encrypt)
  239. round_keys[RoundIdx] = rk_block;
  240. else
  241. {
  242. constexpr size_t idx = NUM_ROUND_KEYS - RoundIdx - 1;
  243. if constexpr (idx == 0 || idx == NUM_ROUND_KEYS - 1)
  244. round_keys[idx] = rk_block;
  245. else
  246. round_keys[idx] = vaesimcq_u8(rk_block);
  247. }
  248. }
  249. ContextNeon(const u8* key)
  250. {
  251. constexpr u8 rcon[]{0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36};
  252. std::array<u32, Nb * NUM_ROUND_KEYS> rk{};
  253. // This uses a nice trick I've seen in wolfssl (not sure original author),
  254. // which uses vaeseq_u8 to assist keygen.
  255. // vaeseq_u8: op1 = SubBytes(ShiftRows(AddRoundKey(op1, op2)))
  256. // given RotWord == ShiftRows for row 1 (rol(x,8))
  257. // Probably not super fast (moves to/from vector regs constantly), but it is nice and simple.
  258. std::memcpy(&rk[0], key, KEY_SIZE);
  259. StoreRoundKey<0>(&rk[0]);
  260. for (size_t i = 0; i < rk.size() - Nk; i += Nk)
  261. {
  262. const uint8x16_t enc = vaeseq_u8(vreinterpretq_u8_u32(vmovq_n_u32(rk[i + 3])), vmovq_n_u8(0));
  263. const u32 temp = vgetq_lane_u32(vreinterpretq_u32_u8(enc), 0);
  264. rk[i + 4] = rk[i + 0] ^ std::rotr(temp, 8) ^ rcon[i / Nk];
  265. rk[i + 5] = rk[i + 4] ^ rk[i + 1];
  266. rk[i + 6] = rk[i + 5] ^ rk[i + 2];
  267. rk[i + 7] = rk[i + 6] ^ rk[i + 3];
  268. // clang-format off
  269. // Not great
  270. const size_t rki = 1 + i / Nk;
  271. switch (rki)
  272. {
  273. case 1: StoreRoundKey< 1>(&rk[Nk * rki]); break;
  274. case 2: StoreRoundKey< 2>(&rk[Nk * rki]); break;
  275. case 3: StoreRoundKey< 3>(&rk[Nk * rki]); break;
  276. case 4: StoreRoundKey< 4>(&rk[Nk * rki]); break;
  277. case 5: StoreRoundKey< 5>(&rk[Nk * rki]); break;
  278. case 6: StoreRoundKey< 6>(&rk[Nk * rki]); break;
  279. case 7: StoreRoundKey< 7>(&rk[Nk * rki]); break;
  280. case 8: StoreRoundKey< 8>(&rk[Nk * rki]); break;
  281. case 9: StoreRoundKey< 9>(&rk[Nk * rki]); break;
  282. case 10: StoreRoundKey<10>(&rk[Nk * rki]); break;
  283. }
  284. // clang-format on
  285. }
  286. }
  287. inline void CryptBlock(uint8x16_t* iv, const u8* buf_in, u8* buf_out) const
  288. {
  289. uint8x16_t block = vld1q_u8(buf_in);
  290. if constexpr (AesMode == Mode::Encrypt)
  291. {
  292. block = veorq_u8(block, *iv);
  293. for (size_t i = 0; i < Nr - 1; ++i)
  294. block = vaesmcq_u8(vaeseq_u8(block, round_keys[i]));
  295. block = vaeseq_u8(block, round_keys[Nr - 1]);
  296. block = veorq_u8(block, round_keys[Nr]);
  297. *iv = block;
  298. }
  299. else
  300. {
  301. uint8x16_t iv_next = block;
  302. for (size_t i = 0; i < Nr - 1; ++i)
  303. block = vaesimcq_u8(vaesdq_u8(block, round_keys[i]));
  304. block = vaesdq_u8(block, round_keys[Nr - 1]);
  305. block = veorq_u8(block, round_keys[Nr]);
  306. block = veorq_u8(block, *iv);
  307. *iv = iv_next;
  308. }
  309. vst1q_u8(buf_out, block);
  310. }
  311. virtual bool Crypt(const u8* iv, u8* iv_out, const u8* buf_in, u8* buf_out,
  312. size_t len) const override
  313. {
  314. if (len % BLOCK_SIZE)
  315. return false;
  316. uint8x16_t iv_block = iv ? vld1q_u8(iv) : vmovq_n_u8(0);
  317. len /= BLOCK_SIZE;
  318. while (len--)
  319. {
  320. CryptBlock(&iv_block, buf_in, buf_out);
  321. buf_in += BLOCK_SIZE;
  322. buf_out += BLOCK_SIZE;
  323. }
  324. if (iv_out)
  325. vst1q_u8(iv_out, iv_block);
  326. return true;
  327. }
  328. private:
  329. std::array<uint8x16_t, NUM_ROUND_KEYS> round_keys;
  330. };
  331. #endif
  332. template <Mode AesMode>
  333. std::unique_ptr<Context> CreateContext(const u8* key)
  334. {
  335. if (cpu_info.bAES)
  336. {
  337. #if defined(_M_X86_64)
  338. #if defined(__AVX__)
  339. // If compiler enables AVX, the intrinsics will generate VAES(VEX.128) instructions.
  340. // In the future we may want to compile the code twice and explicitly override the compiler
  341. // flags. There doesn't seem to be much performance difference between AES(128) and
  342. // VAES(VEX.128) at the moment, though.
  343. if (cpu_info.bAVX)
  344. #endif
  345. return std::make_unique<ContextAESNI<AesMode>>(key);
  346. #elif defined(_M_ARM_64)
  347. return std::make_unique<ContextNeon<AesMode>>(key);
  348. #endif
  349. }
  350. return std::make_unique<ContextGeneric<AesMode>>(key);
  351. }
  352. std::unique_ptr<Context> CreateContextEncrypt(const u8* key)
  353. {
  354. return CreateContext<Mode::Encrypt>(key);
  355. }
  356. std::unique_ptr<Context> CreateContextDecrypt(const u8* key)
  357. {
  358. return CreateContext<Mode::Decrypt>(key);
  359. }
  360. // OFB encryption and decryption are the exact same. We don't encrypt though.
  361. void CryptOFB(const u8* key, const u8* iv, u8* iv_out, const u8* buf_in, u8* buf_out, size_t size)
  362. {
  363. mbedtls_aes_context aes_ctx;
  364. size_t iv_offset = 0;
  365. std::array<u8, 16> iv_tmp{};
  366. if (iv)
  367. std::memcpy(&iv_tmp[0], iv, 16);
  368. ASSERT(!mbedtls_aes_setkey_enc(&aes_ctx, key, 128));
  369. mbedtls_aes_crypt_ofb(&aes_ctx, size, &iv_offset, &iv_tmp[0], buf_in, buf_out);
  370. if (iv_out)
  371. std::memcpy(iv_out, &iv_tmp[0], 16);
  372. }
  373. } // namespace Common::AES