FeatureTrajectory.h 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #pragma once
  9. #include <AzCore/Math/Vector3.h>
  10. #include <AzCore/Memory/Memory.h>
  11. #include <AzCore/RTTI/RTTI.h>
  12. #include <AzCore/RTTI/TypeInfo.h>
  13. #include <AzCore/std/containers/vector.h>
  14. #include <EMotionFX/Source/EMotionFXConfig.h>
  15. #include <EMotionFX/Source/Transform.h>
  16. #include <MotionMatchingInstance.h>
  17. #include <FeatureTrajectory.h>
  18. #include <Feature.h>
  19. namespace AZ
  20. {
  21. class ReflectContext;
  22. }
  23. namespace EMotionFX::MotionMatching
  24. {
  25. class FrameDatabase;
  26. //! Matches the root joint past and future trajectory.
  27. //! For each frame in the motion database, the position and facing direction relative to the current frame of the joint will be evaluated for a past and future time window.
  28. //! The past and future samples together form the trajectory of the current frame within the time window. This basically describes where the character came from to reach the
  29. //! current frame and where it will go when continuing to play the animation.
  30. class EMFX_API FeatureTrajectory
  31. : public Feature
  32. {
  33. public:
  34. AZ_RTTI(FeatureTrajectory, "{0451E95B-A452-439A-81ED-3962A06A3992}", Feature)
  35. AZ_CLASS_ALLOCATOR_DECL
  36. enum class Axis
  37. {
  38. X = 0,
  39. Y = 1,
  40. X_NEGATIVE = 2,
  41. Y_NEGATIVE = 3,
  42. };
  43. struct EMFX_API Sample
  44. {
  45. AZ::Vector2 m_position; //! Position in the space relative to the extracted frame.
  46. AZ::Vector2 m_facingDirection; //! Facing direction in the space relative to the extracted frame.
  47. static constexpr size_t s_componentsPerSample = 4;
  48. };
  49. FeatureTrajectory() = default;
  50. ~FeatureTrajectory() override = default;
  51. bool Init(const InitSettings& settings) override;
  52. void ExtractFeatureValues(const ExtractFeatureContext& context) override;
  53. void FillQueryVector(QueryVector& queryVector, const QueryVectorContext& context) override;
  54. float CalculateFutureFrameCost(size_t frameIndex, const FrameCostContext& context) const;
  55. float CalculatePastFrameCost(size_t frameIndex, const FrameCostContext& context) const;
  56. void DebugDraw(AzFramework::DebugDisplayRequests& debugDisplay,
  57. const Pose& currentPose,
  58. const FeatureMatrix& featureMatrix,
  59. const FeatureMatrixTransformer* featureTransformer,
  60. size_t frameIndex) override;
  61. void SetNumPastSamplesPerFrame(size_t numHistorySamples);
  62. void SetNumFutureSamplesPerFrame(size_t numFutureSamples);
  63. void SetPastTimeRange(float timeInSeconds);
  64. void SetFutureTimeRange(float timeInSeconds);
  65. void SetFacingAxis(const Axis axis);
  66. void UpdateFacingAxis();
  67. float GetPastTimeRange() const { return m_pastTimeRange; }
  68. size_t GetNumPastSamples() const { return m_numPastSamples; }
  69. float GetPastCostFactor() const { return m_pastCostFactor; }
  70. float GetFutureTimeRange() const { return m_futureTimeRange; }
  71. size_t GetNumFutureSamples() const { return m_numFutureSamples; }
  72. float GetFutureCostFactor() const { return m_futureCostFactor; }
  73. AZ::Vector2 CalculateFacingDirection(const Pose& pose, const Transform& invRootTransform) const;
  74. AZ::Vector3 GetFacingAxisDir() const { return m_facingAxisDir; }
  75. static void Reflect(AZ::ReflectContext* context);
  76. size_t GetNumDimensions() const override;
  77. AZStd::string GetDimensionName(size_t index) const override;
  78. // Shared helper function to draw a facing direction.
  79. static void DebugDrawFacingDirection(AzFramework::DebugDisplayRequests& debugDisplay,
  80. const AZ::Vector3& positionWorldSpace,
  81. const AZ::Vector3& facingDirectionWorldSpace);
  82. private:
  83. size_t CalcMidFrameIndex() const;
  84. size_t CalcPastFrameIndex(size_t historyFrameIndex) const;
  85. size_t CalcFutureFrameIndex(size_t futureFrameIndex) const;
  86. size_t CalcNumSamplesPerFrame() const;
  87. using SplineToFeatureMatrixIndex = AZStd::function<size_t(size_t)>;
  88. float CalculateCost(const FeatureMatrix& featureMatrix,
  89. size_t frameIndex,
  90. size_t numControlPoints,
  91. const SplineToFeatureMatrixIndex& splineToFeatureMatrixIndex,
  92. const FrameCostContext& context) const;
  93. //! Called for every sample in the past or future range to extract its information.
  94. //! @param[in] pose The sampled pose within the trajectory range [m_pastTimeRange, m_futureTimeRange].
  95. //! @param[in] invRootTransform The inverse of the world space transform of the joint at frame time that the feature is extracted for.
  96. Sample GetSampleFromPose(const Pose& pose, const Transform& invRootTransform) const;
  97. Sample GetFeatureData(const FeatureMatrix& featureMatrix, size_t frameIndex, size_t sampleIndex) const;
  98. void SetFeatureData(FeatureMatrix& featureMatrix, size_t frameIndex, size_t sampleIndex, const Sample& sample);
  99. Sample GetFeatureDataInverseTransformed(const FeatureMatrix& featureMatrix,
  100. const FeatureMatrixTransformer* featureTransformer,
  101. size_t frameIndex,
  102. size_t sampleIndex) const;
  103. void DebugDrawTrajectory(AzFramework::DebugDisplayRequests& debugDisplay,
  104. const FeatureMatrix& featureMatrix,
  105. const FeatureMatrixTransformer* featureTransformer,
  106. size_t frameIndex,
  107. const Transform& transform,
  108. const AZ::Color& color,
  109. size_t numSamples,
  110. const SplineToFeatureMatrixIndex& splineToFeatureMatrixIndex) const;
  111. void DebugDrawFacingDirection(AzFramework::DebugDisplayRequests& debugDisplay,
  112. const Transform& worldSpaceTransform,
  113. const Sample& sample,
  114. const AZ::Vector3& samplePosWorldSpace) const;
  115. AZ::Crc32 GetCostFactorVisibility() const override;
  116. float m_pastTimeRange = 0.7f; //< The time window the samples are distributed along for the past trajectory.
  117. size_t m_numPastSamples = 4; //< The number of samples stored per frame for the past (history) trajectory.
  118. float m_pastCostFactor = 0.5f; //< Normalized value to weight or scale the future trajectory cost.
  119. float m_futureTimeRange = 1.2f; //< The time window the samples are distributed along for the future trajectory.
  120. size_t m_numFutureSamples = 6; //< The number of samples stored per frame for the future trajectory.
  121. float m_futureCostFactor = 0.75f; //< Normalized value to weight or scale the future trajectory cost.
  122. Axis m_facingAxis = Axis::Y; //< Which axis of the joint transform is facing forward?
  123. AZ::Vector3 m_facingAxisDir = AZ::Vector3::CreateAxisY();
  124. };
  125. } // namespace EMotionFX::MotionMatching