common.h 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. // ======================================================================== //
  2. // Copyright 2009-2019 Intel Corporation //
  3. // //
  4. // Licensed under the Apache License, Version 2.0 (the "License"); //
  5. // you may not use this file except in compliance with the License. //
  6. // You may obtain a copy of the License at //
  7. // //
  8. // http://www.apache.org/licenses/LICENSE-2.0 //
  9. // //
  10. // Unless required by applicable law or agreed to in writing, software //
  11. // distributed under the License is distributed on an "AS IS" BASIS, //
  12. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
  13. // See the License for the specific language governing permissions and //
  14. // limitations under the License. //
  15. // ======================================================================== //
  16. #pragma once
  17. #include "common/platform.h"
  18. #include "mkl-dnn/include/mkldnn.hpp"
  19. #include "mkl-dnn/include/mkldnn_debug.h"
  20. #include "mkl-dnn/src/common/mkldnn_thread.hpp"
  21. #include "mkl-dnn/src/common/type_helpers.hpp"
  22. #include "mkl-dnn/src/cpu/jit_generator.hpp"
  23. #include "common/ref.h"
  24. #include "common/exception.h"
  25. #include "common/thread.h"
  26. // -- GODOT start --
  27. //#include "common/tasking.h"
  28. // -- GODOT end --
  29. #include "math.h"
  30. namespace oidn {
  31. using namespace mkldnn;
  32. using namespace mkldnn::impl::cpu;
  33. using mkldnn::impl::parallel_nd;
  34. using mkldnn::impl::memory_desc_matches_tag;
  35. inline size_t getFormatBytes(Format format)
  36. {
  37. switch (format)
  38. {
  39. case Format::Undefined: return 1;
  40. case Format::Float: return sizeof(float);
  41. case Format::Float2: return sizeof(float)*2;
  42. case Format::Float3: return sizeof(float)*3;
  43. case Format::Float4: return sizeof(float)*4;
  44. }
  45. assert(0);
  46. return 0;
  47. }
  48. inline memory::dims getTensorDims(const std::shared_ptr<memory>& mem)
  49. {
  50. const mkldnn_memory_desc_t& desc = mem->get_desc().data;
  51. return memory::dims(&desc.dims[0], &desc.dims[desc.ndims]);
  52. }
  53. inline memory::data_type getTensorType(const std::shared_ptr<memory>& mem)
  54. {
  55. const mkldnn_memory_desc_t& desc = mem->get_desc().data;
  56. return memory::data_type(desc.data_type);
  57. }
  58. // Returns the number of values in a tensor
  59. inline size_t getTensorSize(const memory::dims& dims)
  60. {
  61. size_t res = 1;
  62. for (int i = 0; i < (int)dims.size(); ++i)
  63. res *= dims[i];
  64. return res;
  65. }
  66. inline memory::dims getMaxTensorDims(const std::vector<memory::dims>& dims)
  67. {
  68. memory::dims result;
  69. size_t maxSize = 0;
  70. for (const auto& d : dims)
  71. {
  72. const size_t size = getTensorSize(d);
  73. if (size > maxSize)
  74. {
  75. result = d;
  76. maxSize = size;
  77. }
  78. }
  79. return result;
  80. }
  81. inline size_t getTensorSize(const std::shared_ptr<memory>& mem)
  82. {
  83. return getTensorSize(getTensorDims(mem));
  84. }
  85. template<int K>
  86. inline int getPadded(int dim)
  87. {
  88. return (dim + (K-1)) & ~(K-1);
  89. }
  90. template<int K>
  91. inline memory::dims getPadded_nchw(const memory::dims& dims)
  92. {
  93. assert(dims.size() == 4);
  94. memory::dims padDims = dims;
  95. padDims[1] = getPadded<K>(dims[1]); // pad C
  96. return padDims;
  97. }
  98. template<int K>
  99. struct BlockedFormat;
  100. template<>
  101. struct BlockedFormat<8>
  102. {
  103. static constexpr memory::format_tag nChwKc = memory::format_tag::nChw8c;
  104. static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw8i8o;
  105. };
  106. template<>
  107. struct BlockedFormat<16>
  108. {
  109. static constexpr memory::format_tag nChwKc = memory::format_tag::nChw16c;
  110. static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw16i16o;
  111. };
  112. } // namespace oidn