MotionMatchingInstance.cpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <AzCore/Console/IConsole.h>
  9. #include <AzCore/Debug/Timer.h>
  10. #include <AzCore/Component/ComponentApplicationBus.h>
  11. #include <AzCore/Serialization/EditContext.h>
  12. #include <AzCore/Serialization/SerializeContext.h>
  13. #include <EMotionFX/Source/ActorInstance.h>
  14. #include <EMotionFX/Source/EMotionFXManager.h>
  15. #include <EMotionFX/Source/Motion.h>
  16. #include <EMotionFX/Source/MotionInstance.h>
  17. #include <EMotionFX/Source/MotionInstancePool.h>
  18. #include <EMotionFX/Source/Pose.h>
  19. #include <EMotionFX/Source/TransformData.h>
  20. #include <Allocators.h>
  21. #include <Feature.h>
  22. #include <FeatureSchema.h>
  23. #include <FeatureTrajectory.h>
  24. #include <FeatureVelocity.h>
  25. #include <ImGuiMonitorBus.h>
  26. #include <KdTree.h>
  27. #include <MotionMatchingData.h>
  28. #include <MotionMatchingInstance.h>
  29. #include <PoseDataJointVelocities.h>
  30. namespace EMotionFX::MotionMatching
  31. {
  32. AZ_CVAR_EXTERNED(bool, mm_debugDraw);
  33. AZ_CVAR_EXTERNED(float, mm_debugDrawVelocityScale);
  34. AZ_CVAR_EXTERNED(bool, mm_debugDrawQueryPose);
  35. AZ_CVAR_EXTERNED(bool, mm_debugDrawQueryVelocities);
  36. AZ_CVAR_EXTERNED(bool, mm_useKdTree);
  37. AZ_CLASS_ALLOCATOR_IMPL(MotionMatchingInstance, MotionMatchAllocator)
  38. MotionMatchingInstance::~MotionMatchingInstance()
  39. {
  40. DebugDrawRequestBus::Handler::BusDisconnect();
  41. if (m_motionInstance)
  42. {
  43. GetMotionInstancePool().Free(m_motionInstance);
  44. }
  45. if (m_prevMotionInstance)
  46. {
  47. GetMotionInstancePool().Free(m_prevMotionInstance);
  48. }
  49. }
  50. MotionInstance* MotionMatchingInstance::CreateMotionInstance() const
  51. {
  52. MotionInstance* result = GetMotionInstancePool().RequestNew(m_data->GetFrameDatabase().GetFrame(0).GetSourceMotion(), m_actorInstance);
  53. return result;
  54. }
  55. void MotionMatchingInstance::Init(const InitSettings& settings)
  56. {
  57. AZ_Assert(settings.m_actorInstance, "The actor instance cannot be a nullptr.");
  58. AZ_Assert(settings.m_data, "The motion match data cannot be nullptr.");
  59. DebugDrawRequestBus::Handler::BusConnect();
  60. // Update the cached pointer to the trajectory feature.
  61. const FeatureSchema& featureSchema = settings.m_data->GetFeatureSchema();
  62. for (Feature* feature : featureSchema.GetFeatures())
  63. {
  64. if (feature->RTTI_GetType() == azrtti_typeid<FeatureTrajectory>())
  65. {
  66. m_cachedTrajectoryFeature = static_cast<FeatureTrajectory*>(feature);
  67. break;
  68. }
  69. }
  70. m_actorInstance = settings.m_actorInstance;
  71. m_data = settings.m_data;
  72. if (settings.m_data->GetFrameDatabase().GetNumFrames() == 0)
  73. {
  74. return;
  75. }
  76. if (!m_motionInstance)
  77. {
  78. m_motionInstance = CreateMotionInstance();
  79. }
  80. if (!m_prevMotionInstance)
  81. {
  82. m_prevMotionInstance = CreateMotionInstance();
  83. }
  84. m_blendSourcePose.LinkToActorInstance(m_actorInstance);
  85. m_blendSourcePose.InitFromBindPose(m_actorInstance);
  86. m_blendTargetPose.LinkToActorInstance(m_actorInstance);
  87. m_blendTargetPose.InitFromBindPose(m_actorInstance);
  88. m_queryPose.LinkToActorInstance(m_actorInstance);
  89. m_queryPose.InitFromBindPose(m_actorInstance);
  90. // Make sure we have enough space inside the frame floats array, which is used to search the kdTree.
  91. const size_t numValuesInKdTree = m_data->GetKdTree().GetNumDimensions();
  92. m_kdTreeQueryVector.Resize(numValuesInKdTree);
  93. m_queryVector.Resize(m_data->GetFeatureMatrix().cols());
  94. // Initialize the trajectory history.
  95. if (m_cachedTrajectoryFeature)
  96. {
  97. size_t rootJointIndex = m_actorInstance->GetActor()->GetMotionExtractionNodeIndex();
  98. if (rootJointIndex == InvalidIndex32)
  99. {
  100. rootJointIndex = 0;
  101. }
  102. m_trajectoryHistory.Init(*m_actorInstance->GetTransformData()->GetCurrentPose(),
  103. rootJointIndex,
  104. m_cachedTrajectoryFeature->GetFacingAxisDir(),
  105. m_trajectorySecsToTrack);
  106. }
  107. }
  108. void MotionMatchingInstance::DebugDraw(AzFramework::DebugDisplayRequests& debugDisplay)
  109. {
  110. if (!mm_debugDraw)
  111. {
  112. return;
  113. }
  114. AZ_PROFILE_SCOPE(Animation, "MotionMatchingInstance::DebugDraw");
  115. // Get the lowest cost frame index from the last search. As we're searching the feature database with a much lower
  116. // frequency and sample the animation onwards from this, the resulting frame index does not represent the current
  117. // feature values from the shown pose.
  118. const size_t curFrameIndex = GetLowestCostFrameIndex();
  119. if (curFrameIndex == InvalidIndex)
  120. {
  121. return;
  122. }
  123. const FrameDatabase& frameDatabase = m_data->GetFrameDatabase();
  124. const FeatureSchema& featureSchema = m_data->GetFeatureSchema();
  125. const FeatureMatrix& featureMatrix = m_data->GetFeatureMatrix();
  126. // Find the frame index in the frame database that belongs to the currently used pose.
  127. const size_t currentFrame = frameDatabase.FindFrameIndex(m_motionInstance->GetMotion(), m_motionInstance->GetCurrentTime());
  128. // Render the feature debug visualizations for the current frame.
  129. if (currentFrame != InvalidIndex)
  130. {
  131. const Pose& currentPose = *m_actorInstance->GetTransformData()->GetCurrentPose();
  132. for (Feature* feature: featureSchema.GetFeatures())
  133. {
  134. if (feature->GetDebugDrawEnabled())
  135. {
  136. feature->DebugDraw(debugDisplay, currentPose, featureMatrix, m_data->GetFeatureTransformer(), currentFrame);
  137. }
  138. }
  139. }
  140. // Draw the desired future trajectory and the sampled version of the past trajectory.
  141. const AZ::Color trajectoryQueryColor = AZ::Color::CreateFromRgba(90,219,64,255);
  142. m_trajectoryQuery.DebugDraw(debugDisplay, trajectoryQueryColor);
  143. // Draw the trajectory history starting after the sampled version of the past trajectory.
  144. m_trajectoryHistory.DebugDraw(debugDisplay, trajectoryQueryColor, m_cachedTrajectoryFeature->GetPastTimeRange());
  145. // Draw the input for the motion matching search.
  146. DebugDrawQueryPose(debugDisplay, mm_debugDrawQueryPose, mm_debugDrawQueryVelocities);
  147. }
  148. void MotionMatchingInstance::DebugDrawQueryPose(AzFramework::DebugDisplayRequests& debugDisplay, bool drawPose, bool drawVelocities) const
  149. {
  150. const AZ::Color color = AZ::Color::CreateOne();
  151. if (drawPose)
  152. {
  153. m_queryPose.DebugDraw(debugDisplay, color);
  154. }
  155. if (drawVelocities)
  156. {
  157. PoseDataJointVelocities* velocityPoseData = m_queryPose.GetPoseData<PoseDataJointVelocities>();
  158. if (velocityPoseData)
  159. {
  160. const Skeleton* skeleton = m_actorInstance->GetActor()->GetSkeleton();
  161. for (const Feature* feature : m_data->GetFeatureSchema().GetFeatures())
  162. {
  163. if (const FeatureVelocity* velocityFeature = azdynamic_cast<const FeatureVelocity*>(feature))
  164. {
  165. Node* joint = skeleton->FindNodeByName(velocityFeature->GetJointName());
  166. if (joint)
  167. {
  168. const size_t jointIndex = joint->GetNodeIndex();
  169. const size_t relativeToJointIndex = feature->GetRelativeToNodeIndex();
  170. const AZ::Vector3& velocity = velocityPoseData->GetVelocities()[jointIndex];
  171. velocityFeature->DebugDraw(debugDisplay, m_queryPose, velocity, jointIndex, relativeToJointIndex, color);
  172. }
  173. }
  174. }
  175. }
  176. }
  177. }
  178. void MotionMatchingInstance::SamplePose(MotionInstance* motionInstance, Pose& outputPose)
  179. {
  180. const Pose* bindPose = m_actorInstance->GetTransformData()->GetBindPose();
  181. motionInstance->GetMotion()->Update(bindPose, &outputPose, motionInstance);
  182. if (m_actorInstance->GetActor()->GetMotionExtractionNode() && m_actorInstance->GetMotionExtractionEnabled())
  183. {
  184. outputPose.CompensateForMotionExtraction();
  185. }
  186. }
  187. void MotionMatchingInstance::SamplePose(Motion* motion, Pose& outputPose, float sampleTime) const
  188. {
  189. MotionDataSampleSettings sampleSettings;
  190. sampleSettings.m_actorInstance = outputPose.GetActorInstance();
  191. sampleSettings.m_inPlace = false;
  192. sampleSettings.m_mirror = false;
  193. sampleSettings.m_retarget = false;
  194. sampleSettings.m_inputPose = sampleSettings.m_actorInstance->GetTransformData()->GetBindPose();
  195. sampleSettings.m_sampleTime = sampleTime;
  196. sampleSettings.m_sampleTime = AZ::GetClamp(sampleTime, 0.0f, motion->GetDuration());
  197. motion->SamplePose(&outputPose, sampleSettings);
  198. }
  199. void MotionMatchingInstance::PostUpdate([[maybe_unused]] float timeDelta)
  200. {
  201. if (!m_data)
  202. {
  203. m_motionExtractionDelta.Identity();
  204. return;
  205. }
  206. const size_t lowestCostFrame = GetLowestCostFrameIndex();
  207. if (m_data->GetFrameDatabase().GetNumFrames() == 0 || lowestCostFrame == InvalidIndex)
  208. {
  209. m_motionExtractionDelta.Identity();
  210. return;
  211. }
  212. // Blend the motion extraction deltas.
  213. // Note: Make sure to update the previous as well as the current/target motion instances.
  214. if (m_blendWeight >= 1.0f - AZ::Constants::FloatEpsilon)
  215. {
  216. m_motionInstance->ExtractMotion(m_motionExtractionDelta);
  217. }
  218. else if (m_blendWeight > AZ::Constants::FloatEpsilon && m_blendWeight < 1.0f - AZ::Constants::FloatEpsilon)
  219. {
  220. Transform targetMotionExtractionDelta;
  221. m_motionInstance->ExtractMotion(m_motionExtractionDelta);
  222. m_prevMotionInstance->ExtractMotion(targetMotionExtractionDelta);
  223. m_motionExtractionDelta.Blend(targetMotionExtractionDelta, m_blendWeight);
  224. }
  225. else
  226. {
  227. m_prevMotionInstance->ExtractMotion(m_motionExtractionDelta);
  228. }
  229. }
  230. void MotionMatchingInstance::Output(Pose& outputPose)
  231. {
  232. AZ_PROFILE_SCOPE(Animation, "MotionMatchingInstance::Output");
  233. if (!m_data)
  234. {
  235. outputPose.InitFromBindPose(m_actorInstance);
  236. return;
  237. }
  238. const size_t lowestCostFrame = GetLowestCostFrameIndex();
  239. if (m_data->GetFrameDatabase().GetNumFrames() == 0 || lowestCostFrame == InvalidIndex)
  240. {
  241. outputPose.InitFromBindPose(m_actorInstance);
  242. return;
  243. }
  244. // Sample the motions and blend the results when needed.
  245. if (m_blendWeight >= 1.0f - AZ::Constants::FloatEpsilon)
  246. {
  247. m_blendTargetPose.InitFromBindPose(m_actorInstance);
  248. if (m_motionInstance)
  249. {
  250. SamplePose(m_motionInstance, m_blendTargetPose);
  251. }
  252. outputPose = m_blendTargetPose;
  253. }
  254. else if (m_blendWeight > AZ::Constants::FloatEpsilon && m_blendWeight < 1.0f - AZ::Constants::FloatEpsilon)
  255. {
  256. m_blendSourcePose.InitFromBindPose(m_actorInstance);
  257. m_blendTargetPose.InitFromBindPose(m_actorInstance);
  258. if (m_motionInstance)
  259. {
  260. SamplePose(m_motionInstance, m_blendTargetPose);
  261. }
  262. if (m_prevMotionInstance)
  263. {
  264. SamplePose(m_prevMotionInstance, m_blendSourcePose);
  265. }
  266. outputPose = m_blendSourcePose;
  267. outputPose.Blend(&m_blendTargetPose, m_blendWeight);
  268. }
  269. else
  270. {
  271. m_blendSourcePose.InitFromBindPose(m_actorInstance);
  272. if (m_prevMotionInstance)
  273. {
  274. SamplePose(m_prevMotionInstance, m_blendSourcePose);
  275. }
  276. outputPose = m_blendSourcePose;
  277. }
  278. }
  279. void MotionMatchingInstance::Update(float timePassedInSeconds,
  280. const AZ::Vector3& targetPos,
  281. const AZ::Vector3& targetFacingDir,
  282. bool useTargetFacingDir,
  283. TrajectoryQuery::EMode mode,
  284. float pathRadius,
  285. float pathSpeed)
  286. {
  287. AZ_PROFILE_SCOPE(Animation, "MotionMatchingInstance::Update");
  288. if (!m_data || !m_motionInstance)
  289. {
  290. return;
  291. }
  292. size_t currentFrameIndex = GetLowestCostFrameIndex();
  293. if (currentFrameIndex == InvalidIndex)
  294. {
  295. currentFrameIndex = 0;
  296. }
  297. // Add the sample from the last frame (post-motion extraction)
  298. m_trajectoryHistory.AddSample(*m_actorInstance->GetTransformData()->GetCurrentPose());
  299. // Update the time. After this there is no sample for the updated time in the history as we're about to prepare this with the current update.
  300. m_trajectoryHistory.Update(timePassedInSeconds);
  301. // Update the trajectory query control points.
  302. m_trajectoryQuery.Update(*m_actorInstance,
  303. m_cachedTrajectoryFeature,
  304. m_trajectoryHistory,
  305. mode,
  306. targetPos,
  307. targetFacingDir,
  308. useTargetFacingDir,
  309. timePassedInSeconds,
  310. pathRadius,
  311. pathSpeed);
  312. // Calculate the new time value of the motion, but don't set it yet (the syncing might adjust this again)
  313. m_motionInstance->SetFreezeAtLastFrame(true);
  314. m_motionInstance->SetMaxLoops(1);
  315. const float newMotionTime = m_motionInstance->CalcPlayStateAfterUpdate(timePassedInSeconds).m_currentTime;
  316. m_newMotionTime = newMotionTime;
  317. // Keep on playing the previous instance as we're blending the poses and motion extraction deltas.
  318. m_prevMotionInstance->Update(timePassedInSeconds);
  319. m_timeSinceLastFrameSwitch += timePassedInSeconds;
  320. const float lowestCostSearchTimeInterval = 1.0f / m_lowestCostSearchFrequency;
  321. if (m_blending)
  322. {
  323. const float maxBlendTime = lowestCostSearchTimeInterval;
  324. m_blendProgressTime += timePassedInSeconds;
  325. if (m_blendProgressTime > maxBlendTime)
  326. {
  327. m_blendWeight = 1.0f;
  328. m_blendProgressTime = maxBlendTime;
  329. m_blending = false;
  330. }
  331. else
  332. {
  333. m_blendWeight = AZ::GetClamp(m_blendProgressTime / maxBlendTime, 0.0f, 1.0f);
  334. }
  335. }
  336. const bool searchLowestCostFrame = m_timeSinceLastFrameSwitch >= lowestCostSearchTimeInterval;
  337. if (searchLowestCostFrame)
  338. {
  339. // Calculate the input query pose for the motion matching search algorithm.
  340. {
  341. AZ_PROFILE_SCOPE(Animation, "MM::EvaluateQueryPose");
  342. // Sample the pose for the new motion time as the motion instance has not been updated with the timeDelta from this frame yet.
  343. SamplePose(m_motionInstance->GetMotion(), m_queryPose, newMotionTime);
  344. // Copy over the motion extraction joint transform from the current pose to the newly sampled pose.
  345. // When sampling a motion, the motion extraction joint is in animation space, while we need the query pose to be in world space.
  346. // Note: This does not yet take the extraction delta from the current tick into account.
  347. if (m_actorInstance->GetActor()->GetMotionExtractionNode())
  348. {
  349. const Pose* currentPose = m_actorInstance->GetTransformData()->GetCurrentPose();
  350. const size_t motionExtractionJointIndex = m_actorInstance->GetActor()->GetMotionExtractionNodeIndex();
  351. m_queryPose.SetWorldSpaceTransform(motionExtractionJointIndex,
  352. currentPose->GetWorldSpaceTransform(motionExtractionJointIndex));
  353. }
  354. // Calculate the joint velocities for the sampled pose using the same method as we do for the frame database.
  355. PoseDataJointVelocities* velocityPoseData = m_queryPose.GetAndPreparePoseData<PoseDataJointVelocities>(m_actorInstance);
  356. AnimGraphPosePool& posePool = GetEMotionFX().GetThreadData(m_actorInstance->GetThreadIndex())->GetPosePool();
  357. velocityPoseData->CalculateVelocity(m_actorInstance, posePool, m_motionInstance->GetMotion(), newMotionTime, m_cachedTrajectoryFeature->GetRelativeToNodeIndex());
  358. }
  359. const FeatureMatrix& featureMatrix = m_data->GetFeatureMatrix();
  360. const FrameDatabase& frameDatabase = m_data->GetFrameDatabase();
  361. Feature::QueryVectorContext queryVectorContext(m_queryPose, m_trajectoryQuery);
  362. queryVectorContext.m_featureTransformer = m_data->GetFeatureTransformer();
  363. Feature::FrameCostContext frameCostContext(m_queryVector, featureMatrix);
  364. const size_t lowestCostFrameIndex = FindLowestCostFrameIndex(queryVectorContext, frameCostContext);
  365. const Frame& currentFrame = frameDatabase.GetFrame(currentFrameIndex);
  366. const Frame& lowestCostFrame = frameDatabase.GetFrame(lowestCostFrameIndex);
  367. const bool sameMotion = (currentFrame.GetSourceMotion() == lowestCostFrame.GetSourceMotion());
  368. const float timeBetweenFrames = newMotionTime - lowestCostFrame.GetSampleTime();
  369. const bool sameLocation = sameMotion && (AZ::GetAbs(timeBetweenFrames) < 0.1f);
  370. if (lowestCostFrameIndex != currentFrameIndex && !sameLocation)
  371. {
  372. // Start a blend.
  373. m_blending = true;
  374. m_blendWeight = 0.0f;
  375. m_blendProgressTime = 0.0f;
  376. // Store the current motion instance state, so we can sample this as source pose.
  377. m_prevMotionInstance->SetMotion(m_motionInstance->GetMotion());
  378. m_prevMotionInstance->SetMirrorMotion(m_motionInstance->GetMirrorMotion());
  379. m_prevMotionInstance->SetCurrentTime(newMotionTime);
  380. m_prevMotionInstance->SetLastCurrentTime(m_prevMotionInstance->GetCurrentTime() - timePassedInSeconds);
  381. m_lowestCostFrameIndex = lowestCostFrameIndex;
  382. m_motionInstance->SetMotion(lowestCostFrame.GetSourceMotion());
  383. m_motionInstance->SetMirrorMotion(lowestCostFrame.GetMirrored());
  384. // The new motion time will become the current time after this frame while the current time
  385. // becomes the last current time. As we just start playing at the search frame, calculate
  386. // the last time based on the time delta.
  387. m_newMotionTime = lowestCostFrame.GetSampleTime();
  388. m_motionInstance->SetCurrentTime(m_newMotionTime - timePassedInSeconds);
  389. }
  390. // Do this always, else wise we search for the lowest cost frame index too many times.
  391. m_timeSinceLastFrameSwitch = 0.0f;
  392. }
  393. // ImGui monitor
  394. {
  395. #ifdef IMGUI_ENABLED
  396. const FrameDatabase& frameDatabase = m_data->GetFrameDatabase();
  397. ImGuiMonitorRequests::FrameDatabaseInfo frameDatabaseInfo{frameDatabase.CalcMemoryUsageInBytes(), frameDatabase.GetNumFrames(), frameDatabase.GetNumUsedMotions(), frameDatabase.GetNumFrames() / (float)frameDatabase.GetSampleRate()};
  398. ImGuiMonitorRequestBus::Broadcast(&ImGuiMonitorRequests::SetFrameDatabaseInfo, frameDatabaseInfo);
  399. const KdTree& kdTree = m_data->GetKdTree();
  400. ImGuiMonitorRequests::KdTreeInfo kdTreeInfo{kdTree.CalcMemoryUsageInBytes(), kdTree.GetNumNodes(), kdTree.GetNumDimensions()};
  401. ImGuiMonitorRequestBus::Broadcast(&ImGuiMonitorRequests::SetKdTreeInfo, kdTreeInfo);
  402. const FeatureMatrix& featureMatrix = m_data->GetFeatureMatrix();
  403. ImGuiMonitorRequests::FeatureMatrixInfo featureMatrixInfo{featureMatrix.CalcMemoryUsageInBytes(), static_cast<size_t>(featureMatrix.rows()), static_cast<size_t>(featureMatrix.cols())};
  404. ImGuiMonitorRequestBus::Broadcast(&ImGuiMonitorRequests::SetFeatureMatrixInfo, featureMatrixInfo);
  405. #endif
  406. }
  407. }
  408. size_t MotionMatchingInstance::FindLowestCostFrameIndex(const Feature::QueryVectorContext& queryVectorContext, const Feature::FrameCostContext& frameCostContext)
  409. {
  410. AZ::Debug::Timer timer;
  411. timer.Stamp();
  412. AZ_PROFILE_SCOPE(Animation, "MotionMatchingInstance::FindLowestCostFrameIndex");
  413. const FrameDatabase& frameDatabase = m_data->GetFrameDatabase();
  414. const FeatureSchema& featureSchema = m_data->GetFeatureSchema();
  415. const FeatureTrajectory* trajectoryFeature = m_cachedTrajectoryFeature;
  416. // 1. Build query vector
  417. {
  418. AZ_PROFILE_SCOPE(Animation, "MM::BuildQueryVector");
  419. // Build the input query features that will be compared to every entry in the feature database in the motion matching search.
  420. AZ_Assert(m_queryVector.GetSize() == aznumeric_cast<size_t>(m_data->GetFeatureMatrix().cols()),
  421. "The query vector should have the same number of elements as the feature matrix has columns.");
  422. for (Feature* feature : featureSchema.GetFeatures())
  423. {
  424. feature->FillQueryVector(m_queryVector, queryVectorContext);
  425. }
  426. if (FeatureMatrixTransformer* transformer = queryVectorContext.m_featureTransformer)
  427. {
  428. transformer->Transform(m_queryVector.GetData());
  429. }
  430. }
  431. // 2. Broad-phase search using KD-tree
  432. if (mm_useKdTree)
  433. {
  434. AZ_PROFILE_SCOPE(Animation, "MM::BroadPhaseKDTree");
  435. AZStd::vector<float>& kdTreeQueryVector = m_kdTreeQueryVector.GetData();
  436. const AZStd::vector<float>& queryVectorData = m_queryVector.GetData();
  437. // For the kd-tree
  438. size_t startOffset = 0;
  439. for (Feature* feature : m_data->GetFeaturesInKdTree())
  440. {
  441. memcpy(&kdTreeQueryVector[startOffset], & queryVectorData[feature->GetColumnOffset()], feature->GetNumDimensions() * sizeof(float));
  442. startOffset += feature->GetNumDimensions();
  443. }
  444. AZ_Assert(startOffset == kdTreeQueryVector.size(), "Frame float vector is not the expected size.");
  445. // Find our nearest frames.
  446. m_data->GetKdTree().FindNearestNeighbors(kdTreeQueryVector, m_nearestFrames);
  447. }
  448. // 2. Narrow-phase, brute force find the actual best matching frame (frame with the minimal cost).
  449. float minCost = FLT_MAX;
  450. size_t minCostFrameIndex = 0;
  451. m_tempCosts.resize(featureSchema.GetNumFeatures());
  452. m_minCosts.resize(featureSchema.GetNumFeatures());
  453. float minTrajectoryPastCost = 0.0f;
  454. float minTrajectoryFutureCost = 0.0f;
  455. // Iterate through the frames filtered by the broad-phase search.
  456. const size_t numFrames = mm_useKdTree ? m_nearestFrames.size() : frameDatabase.GetNumFrames();
  457. for (size_t i = 0; i < numFrames; ++i)
  458. {
  459. const size_t frameIndex = mm_useKdTree ? m_nearestFrames[i] : i;
  460. const Frame& frame = frameDatabase.GetFrame(frameIndex);
  461. // TODO: This shouldn't be there, we should be discarding the frames when extracting the features and not at runtime when checking the cost.
  462. if (frame.GetSampleTime() >= frame.GetSourceMotion()->GetDuration() - 1.0f)
  463. {
  464. continue;
  465. }
  466. float frameCost = 0.0f;
  467. // Calculate the frame cost by accumulating the weighted feature costs.
  468. for (size_t featureIndex = 0; featureIndex < featureSchema.GetNumFeatures(); ++featureIndex)
  469. {
  470. Feature* feature = featureSchema.GetFeature(featureIndex);
  471. if (feature->RTTI_GetType() != azrtti_typeid<FeatureTrajectory>())
  472. {
  473. const float featureCost = feature->CalculateFrameCost(frameIndex, frameCostContext);
  474. const float featureCostFactor = feature->GetCostFactor();
  475. const float featureFinalCost = featureCost * featureCostFactor;
  476. frameCost += featureFinalCost;
  477. m_tempCosts[featureIndex] = featureFinalCost;
  478. }
  479. }
  480. // Manually add the trajectory cost.
  481. float trajectoryPastCost = 0.0f;
  482. float trajectoryFutureCost = 0.0f;
  483. if (trajectoryFeature)
  484. {
  485. trajectoryPastCost = trajectoryFeature->CalculatePastFrameCost(frameIndex, frameCostContext) * trajectoryFeature->GetPastCostFactor();
  486. trajectoryFutureCost = trajectoryFeature->CalculateFutureFrameCost(frameIndex, frameCostContext) * trajectoryFeature->GetFutureCostFactor();
  487. frameCost += trajectoryPastCost;
  488. frameCost += trajectoryFutureCost;
  489. }
  490. // Track the minimum feature and frame costs.
  491. if (frameCost < minCost)
  492. {
  493. minCost = frameCost;
  494. minCostFrameIndex = frameIndex;
  495. for (size_t featureIndex = 0; featureIndex < featureSchema.GetNumFeatures(); ++featureIndex)
  496. {
  497. Feature* feature = featureSchema.GetFeature(featureIndex);
  498. if (feature->RTTI_GetType() != azrtti_typeid<FeatureTrajectory>())
  499. {
  500. m_minCosts[featureIndex] = m_tempCosts[featureIndex];
  501. }
  502. }
  503. minTrajectoryPastCost = trajectoryPastCost;
  504. minTrajectoryFutureCost = trajectoryFutureCost;
  505. }
  506. }
  507. // 3. ImGui debug visualization
  508. {
  509. const float time = timer.GetDeltaTimeInSeconds();
  510. ImGuiMonitorRequestBus::Broadcast(&ImGuiMonitorRequests::PushPerformanceHistogramValue, "FindLowestCostFrameIndex", time * 1000.0f);
  511. for (size_t featureIndex = 0; featureIndex < featureSchema.GetNumFeatures(); ++featureIndex)
  512. {
  513. Feature* feature = featureSchema.GetFeature(featureIndex);
  514. if (feature->RTTI_GetType() != azrtti_typeid<FeatureTrajectory>())
  515. {
  516. ImGuiMonitorRequestBus::Broadcast(&ImGuiMonitorRequests::PushCostHistogramValue,
  517. feature->GetName().c_str(),
  518. m_minCosts[featureIndex],
  519. feature->GetDebugDrawColor());
  520. }
  521. }
  522. if (trajectoryFeature)
  523. {
  524. ImGuiMonitorRequestBus::Broadcast(&ImGuiMonitorRequests::PushCostHistogramValue, "Future Trajectory", minTrajectoryFutureCost, trajectoryFeature->GetDebugDrawColor());
  525. ImGuiMonitorRequestBus::Broadcast(&ImGuiMonitorRequests::PushCostHistogramValue, "Past Trajectory", minTrajectoryPastCost, trajectoryFeature->GetDebugDrawColor());
  526. }
  527. ImGuiMonitorRequestBus::Broadcast(&ImGuiMonitorRequests::PushCostHistogramValue, "Total Cost", minCost, AZ::Color::CreateFromRgba(202,255,191,255));
  528. }
  529. return minCostFrameIndex;
  530. }
  531. } // namespace EMotionFX::MotionMatching