coxpcall.lua 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. -------------------------------------------------------------------------------
  2. -- Coroutine safe xpcall and pcall versions
  3. --
  4. -- Encapsulates the protected calls with a coroutine based loop, so errors can
  5. -- be dealed without the usual Lua 5.x pcall/xpcall issues with coroutines
  6. -- yielding inside the call to pcall or xpcall.
  7. --
  8. -- Authors: Roberto Ierusalimschy and Andre Carregal
  9. -- Contributors: Thomas Harning Jr., Ignacio Burgueño, Fabio Mascarenhas
  10. --
  11. -- Copyright 2005 - Kepler Project
  12. --
  13. -- $Id: coxpcall.lua,v 1.13 2008/05/19 19:20:02 mascarenhas Exp $
  14. -------------------------------------------------------------------------------
  15. -------------------------------------------------------------------------------
  16. -- Checks if (x)pcall function is coroutine safe
  17. -------------------------------------------------------------------------------
  18. local function isCoroutineSafe(func)
  19. local co = coroutine.create(function()
  20. return func(coroutine.yield, function() end)
  21. end)
  22. coroutine.resume(co)
  23. return coroutine.resume(co)
  24. end
  25. -- No need to do anything if pcall and xpcall are already safe.
  26. if isCoroutineSafe(pcall) and isCoroutineSafe(xpcall) then
  27. copcall = pcall
  28. coxpcall = xpcall
  29. return { pcall = pcall, xpcall = xpcall, running = coroutine.running }
  30. end
  31. -------------------------------------------------------------------------------
  32. -- Implements xpcall with coroutines
  33. -------------------------------------------------------------------------------
  34. local performResume, handleReturnValue
  35. local oldpcall, oldxpcall = pcall, xpcall
  36. local pack = table.pack or function(...) return {n = select("#", ...), ...} end
  37. local unpack = table.unpack or unpack
  38. local running = coroutine.running
  39. local coromap = setmetatable({}, { __mode = "k" })
  40. function handleReturnValue(err, co, status, ...)
  41. if not status then
  42. return false, err(debug.traceback(co, (...)), ...)
  43. end
  44. if coroutine.status(co) == 'suspended' then
  45. return performResume(err, co, coroutine.yield(...))
  46. else
  47. return true, ...
  48. end
  49. end
  50. function performResume(err, co, ...)
  51. return handleReturnValue(err, co, coroutine.resume(co, ...))
  52. end
  53. local function id(trace, ...)
  54. return trace
  55. end
  56. function coxpcall(f, err, ...)
  57. local current = running()
  58. if not current then
  59. if err == id then
  60. return oldpcall(f, ...)
  61. else
  62. if select("#", ...) > 0 then
  63. local oldf, params = f, pack(...)
  64. f = function() return oldf(unpack(params, 1, params.n)) end
  65. end
  66. return oldxpcall(f, err)
  67. end
  68. else
  69. local res, co = oldpcall(coroutine.create, f)
  70. if not res then
  71. local newf = function(...) return f(...) end
  72. co = coroutine.create(newf)
  73. end
  74. coromap[co] = current
  75. return performResume(err, co, ...)
  76. end
  77. end
  78. local function corunning(coro)
  79. if coro ~= nil then
  80. assert(type(coro)=="thread", "Bad argument; expected thread, got: "..type(coro))
  81. else
  82. coro = running()
  83. end
  84. while coromap[coro] do
  85. coro = coromap[coro]
  86. end
  87. if coro == "mainthread" then return nil end
  88. return coro
  89. end
  90. -------------------------------------------------------------------------------
  91. -- Implements pcall with coroutines
  92. -------------------------------------------------------------------------------
  93. function copcall(f, ...)
  94. return coxpcall(f, id, ...)
  95. end
  96. return { pcall = copcall, xpcall = coxpcall, running = corunning }