RayTracingAccelerationStructurePass.cpp 16 KB


  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/RHI/BufferFrameAttachment.h>
  9. #include <Atom/RHI/BufferScopeAttachment.h>
  10. #include <Atom/RHI/CommandList.h>
  11. #include <Atom/RHI/FrameScheduler.h>
  12. #include <Atom/RHI/RHISystemInterface.h>
  13. #include <Atom/RHI/ScopeProducerFunction.h>
  14. #include <Atom/RPI.Public/Buffer/Buffer.h>
  15. #include <Atom/RPI.Public/Buffer/BufferSystemInterface.h>
  16. #include <Atom/RPI.Public/RenderPipeline.h>
  17. #include <Atom/RPI.Public/Scene.h>
  18. #include <Mesh/MeshFeatureProcessor.h>
  19. #include <RayTracing/RayTracingAccelerationStructurePass.h>
  20. #include <RayTracing/RayTracingFeatureProcessor.h>
  21. namespace AZ
  22. {
  23. namespace Render
  24. {
  25. RPI::Ptr<RayTracingAccelerationStructurePass> RayTracingAccelerationStructurePass::Create(const RPI::PassDescriptor& descriptor)
  26. {
  27. RPI::Ptr<RayTracingAccelerationStructurePass> rayTracingAccelerationStructurePass = aznew RayTracingAccelerationStructurePass(descriptor);
  28. return AZStd::move(rayTracingAccelerationStructurePass);
  29. }
  30. RayTracingAccelerationStructurePass::RayTracingAccelerationStructurePass(const RPI::PassDescriptor& descriptor)
  31. : Pass(descriptor)
  32. {
  33. // disable this pass if we're on a platform that doesn't support raytracing
  34. if (RHI::RHISystemInterface::Get()->GetRayTracingSupport() == RHI::MultiDevice::NoDevices)
  35. {
  36. SetEnabled(false);
  37. }
  38. }
  39. void RayTracingAccelerationStructurePass::BuildInternal()
  40. {
  41. // [GFX TODO][ATOM-18111] Ideally, this would be done on the Compute queue, but that has multiple issues (see also 18305).
  42. auto deviceIndex = Pass::GetDeviceIndex();
  43. InitScope(
  44. RHI::ScopeId(AZStd::string(GetPathName().GetCStr() + AZStd::to_string(deviceIndex))),
  45. AZ::RHI::HardwareQueueClass::Graphics,
  46. deviceIndex);
  47. }
  48. void RayTracingAccelerationStructurePass::FrameBeginInternal(FramePrepareParams params)
  49. {
  50. if (IsTimestampQueryEnabled())
  51. {
  52. m_timestampResult = AZ::RPI::TimestampResult();
  53. }
  54. if (GetScopeId().IsEmpty())
  55. {
  56. InitScope(RHI::ScopeId(GetPathName()), RHI::HardwareQueueClass::Graphics, Pass::GetDeviceIndex());
  57. }
  58. params.m_frameGraphBuilder->ImportScopeProducer(*this);
  59. RPI::Scene* scene = m_pipeline->GetScene();
  60. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  61. if (rayTracingFeatureProcessor)
  62. {
  63. auto revision = rayTracingFeatureProcessor->BeginFrame();
  64. m_rayTracingRevisionOutDated = revision != m_rayTracingRevision;
  65. if (m_rayTracingRevisionOutDated)
  66. {
  67. m_rayTracingRevision = revision;
  68. ReadbackScopeQueryResults();
  69. }
  70. }
  71. }
  72. RHI::Ptr<RPI::Query> RayTracingAccelerationStructurePass::GetQuery(RPI::ScopeQueryType queryType)
  73. {
  74. auto typeIndex{ static_cast<uint32_t>(queryType) };
  75. if (!m_scopeQueries[typeIndex])
  76. {
  77. RHI::Ptr<RPI::Query> query;
  78. switch (queryType)
  79. {
  80. case RPI::ScopeQueryType::Timestamp:
  81. query = RPI::GpuQuerySystemInterface::Get()->CreateQuery(
  82. RHI::QueryType::Timestamp, RHI::QueryPoolScopeAttachmentType::Global, RHI::ScopeAttachmentAccess::Write);
  83. break;
  84. case RPI::ScopeQueryType::PipelineStatistics:
  85. query = RPI::GpuQuerySystemInterface::Get()->CreateQuery(
  86. RHI::QueryType::PipelineStatistics, RHI::QueryPoolScopeAttachmentType::Global,
  87. RHI::ScopeAttachmentAccess::Write);
  88. break;
  89. }
  90. m_scopeQueries[typeIndex] = query;
  91. }
  92. return m_scopeQueries[typeIndex];
  93. }
  94. template<typename Func>
  95. inline void RayTracingAccelerationStructurePass::ExecuteOnTimestampQuery(Func&& func)
  96. {
  97. if (IsTimestampQueryEnabled())
  98. {
  99. auto query{ GetQuery(RPI::ScopeQueryType::Timestamp) };
  100. if (query)
  101. {
  102. func(query);
  103. }
  104. }
  105. }
  106. template<typename Func>
  107. inline void RayTracingAccelerationStructurePass::ExecuteOnPipelineStatisticsQuery(Func&& func)
  108. {
  109. if (IsPipelineStatisticsQueryEnabled())
  110. {
  111. auto query{ GetQuery(RPI::ScopeQueryType::PipelineStatistics) };
  112. if (query)
  113. {
  114. func(query);
  115. }
  116. }
  117. }
  118. RPI::TimestampResult RayTracingAccelerationStructurePass::GetTimestampResultInternal() const
  119. {
  120. return m_timestampResult;
  121. }
  122. RPI::PipelineStatisticsResult RayTracingAccelerationStructurePass::GetPipelineStatisticsResultInternal() const
  123. {
  124. return m_statisticsResult;
  125. }
  126. void RayTracingAccelerationStructurePass::SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph)
  127. {
  128. RPI::Scene* scene = m_pipeline->GetScene();
  129. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  130. if (rayTracingFeatureProcessor)
  131. {
  132. if (m_rayTracingRevisionOutDated)
  133. {
  134. // create the TLAS buffers based on the descriptor
  135. RHI::Ptr<RHI::RayTracingTlas>& rayTracingTlas = rayTracingFeatureProcessor->GetTlas();
  136. // import and attach the TLAS buffer
  137. const RHI::Ptr<RHI::Buffer>& rayTracingTlasBuffer = rayTracingTlas->GetTlasBuffer();
  138. if (rayTracingTlasBuffer && rayTracingFeatureProcessor->HasGeometry())
  139. {
  140. AZ::RHI::AttachmentId tlasAttachmentId = rayTracingFeatureProcessor->GetTlasAttachmentId();
  141. if (frameGraph.GetAttachmentDatabase().IsAttachmentValid(tlasAttachmentId) == false)
  142. {
  143. [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(tlasAttachmentId, rayTracingTlasBuffer);
  144. AZ_Assert(result == RHI::ResultCode::Success, "Failed to import ray tracing TLAS buffer with error %d", result);
  145. }
  146. uint32_t tlasBufferByteCount = aznumeric_cast<uint32_t>(rayTracingTlasBuffer->GetDescriptor().m_byteCount);
  147. RHI::BufferViewDescriptor tlasBufferViewDescriptor = RHI::BufferViewDescriptor::CreateRayTracingTLAS(tlasBufferByteCount);
  148. RHI::BufferScopeAttachmentDescriptor desc;
  149. desc.m_attachmentId = tlasAttachmentId;
  150. desc.m_bufferViewDescriptor = tlasBufferViewDescriptor;
  151. desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::DontCare;
  152. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::Write, RHI::ScopeAttachmentStage::RayTracingShader);
  153. }
  154. }
  155. // Attach output data from the skinning pass. This is needed to ensure that this pass is executed after
  156. // the skinning pass has finished. We assume that the pipeline has a skinning pass with this output available.
  157. if (rayTracingFeatureProcessor->GetSkinnedMeshCount() > 0)
  158. {
  159. auto skinningPassPtr = FindAdjacentPass(AZ::Name("SkinningPass"));
  160. auto skinnedMeshOutputStreamBindingPtr = skinningPassPtr->FindAttachmentBinding(AZ::Name("SkinnedMeshOutputStream"));
  161. [[maybe_unused]] auto result = frameGraph.UseShaderAttachment(skinnedMeshOutputStreamBindingPtr->m_unifiedScopeDesc.GetAsBuffer(), RHI::ScopeAttachmentAccess::Read, RHI::ScopeAttachmentStage::RayTracingShader);
  162. AZ_Assert(result == AZ::RHI::ResultCode::Success, "Failed to attach SkinnedMeshOutputStream buffer with error %d", result);
  163. }
  164. AddScopeQueryToFrameGraph(frameGraph);
  165. }
  166. }
  167. void RayTracingAccelerationStructurePass::BuildCommandList(const RHI::FrameGraphExecuteContext& context)
  168. {
  169. RPI::Scene* scene = m_pipeline->GetScene();
  170. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  171. if (!rayTracingFeatureProcessor)
  172. {
  173. return;
  174. }
  175. if (!rayTracingFeatureProcessor->GetTlas()->GetTlasBuffer())
  176. {
  177. return;
  178. }
  179. if (!m_rayTracingRevisionOutDated && rayTracingFeatureProcessor->GetSkinnedMeshCount() == 0)
  180. {
  181. // TLAS is up to date
  182. return;
  183. }
  184. if (!rayTracingFeatureProcessor->HasGeometry())
  185. {
  186. // no ray tracing meshes in the scene
  187. return;
  188. }
  189. BeginScopeQuery(context);
  190. // build newly added or skinned BLAS objects
  191. AZStd::vector<const AZ::RHI::DeviceRayTracingBlas*> changedBlasList;
  192. RayTracingFeatureProcessor::BlasInstanceMap& blasInstances = rayTracingFeatureProcessor->GetBlasInstances();
  193. for (auto& blasInstance : blasInstances)
  194. {
  195. const bool isSkinnedMesh = blasInstance.second.m_isSkinnedMesh;
  196. const bool buildBlas = (blasInstance.second.m_blasBuilt & RHI::MultiDevice::DeviceMask(1 << context.GetDeviceIndex())) ==
  197. RHI::MultiDevice::NoDevices;
  198. if (buildBlas || isSkinnedMesh)
  199. {
  200. for (auto submeshIndex = 0; submeshIndex < blasInstance.second.m_subMeshes.size(); ++submeshIndex)
  201. {
  202. auto& submeshBlasInstance = blasInstance.second.m_subMeshes[submeshIndex];
  203. changedBlasList.push_back(submeshBlasInstance.m_blas->GetDeviceRayTracingBlas(context.GetDeviceIndex()).get());
  204. if (buildBlas)
  205. {
  206. // Always build the BLAS, if it has not previously been built
  207. context.GetCommandList()->BuildBottomLevelAccelerationStructure(*submeshBlasInstance.m_blas->GetDeviceRayTracingBlas(context.GetDeviceIndex()));
  208. continue;
  209. }
  210. // Determine if a skinned mesh BLAS needs to be updated or completely rebuilt. For now, we want to rebuild a BLAS every
  211. // SKINNED_BLAS_REBUILD_FRAME_INTERVAL frames, while updating it all other frames. This is based on the assumption that
  212. // by adding together the asset ID hash, submesh index, and frame count, we get a value that allows us to uniformly
  213. // distribute rebuilding all skinned mesh BLASs over all frames.
  214. auto assetGuid = blasInstance.first.m_guid.GetHash();
  215. if (isSkinnedMesh && (assetGuid + submeshIndex + m_frameCount) % SKINNED_BLAS_REBUILD_FRAME_INTERVAL != 0)
  216. {
  217. // Skinned mesh that simply needs an update
  218. context.GetCommandList()->UpdateBottomLevelAccelerationStructure(
  219. *submeshBlasInstance.m_blas->GetDeviceRayTracingBlas(context.GetDeviceIndex()));
  220. }
  221. else
  222. {
  223. // Fall back to building the BLAS in any case
  224. context.GetCommandList()->BuildBottomLevelAccelerationStructure(
  225. *submeshBlasInstance.m_blas->GetDeviceRayTracingBlas(context.GetDeviceIndex()));
  226. }
  227. }
  228. AZStd::lock_guard lock(rayTracingFeatureProcessor->GetBlasBuiltMutex());
  229. blasInstance.second.m_blasBuilt |= RHI::MultiDevice::DeviceMask(1 << context.GetDeviceIndex());
  230. }
  231. }
  232. // build the TLAS object
  233. context.GetCommandList()->BuildTopLevelAccelerationStructure(*rayTracingFeatureProcessor->GetTlas()->GetDeviceRayTracingTlas(context.GetDeviceIndex()), changedBlasList);
  234. ++m_frameCount;
  235. EndScopeQuery(context);
  236. }
  237. void RayTracingAccelerationStructurePass::AddScopeQueryToFrameGraph(RHI::FrameGraphInterface frameGraph)
  238. {
  239. const auto addToFrameGraph = [&frameGraph](RHI::Ptr<RPI::Query> query)
  240. {
  241. query->AddToFrameGraph(frameGraph);
  242. };
  243. ExecuteOnTimestampQuery(addToFrameGraph);
  244. ExecuteOnPipelineStatisticsQuery(addToFrameGraph);
  245. }
  246. void RayTracingAccelerationStructurePass::BeginScopeQuery(const RHI::FrameGraphExecuteContext& context)
  247. {
  248. const auto beginQuery = [&context, this](RHI::Ptr<RPI::Query> query)
  249. {
  250. if (query->BeginQuery(context) == RPI::QueryResultCode::Fail)
  251. {
  252. AZ_UNUSED(this); // Prevent unused warning in release builds
  253. AZ_WarningOnce(
  254. "RayTracingAccelerationStructurePass", false,
  255. "BeginScopeQuery failed. Make sure AddScopeQueryToFrameGraph was called in SetupFrameGraphDependencies"
  256. " for this pass: %s",
  257. this->RTTI_GetTypeName());
  258. }
  259. };
  260. ExecuteOnTimestampQuery(beginQuery);
  261. ExecuteOnPipelineStatisticsQuery(beginQuery);
  262. }
  263. void RayTracingAccelerationStructurePass::EndScopeQuery(const RHI::FrameGraphExecuteContext& context)
  264. {
  265. const auto endQuery = [&context](const RHI::Ptr<RPI::Query>& query)
  266. {
  267. query->EndQuery(context);
  268. };
  269. // This scope query implementation should be replaced by the feature linked below on GitHub:
  270. // [GHI-16945] Feature Request - Add GPU timestamp and pipeline statistic support for scopes
  271. ExecuteOnTimestampQuery(endQuery);
  272. ExecuteOnPipelineStatisticsQuery(endQuery);
  273. m_lastDeviceIndex = context.GetDeviceIndex();
  274. }
  275. void RayTracingAccelerationStructurePass::ReadbackScopeQueryResults()
  276. {
  277. ExecuteOnTimestampQuery(
  278. [this](const RHI::Ptr<RPI::Query>& query)
  279. {
  280. const uint32_t TimestampResultQueryCount{ 2u };
  281. uint64_t timestampResult[TimestampResultQueryCount] = { 0 };
  282. query->GetLatestResult(&timestampResult, sizeof(uint64_t) * TimestampResultQueryCount, m_lastDeviceIndex);
  283. m_timestampResult = RPI::TimestampResult(timestampResult[0], timestampResult[1], RHI::HardwareQueueClass::Graphics);
  284. });
  285. ExecuteOnPipelineStatisticsQuery(
  286. [this](const RHI::Ptr<RPI::Query>& query)
  287. {
  288. query->GetLatestResult(&m_statisticsResult, sizeof(RPI::PipelineStatisticsResult), m_lastDeviceIndex);
  289. });
  290. }
  291. } // namespace RPI
  292. } // namespace AZ