KdTree.h 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #pragma once
  9. #include <AzCore/Memory/Memory.h>
  10. #include <AzCore/std/containers/vector.h>
  11. #include <EMotionFX/Source/EMotionFXConfig.h>
  12. #include <Feature.h>
  13. #include <FeatureMatrix.h>
  14. #include <FrameDatabase.h>
  15. namespace EMotionFX::MotionMatching
  16. {
  17. class KdTree
  18. {
  19. public:
  20. AZ_RTTI(KdTree, "{CDA707EC-4150-463B-8157-90D98351ACED}");
  21. AZ_CLASS_ALLOCATOR_DECL;
  22. KdTree() = default;
  23. virtual ~KdTree();
  24. bool Init(const FrameDatabase& frameDatabase,
  25. const FeatureMatrix& featureMatrix,
  26. const AZStd::vector<Feature*>& features,
  27. size_t maxDepth=10,
  28. size_t minFramesPerLeaf=1000);
  29. //! Calculate the number of dimensions or values for the given feature set.
  30. //! Each feature might store one or multiple values inside the feature matrix and the number of
  31. //! values each feature holds varies with the feature type. This calculates the sum of the number of
  32. //! values of the given feature set.
  33. static size_t CalcNumDimensions(const AZStd::vector<Feature*>& features);
  34. void Clear();
  35. void PrintStats();
  36. size_t GetNumNodes() const;
  37. size_t GetNumDimensions() const;
  38. size_t CalcMemoryUsageInBytes() const;
  39. bool IsInitialized() const;
  40. void FindNearestNeighbors(const AZStd::vector<float>& frameFloats, AZStd::vector<size_t>& resultFrameIndices) const;
  41. private:
  42. struct Node
  43. {
  44. AZ_RTTI(KdTree::Node, "{8A7944B3-86F1-4A33-84BC-A3B6D599E0C9}");
  45. AZ_CLASS_ALLOCATOR_DECL;
  46. virtual ~Node() = default;
  47. Node* m_leftNode = nullptr;
  48. Node* m_rightNode = nullptr;
  49. Node* m_parent = nullptr;
  50. float m_median = 0.0f;
  51. size_t m_dimension = 0;
  52. AZStd::vector<size_t> m_frames;
  53. };
  54. void BuildTreeNodes(const FrameDatabase& frameDatabase,
  55. const FeatureMatrix& featureMatrix,
  56. const AZStd::vector<size_t>& localToSchemaFeatureColumns,
  57. Node* node,
  58. Node* parent,
  59. size_t dimension = 0,
  60. bool leftSide = true);
  61. void FillFramesForNode(Node* node,
  62. const FrameDatabase& frameDatabase,
  63. const FeatureMatrix& featureMatrix,
  64. const AZStd::vector<size_t>& localToSchemaFeatureColumns,
  65. AZStd::vector<float>& frameFeatureValues,
  66. Node* parent,
  67. bool leftSide);
  68. void RecursiveCalcNumFrames(Node* node, size_t& outNumFrames) const;
  69. void ClearFramesForNonEssentialNodes();
  70. void MergeSmallLeafNodesToParents();
  71. void RemoveZeroFrameLeafNodes();
  72. void RemoveLeafNode(Node* node);
  73. void FindNearestNeighbors(Node* node, const AZStd::vector<float>& frameFloats, AZStd::vector<size_t>& resultFrameIndices) const;
  74. AZStd::vector<size_t> CalcLocalToSchemaFeatureColumns(const AZStd::vector<Feature*>& features) const;
  75. private:
  76. AZStd::vector<Node*> m_nodes;
  77. size_t m_numDimensions = 0;
  78. size_t m_maxDepth = 20;
  79. size_t m_minFramesPerLeaf = 1000;
  80. };
  81. } // namespace EMotionFX::MotionMatching