DiffuseProbeGridVisualizationPreparePass.cpp 15 KB


  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/Feature/RayTracing/RayTracingFeatureProcessorInterface.h>
  9. #include <Atom/RHI/CommandList.h>
  10. #include <Atom/RHI/RHISystemInterface.h>
  11. #include <Atom/RPI.Public/RenderPipeline.h>
  12. #include <Atom/RPI.Public/RPIUtils.h>
  13. #include <Atom/RPI.Public/Scene.h>
  14. #include <DiffuseProbeGrid_Traits_Platform.h>
  15. #include <Render/DiffuseProbeGridFeatureProcessor.h>
  16. #include <Render/DiffuseProbeGridVisualizationPreparePass.h>
  17. namespace AZ
  18. {
  19. namespace Render
  20. {
  21. RPI::Ptr<DiffuseProbeGridVisualizationPreparePass> DiffuseProbeGridVisualizationPreparePass::Create(const RPI::PassDescriptor& descriptor)
  22. {
  23. RPI::Ptr<DiffuseProbeGridVisualizationPreparePass> diffuseProbeGridVisualizationPreparePass = aznew DiffuseProbeGridVisualizationPreparePass(descriptor);
  24. return AZStd::move(diffuseProbeGridVisualizationPreparePass);
  25. }
  26. DiffuseProbeGridVisualizationPreparePass::DiffuseProbeGridVisualizationPreparePass(const RPI::PassDescriptor& descriptor)
  27. : RenderPass(descriptor)
  28. {
  29. // disable this pass if we're on a platform that doesn't support raytracing
  30. if (RHI::RHISystemInterface::Get()->GetRayTracingSupport() == RHI::MultiDevice::NoDevices || !AZ_TRAIT_DIFFUSE_GI_PASSES_SUPPORTED)
  31. {
  32. SetEnabled(false);
  33. }
  34. else
  35. {
  36. LoadShader();
  37. }
  38. }
  39. void DiffuseProbeGridVisualizationPreparePass::LoadShader()
  40. {
  41. // load shaders
  42. // Note: the shader may not be available on all platforms
  43. AZStd::string shaderFilePath = "Shaders/DiffuseGlobalIllumination/DiffuseProbeGridVisualizationPrepare.azshader";
  44. m_shader = RPI::LoadCriticalShader(shaderFilePath);
  45. if (m_shader == nullptr)
  46. {
  47. return;
  48. }
  49. RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
  50. const auto& shaderVariant = m_shader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  51. shaderVariant.ConfigurePipelineState(pipelineStateDescriptor);
  52. m_pipelineState = m_shader->AcquirePipelineState(pipelineStateDescriptor);
  53. AZ_Assert(m_pipelineState, "Failed to acquire pipeline state");
  54. m_srgLayout = m_shader->FindShaderResourceGroupLayout(RPI::SrgBindingSlot::Pass);
  55. AZ_Assert(m_srgLayout.get(), "Failed to find Srg layout");
  56. const auto outcome = RPI::GetComputeShaderNumThreads(m_shader->GetAsset(), m_dispatchArgs);
  57. if (!outcome.IsSuccess())
  58. {
  59. AZ_Error("PassSystem", false, "[DiffuseProbeGridVisualizationPreparePass '%s']: Shader '%s' contains invalid numthreads arguments:\n%s", GetPathName().GetCStr(), shaderFilePath.c_str(), outcome.GetError().c_str());
  60. }
  61. }
  62. bool DiffuseProbeGridVisualizationPreparePass::ShouldUpdate(const AZStd::shared_ptr<DiffuseProbeGrid>& diffuseProbeGrid) const
  63. {
  64. return (diffuseProbeGrid->GetVisualizationEnabled() && diffuseProbeGrid->GetVisualizationTlasUpdateRequired());
  65. }
  66. bool DiffuseProbeGridVisualizationPreparePass::IsEnabled() const
  67. {
  68. if (!RenderPass::IsEnabled())
  69. {
  70. return false;
  71. }
  72. RPI::Scene* scene = m_pipeline->GetScene();
  73. if (!scene)
  74. {
  75. return false;
  76. }
  77. DiffuseProbeGridFeatureProcessor* diffuseProbeGridFeatureProcessor = scene->GetFeatureProcessor<DiffuseProbeGridFeatureProcessor>();
  78. if (diffuseProbeGridFeatureProcessor)
  79. {
  80. for (auto& diffuseProbeGrid : diffuseProbeGridFeatureProcessor->GetVisibleProbeGrids())
  81. {
  82. if (ShouldUpdate(diffuseProbeGrid))
  83. {
  84. return true;
  85. }
  86. }
  87. }
  88. return false;
  89. }
  90. void DiffuseProbeGridVisualizationPreparePass::FrameBeginInternal(FramePrepareParams params)
  91. {
  92. RPI::Scene* scene = m_pipeline->GetScene();
  93. DiffuseProbeGridFeatureProcessor* diffuseProbeGridFeatureProcessor = scene->GetFeatureProcessor<DiffuseProbeGridFeatureProcessor>();
  94. for (auto& diffuseProbeGrid : diffuseProbeGridFeatureProcessor->GetVisibleProbeGrids())
  95. {
  96. if (!ShouldUpdate(diffuseProbeGrid))
  97. {
  98. continue;
  99. }
  100. // create the TLAS descriptor by adding an instance entry for each probe in the grid
  101. RHI::RayTracingTlasDescriptor tlasDescriptor;
  102. RHI::RayTracingTlasDescriptor* tlasDescriptorBuild = tlasDescriptor.Build();
  103. // initialize the transform for each probe to Identity(), they will be updated by the compute shader
  104. AZ::Transform transform = AZ::Transform::Identity();
  105. uint32_t probeCount = diffuseProbeGrid->GetTotalProbeCount();
  106. for (uint32_t index = 0; index < probeCount; ++index)
  107. {
  108. tlasDescriptorBuild->Instance()
  109. ->InstanceID(index)
  110. ->InstanceMask(1)
  111. ->HitGroupIndex(0)
  112. ->Blas(diffuseProbeGridFeatureProcessor->GetVisualizationBlas())
  113. ->Transform(transform)
  114. ;
  115. }
  116. // create the TLAS buffers from on the descriptor
  117. RHI::Ptr<RHI::RayTracingTlas>& visualizationTlas = diffuseProbeGrid->GetVisualizationTlas();
  118. visualizationTlas->CreateBuffers(RHI::MultiDevice::AllDevices, &tlasDescriptor, diffuseProbeGridFeatureProcessor->GetVisualizationBufferPools());
  119. }
  120. RenderPass::FrameBeginInternal(params);
  121. }
  122. void DiffuseProbeGridVisualizationPreparePass::SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph)
  123. {
  124. RenderPass::SetupFrameGraphDependencies(frameGraph);
  125. RPI::Scene* scene = m_pipeline->GetScene();
  126. DiffuseProbeGridFeatureProcessor* diffuseProbeGridFeatureProcessor = scene->GetFeatureProcessor<DiffuseProbeGridFeatureProcessor>();
  127. frameGraph.SetEstimatedItemCount(aznumeric_cast<uint32_t>(diffuseProbeGridFeatureProcessor->GetVisibleProbeGrids().size()));
  128. for (auto& diffuseProbeGrid : diffuseProbeGridFeatureProcessor->GetVisibleProbeGrids())
  129. {
  130. if (!ShouldUpdate(diffuseProbeGrid))
  131. {
  132. continue;
  133. }
  134. // import and attach the visualization TLAS and probe data
  135. RHI::Ptr<RHI::RayTracingTlas>& visualizationTlas = diffuseProbeGrid->GetVisualizationTlas();
  136. const RHI::Ptr<RHI::Buffer>& tlasBuffer = visualizationTlas->GetTlasBuffer();
  137. const RHI::Ptr<RHI::Buffer>& tlasInstancesBuffer = visualizationTlas->GetTlasInstancesBuffer();
  138. if (tlasBuffer && tlasInstancesBuffer)
  139. {
  140. // TLAS buffer
  141. {
  142. AZ::RHI::AttachmentId attachmentId = diffuseProbeGrid->GetProbeVisualizationTlasAttachmentId();
  143. if (frameGraph.GetAttachmentDatabase().IsAttachmentValid(attachmentId) == false)
  144. {
  145. [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(attachmentId, tlasBuffer);
  146. AZ_Assert(result == RHI::ResultCode::Success, "Failed to import DiffuseProbeGrid visualization TLAS buffer with error %d", result);
  147. }
  148. uint32_t byteCount = aznumeric_cast<uint32_t>(tlasBuffer->GetDescriptor().m_byteCount);
  149. RHI::BufferViewDescriptor bufferViewDescriptor = RHI::BufferViewDescriptor::CreateRayTracingTLAS(byteCount);
  150. RHI::BufferScopeAttachmentDescriptor desc;
  151. desc.m_attachmentId = attachmentId;
  152. desc.m_bufferViewDescriptor = bufferViewDescriptor;
  153. desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::DontCare;
  154. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::Write, RHI::ScopeAttachmentStage::ComputeShader);
  155. }
  156. // TLAS Instances buffer
  157. {
  158. AZ::RHI::AttachmentId attachmentId = diffuseProbeGrid->GetProbeVisualizationTlasInstancesAttachmentId();
  159. if (frameGraph.GetAttachmentDatabase().IsAttachmentValid(attachmentId) == false)
  160. {
  161. [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(attachmentId, tlasInstancesBuffer);
  162. AZ_Assert(result == RHI::ResultCode::Success, "Failed to import DiffuseProbeGrid visualization TLAS Instances buffer with error %d", result);
  163. }
  164. uint32_t byteCount = aznumeric_cast<uint32_t>(tlasInstancesBuffer->GetDescriptor().m_byteCount);
  165. RHI::BufferViewDescriptor bufferViewDescriptor = RHI::BufferViewDescriptor::CreateStructured(0, byteCount / RayTracingTlasInstanceElementSize, RayTracingTlasInstanceElementSize);
  166. RHI::BufferScopeAttachmentDescriptor desc;
  167. desc.m_attachmentId = attachmentId;
  168. desc.m_bufferViewDescriptor = bufferViewDescriptor;
  169. desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::DontCare;
  170. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::Write, RHI::ScopeAttachmentStage::ComputeShader);
  171. }
  172. // grid data
  173. {
  174. RHI::BufferScopeAttachmentDescriptor desc;
  175. desc.m_attachmentId = diffuseProbeGrid->GetGridDataBufferAttachmentId();
  176. desc.m_bufferViewDescriptor = diffuseProbeGrid->GetRenderData()->m_gridDataBufferViewDescriptor;
  177. desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::Load;
  178. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::Read, RHI::ScopeAttachmentStage::ComputeShader);
  179. }
  180. // probe data
  181. {
  182. AZ::RHI::AttachmentId attachmentId = diffuseProbeGrid->GetProbeDataImageAttachmentId();
  183. if (frameGraph.GetAttachmentDatabase().IsAttachmentValid(attachmentId) == false)
  184. {
  185. [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportImage(attachmentId, diffuseProbeGrid->GetProbeDataImage());
  186. AZ_Assert(result == RHI::ResultCode::Success, "Failed to import DiffuseProbeGrid probe data buffer with error %d", result);
  187. }
  188. RHI::ImageScopeAttachmentDescriptor desc;
  189. desc.m_attachmentId = attachmentId;
  190. desc.m_imageViewDescriptor = diffuseProbeGrid->GetRenderData()->m_probeDataImageViewDescriptor;
  191. desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::Load;
  192. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::Read, RHI::ScopeAttachmentStage::ComputeShader);
  193. }
  194. }
  195. }
  196. }
  197. void DiffuseProbeGridVisualizationPreparePass::CompileResources([[maybe_unused]] const RHI::FrameGraphCompileContext& context)
  198. {
  199. RPI::Scene* scene = m_pipeline->GetScene();
  200. DiffuseProbeGridFeatureProcessor* diffuseProbeGridFeatureProcessor = scene->GetFeatureProcessor<DiffuseProbeGridFeatureProcessor>();
  201. for (auto& diffuseProbeGrid : diffuseProbeGridFeatureProcessor->GetVisibleProbeGrids())
  202. {
  203. if (!ShouldUpdate(diffuseProbeGrid))
  204. {
  205. continue;
  206. }
  207. // the DiffuseProbeGrid Srg must be updated in the Compile phase in order to successfully bind the ReadWrite shader inputs
  208. // (see ValidateSetImageView() in ShaderResourceGroupData.cpp)
  209. diffuseProbeGrid->UpdateVisualizationPrepareSrg(m_shader, m_srgLayout);
  210. if (!diffuseProbeGrid->GetVisualizationPrepareSrg()->IsQueuedForCompile())
  211. {
  212. diffuseProbeGrid->GetVisualizationPrepareSrg()->Compile();
  213. }
  214. }
  215. }
  216. void DiffuseProbeGridVisualizationPreparePass::BuildCommandListInternal(const RHI::FrameGraphExecuteContext& context)
  217. {
  218. RHI::CommandList* commandList = context.GetCommandList();
  219. RPI::Scene* scene = m_pipeline->GetScene();
  220. DiffuseProbeGridFeatureProcessor* diffuseProbeGridFeatureProcessor = scene->GetFeatureProcessor<DiffuseProbeGridFeatureProcessor>();
  221. // submit the DispatchItems for each DiffuseProbeGrid in this range
  222. for (uint32_t index = context.GetSubmitRange().m_startIndex; index < context.GetSubmitRange().m_endIndex; ++index)
  223. {
  224. AZStd::shared_ptr<DiffuseProbeGrid> diffuseProbeGrid = diffuseProbeGridFeatureProcessor->GetVisibleProbeGrids()[index];
  225. if (!ShouldUpdate(diffuseProbeGrid))
  226. {
  227. continue;
  228. }
  229. const RHI::DeviceShaderResourceGroup* shaderResourceGroup = diffuseProbeGrid->GetVisualizationPrepareSrg()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get();
  230. commandList->SetShaderResourceGroupForDispatch(*shaderResourceGroup);
  231. RHI::DeviceDispatchItem dispatchItem;
  232. dispatchItem.m_arguments = m_dispatchArgs;
  233. dispatchItem.m_pipelineState = m_pipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
  234. dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsX = diffuseProbeGrid->GetTotalProbeCount();
  235. dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsY = 1;
  236. dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsZ = 1;
  237. commandList->Submit(dispatchItem, index);
  238. }
  239. }
  240. } // namespace RPI
  241. } // namespace AZ