asynccommon.nim 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. template createAsyncNativeSocketImpl(domain, sockType, protocol) =
  2. let handle = newNativeSocket(domain, sockType, protocol)
  3. if handle == osInvalidSocket:
  4. return osInvalidSocket.AsyncFD
  5. handle.setBlocking(false)
  6. when defined(macosx) and not defined(nimdoc):
  7. handle.setSockOptInt(SOL_SOCKET, SO_NOSIGPIPE, 1)
  8. result = handle.AsyncFD
  9. register(result)
  10. proc createAsyncNativeSocket*(domain: cint, sockType: cint,
  11. protocol: cint): AsyncFD =
  12. createAsyncNativeSocketImpl(domain, sockType, protocol)
  13. proc createAsyncNativeSocket*(domain: Domain = Domain.AF_INET,
  14. sockType: SockType = SOCK_STREAM,
  15. protocol: Protocol = IPPROTO_TCP): AsyncFD =
  16. createAsyncNativeSocketImpl(domain, sockType, protocol)
  17. proc newAsyncNativeSocket*(domain: cint, sockType: cint,
  18. protocol: cint): AsyncFD {.deprecated: "use createAsyncNativeSocket instead".} =
  19. createAsyncNativeSocketImpl(domain, sockType, protocol)
  20. proc newAsyncNativeSocket*(domain: Domain = Domain.AF_INET,
  21. sockType: SockType = SOCK_STREAM,
  22. protocol: Protocol = IPPROTO_TCP): AsyncFD
  23. {.deprecated: "use createAsyncNativeSocket instead".} =
  24. createAsyncNativeSocketImpl(domain, sockType, protocol)
  25. when defined(windows) or defined(nimdoc):
  26. proc bindToDomain(handle: SocketHandle, domain: Domain) =
  27. # Extracted into a separate proc, because connect() on Windows requires
  28. # the socket to be initially bound.
  29. template doBind(saddr) =
  30. if bindAddr(handle, cast[ptr SockAddr](addr(saddr)),
  31. sizeof(saddr).SockLen) < 0'i32:
  32. raiseOSError(osLastError())
  33. if domain == Domain.AF_INET6:
  34. var saddr: Sockaddr_in6
  35. saddr.sin6_family = uint16(toInt(domain))
  36. doBind(saddr)
  37. else:
  38. var saddr: Sockaddr_in
  39. saddr.sin_family = uint16(toInt(domain))
  40. doBind(saddr)
  41. proc doConnect(socket: AsyncFD, addrInfo: ptr AddrInfo): Future[void] =
  42. let retFuture = newFuture[void]("doConnect")
  43. result = retFuture
  44. var ol = PCustomOverlapped()
  45. GC_ref(ol)
  46. ol.data = CompletionData(fd: socket, cb:
  47. proc (fd: AsyncFD, bytesCount: Dword, errcode: OSErrorCode) =
  48. if not retFuture.finished:
  49. if errcode == OSErrorCode(-1):
  50. retFuture.complete()
  51. else:
  52. retFuture.fail(newException(OSError, osErrorMsg(errcode)))
  53. )
  54. let ret = connectEx(socket.SocketHandle, addrInfo.ai_addr,
  55. cint(addrInfo.ai_addrlen), nil, 0, nil,
  56. cast[POVERLAPPED](ol))
  57. if ret:
  58. # Request to connect completed immediately.
  59. retFuture.complete()
  60. # We don't deallocate ``ol`` here because even though this completed
  61. # immediately poll will still be notified about its completion and it
  62. # will free ``ol``.
  63. else:
  64. let lastError = osLastError()
  65. if lastError.int32 != ERROR_IO_PENDING:
  66. # With ERROR_IO_PENDING ``ol`` will be deallocated in ``poll``,
  67. # and the future will be completed/failed there, too.
  68. GC_unref(ol)
  69. retFuture.fail(newException(OSError, osErrorMsg(lastError)))
  70. else:
  71. proc doConnect(socket: AsyncFD, addrInfo: ptr AddrInfo): Future[void] =
  72. let retFuture = newFuture[void]("doConnect")
  73. result = retFuture
  74. proc cb(fd: AsyncFD): bool =
  75. let ret = SocketHandle(fd).getSockOptInt(
  76. cint(SOL_SOCKET), cint(SO_ERROR))
  77. if ret == 0:
  78. # We have connected.
  79. retFuture.complete()
  80. return true
  81. elif ret == EINTR:
  82. # interrupted, keep waiting
  83. return false
  84. else:
  85. retFuture.fail(newException(OSError, osErrorMsg(OSErrorCode(ret))))
  86. return true
  87. let ret = connect(socket.SocketHandle,
  88. addrInfo.ai_addr,
  89. addrInfo.ai_addrlen.Socklen)
  90. if ret == 0:
  91. # Request to connect completed immediately.
  92. retFuture.complete()
  93. else:
  94. let lastError = osLastError()
  95. if lastError.int32 == EINTR or lastError.int32 == EINPROGRESS:
  96. addWrite(socket, cb)
  97. else:
  98. retFuture.fail(newException(OSError, osErrorMsg(lastError)))
  99. template asyncAddrInfoLoop(addrInfo: ptr AddrInfo, fd: untyped,
  100. protocol: Protocol = IPPROTO_RAW) =
  101. ## Iterates through the AddrInfo linked list asynchronously
  102. ## until the connection can be established.
  103. const shouldCreateFd = not declared(fd)
  104. when shouldCreateFd:
  105. let sockType = protocol.toSockType()
  106. var fdPerDomain: array[low(Domain).ord..high(Domain).ord, AsyncFD]
  107. for i in low(fdPerDomain)..high(fdPerDomain):
  108. fdPerDomain[i] = osInvalidSocket.AsyncFD
  109. template closeUnusedFds(domainToKeep = -1) {.dirty.} =
  110. for i, fd in fdPerDomain:
  111. if fd != osInvalidSocket.AsyncFD and i != domainToKeep:
  112. fd.closeSocket()
  113. var lastException: ref Exception
  114. var curAddrInfo = addrInfo
  115. var domain: Domain
  116. when shouldCreateFd:
  117. var curFd: AsyncFD
  118. else:
  119. var curFd = fd
  120. proc tryNextAddrInfo(fut: Future[void]) {.gcsafe.} =
  121. if fut == nil or fut.failed:
  122. if fut != nil:
  123. lastException = fut.readError()
  124. while curAddrInfo != nil:
  125. let domainOpt = curAddrInfo.ai_family.toKnownDomain()
  126. if domainOpt.isSome:
  127. domain = domainOpt.unsafeGet()
  128. break
  129. curAddrInfo = curAddrInfo.ai_next
  130. if curAddrInfo == nil:
  131. freeAddrInfo(addrInfo)
  132. when shouldCreateFd:
  133. closeUnusedFds()
  134. if lastException != nil:
  135. retFuture.fail(lastException)
  136. else:
  137. retFuture.fail(newException(
  138. IOError, "Couldn't resolve address: " & address))
  139. return
  140. when shouldCreateFd:
  141. curFd = fdPerDomain[ord(domain)]
  142. if curFd == osInvalidSocket.AsyncFD:
  143. try:
  144. curFd = newAsyncNativeSocket(domain, sockType, protocol)
  145. except:
  146. freeAddrInfo(addrInfo)
  147. closeUnusedFds()
  148. raise getCurrentException()
  149. when defined(windows):
  150. curFd.SocketHandle.bindToDomain(domain)
  151. fdPerDomain[ord(domain)] = curFd
  152. doConnect(curFd, curAddrInfo).callback = tryNextAddrInfo
  153. curAddrInfo = curAddrInfo.ai_next
  154. else:
  155. freeAddrInfo(addrInfo)
  156. when shouldCreateFd:
  157. closeUnusedFds(ord(domain))
  158. retFuture.complete(curFd)
  159. else:
  160. retFuture.complete()
  161. tryNextAddrInfo(nil)
  162. proc dial*(address: string, port: Port,
  163. protocol: Protocol = IPPROTO_TCP): Future[AsyncFD] =
  164. ## Establishes connection to the specified ``address``:``port`` pair via the
  165. ## specified protocol. The procedure iterates through possible
  166. ## resolutions of the ``address`` until it succeeds, meaning that it
  167. ## seamlessly works with both IPv4 and IPv6.
  168. ## Returns the async file descriptor, registered in the dispatcher of
  169. ## the current thread, ready to send or receive data.
  170. let retFuture = newFuture[AsyncFD]("dial")
  171. result = retFuture
  172. let sockType = protocol.toSockType()
  173. let aiList = getAddrInfo(address, port, Domain.AF_UNSPEC, sockType, protocol)
  174. asyncAddrInfoLoop(aiList, noFD, protocol)
  175. proc connect*(socket: AsyncFD, address: string, port: Port,
  176. domain = Domain.AF_INET): Future[void] =
  177. let retFuture = newFuture[void]("connect")
  178. result = retFuture
  179. when defined(windows):
  180. verifyPresence(socket)
  181. else:
  182. assert getSockDomain(socket.SocketHandle) == domain
  183. let aiList = getAddrInfo(address, port, domain)
  184. when defined(windows):
  185. socket.SocketHandle.bindToDomain(domain)
  186. asyncAddrInfoLoop(aiList, socket)