network.h 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. // ======================================================================== //
  2. // Copyright 2009-2019 Intel Corporation //
  3. // //
  4. // Licensed under the Apache License, Version 2.0 (the "License"); //
  5. // you may not use this file except in compliance with the License. //
  6. // You may obtain a copy of the License at //
  7. // //
  8. // http://www.apache.org/licenses/LICENSE-2.0 //
  9. // //
  10. // Unless required by applicable law or agreed to in writing, software //
  11. // distributed under the License is distributed on an "AS IS" BASIS, //
  12. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
  13. // See the License for the specific language governing permissions and //
  14. // limitations under the License. //
  15. // ======================================================================== //
  16. #include "common/tensor.h"
  17. #include "image.h"
  18. #include "node.h"
  19. #include "input_reorder.h"
  20. #include "output_reorder.h"
  21. #include "transfer_function.h"
  22. #pragma once
  23. namespace oidn {
  24. // Progress state
  25. struct Progress
  26. {
  27. ProgressMonitorFunction func;
  28. void* userPtr;
  29. int taskCount;
  30. };
  31. class Executable
  32. {
  33. public:
  34. virtual ~Executable() {}
  35. virtual void execute(const Progress& progress, int taskIndex) = 0;
  36. };
  37. template<int K>
  38. class Network : public Executable
  39. {
  40. public:
  41. Network(const Ref<Device>& device, const std::map<std::string, Tensor>& weightMap);
  42. void execute(const Progress& progress, int taskIndex) override;
  43. std::shared_ptr<memory> allocTensor(const memory::dims& dims,
  44. memory::format_tag format = memory::format_tag::any,
  45. void* data = nullptr);
  46. std::shared_ptr<memory> castTensor(const memory::dims& dims,
  47. const std::shared_ptr<memory>& src,
  48. size_t srcOffset = 0,
  49. memory::format_tag format = memory::format_tag::any);
  50. std::shared_ptr<memory> castTensor(const memory::dims& dims,
  51. const std::shared_ptr<memory>& src,
  52. const memory::dims& srcOffset);
  53. void zeroTensor(const std::shared_ptr<memory>& dst);
  54. memory::dims getInputReorderDims(const memory::dims& srcDims, int alignment);
  55. std::shared_ptr<Node> addInputReorder(const Image& color,
  56. const Image& albedo,
  57. const Image& normal,
  58. const std::shared_ptr<TransferFunction>& transferFunc,
  59. int alignment,
  60. const std::shared_ptr<memory>& userDst = nullptr);
  61. std::shared_ptr<Node> addOutputReorder(const std::shared_ptr<memory>& src,
  62. const std::shared_ptr<TransferFunction>& transferFunc,
  63. const Image& output);
  64. memory::dims getConvDims(const std::string& name, const memory::dims& srcDims);
  65. std::shared_ptr<Node> addConv(const std::string& name,
  66. const std::shared_ptr<memory>& src,
  67. const std::shared_ptr<memory>& userDst = nullptr,
  68. bool relu = true);
  69. memory::dims getPoolDims(const memory::dims& srcDims);
  70. std::shared_ptr<Node> addPool(const std::shared_ptr<memory>& src,
  71. const std::shared_ptr<memory>& userDst = nullptr);
  72. memory::dims getUpsampleDims(const memory::dims& srcDims);
  73. std::shared_ptr<Node> addUpsample(const std::shared_ptr<memory>& src,
  74. const std::shared_ptr<memory>& userDst = nullptr);
  75. memory::dims getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims);
  76. std::shared_ptr<Node> addAutoexposure(const Image& color,
  77. const std::shared_ptr<HDRTransferFunction>& transferFunc);
  78. void finalize();
  79. private:
  80. Ref<Device> device;
  81. engine eng;
  82. stream sm;
  83. std::vector<std::shared_ptr<Node>> nodes;
  84. std::map<std::string, Tensor> weightMap;
  85. // Memory allocation statistics
  86. size_t activationAllocBytes = 0; // number of allocated activation bytes
  87. size_t totalAllocBytes = 0; // total number of allocated bytes
  88. };
  89. } // namespace oidn