GradientSignalTestHelpers.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Tests/GradientSignalTestHelpers.h>
  9. #include <Atom/RPI.Reflect/Image/ImageMipChainAssetCreator.h>
  10. #include <Atom/RPI.Reflect/Image/StreamingImageAssetCreator.h>
  11. #include <AzCore/Math/Aabb.h>
  12. #include <GradientSignal/GradientSampler.h>
  13. namespace UnitTest
  14. {
  15. AZ::RHI::DeviceImageSubresourceLayout BuildSubImageLayout(AZ::u32 width, AZ::u32 height, AZ::u32 pixelSize)
  16. {
  17. AZ::RHI::DeviceImageSubresourceLayout layout;
  18. layout.m_size = AZ::RHI::Size{ width, height, 1 };
  19. layout.m_rowCount = width;
  20. layout.m_bytesPerRow = width * pixelSize;
  21. layout.m_bytesPerImage = width * height * pixelSize;
  22. return layout;
  23. }
  24. AZStd::vector<uint8_t> BuildBasicImageData(AZ::u32 width, AZ::u32 height, AZ::u32 pixelSize, AZ::s32 seed)
  25. {
  26. const size_t imageSize = width * height * pixelSize;
  27. AZStd::vector<uint8_t> image;
  28. image.reserve(imageSize);
  29. size_t value = 0;
  30. AZStd::hash_combine(value, seed);
  31. for (AZ::u32 x = 0; x < width; ++x)
  32. {
  33. for (AZ::u32 y = 0; y < height; ++y)
  34. {
  35. AZStd::hash_combine(value, x);
  36. AZStd::hash_combine(value, y);
  37. image.push_back(static_cast<AZ::u8>(value));
  38. }
  39. }
  40. EXPECT_EQ(image.size(), imageSize);
  41. return image;
  42. }
  43. AZ::Data::Asset<AZ::RPI::ImageMipChainAsset> BuildBasicMipChainAsset(
  44. AZ::u16 mipLevels, AZ::u32 width, AZ::u32 height, AZ::u32 pixelSize, AZStd::span<const uint8_t> data)
  45. {
  46. using namespace AZ;
  47. RPI::ImageMipChainAssetCreator assetCreator;
  48. const uint16_t arraySize = 1;
  49. assetCreator.Begin(Data::AssetId(AZ::Uuid::CreateRandom()), mipLevels, arraySize);
  50. RHI::DeviceImageSubresourceLayout layout = BuildSubImageLayout(width, height, pixelSize);
  51. assetCreator.BeginMip(layout);
  52. assetCreator.AddSubImage(data.data(), data.size());
  53. assetCreator.EndMip();
  54. Data::Asset<RPI::ImageMipChainAsset> asset;
  55. EXPECT_TRUE(assetCreator.End(asset));
  56. EXPECT_TRUE(asset.IsReady());
  57. EXPECT_NE(asset.Get(), nullptr);
  58. return asset;
  59. }
  60. AZStd::vector<uint8_t> BuildSpecificPixelImageData(
  61. AZ::u32 width, AZ::u32 height, AZ::u32 pixelSize, AZ::u32 pixelX, AZ::u32 pixelY, AZStd::span<const AZ::u8> setPixelValues)
  62. {
  63. AZ_Assert(setPixelValues.size() == pixelSize, "Wrong number of pixel channel values passed in");
  64. const size_t imageSize = width * height * pixelSize;
  65. AZStd::vector<uint8_t> image;
  66. image.reserve(imageSize);
  67. // Image data should be stored inverted on the y axis relative to our engine, so loop backwards through y.
  68. for (int y = static_cast<int>(height) - 1; y >= 0; --y)
  69. {
  70. for (AZ::u32 x = 0; x < width; ++x)
  71. {
  72. for (AZ::u32 component = 0; component < pixelSize; ++component)
  73. {
  74. if ((x == static_cast<int>(pixelX)) && (y == static_cast<int>(pixelY)))
  75. {
  76. image.push_back(setPixelValues[component]);
  77. }
  78. else
  79. {
  80. image.push_back(0);
  81. }
  82. }
  83. }
  84. }
  85. EXPECT_EQ(image.size(), imageSize);
  86. return image;
  87. }
  88. AZ::Data::Asset<AZ::RPI::StreamingImageAsset> CreateImageAssetFromPixelData(
  89. AZ::u32 width, AZ::u32 height, AZ::RHI::Format format, AZStd::span<const uint8_t> data)
  90. {
  91. auto randomAssetId = AZ::Data::AssetId(AZ::Uuid::CreateRandom());
  92. auto imageAsset = AZ::Data::AssetManager::Instance().CreateAsset<AZ::RPI::StreamingImageAsset>(
  93. randomAssetId, AZ::Data::AssetLoadBehavior::Default);
  94. const AZ::u32 mipCountTotal = 1;
  95. const AZ::u32 pixelSize = AZ::RHI::GetFormatComponentCount(format);
  96. AZ::Data::Asset<AZ::RPI::ImageMipChainAsset> mipChain = BuildBasicMipChainAsset(mipCountTotal, width, height, pixelSize, data);
  97. AZ::RPI::StreamingImageAssetCreator assetCreator;
  98. assetCreator.Begin(randomAssetId);
  99. AZ::RHI::ImageDescriptor imageDesc = AZ::RHI::ImageDescriptor::Create2D(AZ::RHI::ImageBindFlags::ShaderRead, width, height, format);
  100. assetCreator.SetImageDescriptor(imageDesc);
  101. assetCreator.AddMipChainAsset(*mipChain.Get());
  102. EXPECT_TRUE(assetCreator.End(imageAsset));
  103. EXPECT_TRUE(imageAsset.IsReady());
  104. EXPECT_NE(imageAsset.Get(), nullptr);
  105. return imageAsset;
  106. }
  107. AZ::Data::Asset<AZ::RPI::StreamingImageAsset> CreateImageAsset(AZ::u32 width, AZ::u32 height, AZ::s32 seed)
  108. {
  109. const auto format = AZ::RHI::Format::R8_UNORM;
  110. const AZ::u32 pixelSize = AZ::RHI::GetFormatComponentCount(format);
  111. AZStd::vector<uint8_t> data = BuildBasicImageData(width, height, pixelSize, seed);
  112. return CreateImageAssetFromPixelData(width, height, format, data);
  113. }
  114. AZ::Data::Asset<AZ::RPI::StreamingImageAsset> CreateSpecificPixelImageAsset(
  115. AZ::u32 width, AZ::u32 height, AZ::u32 pixelX, AZ::u32 pixelY, AZStd::span<const AZ::u8> setPixelValues)
  116. {
  117. const auto format = AZ::RHI::Format::R8G8B8A8_UNORM;
  118. const AZ::u32 pixelSize = AZ::RHI::GetFormatComponentCount(format);
  119. AZStd::vector<uint8_t> data = BuildSpecificPixelImageData(width, height, pixelSize, pixelX, pixelY, setPixelValues);
  120. return CreateImageAssetFromPixelData(width, height, format, data);
  121. }
  122. AZ::Vector3 PixelCoordinatesToWorldSpace(uint32_t pixelX, uint32_t pixelY, const AZ::Aabb& bounds, uint32_t width, uint32_t height)
  123. {
  124. AZ::Vector2 pixelSize(bounds.GetXExtent() / aznumeric_cast<float>(width), bounds.GetYExtent() / aznumeric_cast<float>(height));
  125. // Return the center point of the pixel in world space.
  126. // Note that Y gets flipped because of the way images map into world space. (0,0) is the lower left corner in world space,
  127. // but the upper left corner in image space.
  128. return AZ::Vector3(
  129. bounds.GetMin().GetX() + ((pixelX + 0.5f) * pixelSize.GetX()),
  130. bounds.GetMin().GetY() + ((height - (pixelY + 0.5f)) * pixelSize.GetY()),
  131. 0.0f);
  132. }
  133. void GradientSignalTestHelpers::CompareGetValueAndGetValues(AZ::EntityId gradientEntityId, float queryMin, float queryMax)
  134. {
  135. // Create a gradient sampler and run through a series of points to see if they match expectations.
  136. const AZ::Aabb queryRegion = AZ::Aabb::CreateFromMinMax(AZ::Vector3(queryMin), AZ::Vector3(queryMax));
  137. const AZ::Vector2 stepSize(1.0f, 1.0f);
  138. GradientSignal::GradientSampler gradientSampler;
  139. gradientSampler.m_gradientId = gradientEntityId;
  140. const size_t numSamplesX = aznumeric_cast<size_t>(ceil(queryRegion.GetExtents().GetX() / stepSize.GetX()));
  141. const size_t numSamplesY = aznumeric_cast<size_t>(ceil(queryRegion.GetExtents().GetY() / stepSize.GetY()));
  142. // Build up the list of positions to query.
  143. AZStd::vector<AZ::Vector3> positions(numSamplesX * numSamplesY);
  144. size_t index = 0;
  145. for (size_t yIndex = 0; yIndex < numSamplesY; yIndex++)
  146. {
  147. float y = queryRegion.GetMin().GetY() + (stepSize.GetY() * yIndex);
  148. for (size_t xIndex = 0; xIndex < numSamplesX; xIndex++)
  149. {
  150. float x = queryRegion.GetMin().GetX() + (stepSize.GetX() * xIndex);
  151. positions[index++] = AZ::Vector3(x, y, 0.0f);
  152. }
  153. }
  154. // Get the results from GetValues
  155. AZStd::vector<float> results(numSamplesX * numSamplesY);
  156. gradientSampler.GetValues(positions, results);
  157. // For each position, call GetValue and verify that the values match.
  158. for (size_t positionIndex = 0; positionIndex < positions.size(); positionIndex++)
  159. {
  160. GradientSignal::GradientSampleParams params;
  161. params.m_position = positions[positionIndex];
  162. float value = gradientSampler.GetValue(params);
  163. // We use ASSERT_NEAR instead of EXPECT_NEAR because if one value doesn't match, they probably all won't, so there's no
  164. // reason to keep running and printing failures for every value.
  165. ASSERT_NEAR(value, results[positionIndex], 0.000001f);
  166. }
  167. }
  168. #ifdef HAVE_BENCHMARK
  169. void GradientSignalTestHelpers::FillQueryPositions(AZStd::vector<AZ::Vector3>& positions, float height, float width)
  170. {
  171. size_t index = 0;
  172. for (float y = 0.0f; y < height; y += 1.0f)
  173. {
  174. for (float x = 0.0f; x < width; x += 1.0f)
  175. {
  176. positions[index++] = AZ::Vector3(x, y, 0.0f);
  177. }
  178. }
  179. }
  180. void GradientSignalTestHelpers::RunEBusGetValueBenchmark(benchmark::State& state, const AZ::EntityId& gradientId, int64_t queryRange)
  181. {
  182. AZ_PROFILE_FUNCTION(Entity);
  183. GradientSignal::GradientSampleParams params;
  184. // Get the height and width ranges for querying from our benchmark parameters
  185. const float height = aznumeric_cast<float>(queryRange);
  186. const float width = aznumeric_cast<float>(queryRange);
  187. // Call GetValue() on the EBus for every height and width in our ranges.
  188. for ([[maybe_unused]] auto _ : state)
  189. {
  190. for (float y = 0.0f; y < height; y += 1.0f)
  191. {
  192. for (float x = 0.0f; x < width; x += 1.0f)
  193. {
  194. float value = 0.0f;
  195. params.m_position = AZ::Vector3(x, y, 0.0f);
  196. GradientSignal::GradientRequestBus::EventResult(
  197. value, gradientId, &GradientSignal::GradientRequestBus::Events::GetValue, params);
  198. benchmark::DoNotOptimize(value);
  199. }
  200. }
  201. }
  202. }
  203. void GradientSignalTestHelpers::RunEBusGetValuesBenchmark(benchmark::State& state, const AZ::EntityId& gradientId, int64_t queryRange)
  204. {
  205. AZ_PROFILE_FUNCTION(Entity);
  206. // Get the height and width ranges for querying from our benchmark parameters
  207. float height = aznumeric_cast<float>(queryRange);
  208. float width = aznumeric_cast<float>(queryRange);
  209. int64_t totalQueryPoints = queryRange * queryRange;
  210. // Call GetValues() for every height and width in our ranges.
  211. for ([[maybe_unused]] auto _ : state)
  212. {
  213. // Set up our vector of query positions. This is done inside the benchmark timing since we're counting the work to create
  214. // each query position in the single GetValue() call benchmarks, and will make the timing more directly comparable.
  215. AZStd::vector<AZ::Vector3> positions(totalQueryPoints);
  216. FillQueryPositions(positions, height, width);
  217. // Query and get the results.
  218. AZStd::vector<float> results(totalQueryPoints);
  219. GradientSignal::GradientRequestBus::Event(
  220. gradientId, &GradientSignal::GradientRequestBus::Events::GetValues, positions, results);
  221. benchmark::DoNotOptimize(results);
  222. }
  223. }
  224. void GradientSignalTestHelpers::RunSamplerGetValueBenchmark(benchmark::State& state, const AZ::EntityId& gradientId, int64_t queryRange)
  225. {
  226. AZ_PROFILE_FUNCTION(Entity);
  227. // Create a gradient sampler to use for querying our gradient.
  228. GradientSignal::GradientSampler gradientSampler;
  229. gradientSampler.m_gradientId = gradientId;
  230. // Get the height and width ranges for querying from our benchmark parameters
  231. const float height = aznumeric_cast<float>(queryRange);
  232. const float width = aznumeric_cast<float>(queryRange);
  233. // Call GetValue() through the GradientSampler for every height and width in our ranges.
  234. for ([[maybe_unused]] auto _ : state)
  235. {
  236. for (float y = 0.0f; y < height; y += 1.0f)
  237. {
  238. for (float x = 0.0f; x < width; x += 1.0f)
  239. {
  240. GradientSignal::GradientSampleParams params;
  241. params.m_position = AZ::Vector3(x, y, 0.0f);
  242. float value = gradientSampler.GetValue(params);
  243. benchmark::DoNotOptimize(value);
  244. }
  245. }
  246. }
  247. }
  248. void GradientSignalTestHelpers::RunSamplerGetValuesBenchmark(
  249. benchmark::State& state, const AZ::EntityId& gradientId, int64_t queryRange)
  250. {
  251. AZ_PROFILE_FUNCTION(Entity);
  252. // Create a gradient sampler to use for querying our gradient.
  253. GradientSignal::GradientSampler gradientSampler;
  254. gradientSampler.m_gradientId = gradientId;
  255. // Get the height and width ranges for querying from our benchmark parameters
  256. const float height = aznumeric_cast<float>(queryRange);
  257. const float width = aznumeric_cast<float>(queryRange);
  258. const int64_t totalQueryPoints = queryRange * queryRange;
  259. // Call GetValues() through the GradientSampler for every height and width in our ranges.
  260. for ([[maybe_unused]] auto _ : state)
  261. {
  262. // Set up our vector of query positions. This is done inside the benchmark timing since we're counting the work to create
  263. // each query position in the single GetValue() call benchmarks, and will make the timing more directly comparable.
  264. AZStd::vector<AZ::Vector3> positions(totalQueryPoints);
  265. FillQueryPositions(positions, height, width);
  266. // Query and get the results.
  267. AZStd::vector<float> results(totalQueryPoints);
  268. gradientSampler.GetValues(positions, results);
  269. benchmark::DoNotOptimize(results);
  270. }
  271. }
  272. void GradientSignalTestHelpers::RunGetValueOrGetValuesBenchmark(benchmark::State& state, const AZ::EntityId& gradientId)
  273. {
  274. switch (state.range(0))
  275. {
  276. case GetValuePermutation::EBUS_GET_VALUE:
  277. RunEBusGetValueBenchmark(state, gradientId, state.range(1));
  278. break;
  279. case GetValuePermutation::EBUS_GET_VALUES:
  280. RunEBusGetValuesBenchmark(state, gradientId, state.range(1));
  281. break;
  282. case GetValuePermutation::SAMPLER_GET_VALUE:
  283. RunSamplerGetValueBenchmark(state, gradientId, state.range(1));
  284. break;
  285. case GetValuePermutation::SAMPLER_GET_VALUES:
  286. RunSamplerGetValuesBenchmark(state, gradientId, state.range(1));
  287. break;
  288. default:
  289. AZ_Assert(false, "Benchmark permutation type not supported.");
  290. }
  291. }
  292. #endif
  293. }