autoencoder.h 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. // ======================================================================== //
  2. // Copyright 2009-2019 Intel Corporation //
  3. // //
  4. // Licensed under the Apache License, Version 2.0 (the "License"); //
  5. // you may not use this file except in compliance with the License. //
  6. // You may obtain a copy of the License at //
  7. // //
  8. // http://www.apache.org/licenses/LICENSE-2.0 //
  9. // //
  10. // Unless required by applicable law or agreed to in writing, software //
  11. // distributed under the License is distributed on an "AS IS" BASIS, //
  12. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
  13. // See the License for the specific language governing permissions and //
  14. // limitations under the License. //
  15. // ======================================================================== //
  16. #pragma once
  17. #include "filter.h"
  18. #include "network.h"
  19. #include "transfer_function.h"
  20. namespace oidn {
  21. // --------------------------------------------------------------------------
  22. // AutoencoderFilter - Direct-predicting autoencoder
  23. // --------------------------------------------------------------------------
  24. class AutoencoderFilter : public Filter
  25. {
  26. protected:
  27. static constexpr int alignment = 32; // required spatial alignment in pixels (padding may be necessary)
  28. static constexpr int receptiveField = 222; // receptive field in pixels
  29. static constexpr int overlap = roundUp(receptiveField / 2, alignment); // required spatial overlap between tiles in pixels
  30. static constexpr int estimatedBytesBase = 16*1024*1024; // estimated base memory usage
  31. static constexpr int estimatedBytesPerPixel8 = 889; // estimated memory usage per pixel for K=8
  32. static constexpr int estimatedBytesPerPixel16 = 2185; // estimated memory usage per pixel for K=16
  33. Image color;
  34. Image albedo;
  35. Image normal;
  36. Image output;
  37. bool hdr = false;
  38. float hdrScale = std::numeric_limits<float>::quiet_NaN();
  39. bool srgb = false;
  40. int maxMemoryMB = 6000; // approximate maximum memory usage in MBs
  41. int H = 0; // image height
  42. int W = 0; // image width
  43. int tileH = 0; // tile height
  44. int tileW = 0; // tile width
  45. int tileCountH = 1; // number of tiles in H dimension
  46. int tileCountW = 1; // number of tiles in W dimension
  47. std::shared_ptr<Executable> net;
  48. std::shared_ptr<Node> inputReorder;
  49. std::shared_ptr<Node> outputReorder;
  50. struct
  51. {
  52. void* ldr = nullptr;
  53. void* ldr_alb = nullptr;
  54. void* ldr_alb_nrm = nullptr;
  55. void* hdr = nullptr;
  56. void* hdr_alb = nullptr;
  57. void* hdr_alb_nrm = nullptr;
  58. } weightData;
  59. explicit AutoencoderFilter(const Ref<Device>& device);
  60. virtual std::shared_ptr<TransferFunction> makeTransferFunc();
  61. public:
  62. void setImage(const std::string& name, const Image& data) override;
  63. void set1i(const std::string& name, int value) override;
  64. int get1i(const std::string& name) override;
  65. void set1f(const std::string& name, float value) override;
  66. float get1f(const std::string& name) override;
  67. void commit() override;
  68. void execute() override;
  69. private:
  70. void computeTileSize();
  71. template<int K>
  72. std::shared_ptr<Executable> buildNet();
  73. bool isCommitted() const { return bool(net); }
  74. };
  75. // --------------------------------------------------------------------------
  76. // RTFilter - Generic ray tracing denoiser
  77. // --------------------------------------------------------------------------
  78. // -- GODOT start --
  79. // Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
  80. #if 0
  81. // -- GODOT end --
  82. class RTFilter : public AutoencoderFilter
  83. {
  84. public:
  85. explicit RTFilter(const Ref<Device>& device);
  86. };
  87. // -- GODOT start --
  88. #endif
  89. // -- GODOT end --
  90. // --------------------------------------------------------------------------
  91. // RTLightmapFilter - Ray traced lightmap denoiser
  92. // --------------------------------------------------------------------------
  93. class RTLightmapFilter : public AutoencoderFilter
  94. {
  95. public:
  96. explicit RTLightmapFilter(const Ref<Device>& device);
  97. std::shared_ptr<TransferFunction> makeTransferFunc() override;
  98. };
  99. } // namespace oidn