autoencoder.cpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. // ======================================================================== //
  2. // Copyright 2009-2019 Intel Corporation //
  3. // //
  4. // Licensed under the Apache License, Version 2.0 (the "License"); //
  5. // you may not use this file except in compliance with the License. //
  6. // You may obtain a copy of the License at //
  7. // //
  8. // http://www.apache.org/licenses/LICENSE-2.0 //
  9. // //
  10. // Unless required by applicable law or agreed to in writing, software //
  11. // distributed under the License is distributed on an "AS IS" BASIS, //
  12. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
  13. // See the License for the specific language governing permissions and //
  14. // limitations under the License. //
  15. // ======================================================================== //
  16. #include "autoencoder.h"
  17. namespace oidn {
  18. // --------------------------------------------------------------------------
  19. // AutoencoderFilter
  20. // --------------------------------------------------------------------------
  21. AutoencoderFilter::AutoencoderFilter(const Ref<Device>& device)
  22. : Filter(device)
  23. {
  24. }
  25. void AutoencoderFilter::setImage(const std::string& name, const Image& data)
  26. {
  27. if (name == "color")
  28. color = data;
  29. else if (name == "albedo")
  30. albedo = data;
  31. else if (name == "normal")
  32. normal = data;
  33. else if (name == "output")
  34. output = data;
  35. dirty = true;
  36. }
  37. void AutoencoderFilter::set1i(const std::string& name, int value)
  38. {
  39. if (name == "hdr")
  40. hdr = value;
  41. else if (name == "srgb")
  42. srgb = value;
  43. else if (name == "maxMemoryMB")
  44. maxMemoryMB = value;
  45. dirty = true;
  46. }
  47. int AutoencoderFilter::get1i(const std::string& name)
  48. {
  49. if (name == "hdr")
  50. return hdr;
  51. else if (name == "srgb")
  52. return srgb;
  53. else if (name == "maxMemoryMB")
  54. return maxMemoryMB;
  55. else if (name == "alignment")
  56. return alignment;
  57. else if (name == "overlap")
  58. return overlap;
  59. else
  60. throw Exception(Error::InvalidArgument, "invalid parameter");
  61. }
  62. void AutoencoderFilter::set1f(const std::string& name, float value)
  63. {
  64. if (name == "hdrScale")
  65. hdrScale = value;
  66. dirty = true;
  67. }
  68. float AutoencoderFilter::get1f(const std::string& name)
  69. {
  70. if (name == "hdrScale")
  71. return hdrScale;
  72. else
  73. throw Exception(Error::InvalidArgument, "invalid parameter");
  74. }
  75. void AutoencoderFilter::commit()
  76. {
  77. if (!dirty)
  78. return;
  79. // -- GODOT start --
  80. //device->executeTask([&]()
  81. //{
  82. // GODOT end --
  83. if (mayiuse(avx512_common))
  84. net = buildNet<16>();
  85. else
  86. net = buildNet<8>();
  87. // GODOT start --
  88. //});
  89. // GODOT end --
  90. dirty = false;
  91. }
  92. void AutoencoderFilter::execute()
  93. {
  94. if (dirty)
  95. throw Exception(Error::InvalidOperation, "changes to the filter are not committed");
  96. if (!net)
  97. return;
  98. // -- GODOT start --
  99. //device->executeTask([&]()
  100. //{
  101. // -- GODOT end --
  102. Progress progress;
  103. progress.func = progressFunc;
  104. progress.userPtr = progressUserPtr;
  105. progress.taskCount = tileCountH * tileCountW;
  106. // Iterate over the tiles
  107. int tileIndex = 0;
  108. for (int i = 0; i < tileCountH; ++i)
  109. {
  110. const int h = i * (tileH - 2*overlap); // input tile position (including overlap)
  111. const int overlapBeginH = i > 0 ? overlap : 0; // overlap on the top
  112. const int overlapEndH = i < tileCountH-1 ? overlap : 0; // overlap on the bottom
  113. const int tileH1 = min(H - h, tileH); // input tile size (including overlap)
  114. const int tileH2 = tileH1 - overlapBeginH - overlapEndH; // output tile size
  115. const int alignOffsetH = tileH - roundUp(tileH1, alignment); // align to the bottom in the tile buffer
  116. for (int j = 0; j < tileCountW; ++j)
  117. {
  118. const int w = j * (tileW - 2*overlap); // input tile position (including overlap)
  119. const int overlapBeginW = j > 0 ? overlap : 0; // overlap on the left
  120. const int overlapEndW = j < tileCountW-1 ? overlap : 0; // overlap on the right
  121. const int tileW1 = min(W - w, tileW); // input tile size (including overlap)
  122. const int tileW2 = tileW1 - overlapBeginW - overlapEndW; // output tile size
  123. const int alignOffsetW = tileW - roundUp(tileW1, alignment); // align to the right in the tile buffer
  124. // Set the input tile
  125. inputReorder->setTile(h, w,
  126. alignOffsetH, alignOffsetW,
  127. tileH1, tileW1);
  128. // Set the output tile
  129. outputReorder->setTile(alignOffsetH + overlapBeginH, alignOffsetW + overlapBeginW,
  130. h + overlapBeginH, w + overlapBeginW,
  131. tileH2, tileW2);
  132. //printf("Tile: %d %d -> %d %d\n", w+overlapBeginW, h+overlapBeginH, w+overlapBeginW+tileW2, h+overlapBeginH+tileH2);
  133. // Denoise the tile
  134. net->execute(progress, tileIndex);
  135. // Next tile
  136. tileIndex++;
  137. }
  138. }
  139. // -- GODOT start --
  140. //});
  141. // -- GODOT end --
  142. }
  143. void AutoencoderFilter::computeTileSize()
  144. {
  145. const int minTileSize = 3*overlap;
  146. const int estimatedBytesPerPixel = mayiuse(avx512_common) ? estimatedBytesPerPixel16 : estimatedBytesPerPixel8;
  147. const int64_t maxTilePixels = (int64_t(maxMemoryMB)*1024*1024 - estimatedBytesBase) / estimatedBytesPerPixel;
  148. tileCountH = 1;
  149. tileCountW = 1;
  150. tileH = roundUp(H, alignment);
  151. tileW = roundUp(W, alignment);
  152. // Divide the image into tiles until the tile size gets below the threshold
  153. while (int64_t(tileH) * tileW > maxTilePixels)
  154. {
  155. if (tileH > minTileSize && tileH > tileW)
  156. {
  157. tileCountH++;
  158. tileH = max(roundUp(ceilDiv(H - 2*overlap, tileCountH), alignment) + 2*overlap, minTileSize);
  159. }
  160. else if (tileW > minTileSize)
  161. {
  162. tileCountW++;
  163. tileW = max(roundUp(ceilDiv(W - 2*overlap, tileCountW), alignment) + 2*overlap, minTileSize);
  164. }
  165. else
  166. break;
  167. }
  168. // Compute the final number of tiles
  169. tileCountH = (H > tileH) ? ceilDiv(H - 2*overlap, tileH - 2*overlap) : 1;
  170. tileCountW = (W > tileW) ? ceilDiv(W - 2*overlap, tileW - 2*overlap) : 1;
  171. if (device->isVerbose(2))
  172. {
  173. std::cout << "Tile size : " << tileW << "x" << tileH << std::endl;
  174. std::cout << "Tile count: " << tileCountW << "x" << tileCountH << std::endl;
  175. }
  176. }
  177. template<int K>
  178. std::shared_ptr<Executable> AutoencoderFilter::buildNet()
  179. {
  180. H = color.height;
  181. W = color.width;
  182. // Configure the network
  183. int inputC;
  184. void* weightPtr;
  185. if (srgb && hdr)
  186. throw Exception(Error::InvalidOperation, "srgb and hdr modes cannot be enabled at the same time");
  187. if (color && !albedo && !normal && weightData.hdr)
  188. {
  189. inputC = 3;
  190. weightPtr = hdr ? weightData.hdr : weightData.ldr;
  191. }
  192. else if (color && albedo && !normal && weightData.hdr_alb)
  193. {
  194. inputC = 6;
  195. weightPtr = hdr ? weightData.hdr_alb : weightData.ldr_alb;
  196. }
  197. else if (color && albedo && normal && weightData.hdr_alb_nrm)
  198. {
  199. inputC = 9;
  200. weightPtr = hdr ? weightData.hdr_alb_nrm : weightData.ldr_alb_nrm;
  201. }
  202. else
  203. {
  204. throw Exception(Error::InvalidOperation, "unsupported combination of input features");
  205. }
  206. if (!output)
  207. throw Exception(Error::InvalidOperation, "output image not specified");
  208. if ((color.format != Format::Float3)
  209. || (albedo && albedo.format != Format::Float3)
  210. || (normal && normal.format != Format::Float3)
  211. || (output.format != Format::Float3))
  212. throw Exception(Error::InvalidOperation, "unsupported image format");
  213. if ((albedo && (albedo.width != W || albedo.height != H))
  214. || (normal && (normal.width != W || normal.height != H))
  215. || (output.width != W || output.height != H))
  216. throw Exception(Error::InvalidOperation, "image size mismatch");
  217. // Compute the tile size
  218. computeTileSize();
  219. // If the image size is zero, there is nothing else to do
  220. if (H <= 0 || W <= 0)
  221. return nullptr;
  222. // Parse the weights
  223. const auto weightMap = parseTensors(weightPtr);
  224. // Create the network
  225. std::shared_ptr<Network<K>> net = std::make_shared<Network<K>>(device, weightMap);
  226. // Compute the tensor sizes
  227. const auto inputDims = memory::dims({1, inputC, tileH, tileW});
  228. const auto inputReorderDims = net->getInputReorderDims(inputDims, alignment); //-> concat0
  229. const auto conv1Dims = net->getConvDims("conv1", inputReorderDims); //-> temp0
  230. const auto conv1bDims = net->getConvDims("conv1b", conv1Dims); //-> temp1
  231. const auto pool1Dims = net->getPoolDims(conv1bDims); //-> concat1
  232. const auto conv2Dims = net->getConvDims("conv2", pool1Dims); //-> temp0
  233. const auto pool2Dims = net->getPoolDims(conv2Dims); //-> concat2
  234. const auto conv3Dims = net->getConvDims("conv3", pool2Dims); //-> temp0
  235. const auto pool3Dims = net->getPoolDims(conv3Dims); //-> concat3
  236. const auto conv4Dims = net->getConvDims("conv4", pool3Dims); //-> temp0
  237. const auto pool4Dims = net->getPoolDims(conv4Dims); //-> concat4
  238. const auto conv5Dims = net->getConvDims("conv5", pool4Dims); //-> temp0
  239. const auto pool5Dims = net->getPoolDims(conv5Dims); //-> temp1
  240. const auto upsample4Dims = net->getUpsampleDims(pool5Dims); //-> concat4
  241. const auto concat4Dims = net->getConcatDims(upsample4Dims, pool4Dims);
  242. const auto conv6Dims = net->getConvDims("conv6", concat4Dims); //-> temp0
  243. const auto conv6bDims = net->getConvDims("conv6b", conv6Dims); //-> temp1
  244. const auto upsample3Dims = net->getUpsampleDims(conv6bDims); //-> concat3
  245. const auto concat3Dims = net->getConcatDims(upsample3Dims, pool3Dims);
  246. const auto conv7Dims = net->getConvDims("conv7", concat3Dims); //-> temp0
  247. const auto conv7bDims = net->getConvDims("conv7b", conv7Dims); //-> temp1
  248. const auto upsample2Dims = net->getUpsampleDims(conv7bDims); //-> concat2
  249. const auto concat2Dims = net->getConcatDims(upsample2Dims, pool2Dims);
  250. const auto conv8Dims = net->getConvDims("conv8", concat2Dims); //-> temp0
  251. const auto conv8bDims = net->getConvDims("conv8b", conv8Dims); //-> temp1
  252. const auto upsample1Dims = net->getUpsampleDims(conv8bDims); //-> concat1
  253. const auto concat1Dims = net->getConcatDims(upsample1Dims, pool1Dims);
  254. const auto conv9Dims = net->getConvDims("conv9", concat1Dims); //-> temp0
  255. const auto conv9bDims = net->getConvDims("conv9b", conv9Dims); //-> temp1
  256. const auto upsample0Dims = net->getUpsampleDims(conv9bDims); //-> concat0
  257. const auto concat0Dims = net->getConcatDims(upsample0Dims, inputReorderDims);
  258. const auto conv10Dims = net->getConvDims("conv10", concat0Dims); //-> temp0
  259. const auto conv10bDims = net->getConvDims("conv10b", conv10Dims); //-> temp1
  260. const auto conv11Dims = net->getConvDims("conv11", conv10bDims); //-> temp0
  261. const auto outputDims = memory::dims({1, 3, tileH, tileW});
  262. // Allocate two temporary ping-pong buffers to decrease memory usage
  263. const auto temp0Dims = getMaxTensorDims({
  264. conv1Dims,
  265. conv2Dims,
  266. conv3Dims,
  267. conv4Dims,
  268. conv5Dims,
  269. conv6Dims,
  270. conv7Dims,
  271. conv8Dims,
  272. conv9Dims,
  273. conv10Dims,
  274. conv11Dims
  275. });
  276. const auto temp1Dims = getMaxTensorDims({
  277. conv1bDims,
  278. pool5Dims,
  279. conv6bDims,
  280. conv7bDims,
  281. conv8bDims,
  282. conv9bDims,
  283. conv10bDims,
  284. });
  285. auto temp0 = net->allocTensor(temp0Dims);
  286. auto temp1 = net->allocTensor(temp1Dims);
  287. // Allocate enough memory to hold the concat outputs. Then use the first
  288. // half to hold the previous conv output and the second half to hold the
  289. // pool/orig image output. This works because everything is C dimension
  290. // outermost, padded to K floats, and all the concats are on the C dimension.
  291. auto concat0Dst = net->allocTensor(concat0Dims);
  292. auto concat1Dst = net->allocTensor(concat1Dims);
  293. auto concat2Dst = net->allocTensor(concat2Dims);
  294. auto concat3Dst = net->allocTensor(concat3Dims);
  295. auto concat4Dst = net->allocTensor(concat4Dims);
  296. // Transfer function
  297. std::shared_ptr<TransferFunction> transferFunc = makeTransferFunc();
  298. // Autoexposure
  299. if (auto tf = std::dynamic_pointer_cast<HDRTransferFunction>(transferFunc))
  300. {
  301. if (isnan(hdrScale))
  302. net->addAutoexposure(color, tf);
  303. else
  304. tf->setExposure(hdrScale);
  305. }
  306. // Input reorder
  307. auto inputReorderDst = net->castTensor(inputReorderDims, concat0Dst, upsample0Dims);
  308. inputReorder = net->addInputReorder(color, albedo, normal,
  309. transferFunc,
  310. alignment, inputReorderDst);
  311. // conv1
  312. auto conv1 = net->addConv("conv1", inputReorder->getDst(), temp0);
  313. // conv1b
  314. auto conv1b = net->addConv("conv1b", conv1->getDst(), temp1);
  315. // pool1
  316. // Adjust pointer for pool1 to eliminate concat1
  317. auto pool1Dst = net->castTensor(pool1Dims, concat1Dst, upsample1Dims);
  318. auto pool1 = net->addPool(conv1b->getDst(), pool1Dst);
  319. // conv2
  320. auto conv2 = net->addConv("conv2", pool1->getDst(), temp0);
  321. // pool2
  322. // Adjust pointer for pool2 to eliminate concat2
  323. auto pool2Dst = net->castTensor(pool2Dims, concat2Dst, upsample2Dims);
  324. auto pool2 = net->addPool(conv2->getDst(), pool2Dst);
  325. // conv3
  326. auto conv3 = net->addConv("conv3", pool2->getDst(), temp0);
  327. // pool3
  328. // Adjust pointer for pool3 to eliminate concat3
  329. auto pool3Dst = net->castTensor(pool3Dims, concat3Dst, upsample3Dims);
  330. auto pool3 = net->addPool(conv3->getDst(), pool3Dst);
  331. // conv4
  332. auto conv4 = net->addConv("conv4", pool3->getDst(), temp0);
  333. // pool4
  334. // Adjust pointer for pool4 to eliminate concat4
  335. auto pool4Dst = net->castTensor(pool4Dims, concat4Dst, upsample4Dims);
  336. auto pool4 = net->addPool(conv4->getDst(), pool4Dst);
  337. // conv5
  338. auto conv5 = net->addConv("conv5", pool4->getDst(), temp0);
  339. // pool5
  340. auto pool5 = net->addPool(conv5->getDst(), temp1);
  341. // upsample4
  342. auto upsample4Dst = net->castTensor(upsample4Dims, concat4Dst);
  343. auto upsample4 = net->addUpsample(pool5->getDst(), upsample4Dst);
  344. // conv6
  345. auto conv6 = net->addConv("conv6", concat4Dst, temp0);
  346. // conv6b
  347. auto conv6b = net->addConv("conv6b", conv6->getDst(), temp1);
  348. // upsample3
  349. auto upsample3Dst = net->castTensor(upsample3Dims, concat3Dst);
  350. auto upsample3 = net->addUpsample(conv6b->getDst(), upsample3Dst);
  351. // conv7
  352. auto conv7 = net->addConv("conv7", concat3Dst, temp0);
  353. // conv7b
  354. auto conv7b = net->addConv("conv7b", conv7->getDst(), temp1);
  355. // upsample2
  356. auto upsample2Dst = net->castTensor(upsample2Dims, concat2Dst);
  357. auto upsample2 = net->addUpsample(conv7b->getDst(), upsample2Dst);
  358. // conv8
  359. auto conv8 = net->addConv("conv8", concat2Dst, temp0);
  360. // conv8b
  361. auto conv8b = net->addConv("conv8b", conv8->getDst(), temp1);
  362. // upsample1
  363. auto upsample1Dst = net->castTensor(upsample1Dims, concat1Dst);
  364. auto upsample1 = net->addUpsample(conv8b->getDst(), upsample1Dst);
  365. // conv9
  366. auto conv9 = net->addConv("conv9", concat1Dst, temp0);
  367. // conv9b
  368. auto conv9b = net->addConv("conv9b", conv9->getDst(), temp1);
  369. // upsample0
  370. auto upsample0Dst = net->castTensor(upsample0Dims, concat0Dst);
  371. auto upsample0 = net->addUpsample(conv9b->getDst(), upsample0Dst);
  372. // conv10
  373. auto conv10 = net->addConv("conv10", concat0Dst, temp0);
  374. // conv10b
  375. auto conv10b = net->addConv("conv10b", conv10->getDst(), temp1);
  376. // conv11
  377. auto conv11 = net->addConv("conv11", conv10b->getDst(), temp0, false /* no relu */);
  378. // Output reorder
  379. outputReorder = net->addOutputReorder(conv11->getDst(), transferFunc, output);
  380. net->finalize();
  381. return net;
  382. }
  383. std::shared_ptr<TransferFunction> AutoencoderFilter::makeTransferFunc()
  384. {
  385. if (hdr)
  386. return std::make_shared<PQXTransferFunction>();
  387. else if (srgb)
  388. return std::make_shared<LinearTransferFunction>();
  389. else
  390. return std::make_shared<GammaTransferFunction>();
  391. }
  392. // -- GODOT start --
  393. // Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
  394. #if 0
  395. // -- GODOT end --
  396. // --------------------------------------------------------------------------
  397. // RTFilter
  398. // --------------------------------------------------------------------------
  399. namespace weights
  400. {
  401. // LDR
  402. extern unsigned char rt_ldr[]; // color
  403. extern unsigned char rt_ldr_alb[]; // color, albedo
  404. extern unsigned char rt_ldr_alb_nrm[]; // color, albedo, normal
  405. // HDR
  406. extern unsigned char rt_hdr[]; // color
  407. extern unsigned char rt_hdr_alb[]; // color, albedo
  408. extern unsigned char rt_hdr_alb_nrm[]; // color, albedo, normal
  409. }
  410. RTFilter::RTFilter(const Ref<Device>& device)
  411. : AutoencoderFilter(device)
  412. {
  413. weightData.ldr = weights::rt_ldr;
  414. weightData.ldr_alb = weights::rt_ldr_alb;
  415. weightData.ldr_alb_nrm = weights::rt_ldr_alb_nrm;
  416. weightData.hdr = weights::rt_hdr;
  417. weightData.hdr_alb = weights::rt_hdr_alb;
  418. weightData.hdr_alb_nrm = weights::rt_hdr_alb_nrm;
  419. }
  420. // -- GODOT start --
  421. #endif
  422. // -- GODOT end --
  423. // --------------------------------------------------------------------------
  424. // RTLightmapFilter
  425. // --------------------------------------------------------------------------
  426. namespace weights
  427. {
  428. // HDR
  429. extern unsigned char rtlightmap_hdr[]; // color
  430. }
  431. RTLightmapFilter::RTLightmapFilter(const Ref<Device>& device)
  432. : AutoencoderFilter(device)
  433. {
  434. weightData.hdr = weights::rtlightmap_hdr;
  435. hdr = true;
  436. }
  437. std::shared_ptr<TransferFunction> RTLightmapFilter::makeTransferFunc()
  438. {
  439. return std::make_shared<LogTransferFunction>();
  440. }
  441. } // namespace oidn