tensor.cpp 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. // ======================================================================== //
  2. // Copyright 2009-2019 Intel Corporation //
  3. // //
  4. // Licensed under the Apache License, Version 2.0 (the "License"); //
  5. // you may not use this file except in compliance with the License. //
  6. // You may obtain a copy of the License at //
  7. // //
  8. // http://www.apache.org/licenses/LICENSE-2.0 //
  9. // //
  10. // Unless required by applicable law or agreed to in writing, software //
  11. // distributed under the License is distributed on an "AS IS" BASIS, //
  12. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
  13. // See the License for the specific language governing permissions and //
  14. // limitations under the License. //
  15. // ======================================================================== //
  16. #include "exception.h"
  17. #include "tensor.h"
  18. namespace oidn {
  19. std::map<std::string, Tensor> parseTensors(void* buffer)
  20. {
  21. char* input = (char*)buffer;
  22. // Parse the magic value
  23. const int magic = *(unsigned short*)input;
  24. if (magic != 0x41D7)
  25. throw Exception(Error::InvalidOperation, "invalid tensor archive");
  26. input += sizeof(unsigned short);
  27. // Parse the version
  28. const int majorVersion = *(unsigned char*)input++;
  29. const int minorVersion = *(unsigned char*)input++;
  30. UNUSED(minorVersion);
  31. if (majorVersion > 1)
  32. throw Exception(Error::InvalidOperation, "unsupported tensor archive version");
  33. // Parse the number of tensors
  34. const int numTensors = *(int*)input;
  35. input += sizeof(int);
  36. // Parse the tensors
  37. std::map<std::string, Tensor> tensorMap;
  38. for (int i = 0; i < numTensors; ++i)
  39. {
  40. Tensor tensor;
  41. // Parse the name
  42. const int nameLen = *(unsigned char*)input++;
  43. std::string name(input, nameLen);
  44. input += nameLen;
  45. // Parse the number of dimensions
  46. const int ndims = *(unsigned char*)input++;
  47. // Parse the shape of the tensor
  48. tensor.dims.resize(ndims);
  49. for (int i = 0; i < ndims; ++i)
  50. tensor.dims[i] = ((int*)input)[i];
  51. input += ndims * sizeof(int);
  52. // Parse the format of the tensor
  53. tensor.format = std::string(input, input + ndims);
  54. input += ndims;
  55. // Parse the data type of the tensor
  56. const char type = *(unsigned char*)input++;
  57. if (type != 'f') // only float32 is supported
  58. throw Exception(Error::InvalidOperation, "unsupported tensor data type");
  59. // Skip the data
  60. tensor.data = (float*)input;
  61. input += tensor.size() * sizeof(float);
  62. // Add the tensor to the map
  63. tensorMap.emplace(name, std::move(tensor));
  64. }
  65. return tensorMap;
  66. }
  67. } // namespace oidn