FrameSchedulerTests.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include "RHITestFixture.h"
  9. #include <Tests/Factory.h>
  10. #include <Tests/Device.h>
  11. #include <Atom/RHI/ScopeProducer.h>
  12. #include <Atom/RHI/FrameScheduler.h>
  13. #include <AzCore/Math/Random.h>
  14. #include <Atom/RHI/BufferPool.h>
  15. #include <Atom/RHI/ImagePool.h>
  16. #include <Atom/RHI/RHISystemInterface.h>
  17. namespace UnitTest
  18. {
  19. using namespace AZ;
  20. struct ImportedImage
  21. {
  22. RHI::AttachmentId m_id;
  23. RHI::Ptr<RHI::Image> m_image;
  24. };
  25. struct ImportedBuffer
  26. {
  27. RHI::AttachmentId m_id;
  28. RHI::Ptr<RHI::Buffer> m_buffer;
  29. };
  30. struct TransientImage
  31. {
  32. RHI::AttachmentId m_id;
  33. RHI::ImageDescriptor m_descriptor;
  34. };
  35. struct TransientBuffer
  36. {
  37. RHI::AttachmentId m_id;
  38. RHI::BufferDescriptor m_descriptor;
  39. };
  40. class ScopeProducer
  41. : public RHI::ScopeProducer
  42. {
  43. public:
  44. AZ_CLASS_ALLOCATOR(ScopeProducer, SystemAllocator);
  45. ScopeProducer(const RHI::ScopeId& scopeId)
  46. : RHI::ScopeProducer(scopeId)
  47. {}
  48. void SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph) override
  49. {
  50. RHI::FrameGraphAttachmentInterface attachmentDatabase = frameGraph.GetAttachmentDatabase();
  51. for (ImportedImage& image : m_imageImports)
  52. {
  53. ASSERT_FALSE(attachmentDatabase.IsAttachmentValid(image.m_id));
  54. attachmentDatabase.ImportImage(image.m_id, image.m_image);
  55. ASSERT_TRUE(attachmentDatabase.IsAttachmentValid(image.m_id));
  56. }
  57. for (ImportedBuffer& buffer : m_bufferImports)
  58. {
  59. ASSERT_FALSE(attachmentDatabase.IsAttachmentValid(buffer.m_id));
  60. attachmentDatabase.ImportBuffer(buffer.m_id, buffer.m_buffer);
  61. ASSERT_TRUE(attachmentDatabase.IsAttachmentValid(buffer.m_id));
  62. }
  63. for (const TransientImage& image : m_transientImages)
  64. {
  65. ASSERT_FALSE(attachmentDatabase.IsAttachmentValid(image.m_id));
  66. attachmentDatabase.CreateTransientImage(RHI::TransientImageDescriptor{image.m_id, image.m_descriptor});
  67. ASSERT_TRUE(attachmentDatabase.IsAttachmentValid(image.m_id));
  68. }
  69. for (const TransientBuffer& buffer : m_transientBuffers)
  70. {
  71. ASSERT_FALSE(attachmentDatabase.IsAttachmentValid(buffer.m_id));
  72. attachmentDatabase.CreateTransientBuffer(RHI::TransientBufferDescriptor{buffer.m_id, buffer.m_descriptor});
  73. ASSERT_TRUE(attachmentDatabase.IsAttachmentValid(buffer.m_id));
  74. }
  75. for (const ImageUsage& usage : m_imageUsages)
  76. {
  77. frameGraph.UseShaderAttachment(usage.m_descriptor, usage.m_access, RHI::ScopeAttachmentStage::AnyGraphics);
  78. }
  79. for (const BufferUsage& usage : m_bufferUsages)
  80. {
  81. frameGraph.UseShaderAttachment(usage.m_descriptor, usage.m_access, RHI::ScopeAttachmentStage::AnyGraphics);
  82. }
  83. frameGraph.SetEstimatedItemCount(0);
  84. }
  85. void CompileResources(const RHI::FrameGraphCompileContext& context) override
  86. {
  87. ASSERT_TRUE(context.GetScopeId() == GetScopeId());
  88. for (const ImageUsage& usage : m_imageUsages)
  89. {
  90. ASSERT_TRUE(context.GetImageView(usage.m_descriptor.m_attachmentId) != nullptr);
  91. }
  92. for (const BufferUsage& usage : m_bufferUsages)
  93. {
  94. ASSERT_TRUE(context.GetBufferView(usage.m_descriptor.m_attachmentId) != nullptr);
  95. }
  96. }
  97. void BuildCommandList(const RHI::FrameGraphExecuteContext& context) override
  98. {
  99. ASSERT_TRUE(context.GetScopeId() == GetScopeId());
  100. ASSERT_TRUE(context.GetCommandListIndex() == 0);
  101. ASSERT_TRUE(context.GetCommandListCount() == 1);
  102. }
  103. AZStd::vector<ImportedImage> m_imageImports;
  104. AZStd::vector<ImportedBuffer> m_bufferImports;
  105. AZStd::vector<TransientImage> m_transientImages;
  106. AZStd::vector<TransientBuffer> m_transientBuffers;
  107. struct ImageUsage
  108. {
  109. RHI::ImageScopeAttachmentDescriptor m_descriptor;
  110. RHI::ScopeAttachmentAccess m_access;
  111. };
  112. struct BufferUsage
  113. {
  114. RHI::BufferScopeAttachmentDescriptor m_descriptor;
  115. RHI::ScopeAttachmentAccess m_access;
  116. };
  117. AZStd::vector<ImageUsage> m_imageUsages;
  118. AZStd::vector<BufferUsage> m_bufferUsages;
  119. };
  120. class FrameSchedulerTests
  121. : public RHITestFixture
  122. {
  123. public:
  124. FrameSchedulerTests()
  125. : RHITestFixture()
  126. {
  127. }
  128. void SetUp() override
  129. {
  130. UnitTest::RHITestFixture::SetUp();
  131. m_rootFactory.reset(aznew Factory());
  132. m_rhiSystem.reset(aznew AZ::RHI::RHISystem);
  133. m_rhiSystem->InitDevices();
  134. m_rhiSystem->Init();
  135. m_device = AZ::RHI::RHISystemInterface::Get()->GetDevice(RHI::MultiDevice::DefaultDeviceIndex);
  136. m_state.reset(new State);
  137. {
  138. m_state->m_bufferPool = aznew RHI::BufferPool;
  139. RHI::BufferPoolDescriptor desc;
  140. desc.m_bindFlags = RHI::BufferBindFlags::ShaderReadWrite;
  141. desc.m_deviceMask = RHI::MultiDevice::DefaultDevice;
  142. m_state->m_bufferPool->Init(desc);
  143. }
  144. for (uint32_t i = 0; i < ImportedBufferCount; ++i)
  145. {
  146. RHI::Ptr<RHI::Buffer> buffer;
  147. buffer = aznew RHI::Buffer;
  148. RHI::BufferDescriptor desc;
  149. desc.m_bindFlags = RHI::BufferBindFlags::ShaderReadWrite;
  150. desc.m_byteCount = BufferSize;
  151. RHI::BufferInitRequest request;
  152. request.m_descriptor = desc;
  153. request.m_buffer = buffer.get();
  154. m_state->m_bufferPool->InitBuffer(request);
  155. m_state->m_bufferAttachments[i].m_id = RHI::AttachmentId{AZStd::string::format("B%d", i)};
  156. m_state->m_bufferAttachments[i].m_buffer = AZStd::move(buffer);
  157. }
  158. {
  159. m_state->m_imagePool = aznew RHI::ImagePool();
  160. RHI::ImagePoolDescriptor desc;
  161. desc.m_bindFlags = RHI::ImageBindFlags::ShaderReadWrite;
  162. m_state->m_imagePool->Init(desc);
  163. }
  164. for (uint32_t i = 0; i < ImportedImageCount; ++i)
  165. {
  166. RHI::Ptr<RHI::Image> image;
  167. image = aznew RHI::Image();
  168. RHI::ImageDescriptor desc = RHI::ImageDescriptor::Create2D(
  169. RHI::ImageBindFlags::ShaderReadWrite,
  170. ImageSize,
  171. ImageSize,
  172. RHI::Format::R8G8B8A8_UNORM);
  173. RHI::ImageInitRequest request;
  174. request.m_descriptor = desc;
  175. request.m_image = image.get();
  176. m_state->m_imagePool->InitImage(request);
  177. m_state->m_imageAttachments[i].m_id = RHI::AttachmentId{AZStd::string::format("I%d", i)};
  178. m_state->m_imageAttachments[i].m_image = AZStd::move(image);
  179. }
  180. for (uint32_t i = 0; i < ScopeCount; ++i)
  181. {
  182. m_state->m_producers.emplace_back(aznew ScopeProducer(RHI::ScopeId{AZStd::string::format("S%d", i)}));
  183. }
  184. }
  185. void TearDown() override
  186. {
  187. m_state.reset();
  188. m_device = nullptr;
  189. m_rhiSystem->Shutdown();
  190. m_rhiSystem.reset();
  191. m_rootFactory.reset();
  192. RHITestFixture::TearDown();
  193. }
  194. void Test()
  195. {
  196. RHI::FrameScheduler frameScheduler;
  197. RHI::FrameSchedulerDescriptor descriptor;
  198. descriptor.m_transientAttachmentPoolDescriptors[RHI::MultiDevice::DefaultDeviceIndex].m_bufferBudgetInBytes = 80 * 1024 * 1024;
  199. frameScheduler.Init(RHI::MultiDevice::DefaultDevice, descriptor);
  200. RHI::ImageScopeAttachmentDescriptor imageBindingDescs[2];
  201. imageBindingDescs[0].m_imageViewDescriptor = RHI::ImageViewDescriptor();
  202. imageBindingDescs[0].m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Clear;
  203. imageBindingDescs[0].m_loadStoreAction.m_clearValue = RHI::ClearValue::CreateVector4Float(1.0f, 0.0, 0.0, 0.0);
  204. imageBindingDescs[1] = imageBindingDescs[0];
  205. imageBindingDescs[1].m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
  206. RHI::BufferScopeAttachmentDescriptor bufferBindingDescs[2];
  207. bufferBindingDescs[0].m_bufferViewDescriptor = RHI::BufferViewDescriptor::CreateRaw(0, BufferSize);
  208. bufferBindingDescs[0].m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Clear;
  209. bufferBindingDescs[0].m_loadStoreAction.m_clearValue = RHI::ClearValue::CreateVector4Float(1.0f, 0.0, 0.0, 0.0);
  210. bufferBindingDescs[1] = bufferBindingDescs[0];
  211. bufferBindingDescs[1].m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
  212. AZ::SimpleLcgRandom random;
  213. struct Interval
  214. {
  215. uint32_t m_begin;
  216. uint32_t m_end;
  217. };
  218. Interval bufferScopeIntervals[BufferCount];
  219. for (uint32_t i = 0; i < BufferCount; ++i)
  220. {
  221. uint32_t b = random.GetRandom() % ScopeCount;
  222. uint32_t e = random.GetRandom() % ScopeCount;
  223. if (b > e)
  224. {
  225. AZStd::swap(b, e);
  226. }
  227. bufferScopeIntervals[i].m_begin = b;
  228. bufferScopeIntervals[i].m_end = e;
  229. }
  230. Interval imageScopeIntervals[ImageCount];
  231. for (uint32_t i = 0; i < ImageCount; ++i)
  232. {
  233. uint32_t b = random.GetRandom() % ScopeCount;
  234. uint32_t e = random.GetRandom() % ScopeCount;
  235. if (b > e)
  236. {
  237. AZStd::swap(b, e);
  238. }
  239. imageScopeIntervals[i].m_begin = b;
  240. imageScopeIntervals[i].m_end = e;
  241. }
  242. for (uint32_t scopeIdx = 0; scopeIdx < ScopeCount; ++scopeIdx)
  243. {
  244. ScopeProducer& producer = *m_state->m_producers[scopeIdx];
  245. //
  246. // IMPORTS
  247. //
  248. for (uint32_t i = 0; i < ImportedBufferCount; ++i)
  249. {
  250. if (scopeIdx == bufferScopeIntervals[i].m_begin)
  251. {
  252. producer.m_bufferImports.push_back(m_state->m_bufferAttachments[i]);
  253. bufferBindingDescs[0].m_attachmentId = m_state->m_bufferAttachments[i].m_id;
  254. producer.m_bufferUsages.push_back(ScopeProducer::BufferUsage{ bufferBindingDescs[0], RHI::ScopeAttachmentAccess::ReadWrite });
  255. }
  256. else if (scopeIdx == bufferScopeIntervals[i].m_end)
  257. {
  258. bufferBindingDescs[1].m_attachmentId = m_state->m_bufferAttachments[i].m_id;
  259. producer.m_bufferUsages.push_back(ScopeProducer::BufferUsage{ bufferBindingDescs[1], RHI::ScopeAttachmentAccess::Read });
  260. }
  261. }
  262. for (uint32_t i = 0; i < ImportedImageCount; ++i)
  263. {
  264. if (scopeIdx == imageScopeIntervals[i].m_begin)
  265. {
  266. producer.m_imageImports.push_back(m_state->m_imageAttachments[i]);
  267. imageBindingDescs[0].m_attachmentId = m_state->m_imageAttachments[i].m_id;
  268. producer.m_imageUsages.push_back(ScopeProducer::ImageUsage{ imageBindingDescs[0], RHI::ScopeAttachmentAccess::ReadWrite });
  269. }
  270. else if (scopeIdx == imageScopeIntervals[i].m_end)
  271. {
  272. imageBindingDescs[1].m_attachmentId = m_state->m_imageAttachments[i].m_id;
  273. producer.m_imageUsages.push_back(ScopeProducer::ImageUsage{ imageBindingDescs[1], RHI::ScopeAttachmentAccess::Read });
  274. }
  275. }
  276. //
  277. // TRANSIENTS
  278. //
  279. for (uint32_t i = 0; i < TransientBufferCount; ++i)
  280. {
  281. const uint32_t adjustedIndex = i + ImportedBufferCount;
  282. TransientBuffer transientBuffer =
  283. {
  284. RHI::AttachmentId{AZStd::string::format("B%d", adjustedIndex)},
  285. RHI::BufferDescriptor(RHI::BufferBindFlags::ShaderReadWrite, BufferSize)
  286. };
  287. bufferBindingDescs[0].m_attachmentId = transientBuffer.m_id;
  288. bufferBindingDescs[1].m_attachmentId = transientBuffer.m_id;
  289. if (scopeIdx == bufferScopeIntervals[adjustedIndex].m_begin)
  290. {
  291. producer.m_transientBuffers.push_back(transientBuffer);
  292. producer.m_bufferUsages.push_back(ScopeProducer::BufferUsage{ bufferBindingDescs[0], RHI::ScopeAttachmentAccess::ReadWrite });
  293. }
  294. else if (scopeIdx == bufferScopeIntervals[adjustedIndex].m_end)
  295. {
  296. producer.m_bufferUsages.push_back(ScopeProducer::BufferUsage{ bufferBindingDescs[1], RHI::ScopeAttachmentAccess::Read });
  297. }
  298. }
  299. for (uint32_t i = 0; i < TransientImageCount; ++i)
  300. {
  301. const uint32_t adjustedIndex = i + ImportedImageCount;
  302. TransientImage transientImage =
  303. {
  304. RHI::AttachmentId{AZStd::string::format("I%d", adjustedIndex)},
  305. RHI::ImageDescriptor::Create2D(RHI::ImageBindFlags::ShaderReadWrite, ImageSize, ImageSize, RHI::Format::R8G8B8A8_UNORM)
  306. };
  307. imageBindingDescs[0].m_attachmentId = transientImage.m_id;
  308. imageBindingDescs[1].m_attachmentId = transientImage.m_id;
  309. if (scopeIdx == imageScopeIntervals[adjustedIndex].m_begin)
  310. {
  311. producer.m_transientImages.push_back(transientImage);
  312. producer.m_imageUsages.push_back(ScopeProducer::ImageUsage{ imageBindingDescs[0], RHI::ScopeAttachmentAccess::ReadWrite });
  313. }
  314. else if (scopeIdx == imageScopeIntervals[adjustedIndex].m_end)
  315. {
  316. producer.m_imageUsages.push_back(ScopeProducer::ImageUsage{ imageBindingDescs[1], RHI::ScopeAttachmentAccess::Read });
  317. }
  318. }
  319. }
  320. for (uint32_t frameIdx = 0; frameIdx < FrameIterationCount; ++frameIdx)
  321. {
  322. frameScheduler.BeginFrame();
  323. for (AZStd::unique_ptr<ScopeProducer>& producer : m_state->m_producers)
  324. {
  325. frameScheduler.ImportScopeProducer(*producer);
  326. }
  327. RHI::FrameSchedulerCompileRequest compileRequest;
  328. compileRequest.m_jobPolicy = RHI::JobPolicy::Serial;
  329. frameScheduler.Compile(compileRequest);
  330. frameScheduler.Execute(RHI::JobPolicy::Serial);
  331. frameScheduler.EndFrame();
  332. }
  333. frameScheduler.Shutdown();
  334. }
  335. private:
  336. static const uint32_t FrameIterationCount = 128;
  337. static const uint32_t ImportedImageCount = 16;
  338. static const uint32_t ImportedBufferCount = 16;
  339. static const uint32_t TransientBufferCount = 16;
  340. static const uint32_t TransientImageCount = 16;
  341. static const uint32_t BufferCount = ImportedBufferCount + TransientBufferCount;
  342. static const uint32_t ImageCount = ImportedImageCount + TransientImageCount;
  343. static const uint32_t BufferSize = 64;
  344. static const uint32_t ImageSize = 16;
  345. static const uint32_t ScopeCount = 16;
  346. AZStd::unique_ptr<Factory> m_rootFactory;
  347. AZStd::unique_ptr<AZ::RHI::RHISystem> m_rhiSystem; //! Needed for the TransientAttachmentPool in the FrameScheduler
  348. RHI::Ptr<RHI::Device> m_device;
  349. struct State
  350. {
  351. RHI::Ptr<RHI::BufferPool> m_bufferPool;
  352. RHI::Ptr<RHI::ImagePool> m_imagePool;
  353. ImportedImage m_imageAttachments[ImportedImageCount];
  354. ImportedBuffer m_bufferAttachments[ImportedBufferCount];
  355. AZStd::vector<AZStd::unique_ptr<ScopeProducer>> m_producers;
  356. };
  357. AZStd::unique_ptr<State> m_state;
  358. };
  359. TEST_F(FrameSchedulerTests, Test)
  360. {
  361. Test();
  362. }
  363. }