input_reorder.h 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. // ======================================================================== //
  2. // Copyright 2009-2019 Intel Corporation //
  3. // //
  4. // Licensed under the Apache License, Version 2.0 (the "License"); //
  5. // you may not use this file except in compliance with the License. //
  6. // You may obtain a copy of the License at //
  7. // //
  8. // http://www.apache.org/licenses/LICENSE-2.0 //
  9. // //
  10. // Unless required by applicable law or agreed to in writing, software //
  11. // distributed under the License is distributed on an "AS IS" BASIS, //
  12. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
  13. // See the License for the specific language governing permissions and //
  14. // limitations under the License. //
  15. // ======================================================================== //
  16. #pragma once
  17. #include "node.h"
  18. #include "image.h"
  19. namespace oidn {
  20. // Input reorder node
  21. template<int K, class TransferFunction>
  22. class InputReorderNode : public Node
  23. {
  24. private:
  25. // Source
  26. Image color;
  27. Image albedo;
  28. Image normal;
  29. // Destination
  30. std::shared_ptr<memory> dst;
  31. float* dstPtr;
  32. int C2;
  33. int H2;
  34. int W2;
  35. // Tile
  36. int h1Begin;
  37. int w1Begin;
  38. int h2Begin;
  39. int w2Begin;
  40. int H;
  41. int W;
  42. std::shared_ptr<TransferFunction> transferFunc;
  43. public:
  44. InputReorderNode(const Image& color,
  45. const Image& albedo,
  46. const Image& normal,
  47. const std::shared_ptr<memory>& dst,
  48. const std::shared_ptr<TransferFunction>& transferFunc)
  49. : color(color), albedo(albedo), normal(normal),
  50. dst(dst),
  51. h1Begin(0), w1Begin(0),
  52. H(color.height), W(color.width),
  53. transferFunc(transferFunc)
  54. {
  55. const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data;
  56. assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
  57. assert(dstDesc.ndims == 4);
  58. assert(dstDesc.data_type == memory::data_type::f32);
  59. assert(dstDesc.dims[0] == 1);
  60. //assert(dstDesc.dims[1] >= getPadded<K>(C1));
  61. dstPtr = (float*)dst->get_data_handle();
  62. C2 = dstDesc.dims[1];
  63. H2 = dstDesc.dims[2];
  64. W2 = dstDesc.dims[3];
  65. }
  66. void setTile(int h1, int w1, int h2, int w2, int H, int W) override
  67. {
  68. h1Begin = h1;
  69. w1Begin = w1;
  70. h2Begin = h2;
  71. w2Begin = w2;
  72. this->H = H;
  73. this->W = W;
  74. }
  75. void execute(stream& sm) override
  76. {
  77. assert(H + h1Begin <= color.height);
  78. assert(W + w1Begin <= color.width);
  79. assert(H + h2Begin <= H2);
  80. assert(W + w2Begin <= W2);
  81. parallel_nd(H2, [&](int h2)
  82. {
  83. const int h = h2 - h2Begin;
  84. if (h >= 0 && h < H)
  85. {
  86. const int h1 = h + h1Begin;
  87. // Zero pad
  88. for (int w2 = 0; w2 < w2Begin; ++w2)
  89. {
  90. int c = 0;
  91. while (c < C2)
  92. store(h2, w2, c, 0.f);
  93. }
  94. // Reorder
  95. for (int w = 0; w < W; ++w)
  96. {
  97. const int w1 = w + w1Begin;
  98. const int w2 = w + w2Begin;
  99. int c = 0;
  100. storeColor(h2, w2, c, (float*)color.get(h1, w1));
  101. if (albedo)
  102. storeAlbedo(h2, w2, c, (float*)albedo.get(h1, w1));
  103. if (normal)
  104. storeNormal(h2, w2, c, (float*)normal.get(h1, w1));
  105. while (c < C2)
  106. store(h2, w2, c, 0.f);
  107. }
  108. // Zero pad
  109. for (int w2 = W + w2Begin; w2 < W2; ++w2)
  110. {
  111. int c = 0;
  112. while (c < C2)
  113. store(h2, w2, c, 0.f);
  114. }
  115. }
  116. else
  117. {
  118. // Zero pad
  119. for (int w2 = 0; w2 < W2; ++w2)
  120. {
  121. int c = 0;
  122. while (c < C2)
  123. store(h2, w2, c, 0.f);
  124. }
  125. }
  126. });
  127. }
  128. std::shared_ptr<memory> getDst() const override { return dst; }
  129. private:
  130. // Stores a single value
  131. __forceinline void store(int h, int w, int& c, float value)
  132. {
  133. // Destination is in nChwKc format
  134. float* dst_c = dstPtr + (H2*W2*K*(c/K)) + h*W2*K + w*K + (c%K);
  135. *dst_c = value;
  136. c++;
  137. }
  138. // Stores a color
  139. __forceinline void storeColor(int h, int w, int& c, const float* values)
  140. {
  141. #pragma unroll
  142. for (int i = 0; i < 3; ++i)
  143. {
  144. // Load the value
  145. float x = values[i];
  146. // Sanitize the value
  147. x = maxSafe(x, 0.f);
  148. // Apply the transfer function
  149. x = transferFunc->forward(x);
  150. // Store the value
  151. store(h, w, c, x);
  152. }
  153. }
  154. // Stores an albedo
  155. __forceinline void storeAlbedo(int h, int w, int& c, const float* values)
  156. {
  157. #pragma unroll
  158. for (int i = 0; i < 3; ++i)
  159. {
  160. // Load the value
  161. float x = values[i];
  162. // Sanitize the value
  163. x = clampSafe(x, 0.f, 1.f);
  164. // Store the value
  165. store(h, w, c, x);
  166. }
  167. }
  168. // Stores a normal
  169. __forceinline void storeNormal(int h, int w, int& c, const float* values)
  170. {
  171. // Load the normal
  172. float x = values[0];
  173. float y = values[1];
  174. float z = values[2];
  175. // Compute the length of the normal
  176. const float lengthSqr = sqr(x) + sqr(y) + sqr(z);
  177. // Normalize the normal and transform it to [0..1]
  178. if (isfinite(lengthSqr))
  179. {
  180. const float invLength = (lengthSqr > minVectorLengthSqr) ? rsqrt(lengthSqr) : 1.f;
  181. const float scale = invLength * 0.5f;
  182. const float offset = 0.5f;
  183. x = x * scale + offset;
  184. y = y * scale + offset;
  185. z = z * scale + offset;
  186. }
  187. else
  188. {
  189. x = 0.f;
  190. y = 0.f;
  191. z = 0.f;
  192. }
  193. // Store the normal
  194. store(h, w, c, x);
  195. store(h, w, c, y);
  196. store(h, w, c, z);
  197. }
  198. };
  199. } // namespace oidn