tensor.h 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. // ======================================================================== //
  2. // Copyright 2009-2019 Intel Corporation //
  3. // //
  4. // Licensed under the Apache License, Version 2.0 (the "License"); //
  5. // you may not use this file except in compliance with the License. //
  6. // You may obtain a copy of the License at //
  7. // //
  8. // http://www.apache.org/licenses/LICENSE-2.0 //
  9. // //
  10. // Unless required by applicable law or agreed to in writing, software //
  11. // distributed under the License is distributed on an "AS IS" BASIS, //
  12. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
  13. // See the License for the specific language governing permissions and //
  14. // limitations under the License. //
  15. // ======================================================================== //
  16. #pragma once
  17. #include "platform.h"
  18. #include <vector>
  19. #include <map>
  20. namespace oidn {
  21. template<typename T>
  22. using shared_vector = std::shared_ptr<std::vector<T>>;
  23. // Generic tensor
  24. struct Tensor
  25. {
  26. float* data;
  27. std::vector<int64_t> dims;
  28. std::string format;
  29. shared_vector<char> buffer; // optional, only for reference counting
  30. __forceinline Tensor() : data(nullptr) {}
  31. __forceinline Tensor(const std::vector<int64_t>& dims, const std::string& format)
  32. : dims(dims),
  33. format(format)
  34. {
  35. buffer = std::make_shared<std::vector<char>>(size() * sizeof(float));
  36. data = (float*)buffer->data();
  37. }
  38. __forceinline operator bool() const { return data != nullptr; }
  39. __forceinline int ndims() const { return (int)dims.size(); }
  40. // Returns the number of values
  41. __forceinline size_t size() const
  42. {
  43. size_t size = 1;
  44. for (int i = 0; i < ndims(); ++i)
  45. size *= dims[i];
  46. return size;
  47. }
  48. __forceinline float& operator [](size_t i) { return data[i]; }
  49. __forceinline const float& operator [](size_t i) const { return data[i]; }
  50. };
  51. // Parses tensors from a buffer
  52. std::map<std::string, Tensor> parseTensors(void* buffer);
  53. } // namespace oidn