output_reorder.h 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. // ======================================================================== //
  2. // Copyright 2009-2019 Intel Corporation //
  3. // //
  4. // Licensed under the Apache License, Version 2.0 (the "License"); //
  5. // you may not use this file except in compliance with the License. //
  6. // You may obtain a copy of the License at //
  7. // //
  8. // http://www.apache.org/licenses/LICENSE-2.0 //
  9. // //
  10. // Unless required by applicable law or agreed to in writing, software //
  11. // distributed under the License is distributed on an "AS IS" BASIS, //
  12. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
  13. // See the License for the specific language governing permissions and //
  14. // limitations under the License. //
  15. // ======================================================================== //
  16. #pragma once
  17. #include "node.h"
  18. #include "image.h"
  19. namespace oidn {
  20. // Output reorder node
  21. template<int K, class TransferFunction>
  22. class OutputReorderNode : public Node
  23. {
  24. private:
  25. // Source
  26. std::shared_ptr<memory> src;
  27. const float* srcPtr;
  28. int H1;
  29. int W1;
  30. // Destination
  31. Image output;
  32. // Tile
  33. int h1Begin;
  34. int w1Begin;
  35. int h2Begin;
  36. int w2Begin;
  37. int H;
  38. int W;
  39. std::shared_ptr<TransferFunction> transferFunc;
  40. public:
  41. OutputReorderNode(const std::shared_ptr<memory>& src,
  42. const Image& output,
  43. const std::shared_ptr<TransferFunction>& transferFunc)
  44. : src(src),
  45. output(output),
  46. h1Begin(0), w1Begin(0),
  47. h2Begin(0), w2Begin(0),
  48. H(output.height), W(output.width),
  49. transferFunc(transferFunc)
  50. {
  51. const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
  52. MAYBE_UNUSED(srcDesc);
  53. assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
  54. assert(srcDesc.ndims == 4);
  55. assert(srcDesc.data_type == memory::data_type::f32);
  56. assert(srcDesc.dims[0] == 1);
  57. // We assume output data is <= K OC
  58. assert(srcDesc.dims[1] == K);
  59. srcPtr = (float*)src->get_data_handle();
  60. H1 = srcDesc.dims[2];
  61. W1 = srcDesc.dims[3];
  62. }
  63. void setTile(int h1, int w1, int h2, int w2, int H, int W) override
  64. {
  65. h1Begin = h1;
  66. w1Begin = w1;
  67. h2Begin = h2;
  68. w2Begin = w2;
  69. this->H = H;
  70. this->W = W;
  71. }
  72. void execute(stream& sm) override
  73. {
  74. assert(h1Begin + H <= H1);
  75. assert(w1Begin + W <= W1);
  76. assert(h2Begin + H <= output.height);
  77. assert(w2Begin + W <= output.width);
  78. const int C1 = K;
  79. parallel_nd(H, [&](int h)
  80. {
  81. const int h1 = h + h1Begin;
  82. const int h2 = h + h2Begin;
  83. for (int w = 0; w < W; ++w)
  84. {
  85. const int w1 = w + w1Begin;
  86. const int w2 = w + w2Begin;
  87. float* dstPtr_C = (float*)output.get(h2, w2);
  88. // Source is in nChwKc format. In this case C is 1 so this is really nhwc
  89. const float* srcPtr_C = srcPtr + h1*W1*C1 + w1*C1;
  90. #pragma unroll
  91. for (int i = 0; i < 3; ++i)
  92. {
  93. // Load the value
  94. float x = srcPtr_C[i];
  95. // The CNN output may contain negative values or even NaNs, so it must be sanitized
  96. x = maxSafe(x, 0.f);
  97. // Apply the inverse transfer function
  98. x = transferFunc->inverse(x);
  99. // Sanitize and store the final value
  100. dstPtr_C[i] = max(x, 0.f);
  101. }
  102. }
  103. });
  104. }
  105. };
  106. } // namespace oidn