api.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. // ======================================================================== //
  2. // Copyright 2009-2019 Intel Corporation //
  3. // //
  4. // Licensed under the Apache License, Version 2.0 (the "License"); //
  5. // you may not use this file except in compliance with the License. //
  6. // You may obtain a copy of the License at //
  7. // //
  8. // http://www.apache.org/licenses/LICENSE-2.0 //
  9. // //
  10. // Unless required by applicable law or agreed to in writing, software //
  11. // distributed under the License is distributed on an "AS IS" BASIS, //
  12. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
  13. // See the License for the specific language governing permissions and //
  14. // limitations under the License. //
  15. // ======================================================================== //
  16. #ifdef _WIN32
  17. # define OIDN_API extern "C" __declspec(dllexport)
  18. #else
  19. # define OIDN_API extern "C" __attribute__ ((visibility ("default")))
  20. #endif
  21. // Locks the device that owns the specified object
  22. // Use *only* inside OIDN_TRY/CATCH!
  23. #define OIDN_LOCK(obj) \
  24. std::lock_guard<std::mutex> lock(obj->getDevice()->getMutex());
  25. // Try/catch for converting exceptions to errors
  26. #define OIDN_TRY \
  27. try {
  28. #define OIDN_CATCH(obj) \
  29. } catch (Exception& e) { \
  30. Device::setError(obj ? obj->getDevice() : nullptr, e.code(), e.what()); \
  31. } catch (std::bad_alloc&) { \
  32. Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory"); \
  33. } catch (mkldnn::error& e) { \
  34. if (e.status == mkldnn_out_of_memory) \
  35. Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory"); \
  36. else \
  37. Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.message); \
  38. } catch (std::exception& e) { \
  39. Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.what()); \
  40. } catch (...) { \
  41. Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, "unknown exception caught"); \
  42. }
  43. #include "device.h"
  44. #include "filter.h"
  45. #include <mutex>
  46. namespace oidn {
  47. namespace
  48. {
  49. __forceinline void checkHandle(void* handle)
  50. {
  51. if (handle == nullptr)
  52. throw Exception(Error::InvalidArgument, "invalid handle");
  53. }
  54. template<typename T>
  55. __forceinline void retainObject(T* obj)
  56. {
  57. if (obj)
  58. {
  59. obj->incRef();
  60. }
  61. else
  62. {
  63. OIDN_TRY
  64. checkHandle(obj);
  65. OIDN_CATCH(obj)
  66. }
  67. }
  68. template<typename T>
  69. __forceinline void releaseObject(T* obj)
  70. {
  71. if (obj == nullptr || obj->decRefKeep() == 0)
  72. {
  73. OIDN_TRY
  74. checkHandle(obj);
  75. OIDN_LOCK(obj);
  76. obj->destroy();
  77. OIDN_CATCH(obj)
  78. }
  79. }
  80. template<>
  81. __forceinline void releaseObject(Device* obj)
  82. {
  83. if (obj == nullptr || obj->decRefKeep() == 0)
  84. {
  85. OIDN_TRY
  86. checkHandle(obj);
  87. // Do NOT lock the device because it owns the mutex
  88. obj->destroy();
  89. OIDN_CATCH(obj)
  90. }
  91. }
  92. }
  93. OIDN_API OIDNDevice oidnNewDevice(OIDNDeviceType type)
  94. {
  95. Ref<Device> device = nullptr;
  96. OIDN_TRY
  97. if (type == OIDN_DEVICE_TYPE_CPU || type == OIDN_DEVICE_TYPE_DEFAULT)
  98. device = makeRef<Device>();
  99. else
  100. throw Exception(Error::InvalidArgument, "invalid device type");
  101. OIDN_CATCH(device)
  102. return (OIDNDevice)device.detach();
  103. }
  104. OIDN_API void oidnRetainDevice(OIDNDevice hDevice)
  105. {
  106. Device* device = (Device*)hDevice;
  107. retainObject(device);
  108. }
  109. OIDN_API void oidnReleaseDevice(OIDNDevice hDevice)
  110. {
  111. Device* device = (Device*)hDevice;
  112. releaseObject(device);
  113. }
  114. OIDN_API void oidnSetDevice1b(OIDNDevice hDevice, const char* name, bool value)
  115. {
  116. Device* device = (Device*)hDevice;
  117. OIDN_TRY
  118. checkHandle(hDevice);
  119. OIDN_LOCK(device);
  120. device->set1i(name, value);
  121. OIDN_CATCH(device)
  122. }
  123. OIDN_API void oidnSetDevice1i(OIDNDevice hDevice, const char* name, int value)
  124. {
  125. Device* device = (Device*)hDevice;
  126. OIDN_TRY
  127. checkHandle(hDevice);
  128. OIDN_LOCK(device);
  129. device->set1i(name, value);
  130. OIDN_CATCH(device)
  131. }
  132. OIDN_API bool oidnGetDevice1b(OIDNDevice hDevice, const char* name)
  133. {
  134. Device* device = (Device*)hDevice;
  135. OIDN_TRY
  136. checkHandle(hDevice);
  137. OIDN_LOCK(device);
  138. return device->get1i(name);
  139. OIDN_CATCH(device)
  140. return false;
  141. }
  142. OIDN_API int oidnGetDevice1i(OIDNDevice hDevice, const char* name)
  143. {
  144. Device* device = (Device*)hDevice;
  145. OIDN_TRY
  146. checkHandle(hDevice);
  147. OIDN_LOCK(device);
  148. return device->get1i(name);
  149. OIDN_CATCH(device)
  150. return 0;
  151. }
  152. OIDN_API void oidnSetDeviceErrorFunction(OIDNDevice hDevice, OIDNErrorFunction func, void* userPtr)
  153. {
  154. Device* device = (Device*)hDevice;
  155. OIDN_TRY
  156. checkHandle(hDevice);
  157. OIDN_LOCK(device);
  158. device->setErrorFunction((ErrorFunction)func, userPtr);
  159. OIDN_CATCH(device)
  160. }
  161. OIDN_API OIDNError oidnGetDeviceError(OIDNDevice hDevice, const char** outMessage)
  162. {
  163. Device* device = (Device*)hDevice;
  164. OIDN_TRY
  165. return (OIDNError)Device::getError(device, outMessage);
  166. OIDN_CATCH(device)
  167. if (outMessage) *outMessage = "";
  168. return OIDN_ERROR_UNKNOWN;
  169. }
  170. OIDN_API void oidnCommitDevice(OIDNDevice hDevice)
  171. {
  172. Device* device = (Device*)hDevice;
  173. OIDN_TRY
  174. checkHandle(hDevice);
  175. OIDN_LOCK(device);
  176. device->commit();
  177. OIDN_CATCH(device)
  178. }
  179. OIDN_API OIDNBuffer oidnNewBuffer(OIDNDevice hDevice, size_t byteSize)
  180. {
  181. Device* device = (Device*)hDevice;
  182. OIDN_TRY
  183. checkHandle(hDevice);
  184. OIDN_LOCK(device);
  185. Ref<Buffer> buffer = device->newBuffer(byteSize);
  186. return (OIDNBuffer)buffer.detach();
  187. OIDN_CATCH(device)
  188. return nullptr;
  189. }
  190. OIDN_API OIDNBuffer oidnNewSharedBuffer(OIDNDevice hDevice, void* ptr, size_t byteSize)
  191. {
  192. Device* device = (Device*)hDevice;
  193. OIDN_TRY
  194. checkHandle(hDevice);
  195. OIDN_LOCK(device);
  196. Ref<Buffer> buffer = device->newBuffer(ptr, byteSize);
  197. return (OIDNBuffer)buffer.detach();
  198. OIDN_CATCH(device)
  199. return nullptr;
  200. }
  201. OIDN_API void oidnRetainBuffer(OIDNBuffer hBuffer)
  202. {
  203. Buffer* buffer = (Buffer*)hBuffer;
  204. retainObject(buffer);
  205. }
  206. OIDN_API void oidnReleaseBuffer(OIDNBuffer hBuffer)
  207. {
  208. Buffer* buffer = (Buffer*)hBuffer;
  209. releaseObject(buffer);
  210. }
  211. OIDN_API void* oidnMapBuffer(OIDNBuffer hBuffer, OIDNAccess access, size_t byteOffset, size_t byteSize)
  212. {
  213. Buffer* buffer = (Buffer*)hBuffer;
  214. OIDN_TRY
  215. checkHandle(hBuffer);
  216. OIDN_LOCK(buffer);
  217. return buffer->map(byteOffset, byteSize);
  218. OIDN_CATCH(buffer)
  219. return nullptr;
  220. }
  221. OIDN_API void oidnUnmapBuffer(OIDNBuffer hBuffer, void* mappedPtr)
  222. {
  223. Buffer* buffer = (Buffer*)hBuffer;
  224. OIDN_TRY
  225. checkHandle(hBuffer);
  226. OIDN_LOCK(buffer);
  227. return buffer->unmap(mappedPtr);
  228. OIDN_CATCH(buffer)
  229. }
  230. OIDN_API OIDNFilter oidnNewFilter(OIDNDevice hDevice, const char* type)
  231. {
  232. Device* device = (Device*)hDevice;
  233. OIDN_TRY
  234. checkHandle(hDevice);
  235. OIDN_LOCK(device);
  236. Ref<Filter> filter = device->newFilter(type);
  237. return (OIDNFilter)filter.detach();
  238. OIDN_CATCH(device)
  239. return nullptr;
  240. }
  241. OIDN_API void oidnRetainFilter(OIDNFilter hFilter)
  242. {
  243. Filter* filter = (Filter*)hFilter;
  244. retainObject(filter);
  245. }
  246. OIDN_API void oidnReleaseFilter(OIDNFilter hFilter)
  247. {
  248. Filter* filter = (Filter*)hFilter;
  249. releaseObject(filter);
  250. }
  251. OIDN_API void oidnSetFilterImage(OIDNFilter hFilter, const char* name,
  252. OIDNBuffer hBuffer, OIDNFormat format,
  253. size_t width, size_t height,
  254. size_t byteOffset,
  255. size_t bytePixelStride, size_t byteRowStride)
  256. {
  257. Filter* filter = (Filter*)hFilter;
  258. OIDN_TRY
  259. checkHandle(hFilter);
  260. checkHandle(hBuffer);
  261. OIDN_LOCK(filter);
  262. Ref<Buffer> buffer = (Buffer*)hBuffer;
  263. if (buffer->getDevice() != filter->getDevice())
  264. throw Exception(Error::InvalidArgument, "the specified objects are bound to different devices");
  265. Image data(buffer, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride);
  266. filter->setImage(name, data);
  267. OIDN_CATCH(filter)
  268. }
  269. OIDN_API void oidnSetSharedFilterImage(OIDNFilter hFilter, const char* name,
  270. void* ptr, OIDNFormat format,
  271. size_t width, size_t height,
  272. size_t byteOffset,
  273. size_t bytePixelStride, size_t byteRowStride)
  274. {
  275. Filter* filter = (Filter*)hFilter;
  276. OIDN_TRY
  277. checkHandle(hFilter);
  278. OIDN_LOCK(filter);
  279. Image data(ptr, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride);
  280. filter->setImage(name, data);
  281. OIDN_CATCH(filter)
  282. }
  283. OIDN_API void oidnSetFilter1b(OIDNFilter hFilter, const char* name, bool value)
  284. {
  285. Filter* filter = (Filter*)hFilter;
  286. OIDN_TRY
  287. checkHandle(hFilter);
  288. OIDN_LOCK(filter);
  289. filter->set1i(name, int(value));
  290. OIDN_CATCH(filter)
  291. }
  292. OIDN_API bool oidnGetFilter1b(OIDNFilter hFilter, const char* name)
  293. {
  294. Filter* filter = (Filter*)hFilter;
  295. OIDN_TRY
  296. checkHandle(hFilter);
  297. OIDN_LOCK(filter);
  298. return filter->get1i(name);
  299. OIDN_CATCH(filter)
  300. return false;
  301. }
  302. OIDN_API void oidnSetFilter1i(OIDNFilter hFilter, const char* name, int value)
  303. {
  304. Filter* filter = (Filter*)hFilter;
  305. OIDN_TRY
  306. checkHandle(hFilter);
  307. OIDN_LOCK(filter);
  308. filter->set1i(name, value);
  309. OIDN_CATCH(filter)
  310. }
  311. OIDN_API int oidnGetFilter1i(OIDNFilter hFilter, const char* name)
  312. {
  313. Filter* filter = (Filter*)hFilter;
  314. OIDN_TRY
  315. checkHandle(hFilter);
  316. OIDN_LOCK(filter);
  317. return filter->get1i(name);
  318. OIDN_CATCH(filter)
  319. return 0;
  320. }
  321. OIDN_API void oidnSetFilter1f(OIDNFilter hFilter, const char* name, float value)
  322. {
  323. Filter* filter = (Filter*)hFilter;
  324. OIDN_TRY
  325. checkHandle(hFilter);
  326. OIDN_LOCK(filter);
  327. filter->set1f(name, value);
  328. OIDN_CATCH(filter)
  329. }
  330. OIDN_API float oidnGetFilter1f(OIDNFilter hFilter, const char* name)
  331. {
  332. Filter* filter = (Filter*)hFilter;
  333. OIDN_TRY
  334. checkHandle(hFilter);
  335. OIDN_LOCK(filter);
  336. return filter->get1f(name);
  337. OIDN_CATCH(filter)
  338. return 0;
  339. }
  340. OIDN_API void oidnSetFilterProgressMonitorFunction(OIDNFilter hFilter, OIDNProgressMonitorFunction func, void* userPtr)
  341. {
  342. Filter* filter = (Filter*)hFilter;
  343. OIDN_TRY
  344. checkHandle(hFilter);
  345. OIDN_LOCK(filter);
  346. filter->setProgressMonitorFunction(func, userPtr);
  347. OIDN_CATCH(filter)
  348. }
  349. OIDN_API void oidnCommitFilter(OIDNFilter hFilter)
  350. {
  351. Filter* filter = (Filter*)hFilter;
  352. OIDN_TRY
  353. checkHandle(hFilter);
  354. OIDN_LOCK(filter);
  355. filter->commit();
  356. OIDN_CATCH(filter)
  357. }
  358. OIDN_API void oidnExecuteFilter(OIDNFilter hFilter)
  359. {
  360. Filter* filter = (Filter*)hFilter;
  361. OIDN_TRY
  362. checkHandle(hFilter);
  363. OIDN_LOCK(filter);
  364. filter->execute();
  365. OIDN_CATCH(filter)
  366. }
  367. } // namespace oidn