net.nim 74 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013
  1. #
  2. #
  3. # Nim's Runtime Library
  4. # (c) Copyright 2015 Dominik Picheta
  5. #
  6. # See the file "copying.txt", included in this
  7. # distribution, for details about the copyright.
  8. #
  9. ## This module implements a high-level cross-platform sockets interface.
  10. ## The procedures implemented in this module are primarily for blocking sockets.
  11. ## For asynchronous non-blocking sockets use the ``asyncnet`` module together
  12. ## with the ``asyncdispatch`` module.
  13. ##
  14. ## The first thing you will always need to do in order to start using sockets,
  15. ## is to create a new instance of the ``Socket`` type using the ``newSocket``
  16. ## procedure.
  17. ##
  18. ## SSL
  19. ## ====
  20. ##
  21. ## In order to use the SSL procedures defined in this module, you will need to
  22. ## compile your application with the ``-d:ssl`` flag. See the
  23. ## `newContext<net.html#newContext%2Cstring%2Cstring%2Cstring%2Cstring%2Cstring>`_
  24. ## procedure for additional details.
  25. ##
  26. ## Examples
  27. ## ========
  28. ##
  29. ## Connecting to a server
  30. ## ----------------------
  31. ##
  32. ## After you create a socket with the ``newSocket`` procedure, you can easily
  33. ## connect it to a server running at a known hostname (or IP address) and port.
  34. ## To do so over TCP, use the example below.
  35. ##
  36. ## .. code-block:: Nim
  37. ## var socket = newSocket()
  38. ## socket.connect("google.com", Port(80))
  39. ##
  40. ## For SSL, use the following example (and make sure to compile with ``-d:ssl``):
  41. ##
  42. ## .. code-block:: Nim
  43. ## var socket = newSocket()
  44. ## var ctx = newContext()
  45. ## wrapSocket(ctx, socket)
  46. ## socket.connect("google.com", Port(443))
  47. ##
  48. ## UDP is a connectionless protocol, so UDP sockets don't have to explicitly
  49. ## call the `connect <net.html#connect%2CSocket%2Cstring>`_ procedure. They can
  50. ## simply start sending data immediately.
  51. ##
  52. ## .. code-block:: Nim
  53. ## var socket = newSocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP)
  54. ## socket.sendTo("192.168.0.1", Port(27960), "status\n")
  55. ##
  56. ## Creating a server
  57. ## -----------------
  58. ##
  59. ## After you create a socket with the ``newSocket`` procedure, you can create a
  60. ## TCP server by calling the ``bindAddr`` and ``listen`` procedures.
  61. ##
  62. ## .. code-block:: Nim
  63. ## var socket = newSocket()
  64. ## socket.bindAddr(Port(1234))
  65. ## socket.listen()
  66. ##
  67. ## You can then begin accepting connections using the ``accept`` procedure.
  68. ##
  69. ## .. code-block:: Nim
  70. ## var client: Socket
  71. ## var address = ""
  72. ## while true:
  73. ## socket.acceptAddr(client, address)
  74. ## echo("Client connected from: ", address)
  75. import std/private/since
  76. import nativesockets, os, strutils, times, sets, options, std/monotimes
  77. from ssl_certs import scanSSLCertificates
  78. import ssl_config
  79. export nativesockets.Port, nativesockets.`$`, nativesockets.`==`
  80. export Domain, SockType, Protocol
  81. const useWinVersion = defined(Windows) or defined(nimdoc)
  82. const defineSsl = defined(ssl) or defined(nimdoc)
  83. when useWinVersion:
  84. from winlean import WSAESHUTDOWN
  85. when defineSsl:
  86. import openssl
  87. # Note: The enumerations are mapped to Window's constants.
  88. when defineSsl:
  89. type
  90. Certificate* = string ## DER encoded certificate
  91. SslError* = object of CatchableError
  92. SslCVerifyMode* = enum
  93. CVerifyNone, CVerifyPeer, CVerifyPeerUseEnvVars
  94. SslProtVersion* = enum
  95. protSSLv2, protSSLv3, protTLSv1, protSSLv23
  96. SslContext* = ref object
  97. context*: SslCtx
  98. referencedData: HashSet[int]
  99. extraInternal: SslContextExtraInternal
  100. SslAcceptResult* = enum
  101. AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess
  102. SslHandshakeType* = enum
  103. handshakeAsClient, handshakeAsServer
  104. SslClientGetPskFunc* = proc(hint: string): tuple[identity: string, psk: string]
  105. SslServerGetPskFunc* = proc(identity: string): string
  106. SslContextExtraInternal = ref object of RootRef
  107. serverGetPskFunc: SslServerGetPskFunc
  108. clientGetPskFunc: SslClientGetPskFunc
  109. else:
  110. type
  111. SslContext* = ref object # TODO: Workaround #4797.
  112. const
  113. BufferSize*: int = 4000 ## size of a buffered socket's buffer
  114. MaxLineLength* = 1_000_000
  115. type
  116. SocketImpl* = object ## socket type
  117. fd: SocketHandle
  118. isBuffered: bool # determines whether this socket is buffered.
  119. buffer: array[0..BufferSize, char]
  120. currPos: int # current index in buffer
  121. bufLen: int # current length of buffer
  122. when defineSsl:
  123. isSsl: bool
  124. sslHandle: SslPtr
  125. sslContext: SslContext
  126. sslNoHandshake: bool # True if needs handshake.
  127. sslHasPeekChar: bool
  128. sslPeekChar: char
  129. sslNoShutdown: bool # True if shutdown shouldn't be done.
  130. lastError: OSErrorCode ## stores the last error on this socket
  131. domain: Domain
  132. sockType: SockType
  133. protocol: Protocol
  134. Socket* = ref SocketImpl
  135. SOBool* = enum ## Boolean socket options.
  136. OptAcceptConn, OptBroadcast, OptDebug, OptDontRoute, OptKeepAlive,
  137. OptOOBInline, OptReuseAddr, OptReusePort, OptNoDelay
  138. ReadLineResult* = enum ## result for readLineAsync
  139. ReadFullLine, ReadPartialLine, ReadDisconnected, ReadNone
  140. TimeoutError* = object of CatchableError
  141. SocketFlag* {.pure.} = enum
  142. Peek,
  143. SafeDisconn ## Ensures disconnection exceptions (ECONNRESET, EPIPE etc) are not thrown.
  144. when defined(nimHasStyleChecks):
  145. {.push styleChecks: off.}
  146. type
  147. IpAddressFamily* {.pure.} = enum ## Describes the type of an IP address
  148. IPv6, ## IPv6 address
  149. IPv4 ## IPv4 address
  150. IpAddress* = object ## stores an arbitrary IP address
  151. case family*: IpAddressFamily ## the type of the IP address (IPv4 or IPv6)
  152. of IpAddressFamily.IPv6:
  153. address_v6*: array[0..15, uint8] ## Contains the IP address in bytes in
  154. ## case of IPv6
  155. of IpAddressFamily.IPv4:
  156. address_v4*: array[0..3, uint8] ## Contains the IP address in bytes in
  157. ## case of IPv4
  158. when defined(nimHasStyleChecks):
  159. {.pop.}
  160. proc socketError*(socket: Socket, err: int = -1, async = false,
  161. lastError = (-1).OSErrorCode,
  162. flags: set[SocketFlag] = {}): void {.gcsafe.}
  163. proc isDisconnectionError*(flags: set[SocketFlag],
  164. lastError: OSErrorCode): bool =
  165. ## Determines whether ``lastError`` is a disconnection error. Only does this
  166. ## if flags contains ``SafeDisconn``.
  167. when useWinVersion:
  168. SocketFlag.SafeDisconn in flags and
  169. (lastError.int32 == WSAECONNRESET or
  170. lastError.int32 == WSAECONNABORTED or
  171. lastError.int32 == WSAENETRESET or
  172. lastError.int32 == WSAEDISCON or
  173. lastError.int32 == WSAESHUTDOWN or
  174. lastError.int32 == ERROR_NETNAME_DELETED)
  175. else:
  176. SocketFlag.SafeDisconn in flags and
  177. (lastError.int32 == ECONNRESET or
  178. lastError.int32 == EPIPE or
  179. lastError.int32 == ENETRESET)
  180. proc toOSFlags*(socketFlags: set[SocketFlag]): cint =
  181. ## Converts the flags into the underlying OS representation.
  182. for f in socketFlags:
  183. case f
  184. of SocketFlag.Peek:
  185. result = result or MSG_PEEK
  186. of SocketFlag.SafeDisconn: continue
  187. proc newSocket*(fd: SocketHandle, domain: Domain = AF_INET,
  188. sockType: SockType = SOCK_STREAM,
  189. protocol: Protocol = IPPROTO_TCP, buffered = true): owned(Socket) =
  190. ## Creates a new socket as specified by the params.
  191. assert fd != osInvalidSocket
  192. result = Socket(
  193. fd: fd,
  194. isBuffered: buffered,
  195. domain: domain,
  196. sockType: sockType,
  197. protocol: protocol)
  198. if buffered:
  199. result.currPos = 0
  200. # Set SO_NOSIGPIPE on OS X.
  201. when defined(macosx) and not defined(nimdoc):
  202. setSockOptInt(fd, SOL_SOCKET, SO_NOSIGPIPE, 1)
  203. proc newSocket*(domain, sockType, protocol: cint, buffered = true,
  204. inheritable = defined(nimInheritHandles)): owned(Socket) =
  205. ## Creates a new socket.
  206. ##
  207. ## The SocketHandle associated with the resulting Socket will not be
  208. ## inheritable by child processes by default. This can be changed via
  209. ## the `inheritable` parameter.
  210. ##
  211. ## If an error occurs OSError will be raised.
  212. let fd = createNativeSocket(domain, sockType, protocol, inheritable)
  213. if fd == osInvalidSocket:
  214. raiseOSError(osLastError())
  215. result = newSocket(fd, domain.Domain, sockType.SockType, protocol.Protocol,
  216. buffered)
  217. proc newSocket*(domain: Domain = AF_INET, sockType: SockType = SOCK_STREAM,
  218. protocol: Protocol = IPPROTO_TCP, buffered = true,
  219. inheritable = defined(nimInheritHandles)): owned(Socket) =
  220. ## Creates a new socket.
  221. ##
  222. ## The SocketHandle associated with the resulting Socket will not be
  223. ## inheritable by child processes by default. This can be changed via
  224. ## the `inheritable` parameter.
  225. ##
  226. ## If an error occurs OSError will be raised.
  227. let fd = createNativeSocket(domain, sockType, protocol, inheritable)
  228. if fd == osInvalidSocket:
  229. raiseOSError(osLastError())
  230. result = newSocket(fd, domain, sockType, protocol, buffered)
  231. proc parseIPv4Address(addressStr: string): IpAddress =
  232. ## Parses IPv4 addresses
  233. ## Raises ValueError on errors
  234. var
  235. byteCount = 0
  236. currentByte: uint16 = 0
  237. separatorValid = false
  238. result = IpAddress(family: IpAddressFamily.IPv4)
  239. for i in 0 .. high(addressStr):
  240. if addressStr[i] in strutils.Digits: # Character is a number
  241. currentByte = currentByte * 10 +
  242. cast[uint16](ord(addressStr[i]) - ord('0'))
  243. if currentByte > 255'u16:
  244. raise newException(ValueError,
  245. "Invalid IP Address. Value is out of range")
  246. separatorValid = true
  247. elif addressStr[i] == '.': # IPv4 address separator
  248. if not separatorValid or byteCount >= 3:
  249. raise newException(ValueError,
  250. "Invalid IP Address. The address consists of too many groups")
  251. result.address_v4[byteCount] = cast[uint8](currentByte)
  252. currentByte = 0
  253. byteCount.inc
  254. separatorValid = false
  255. else:
  256. raise newException(ValueError,
  257. "Invalid IP Address. Address contains an invalid character")
  258. if byteCount != 3 or not separatorValid:
  259. raise newException(ValueError, "Invalid IP Address")
  260. result.address_v4[byteCount] = cast[uint8](currentByte)
  261. proc parseIPv6Address(addressStr: string): IpAddress =
  262. ## Parses IPv6 addresses
  263. ## Raises ValueError on errors
  264. result = IpAddress(family: IpAddressFamily.IPv6)
  265. if addressStr.len < 2:
  266. raise newException(ValueError, "Invalid IP Address")
  267. var
  268. groupCount = 0
  269. currentGroupStart = 0
  270. currentShort: uint32 = 0
  271. separatorValid = true
  272. dualColonGroup = -1
  273. lastWasColon = false
  274. v4StartPos = -1
  275. byteCount = 0
  276. for i, c in addressStr:
  277. if c == ':':
  278. if not separatorValid:
  279. raise newException(ValueError,
  280. "Invalid IP Address. Address contains an invalid separator")
  281. if lastWasColon:
  282. if dualColonGroup != -1:
  283. raise newException(ValueError,
  284. "Invalid IP Address. Address contains more than one \"::\" separator")
  285. dualColonGroup = groupCount
  286. separatorValid = false
  287. elif i != 0 and i != high(addressStr):
  288. if groupCount >= 8:
  289. raise newException(ValueError,
  290. "Invalid IP Address. The address consists of too many groups")
  291. result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8)
  292. result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF)
  293. currentShort = 0
  294. groupCount.inc()
  295. if dualColonGroup != -1: separatorValid = false
  296. elif i == 0: # only valid if address starts with ::
  297. if addressStr[1] != ':':
  298. raise newException(ValueError,
  299. "Invalid IP Address. Address may not start with \":\"")
  300. else: # i == high(addressStr) - only valid if address ends with ::
  301. if addressStr[high(addressStr)-1] != ':':
  302. raise newException(ValueError,
  303. "Invalid IP Address. Address may not end with \":\"")
  304. lastWasColon = true
  305. currentGroupStart = i + 1
  306. elif c == '.': # Switch to parse IPv4 mode
  307. if i < 3 or not separatorValid or groupCount >= 7:
  308. raise newException(ValueError, "Invalid IP Address")
  309. v4StartPos = currentGroupStart
  310. currentShort = 0
  311. separatorValid = false
  312. break
  313. elif c in strutils.HexDigits:
  314. if c in strutils.Digits: # Normal digit
  315. currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('0'))
  316. elif c >= 'a' and c <= 'f': # Lower case hex
  317. currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('a')) + 10
  318. else: # Upper case hex
  319. currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('A')) + 10
  320. if currentShort > 65535'u32:
  321. raise newException(ValueError,
  322. "Invalid IP Address. Value is out of range")
  323. lastWasColon = false
  324. separatorValid = true
  325. else:
  326. raise newException(ValueError,
  327. "Invalid IP Address. Address contains an invalid character")
  328. if v4StartPos == -1: # Don't parse v4. Copy the remaining v6 stuff
  329. if separatorValid: # Copy remaining data
  330. if groupCount >= 8:
  331. raise newException(ValueError,
  332. "Invalid IP Address. The address consists of too many groups")
  333. result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8)
  334. result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF)
  335. groupCount.inc()
  336. else: # Must parse IPv4 address
  337. for i, c in addressStr[v4StartPos..high(addressStr)]:
  338. if c in strutils.Digits: # Character is a number
  339. currentShort = currentShort * 10 + cast[uint32](ord(c) - ord('0'))
  340. if currentShort > 255'u32:
  341. raise newException(ValueError,
  342. "Invalid IP Address. Value is out of range")
  343. separatorValid = true
  344. elif c == '.': # IPv4 address separator
  345. if not separatorValid or byteCount >= 3:
  346. raise newException(ValueError, "Invalid IP Address")
  347. result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort)
  348. currentShort = 0
  349. byteCount.inc()
  350. separatorValid = false
  351. else: # Invalid character
  352. raise newException(ValueError,
  353. "Invalid IP Address. Address contains an invalid character")
  354. if byteCount != 3 or not separatorValid:
  355. raise newException(ValueError, "Invalid IP Address")
  356. result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort)
  357. groupCount += 2
  358. # Shift and fill zeros in case of ::
  359. if groupCount > 8:
  360. raise newException(ValueError,
  361. "Invalid IP Address. The address consists of too many groups")
  362. elif groupCount < 8: # must fill
  363. if dualColonGroup == -1:
  364. raise newException(ValueError,
  365. "Invalid IP Address. The address consists of too few groups")
  366. var toFill = 8 - groupCount # The number of groups to fill
  367. var toShift = groupCount - dualColonGroup # Nr of known groups after ::
  368. for i in 0..2*toShift-1: # shift
  369. result.address_v6[15-i] = result.address_v6[groupCount*2-i-1]
  370. for i in 0..2*toFill-1: # fill with 0s
  371. result.address_v6[dualColonGroup*2+i] = 0
  372. elif dualColonGroup != -1:
  373. raise newException(ValueError,
  374. "Invalid IP Address. The address consists of too many groups")
  375. proc parseIpAddress*(addressStr: string): IpAddress =
  376. ## Parses an IP address
  377. ## Raises ValueError on error
  378. if addressStr.len == 0:
  379. raise newException(ValueError, "IP Address string is empty")
  380. if addressStr.contains(':'):
  381. return parseIPv6Address(addressStr)
  382. else:
  383. return parseIPv4Address(addressStr)
  384. proc isIpAddress*(addressStr: string): bool {.tags: [].} =
  385. ## Checks if a string is an IP address
  386. ## Returns true if it is, false otherwise
  387. try:
  388. discard parseIpAddress(addressStr)
  389. except ValueError:
  390. return false
  391. return true
  392. proc toSockAddr*(address: IpAddress, port: Port, sa: var Sockaddr_storage,
  393. sl: var SockLen) =
  394. ## Converts `IpAddress` and `Port` to `SockAddr` and `SockLen`
  395. let port = htons(uint16(port))
  396. case address.family
  397. of IpAddressFamily.IPv4:
  398. sl = sizeof(Sockaddr_in).SockLen
  399. let s = cast[ptr Sockaddr_in](addr sa)
  400. s.sin_family = type(s.sin_family)(toInt(AF_INET))
  401. s.sin_port = port
  402. copyMem(addr s.sin_addr, unsafeAddr address.address_v4[0],
  403. sizeof(s.sin_addr))
  404. of IpAddressFamily.IPv6:
  405. sl = sizeof(Sockaddr_in6).SockLen
  406. let s = cast[ptr Sockaddr_in6](addr sa)
  407. s.sin6_family = type(s.sin6_family)(toInt(AF_INET6))
  408. s.sin6_port = port
  409. copyMem(addr s.sin6_addr, unsafeAddr address.address_v6[0],
  410. sizeof(s.sin6_addr))
  411. proc fromSockAddrAux(sa: ptr Sockaddr_storage, sl: SockLen,
  412. address: var IpAddress, port: var Port) =
  413. if sa.ss_family.cint == toInt(AF_INET) and sl == sizeof(Sockaddr_in).SockLen:
  414. address = IpAddress(family: IpAddressFamily.IPv4)
  415. let s = cast[ptr Sockaddr_in](sa)
  416. copyMem(addr address.address_v4[0], addr s.sin_addr,
  417. sizeof(address.address_v4))
  418. port = ntohs(s.sin_port).Port
  419. elif sa.ss_family.cint == toInt(AF_INET6) and
  420. sl == sizeof(Sockaddr_in6).SockLen:
  421. address = IpAddress(family: IpAddressFamily.IPv6)
  422. let s = cast[ptr Sockaddr_in6](sa)
  423. copyMem(addr address.address_v6[0], addr s.sin6_addr,
  424. sizeof(address.address_v6))
  425. port = ntohs(s.sin6_port).Port
  426. else:
  427. raise newException(ValueError, "Neither IPv4 nor IPv6")
  428. proc fromSockAddr*(sa: Sockaddr_storage | SockAddr | Sockaddr_in | Sockaddr_in6,
  429. sl: SockLen, address: var IpAddress, port: var Port) {.inline.} =
  430. ## Converts `SockAddr` and `SockLen` to `IpAddress` and `Port`. Raises
  431. ## `ObjectConversionDefect` in case of invalid `sa` and `sl` arguments.
  432. fromSockAddrAux(cast[ptr Sockaddr_storage](unsafeAddr sa), sl, address, port)
  433. when defineSsl:
  434. CRYPTO_malloc_init()
  435. doAssert SslLibraryInit() == 1
  436. SSL_load_error_strings()
  437. ERR_load_BIO_strings()
  438. OpenSSL_add_all_algorithms()
  439. proc raiseSSLError*(s = "") =
  440. ## Raises a new SSL error.
  441. if s != "":
  442. raise newException(SslError, s)
  443. let err = ERR_peek_last_error()
  444. if err == 0:
  445. raise newException(SslError, "No error reported.")
  446. var errStr = $ERR_error_string(err, nil)
  447. case err
  448. of 336032814, 336032784:
  449. errStr = "Please upgrade your OpenSSL library, it does not support the " &
  450. "necessary protocols. OpenSSL error is: " & errStr
  451. else:
  452. discard
  453. raise newException(SslError, errStr)
  454. proc getExtraData*(ctx: SslContext, index: int): RootRef =
  455. ## Retrieves arbitrary data stored inside SslContext.
  456. if index notin ctx.referencedData:
  457. raise newException(IndexDefect, "No data with that index.")
  458. let res = ctx.context.SSL_CTX_get_ex_data(index.cint)
  459. if cast[int](res) == 0:
  460. raiseSSLError()
  461. return cast[RootRef](res)
  462. proc setExtraData*(ctx: SslContext, index: int, data: RootRef) =
  463. ## Stores arbitrary data inside SslContext. The unique `index`
  464. ## should be retrieved using getSslContextExtraDataIndex.
  465. if index in ctx.referencedData:
  466. GC_unref(getExtraData(ctx, index))
  467. if ctx.context.SSL_CTX_set_ex_data(index.cint, cast[pointer](data)) == -1:
  468. raiseSSLError()
  469. if index notin ctx.referencedData:
  470. ctx.referencedData.incl(index)
  471. GC_ref(data)
  472. # http://simplestcodings.blogspot.co.uk/2010/08/secure-server-client-using-openssl-in-c.html
  473. proc loadCertificates(ctx: SslCtx, certFile, keyFile: string) =
  474. if certFile != "" and not fileExists(certFile):
  475. raise newException(system.IOError,
  476. "Certificate file could not be found: " & certFile)
  477. if keyFile != "" and not fileExists(keyFile):
  478. raise newException(system.IOError, "Key file could not be found: " & keyFile)
  479. if certFile != "":
  480. var ret = SSL_CTX_use_certificate_chain_file(ctx, certFile)
  481. if ret != 1:
  482. raiseSSLError()
  483. # TODO: Password? www.rtfm.com/openssl-examples/part1.pdf
  484. if keyFile != "":
  485. if SSL_CTX_use_PrivateKey_file(ctx, keyFile,
  486. SSL_FILETYPE_PEM) != 1:
  487. raiseSSLError()
  488. if SSL_CTX_check_private_key(ctx) != 1:
  489. raiseSSLError("Verification of private key file failed.")
  490. proc newContext*(protVersion = protSSLv23, verifyMode = CVerifyPeer,
  491. certFile = "", keyFile = "", cipherList = CiphersIntermediate,
  492. caDir = "", caFile = ""): SSLContext =
  493. ## Creates an SSL context.
  494. ##
  495. ## Protocol version specifies the protocol to use. SSLv2, SSLv3, TLSv1
  496. ## are available with the addition of ``protSSLv23`` which allows for
  497. ## compatibility with all of them.
  498. ##
  499. ## There are three options for verify mode:
  500. ## ``CVerifyNone``: certificates are not verified;
  501. ## ``CVerifyPeer``: certificates are verified;
  502. ## ``CVerifyPeerUseEnvVars``: certificates are verified and the optional
  503. ## environment variables SSL_CERT_FILE and SSL_CERT_DIR are also used to
  504. ## locate certificates
  505. ##
  506. ## The `nimDisableCertificateValidation` define overrides verifyMode and
  507. ## disables certificate verification globally!
  508. ##
  509. ## CA certificates will be loaded, in the following order, from:
  510. ##
  511. ## - caFile, caDir, parameters, if set
  512. ## - if `verifyMode` is set to ``CVerifyPeerUseEnvVars``,
  513. ## the SSL_CERT_FILE and SSL_CERT_DIR environment variables are used
  514. ## - a set of files and directories from the `ssl_certs <ssl_certs.html>`_ file.
  515. ##
  516. ## The last two parameters specify the certificate file path and the key file
  517. ## path, a server socket will most likely not work without these.
  518. ##
  519. ## Certificates can be generated using the following command:
  520. ## - ``openssl req -x509 -nodes -days 365 -newkey rsa:4096 -keyout mykey.pem -out mycert.pem``
  521. ## or using ECDSA:
  522. ## - ``openssl ecparam -out mykey.pem -name secp256k1 -genkey``
  523. ## - ``openssl req -new -key mykey.pem -x509 -nodes -days 365 -out mycert.pem``
  524. var newCTX: SslCtx
  525. case protVersion
  526. of protSSLv23:
  527. newCTX = SSL_CTX_new(SSLv23_method()) # SSlv2,3 and TLS1 support.
  528. of protSSLv2:
  529. raiseSSLError("SSLv2 is no longer secure and has been deprecated, use protSSLv23")
  530. of protSSLv3:
  531. raiseSSLError("SSLv3 is no longer secure and has been deprecated, use protSSLv23")
  532. of protTLSv1:
  533. newCTX = SSL_CTX_new(TLSv1_method())
  534. if newCTX.SSL_CTX_set_cipher_list(cipherList) != 1:
  535. raiseSSLError()
  536. when not defined(openssl10) and not defined(libressl):
  537. let sslVersion = getOpenSSLVersion()
  538. if sslVersion >= 0x010101000 and not sslVersion == 0x020000000:
  539. # In OpenSSL >= 1.1.1, TLSv1.3 cipher suites can only be configured via
  540. # this API.
  541. if newCTX.SSL_CTX_set_ciphersuites(cipherList) != 1:
  542. raiseSSLError()
  543. # Automatically the best ECDH curve for client exchange. Without this, ECDH
  544. # ciphers will be ignored by the server.
  545. #
  546. # From OpenSSL >= 1.1.0, this setting is set by default and can't be
  547. # overriden.
  548. if newCTX.SSL_CTX_set_ecdh_auto(1) != 1:
  549. raiseSSLError()
  550. when defined(nimDisableCertificateValidation) or defined(windows):
  551. newCTX.SSL_CTX_set_verify(SSL_VERIFY_NONE, nil)
  552. else:
  553. case verifyMode
  554. of CVerifyPeer, CVerifyPeerUseEnvVars:
  555. newCTX.SSL_CTX_set_verify(SSL_VERIFY_PEER, nil)
  556. of CVerifyNone:
  557. newCTX.SSL_CTX_set_verify(SSL_VERIFY_NONE, nil)
  558. if newCTX == nil:
  559. raiseSSLError()
  560. discard newCTX.SSLCTXSetMode(SSL_MODE_AUTO_RETRY)
  561. newCTX.loadCertificates(certFile, keyFile)
  562. when not defined(nimDisableCertificateValidation) and not defined(windows):
  563. if verifyMode != CVerifyNone:
  564. # Use the caDir and caFile parameters if set
  565. if caDir != "" or caFile != "":
  566. if newCTX.SSL_CTX_load_verify_locations(caFile, caDir) != 0:
  567. raise newException(IOError, "Failed to load SSL/TLS CA certificate(s).")
  568. else:
  569. # Scan for certs in known locations. For CVerifyPeerUseEnvVars also scan
  570. # the SSL_CERT_FILE and SSL_CERT_DIR env vars
  571. var found = false
  572. for fn in scanSSLCertificates():
  573. if newCTX.SSL_CTX_load_verify_locations(fn, "") == 0:
  574. found = true
  575. break
  576. if not found:
  577. raise newException(IOError, "No SSL/TLS CA certificates found.")
  578. result = SSLContext(context: newCTX, referencedData: initHashSet[int](),
  579. extraInternal: new(SslContextExtraInternal))
  580. proc getExtraInternal(ctx: SslContext): SslContextExtraInternal =
  581. return ctx.extraInternal
  582. proc destroyContext*(ctx: SslContext) =
  583. ## Free memory referenced by SslContext.
  584. # We assume here that OpenSSL's internal indexes increase by 1 each time.
  585. # That means we can assume that the next internal index is the length of
  586. # extra data indexes.
  587. for i in ctx.referencedData:
  588. GC_unref(getExtraData(ctx, i).RootRef)
  589. ctx.context.SSL_CTX_free()
  590. proc `pskIdentityHint=`*(ctx: SslContext, hint: string) =
  591. ## Sets the identity hint passed to server.
  592. ##
  593. ## Only used in PSK ciphersuites.
  594. if ctx.context.SSL_CTX_use_psk_identity_hint(hint) <= 0:
  595. raiseSSLError()
  596. proc clientGetPskFunc*(ctx: SslContext): SslClientGetPskFunc =
  597. return ctx.getExtraInternal().clientGetPskFunc
  598. proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring;
  599. max_identity_len: cuint; psk: ptr cuchar;
  600. max_psk_len: cuint): cuint {.cdecl.} =
  601. let ctx = SslContext(context: ssl.SSL_get_SSL_CTX)
  602. let hintString = if hint == nil: "" else: $hint
  603. let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString)
  604. if psk.len.cuint > max_psk_len:
  605. return 0
  606. if identityString.len.cuint >= max_identity_len:
  607. return 0
  608. copyMem(identity, identityString.cstring, pskString.len + 1) # with the last zero byte
  609. copyMem(psk, pskString.cstring, pskString.len)
  610. return pskString.len.cuint
  611. proc `clientGetPskFunc=`*(ctx: SslContext, fun: SslClientGetPskFunc) =
  612. ## Sets function that returns the client identity and the PSK based on identity
  613. ## hint from the server.
  614. ##
  615. ## Only used in PSK ciphersuites.
  616. ctx.getExtraInternal().clientGetPskFunc = fun
  617. ctx.context.SSL_CTX_set_psk_client_callback(
  618. if fun == nil: nil else: pskClientCallback)
  619. proc serverGetPskFunc*(ctx: SslContext): SslServerGetPskFunc =
  620. return ctx.getExtraInternal().serverGetPskFunc
  621. proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar;
  622. max_psk_len: cint): cuint {.cdecl.} =
  623. let ctx = SslContext(context: ssl.SSL_get_SSL_CTX)
  624. let pskString = (ctx.serverGetPskFunc)($identity)
  625. if psk.len.cint > max_psk_len:
  626. return 0
  627. copyMem(psk, pskString.cstring, pskString.len)
  628. return pskString.len.cuint
  629. proc `serverGetPskFunc=`*(ctx: SslContext, fun: SslServerGetPskFunc) =
  630. ## Sets function that returns PSK based on the client identity.
  631. ##
  632. ## Only used in PSK ciphersuites.
  633. ctx.getExtraInternal().serverGetPskFunc = fun
  634. ctx.context.SSL_CTX_set_psk_server_callback(if fun == nil: nil
  635. else: pskServerCallback)
  636. proc getPskIdentity*(socket: Socket): string =
  637. ## Gets the PSK identity provided by the client.
  638. assert socket.isSsl
  639. return $(socket.sslHandle.SSL_get_psk_identity)
  640. proc wrapSocket*(ctx: SslContext, socket: Socket) =
  641. ## Wraps a socket in an SSL context. This function effectively turns
  642. ## ``socket`` into an SSL socket.
  643. ##
  644. ## This must be called on an unconnected socket; an SSL session will
  645. ## be started when the socket is connected.
  646. ##
  647. ## FIXME:
  648. ## **Disclaimer**: This code is not well tested, may be very unsafe and
  649. ## prone to security vulnerabilities.
  650. assert(not socket.isSsl)
  651. socket.isSsl = true
  652. socket.sslContext = ctx
  653. socket.sslHandle = SSL_new(socket.sslContext.context)
  654. socket.sslNoHandshake = false
  655. socket.sslHasPeekChar = false
  656. socket.sslNoShutdown = false
  657. if socket.sslHandle == nil:
  658. raiseSSLError()
  659. if SSL_set_fd(socket.sslHandle, socket.fd) != 1:
  660. raiseSSLError()
  661. proc checkCertName(socket: Socket, hostname: string) =
  662. ## Check if the certificate Subject Alternative Name (SAN) or Subject CommonName (CN) matches hostname.
  663. ## Wildcards match only in the left-most label.
  664. ## When name starts with a dot it will be matched by a certificate valid for any subdomain
  665. when not defined(nimDisableCertificateValidation) and not defined(windows):
  666. assert socket.isSSL
  667. let certificate = socket.sslHandle.SSL_get_peer_certificate()
  668. if certificate.isNil:
  669. raiseSSLError("No SSL certificate found.")
  670. const X509_CHECK_FLAG_ALWAYS_CHECK_SUBJECT = 0x1.cuint
  671. const size = 1024
  672. var peername: string = newString(size)
  673. let match = certificate.X509_check_host(hostname.cstring, hostname.len.cint,
  674. X509_CHECK_FLAG_ALWAYS_CHECK_SUBJECT, peername)
  675. if match != 1:
  676. raiseSSLError("SSL Certificate check failed.")
  677. proc wrapConnectedSocket*(ctx: SSLContext, socket: Socket,
  678. handshake: SslHandshakeType,
  679. hostname: string = "") =
  680. ## Wraps a connected socket in an SSL context. This function effectively
  681. ## turns ``socket`` into an SSL socket.
  682. ## ``hostname`` should be specified so that the client knows which hostname
  683. ## the server certificate should be validated against.
  684. ##
  685. ## This should be called on a connected socket, and will perform
  686. ## an SSL handshake immediately.
  687. ##
  688. ## FIXME:
  689. ## **Disclaimer**: This code is not well tested, may be very unsafe and
  690. ## prone to security vulnerabilities.
  691. wrapSocket(ctx, socket)
  692. case handshake
  693. of handshakeAsClient:
  694. if hostname.len > 0 and not isIpAddress(hostname):
  695. # Discard result in case OpenSSL version doesn't support SNI, or we're
  696. # not using TLSv1+
  697. discard SSL_set_tlsext_host_name(socket.sslHandle, hostname)
  698. ErrClearError()
  699. let ret = SSL_connect(socket.sslHandle)
  700. socketError(socket, ret)
  701. when not defined(nimDisableCertificateValidation) and not defined(windows):
  702. if hostname.len > 0 and not isIpAddress(hostname):
  703. socket.checkCertName(hostname)
  704. of handshakeAsServer:
  705. ErrClearError()
  706. let ret = SSL_accept(socket.sslHandle)
  707. socketError(socket, ret)
  708. proc getPeerCertificates*(sslHandle: SslPtr): seq[Certificate] {.since: (1, 1).} =
  709. ## Returns the certificate chain received by the peer we are connected to
  710. ## through the OpenSSL connection represented by ``sslHandle``.
  711. ## The handshake must have been completed and the certificate chain must
  712. ## have been verified successfully or else an empty sequence is returned.
  713. ## The chain is ordered from leaf certificate to root certificate.
  714. result = newSeq[Certificate]()
  715. if SSL_get_verify_result(sslHandle) != X509_V_OK:
  716. return
  717. let stack = SSL_get0_verified_chain(sslHandle)
  718. if stack == nil:
  719. return
  720. let length = OPENSSL_sk_num(stack)
  721. if length == 0:
  722. return
  723. for i in 0 .. length - 1:
  724. let x509 = cast[PX509](OPENSSL_sk_value(stack, i))
  725. result.add(i2d_X509(x509))
  726. proc getPeerCertificates*(socket: Socket): seq[Certificate] {.since: (1, 1).} =
  727. ## Returns the certificate chain received by the peer we are connected to
  728. ## through the given socket.
  729. ## The handshake must have been completed and the certificate chain must
  730. ## have been verified successfully or else an empty sequence is returned.
  731. ## The chain is ordered from leaf certificate to root certificate.
  732. if not socket.isSsl:
  733. result = newSeq[Certificate]()
  734. else:
  735. result = getPeerCertificates(socket.sslHandle)
  736. proc `sessionIdContext=`*(ctx: SslContext, sidCtx: string) =
  737. ## Sets the session id context in which a session can be reused.
  738. ## Used for permitting clients to reuse a session id instead of
  739. ## doing a new handshake.
  740. ##
  741. ## TLS clients might attempt to resume a session using the session id context,
  742. ## thus it must be set if verifyMode is set to CVerifyPeer or CVerifyPeerUseEnvVars,
  743. ## otherwise the connection will fail and SslError will be raised if resumption occurs.
  744. ##
  745. ## - Only useful if set server-side.
  746. ## - Should be unique per-application to prevent clients from malfunctioning.
  747. ## - sidCtx must be at most 32 characters in length.
  748. if sidCtx.len > 32:
  749. raiseSSLError("sessionIdContext must be shorter than 32 characters")
  750. SSL_CTX_set_session_id_context(ctx.context, sidCtx, sidCtx.len)
  751. proc getSocketError*(socket: Socket): OSErrorCode =
  752. ## Checks ``osLastError`` for a valid error. If it has been reset it uses
  753. ## the last error stored in the socket object.
  754. result = osLastError()
  755. if result == 0.OSErrorCode:
  756. result = socket.lastError
  757. if result == 0.OSErrorCode:
  758. raiseOSError(result, "No valid socket error code available")
  759. proc socketError*(socket: Socket, err: int = -1, async = false,
  760. lastError = (-1).OSErrorCode,
  761. flags: set[SocketFlag] = {}) =
  762. ## Raises an OSError based on the error code returned by ``SSL_get_error``
  763. ## (for SSL sockets) and ``osLastError`` otherwise.
  764. ##
  765. ## If ``async`` is ``true`` no error will be thrown in the case when the
  766. ## error was caused by no data being available to be read.
  767. ##
  768. ## If ``err`` is not lower than 0 no exception will be raised.
  769. ##
  770. ## If ``flags`` contains ``SafeDisconn``, no exception will be raised
  771. ## when the error was caused by a peer disconnection.
  772. when defineSsl:
  773. if socket.isSsl:
  774. if err <= 0:
  775. var ret = SSL_get_error(socket.sslHandle, err.cint)
  776. case ret
  777. of SSL_ERROR_ZERO_RETURN:
  778. raiseSSLError("TLS/SSL connection failed to initiate, socket closed prematurely.")
  779. of SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_ACCEPT:
  780. if async:
  781. return
  782. else: raiseSSLError("Not enough data on socket.")
  783. of SSL_ERROR_WANT_WRITE, SSL_ERROR_WANT_READ:
  784. if async:
  785. return
  786. else: raiseSSLError("Not enough data on socket.")
  787. of SSL_ERROR_WANT_X509_LOOKUP:
  788. raiseSSLError("Function for x509 lookup has been called.")
  789. of SSL_ERROR_SYSCALL:
  790. # SSL shutdown must not be done if a fatal error occurred.
  791. socket.sslNoShutdown = true
  792. let osErr = osLastError()
  793. if not flags.isDisconnectionError(osErr):
  794. var errStr = "IO error has occurred "
  795. let sslErr = ERR_peek_last_error()
  796. if sslErr == 0 and err == 0:
  797. errStr.add "because an EOF was observed that violates the protocol"
  798. elif sslErr == 0 and err == -1:
  799. errStr.add "in the BIO layer"
  800. else:
  801. let errStr = $ERR_error_string(sslErr, nil)
  802. raiseSSLError(errStr & ": " & errStr)
  803. raiseOSError(osErr, errStr)
  804. of SSL_ERROR_SSL:
  805. # SSL shutdown must not be done if a fatal error occurred.
  806. socket.sslNoShutdown = true
  807. raiseSSLError()
  808. else: raiseSSLError("Unknown Error")
  809. if err == -1 and not (when defineSsl: socket.isSsl else: false):
  810. var lastE = if lastError.int == -1: getSocketError(socket) else: lastError
  811. if not flags.isDisconnectionError(lastE):
  812. if async:
  813. when useWinVersion:
  814. if lastE.int32 == WSAEWOULDBLOCK:
  815. return
  816. else: raiseOSError(lastE)
  817. else:
  818. if lastE.int32 == EAGAIN or lastE.int32 == EWOULDBLOCK:
  819. return
  820. else: raiseOSError(lastE)
  821. else: raiseOSError(lastE)
  822. proc listen*(socket: Socket, backlog = SOMAXCONN) {.tags: [ReadIOEffect].} =
  823. ## Marks ``socket`` as accepting connections.
  824. ## ``Backlog`` specifies the maximum length of the
  825. ## queue of pending connections.
  826. ##
  827. ## Raises an OSError error upon failure.
  828. if nativesockets.listen(socket.fd, backlog) < 0'i32:
  829. raiseOSError(osLastError())
  830. proc bindAddr*(socket: Socket, port = Port(0), address = "") {.
  831. tags: [ReadIOEffect].} =
  832. ## Binds ``address``:``port`` to the socket.
  833. ##
  834. ## If ``address`` is "" then ADDR_ANY will be bound.
  835. var realaddr = address
  836. if realaddr == "":
  837. case socket.domain
  838. of AF_INET6: realaddr = "::"
  839. of AF_INET: realaddr = "0.0.0.0"
  840. else:
  841. raise newException(ValueError,
  842. "Unknown socket address family and no address specified to bindAddr")
  843. var aiList = getAddrInfo(realaddr, port, socket.domain)
  844. if bindAddr(socket.fd, aiList.ai_addr, aiList.ai_addrlen.SockLen) < 0'i32:
  845. freeaddrinfo(aiList)
  846. raiseOSError(osLastError())
  847. freeaddrinfo(aiList)
  848. proc acceptAddr*(server: Socket, client: var owned(Socket), address: var string,
  849. flags = {SocketFlag.SafeDisconn},
  850. inheritable = defined(nimInheritHandles)) {.
  851. tags: [ReadIOEffect], gcsafe, locks: 0.} =
  852. ## Blocks until a connection is being made from a client. When a connection
  853. ## is made sets ``client`` to the client socket and ``address`` to the address
  854. ## of the connecting client.
  855. ## This function will raise OSError if an error occurs.
  856. ##
  857. ## The resulting client will inherit any properties of the server socket. For
  858. ## example: whether the socket is buffered or not.
  859. ##
  860. ## The SocketHandle associated with the resulting client will not be
  861. ## inheritable by child processes by default. This can be changed via
  862. ## the `inheritable` parameter.
  863. ##
  864. ## The ``accept`` call may result in an error if the connecting socket
  865. ## disconnects during the duration of the ``accept``. If the ``SafeDisconn``
  866. ## flag is specified then this error will not be raised and instead
  867. ## accept will be called again.
  868. if client.isNil:
  869. new(client)
  870. let ret = accept(server.fd, inheritable)
  871. let sock = ret[0]
  872. if sock == osInvalidSocket:
  873. let err = osLastError()
  874. if flags.isDisconnectionError(err):
  875. acceptAddr(server, client, address, flags, inheritable)
  876. raiseOSError(err)
  877. else:
  878. address = ret[1]
  879. client.fd = sock
  880. client.domain = getSockDomain(sock)
  881. client.isBuffered = server.isBuffered
  882. # Handle SSL.
  883. when defineSsl:
  884. if server.isSsl:
  885. # We must wrap the client sock in a ssl context.
  886. server.sslContext.wrapSocket(client)
  887. ErrClearError()
  888. let ret = SSL_accept(client.sslHandle)
  889. socketError(client, ret, false)
  890. when false: #defineSsl:
  891. proc acceptAddrSSL*(server: Socket, client: var Socket,
  892. address: var string): SSL_acceptResult {.
  893. tags: [ReadIOEffect].} =
  894. ## This procedure should only be used for non-blocking **SSL** sockets.
  895. ## It will immediately return with one of the following values:
  896. ##
  897. ## ``AcceptSuccess`` will be returned when a client has been successfully
  898. ## accepted and the handshake has been successfully performed between
  899. ## ``server`` and the newly connected client.
  900. ##
  901. ## ``AcceptNoHandshake`` will be returned when a client has been accepted
  902. ## but no handshake could be performed. This can happen when the client
  903. ## connects but does not yet initiate a handshake. In this case
  904. ## ``acceptAddrSSL`` should be called again with the same parameters.
  905. ##
  906. ## ``AcceptNoClient`` will be returned when no client is currently attempting
  907. ## to connect.
  908. template doHandshake(): untyped =
  909. when defineSsl:
  910. if server.isSsl:
  911. client.setBlocking(false)
  912. # We must wrap the client sock in a ssl context.
  913. if not client.isSsl or client.sslHandle == nil:
  914. server.sslContext.wrapSocket(client)
  915. ErrClearError()
  916. let ret = SSL_accept(client.sslHandle)
  917. while ret <= 0:
  918. let err = SSL_get_error(client.sslHandle, ret)
  919. if err != SSL_ERROR_WANT_ACCEPT:
  920. case err
  921. of SSL_ERROR_ZERO_RETURN:
  922. raiseSSLError("TLS/SSL connection failed to initiate, socket closed prematurely.")
  923. of SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE,
  924. SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_ACCEPT:
  925. client.sslNoHandshake = true
  926. return AcceptNoHandshake
  927. of SSL_ERROR_WANT_X509_LOOKUP:
  928. raiseSSLError("Function for x509 lookup has been called.")
  929. of SSL_ERROR_SYSCALL, SSL_ERROR_SSL:
  930. raiseSSLError()
  931. else:
  932. raiseSSLError("Unknown error")
  933. client.sslNoHandshake = false
  934. if client.isSsl and client.sslNoHandshake:
  935. doHandshake()
  936. return AcceptSuccess
  937. else:
  938. acceptAddrPlain(AcceptNoClient, AcceptSuccess):
  939. doHandshake()
  940. proc accept*(server: Socket, client: var owned(Socket),
  941. flags = {SocketFlag.SafeDisconn},
  942. inheritable = defined(nimInheritHandles))
  943. {.tags: [ReadIOEffect].} =
  944. ## Equivalent to ``acceptAddr`` but doesn't return the address, only the
  945. ## socket.
  946. ##
  947. ## The SocketHandle associated with the resulting client will not be
  948. ## inheritable by child processes by default. This can be changed via
  949. ## the `inheritable` parameter.
  950. ##
  951. ## The ``accept`` call may result in an error if the connecting socket
  952. ## disconnects during the duration of the ``accept``. If the ``SafeDisconn``
  953. ## flag is specified then this error will not be raised and instead
  954. ## accept will be called again.
  955. var addrDummy = ""
  956. acceptAddr(server, client, addrDummy, flags)
  957. when defined(posix) and not defined(lwip):
  958. from posix import Sigset, sigwait, sigismember, sigemptyset, sigaddset,
  959. sigprocmask, pthread_sigmask, SIGPIPE, SIG_BLOCK, SIG_UNBLOCK
  960. template blockSigpipe(body: untyped): untyped =
  961. ## Temporary block SIGPIPE within the provided code block. If SIGPIPE is
  962. ## raised for the duration of the code block, it will be queued and will be
  963. ## raised once the block ends.
  964. ##
  965. ## Within the block a `selectSigpipe()` template is provided which can be
  966. ## used to remove SIGPIPE from the queue. Note that if SIGPIPE is **not**
  967. ## raised at the time of call, it will block until SIGPIPE is raised.
  968. ##
  969. ## If SIGPIPE has already been blocked at the time of execution, the
  970. ## signal mask is left as-is and `selectSigpipe()` will become a no-op.
  971. ##
  972. ## For convenience, this template is also available for non-POSIX system,
  973. ## where `body` will be executed as-is.
  974. when not defined(posix) or defined(lwip):
  975. body
  976. else:
  977. template sigmask(how: cint, set, oset: var Sigset): untyped {.gensym.} =
  978. ## Alias for pthread_sigmask or sigprocmask depending on the status
  979. ## of --threads
  980. when compileOption("threads"):
  981. pthread_sigmask(how, set, oset)
  982. else:
  983. sigprocmask(how, set, oset)
  984. var oldSet, watchSet: Sigset
  985. if sigemptyset(oldSet) == -1:
  986. raiseOSError(osLastError())
  987. if sigemptyset(watchSet) == -1:
  988. raiseOSError(osLastError())
  989. if sigaddset(watchSet, SIGPIPE) == -1:
  990. raiseOSError(osLastError(), "Couldn't add SIGPIPE to Sigset")
  991. if sigmask(SIG_BLOCK, watchSet, oldSet) == -1:
  992. raiseOSError(osLastError(), "Couldn't block SIGPIPE")
  993. let alreadyBlocked = sigismember(oldSet, SIGPIPE) == 1
  994. template selectSigpipe(): untyped {.used.} =
  995. if not alreadyBlocked:
  996. var signal: cint
  997. let err = sigwait(watchSet, signal)
  998. if err != 0:
  999. raiseOSError(err.OSErrorCode, "Couldn't select SIGPIPE")
  1000. assert signal == SIGPIPE
  1001. try:
  1002. body
  1003. finally:
  1004. if not alreadyBlocked:
  1005. if sigmask(SIG_UNBLOCK, watchSet, oldSet) == -1:
  1006. raiseOSError(osLastError(), "Couldn't unblock SIGPIPE")
  1007. proc close*(socket: Socket, flags = {SocketFlag.SafeDisconn}) =
  1008. ## Closes a socket.
  1009. ##
  1010. ## If `socket` is an SSL/TLS socket, this proc will also send a closure
  1011. ## notification to the peer. If `SafeDisconn` is in `flags`, failure to do so
  1012. ## due to disconnections will be ignored. This is generally safe in
  1013. ## practice. See
  1014. ## `here <https://security.stackexchange.com/a/82044>`_ for more details.
  1015. try:
  1016. when defineSsl:
  1017. if socket.isSsl and socket.sslHandle != nil:
  1018. # Don't call SSL_shutdown if the connection has not been fully
  1019. # established, see:
  1020. # https://github.com/openssl/openssl/issues/710#issuecomment-253897666
  1021. if not socket.sslNoShutdown and SSL_in_init(socket.sslHandle) == 0:
  1022. # As we are closing the underlying socket immediately afterwards,
  1023. # it is valid, under the TLS standard, to perform a unidirectional
  1024. # shutdown i.e not wait for the peers "close notify" alert with a second
  1025. # call to SSL_shutdown
  1026. blockSigpipe:
  1027. ErrClearError()
  1028. let res = SSL_shutdown(socket.sslHandle)
  1029. if res == 0:
  1030. discard
  1031. elif res != 1:
  1032. let
  1033. err = osLastError()
  1034. sslError = SSL_get_error(socket.sslHandle, res)
  1035. # If a close notification is received, failures outside of the
  1036. # protocol will be returned as SSL_ERROR_ZERO_RETURN instead
  1037. # of SSL_ERROR_SYSCALL. This fact is deduced by digging into
  1038. # SSL_get_error() source code.
  1039. if sslError == SSL_ERROR_ZERO_RETURN or
  1040. sslError == SSL_ERROR_SYSCALL:
  1041. when defined(posix) and not defined(macosx) and
  1042. not defined(nimdoc):
  1043. if err == EPIPE.OSErrorCode:
  1044. # Clear the SIGPIPE that's been raised due to
  1045. # the disconnection.
  1046. selectSigpipe()
  1047. else:
  1048. discard
  1049. if not flags.isDisconnectionError(err):
  1050. socketError(socket, res, lastError = err, flags = flags)
  1051. else:
  1052. socketError(socket, res, lastError = err, flags = flags)
  1053. finally:
  1054. when defineSsl:
  1055. if socket.isSsl and socket.sslHandle != nil:
  1056. SSL_free(socket.sslHandle)
  1057. socket.sslHandle = nil
  1058. socket.fd.close()
  1059. socket.fd = osInvalidSocket
  1060. when defined(posix):
  1061. from posix import TCP_NODELAY
  1062. else:
  1063. from winlean import TCP_NODELAY
  1064. proc toCInt*(opt: SOBool): cint =
  1065. ## Converts a ``SOBool`` into its Socket Option cint representation.
  1066. case opt
  1067. of OptAcceptConn: SO_ACCEPTCONN
  1068. of OptBroadcast: SO_BROADCAST
  1069. of OptDebug: SO_DEBUG
  1070. of OptDontRoute: SO_DONTROUTE
  1071. of OptKeepAlive: SO_KEEPALIVE
  1072. of OptOOBInline: SO_OOBINLINE
  1073. of OptReuseAddr: SO_REUSEADDR
  1074. of OptReusePort: SO_REUSEPORT
  1075. of OptNoDelay: TCP_NODELAY
  1076. proc getSockOpt*(socket: Socket, opt: SOBool, level = SOL_SOCKET): bool {.
  1077. tags: [ReadIOEffect].} =
  1078. ## Retrieves option ``opt`` as a boolean value.
  1079. var res = getSockOptInt(socket.fd, cint(level), toCInt(opt))
  1080. result = res != 0
  1081. proc getLocalAddr*(socket: Socket): (string, Port) =
  1082. ## Get the socket's local address and port number.
  1083. ##
  1084. ## This is high-level interface for `getsockname`:idx:.
  1085. getLocalAddr(socket.fd, socket.domain)
  1086. proc getPeerAddr*(socket: Socket): (string, Port) =
  1087. ## Get the socket's peer address and port number.
  1088. ##
  1089. ## This is high-level interface for `getpeername`:idx:.
  1090. getPeerAddr(socket.fd, socket.domain)
  1091. proc setSockOpt*(socket: Socket, opt: SOBool, value: bool,
  1092. level = SOL_SOCKET) {.tags: [WriteIOEffect].} =
  1093. ## Sets option ``opt`` to a boolean value specified by ``value``.
  1094. ##
  1095. ## .. code-block:: Nim
  1096. ## var socket = newSocket()
  1097. ## socket.setSockOpt(OptReusePort, true)
  1098. ## socket.setSockOpt(OptNoDelay, true, level=IPPROTO_TCP.toInt)
  1099. ##
  1100. var valuei = cint(if value: 1 else: 0)
  1101. setSockOptInt(socket.fd, cint(level), toCInt(opt), valuei)
  1102. when defined(posix) or defined(nimdoc):
  1103. proc connectUnix*(socket: Socket, path: string) =
  1104. ## Connects to Unix socket on `path`.
  1105. ## This only works on Unix-style systems: Mac OS X, BSD and Linux
  1106. when not defined(nimdoc):
  1107. var socketAddr = makeUnixAddr(path)
  1108. if socket.fd.connect(cast[ptr SockAddr](addr socketAddr),
  1109. (sizeof(socketAddr.sun_family) + path.len).SockLen) != 0'i32:
  1110. raiseOSError(osLastError())
  1111. proc bindUnix*(socket: Socket, path: string) =
  1112. ## Binds Unix socket to `path`.
  1113. ## This only works on Unix-style systems: Mac OS X, BSD and Linux
  1114. when not defined(nimdoc):
  1115. var socketAddr = makeUnixAddr(path)
  1116. if socket.fd.bindAddr(cast[ptr SockAddr](addr socketAddr),
  1117. (sizeof(socketAddr.sun_family) + path.len).SockLen) != 0'i32:
  1118. raiseOSError(osLastError())
  1119. when defined(ssl):
  1120. proc gotHandshake*(socket: Socket): bool =
  1121. ## Determines whether a handshake has occurred between a client (``socket``)
  1122. ## and the server that ``socket`` is connected to.
  1123. ##
  1124. ## Throws SslError if ``socket`` is not an SSL socket.
  1125. if socket.isSsl:
  1126. return not socket.sslNoHandshake
  1127. else:
  1128. raiseSSLError("Socket is not an SSL socket.")
  1129. proc hasDataBuffered*(s: Socket): bool =
  1130. ## Determines whether a socket has data buffered.
  1131. result = false
  1132. if s.isBuffered:
  1133. result = s.bufLen > 0 and s.currPos != s.bufLen
  1134. when defineSsl:
  1135. if s.isSsl and not result:
  1136. result = s.sslHasPeekChar
  1137. proc select(readfd: Socket, timeout = 500): int =
  1138. ## Used for socket operation timeouts.
  1139. if readfd.hasDataBuffered:
  1140. return 1
  1141. var fds = @[readfd.fd]
  1142. result = selectRead(fds, timeout)
  1143. proc isClosed(socket: Socket): bool =
  1144. socket.fd == osInvalidSocket
  1145. proc uniRecv(socket: Socket, buffer: pointer, size, flags: cint): int =
  1146. ## Handles SSL and non-ssl recv in a nice package.
  1147. ##
  1148. ## In particular handles the case where socket has been closed properly
  1149. ## for both SSL and non-ssl.
  1150. result = 0
  1151. assert(not socket.isClosed, "Cannot `recv` on a closed socket")
  1152. when defineSsl:
  1153. if socket.isSsl:
  1154. ErrClearError()
  1155. return SSL_read(socket.sslHandle, buffer, size)
  1156. return recv(socket.fd, buffer, size, flags)
  1157. proc readIntoBuf(socket: Socket, flags: int32): int =
  1158. result = 0
  1159. result = uniRecv(socket, addr(socket.buffer), socket.buffer.high, flags)
  1160. if result < 0:
  1161. # Save it in case it gets reset (the Nim codegen occasionally may call
  1162. # Win API functions which reset it).
  1163. socket.lastError = osLastError()
  1164. if result <= 0:
  1165. socket.bufLen = 0
  1166. socket.currPos = 0
  1167. return result
  1168. socket.bufLen = result
  1169. socket.currPos = 0
  1170. template retRead(flags, readBytes: int) {.dirty.} =
  1171. let res = socket.readIntoBuf(flags.int32)
  1172. if res <= 0:
  1173. if readBytes > 0:
  1174. return readBytes
  1175. else:
  1176. return res
  1177. proc recv*(socket: Socket, data: pointer, size: int): int {.tags: [
  1178. ReadIOEffect].} =
  1179. ## Receives data from a socket.
  1180. ##
  1181. ## **Note**: This is a low-level function, you may be interested in the higher
  1182. ## level versions of this function which are also named ``recv``.
  1183. if size == 0: return
  1184. if socket.isBuffered:
  1185. if socket.bufLen == 0:
  1186. retRead(0'i32, 0)
  1187. var read = 0
  1188. while read < size:
  1189. if socket.currPos >= socket.bufLen:
  1190. retRead(0'i32, read)
  1191. let chunk = min(socket.bufLen-socket.currPos, size-read)
  1192. var d = cast[cstring](data)
  1193. assert size-read >= chunk
  1194. copyMem(addr(d[read]), addr(socket.buffer[socket.currPos]), chunk)
  1195. read.inc(chunk)
  1196. socket.currPos.inc(chunk)
  1197. result = read
  1198. else:
  1199. when defineSsl:
  1200. if socket.isSsl:
  1201. if socket.sslHasPeekChar: # TODO: Merge this peek char mess into uniRecv
  1202. copyMem(data, addr(socket.sslPeekChar), 1)
  1203. socket.sslHasPeekChar = false
  1204. if size-1 > 0:
  1205. var d = cast[cstring](data)
  1206. result = uniRecv(socket, addr(d[1]), cint(size-1), 0'i32) + 1
  1207. else:
  1208. result = 1
  1209. else:
  1210. result = uniRecv(socket, data, size.cint, 0'i32)
  1211. else:
  1212. result = recv(socket.fd, data, size.cint, 0'i32)
  1213. else:
  1214. result = recv(socket.fd, data, size.cint, 0'i32)
  1215. if result < 0:
  1216. # Save the error in case it gets reset.
  1217. socket.lastError = osLastError()
  1218. proc waitFor(socket: Socket, waited: var Duration, timeout, size: int,
  1219. funcName: string): int {.tags: [TimeEffect].} =
  1220. ## determines the amount of characters that can be read. Result will never
  1221. ## be larger than ``size``. For unbuffered sockets this will be ``1``.
  1222. ## For buffered sockets it can be as big as ``BufferSize``.
  1223. ##
  1224. ## If this function does not determine that there is data on the socket
  1225. ## within ``timeout`` ms, a TimeoutError error will be raised.
  1226. result = 1
  1227. if size <= 0: assert false
  1228. if timeout == -1: return size
  1229. if socket.isBuffered and socket.bufLen != 0 and
  1230. socket.bufLen != socket.currPos:
  1231. result = socket.bufLen - socket.currPos
  1232. result = min(result, size)
  1233. else:
  1234. if timeout - waited.inMilliseconds < 1:
  1235. raise newException(TimeoutError, "Call to '" & funcName & "' timed out.")
  1236. when defineSsl:
  1237. if socket.isSsl:
  1238. if socket.hasDataBuffered:
  1239. # sslPeekChar is present.
  1240. return 1
  1241. let sslPending = SSL_pending(socket.sslHandle)
  1242. if sslPending != 0:
  1243. return min(sslPending, size)
  1244. var startTime = getMonoTime()
  1245. let selRet = select(socket, (timeout - waited.inMilliseconds).int)
  1246. if selRet < 0: raiseOSError(osLastError())
  1247. if selRet != 1:
  1248. raise newException(TimeoutError, "Call to '" & funcName & "' timed out.")
  1249. waited += (getMonoTime() - startTime)
  1250. proc recv*(socket: Socket, data: pointer, size: int, timeout: int): int {.
  1251. tags: [ReadIOEffect, TimeEffect].} =
  1252. ## overload with a ``timeout`` parameter in milliseconds.
  1253. var waited: Duration # duration already waited
  1254. var read = 0
  1255. while read < size:
  1256. let avail = waitFor(socket, waited, timeout, size-read, "recv")
  1257. var d = cast[cstring](data)
  1258. assert avail <= size-read
  1259. result = recv(socket, addr(d[read]), avail)
  1260. if result == 0: break
  1261. if result < 0:
  1262. return result
  1263. inc(read, result)
  1264. result = read
  1265. proc recv*(socket: Socket, data: var string, size: int, timeout = -1,
  1266. flags = {SocketFlag.SafeDisconn}): int =
  1267. ## Higher-level version of ``recv``.
  1268. ##
  1269. ## Reads **up to** ``size`` bytes from ``socket`` into ``buf``.
  1270. ##
  1271. ## For buffered sockets this function will attempt to read all the requested
  1272. ## data. It will read this data in ``BufferSize`` chunks.
  1273. ##
  1274. ## For unbuffered sockets this function makes no effort to read
  1275. ## all the data requested. It will return as much data as the operating system
  1276. ## gives it.
  1277. ##
  1278. ## When 0 is returned the socket's connection has been closed.
  1279. ##
  1280. ## This function will throw an OSError exception when an error occurs. A value
  1281. ## lower than 0 is never returned.
  1282. ##
  1283. ## A timeout may be specified in milliseconds, if enough data is not received
  1284. ## within the time specified a TimeoutError exception will be raised.
  1285. ##
  1286. ## **Note**: ``data`` must be initialised.
  1287. ##
  1288. ## **Warning**: Only the ``SafeDisconn`` flag is currently supported.
  1289. data.setLen(size)
  1290. result =
  1291. if timeout == -1:
  1292. recv(socket, cstring(data), size)
  1293. else:
  1294. recv(socket, cstring(data), size, timeout)
  1295. if result < 0:
  1296. data.setLen(0)
  1297. let lastError = getSocketError(socket)
  1298. socket.socketError(result, lastError = lastError, flags = flags)
  1299. else:
  1300. data.setLen(result)
  1301. proc recv*(socket: Socket, size: int, timeout = -1,
  1302. flags = {SocketFlag.SafeDisconn}): string {.inline.} =
  1303. ## Higher-level version of ``recv`` which returns a string.
  1304. ##
  1305. ## Reads **up to** ``size`` bytes from ``socket`` into ``buf``.
  1306. ##
  1307. ## For buffered sockets this function will attempt to read all the requested
  1308. ## data. It will read this data in ``BufferSize`` chunks.
  1309. ##
  1310. ## For unbuffered sockets this function makes no effort to read
  1311. ## all the data requested. It will return as much data as the operating system
  1312. ## gives it.
  1313. ##
  1314. ## When ``""`` is returned the socket's connection has been closed.
  1315. ##
  1316. ## This function will throw an OSError exception when an error occurs.
  1317. ##
  1318. ## A timeout may be specified in milliseconds, if enough data is not received
  1319. ## within the time specified a TimeoutError exception will be raised.
  1320. ##
  1321. ##
  1322. ## **Warning**: Only the ``SafeDisconn`` flag is currently supported.
  1323. result = newString(size)
  1324. discard recv(socket, result, size, timeout, flags)
  1325. proc peekChar(socket: Socket, c: var char): int {.tags: [ReadIOEffect].} =
  1326. if socket.isBuffered:
  1327. result = 1
  1328. if socket.bufLen == 0 or socket.currPos > socket.bufLen-1:
  1329. var res = socket.readIntoBuf(0'i32)
  1330. if res <= 0:
  1331. result = res
  1332. c = socket.buffer[socket.currPos]
  1333. else:
  1334. when defineSsl:
  1335. if socket.isSsl:
  1336. if not socket.sslHasPeekChar:
  1337. result = uniRecv(socket, addr(socket.sslPeekChar), 1, 0'i32)
  1338. socket.sslHasPeekChar = true
  1339. c = socket.sslPeekChar
  1340. return
  1341. result = recv(socket.fd, addr(c), 1, MSG_PEEK)
  1342. proc readLine*(socket: Socket, line: var TaintedString, timeout = -1,
  1343. flags = {SocketFlag.SafeDisconn}, maxLength = MaxLineLength) {.
  1344. tags: [ReadIOEffect, TimeEffect].} =
  1345. ## Reads a line of data from ``socket``.
  1346. ##
  1347. ## If a full line is read ``\r\L`` is not
  1348. ## added to ``line``, however if solely ``\r\L`` is read then ``line``
  1349. ## will be set to it.
  1350. ##
  1351. ## If the socket is disconnected, ``line`` will be set to ``""``.
  1352. ##
  1353. ## An OSError exception will be raised in the case of a socket error.
  1354. ##
  1355. ## A timeout can be specified in milliseconds, if data is not received within
  1356. ## the specified time a TimeoutError exception will be raised.
  1357. ##
  1358. ## The ``maxLength`` parameter determines the maximum amount of characters
  1359. ## that can be read. The result is truncated after that.
  1360. ##
  1361. ## **Warning**: Only the ``SafeDisconn`` flag is currently supported.
  1362. template addNLIfEmpty() =
  1363. if line.len == 0:
  1364. line.string.add("\c\L")
  1365. template raiseSockError() {.dirty.} =
  1366. let lastError = getSocketError(socket)
  1367. if flags.isDisconnectionError(lastError):
  1368. setLen(line.string, 0)
  1369. socket.socketError(n, lastError = lastError, flags = flags)
  1370. var waited: Duration
  1371. setLen(line.string, 0)
  1372. while true:
  1373. var c: char
  1374. discard waitFor(socket, waited, timeout, 1, "readLine")
  1375. var n = recv(socket, addr(c), 1)
  1376. if n < 0: raiseSockError()
  1377. elif n == 0: setLen(line.string, 0); return
  1378. if c == '\r':
  1379. discard waitFor(socket, waited, timeout, 1, "readLine")
  1380. n = peekChar(socket, c)
  1381. if n > 0 and c == '\L':
  1382. discard recv(socket, addr(c), 1)
  1383. elif n <= 0: raiseSockError()
  1384. addNLIfEmpty()
  1385. return
  1386. elif c == '\L':
  1387. addNLIfEmpty()
  1388. return
  1389. add(line.string, c)
  1390. # Verify that this isn't a DOS attack: #3847.
  1391. if line.string.len > maxLength: break
  1392. proc recvLine*(socket: Socket, timeout = -1,
  1393. flags = {SocketFlag.SafeDisconn},
  1394. maxLength = MaxLineLength): TaintedString =
  1395. ## Reads a line of data from ``socket``.
  1396. ##
  1397. ## If a full line is read ``\r\L`` is not
  1398. ## added to the result, however if solely ``\r\L`` is read then the result
  1399. ## will be set to it.
  1400. ##
  1401. ## If the socket is disconnected, the result will be set to ``""``.
  1402. ##
  1403. ## An OSError exception will be raised in the case of a socket error.
  1404. ##
  1405. ## A timeout can be specified in milliseconds, if data is not received within
  1406. ## the specified time a TimeoutError exception will be raised.
  1407. ##
  1408. ## The ``maxLength`` parameter determines the maximum amount of characters
  1409. ## that can be read. The result is truncated after that.
  1410. ##
  1411. ## **Warning**: Only the ``SafeDisconn`` flag is currently supported.
  1412. result = "".TaintedString
  1413. readLine(socket, result, timeout, flags, maxLength)
  1414. proc recvFrom*(socket: Socket, data: var string, length: int,
  1415. address: var string, port: var Port, flags = 0'i32): int {.
  1416. tags: [ReadIOEffect].} =
  1417. ## Receives data from ``socket``. This function should normally be used with
  1418. ## connection-less sockets (UDP sockets).
  1419. ##
  1420. ## If an error occurs an OSError exception will be raised. Otherwise the return
  1421. ## value will be the length of data received.
  1422. ##
  1423. ## **Warning:** This function does not yet have a buffered implementation,
  1424. ## so when ``socket`` is buffered the non-buffered implementation will be
  1425. ## used. Therefore if ``socket`` contains something in its buffer this
  1426. ## function will make no effort to return it.
  1427. template adaptRecvFromToDomain(domain: Domain) =
  1428. var addrLen = sizeof(sockAddress).SockLen
  1429. result = recvfrom(socket.fd, cstring(data), length.cint, flags.cint,
  1430. cast[ptr SockAddr](addr(sockAddress)), addr(addrLen))
  1431. if result != -1:
  1432. data.setLen(result)
  1433. address = getAddrString(cast[ptr SockAddr](addr(sockAddress)))
  1434. when domain == AF_INET6:
  1435. port = ntohs(sockAddress.sin6_port).Port
  1436. else:
  1437. port = ntohs(sockAddress.sin_port).Port
  1438. else:
  1439. raiseOSError(osLastError())
  1440. assert(socket.protocol != IPPROTO_TCP, "Cannot `recvFrom` on a TCP socket")
  1441. # TODO: Buffered sockets
  1442. data.setLen(length)
  1443. case socket.domain
  1444. of AF_INET6:
  1445. var sockAddress: Sockaddr_in6
  1446. adaptRecvFromToDomain(AF_INET6)
  1447. of AF_INET:
  1448. var sockAddress: Sockaddr_in
  1449. adaptRecvFromToDomain(AF_INET)
  1450. else:
  1451. raise newException(ValueError, "Unknown socket address family")
  1452. proc skip*(socket: Socket, size: int, timeout = -1) =
  1453. ## Skips ``size`` amount of bytes.
  1454. ##
  1455. ## An optional timeout can be specified in milliseconds, if skipping the
  1456. ## bytes takes longer than specified a TimeoutError exception will be raised.
  1457. ##
  1458. ## Returns the number of skipped bytes.
  1459. var waited: Duration
  1460. var dummy = alloc(size)
  1461. var bytesSkipped = 0
  1462. while bytesSkipped != size:
  1463. let avail = waitFor(socket, waited, timeout, size-bytesSkipped, "skip")
  1464. bytesSkipped += recv(socket, dummy, avail)
  1465. dealloc(dummy)
  1466. proc send*(socket: Socket, data: pointer, size: int): int {.
  1467. tags: [WriteIOEffect].} =
  1468. ## Sends data to a socket.
  1469. ##
  1470. ## **Note**: This is a low-level version of ``send``. You likely should use
  1471. ## the version below.
  1472. assert(not socket.isClosed, "Cannot `send` on a closed socket")
  1473. when defineSsl:
  1474. if socket.isSsl:
  1475. ErrClearError()
  1476. return SSL_write(socket.sslHandle, cast[cstring](data), size)
  1477. when useWinVersion or defined(macosx):
  1478. result = send(socket.fd, data, size.cint, 0'i32)
  1479. else:
  1480. when defined(solaris):
  1481. const MSG_NOSIGNAL = 0
  1482. result = send(socket.fd, data, size, int32(MSG_NOSIGNAL))
  1483. proc send*(socket: Socket, data: string,
  1484. flags = {SocketFlag.SafeDisconn}) {.tags: [WriteIOEffect].} =
  1485. ## sends data to a socket.
  1486. let sent = send(socket, cstring(data), data.len)
  1487. if sent < 0:
  1488. let lastError = osLastError()
  1489. socketError(socket, lastError = lastError, flags = flags)
  1490. if sent != data.len:
  1491. raiseOSError(osLastError(), "Could not send all data.")
  1492. template `&=`*(socket: Socket; data: typed) =
  1493. ## an alias for 'send'.
  1494. send(socket, data)
  1495. proc trySend*(socket: Socket, data: string): bool {.tags: [WriteIOEffect].} =
  1496. ## Safe alternative to ``send``. Does not raise an OSError when an error occurs,
  1497. ## and instead returns ``false`` on failure.
  1498. result = send(socket, cstring(data), data.len) == data.len
  1499. proc sendTo*(socket: Socket, address: string, port: Port, data: pointer,
  1500. size: int, af: Domain = AF_INET, flags = 0'i32) {.
  1501. tags: [WriteIOEffect].} =
  1502. ## This proc sends ``data`` to the specified ``address``,
  1503. ## which may be an IP address or a hostname, if a hostname is specified
  1504. ## this function will try each IP of that hostname.
  1505. ##
  1506. ## If an error occurs an OSError exception will be raised.
  1507. ##
  1508. ## **Note:** You may wish to use the high-level version of this function
  1509. ## which is defined below.
  1510. ##
  1511. ## **Note:** This proc is not available for SSL sockets.
  1512. assert(socket.protocol != IPPROTO_TCP, "Cannot `sendTo` on a TCP socket")
  1513. assert(not socket.isClosed, "Cannot `sendTo` on a closed socket")
  1514. var aiList = getAddrInfo(address, port, af, socket.sockType, socket.protocol)
  1515. # try all possibilities:
  1516. var success = false
  1517. var it = aiList
  1518. var result = 0
  1519. while it != nil:
  1520. result = sendto(socket.fd, data, size.cint, flags.cint, it.ai_addr,
  1521. it.ai_addrlen.SockLen)
  1522. if result != -1'i32:
  1523. success = true
  1524. break
  1525. it = it.ai_next
  1526. let osError = osLastError()
  1527. freeaddrinfo(aiList)
  1528. if not success:
  1529. raiseOSError(osError)
  1530. proc sendTo*(socket: Socket, address: string, port: Port,
  1531. data: string) {.tags: [WriteIOEffect].} =
  1532. ## This proc sends ``data`` to the specified ``address``,
  1533. ## which may be an IP address or a hostname, if a hostname is specified
  1534. ## this function will try each IP of that hostname.
  1535. ##
  1536. ## If an error occurs an OSError exception will be raised.
  1537. ##
  1538. ## This is the high-level version of the above ``sendTo`` function.
  1539. socket.sendTo(address, port, cstring(data), data.len, socket.domain)
  1540. proc isSsl*(socket: Socket): bool =
  1541. ## Determines whether ``socket`` is a SSL socket.
  1542. when defineSsl:
  1543. result = socket.isSsl
  1544. else:
  1545. result = false
  1546. proc getFd*(socket: Socket): SocketHandle = return socket.fd
  1547. ## Returns the socket's file descriptor
  1548. when defined(nimHasStyleChecks):
  1549. {.push styleChecks: off.}
  1550. proc IPv4_any*(): IpAddress =
  1551. ## Returns the IPv4 any address, which can be used to listen on all available
  1552. ## network adapters
  1553. result = IpAddress(
  1554. family: IpAddressFamily.IPv4,
  1555. address_v4: [0'u8, 0, 0, 0])
  1556. proc IPv4_loopback*(): IpAddress =
  1557. ## Returns the IPv4 loopback address (127.0.0.1)
  1558. result = IpAddress(
  1559. family: IpAddressFamily.IPv4,
  1560. address_v4: [127'u8, 0, 0, 1])
  1561. proc IPv4_broadcast*(): IpAddress =
  1562. ## Returns the IPv4 broadcast address (255.255.255.255)
  1563. result = IpAddress(
  1564. family: IpAddressFamily.IPv4,
  1565. address_v4: [255'u8, 255, 255, 255])
  1566. proc IPv6_any*(): IpAddress =
  1567. ## Returns the IPv6 any address (::0), which can be used
  1568. ## to listen on all available network adapters
  1569. result = IpAddress(
  1570. family: IpAddressFamily.IPv6,
  1571. address_v6: [0'u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
  1572. proc IPv6_loopback*(): IpAddress =
  1573. ## Returns the IPv6 loopback address (::1)
  1574. result = IpAddress(
  1575. family: IpAddressFamily.IPv6,
  1576. address_v6: [0'u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])
  1577. when defined(nimHasStyleChecks):
  1578. {.pop.}
  1579. proc `==`*(lhs, rhs: IpAddress): bool =
  1580. ## Compares two IpAddresses for Equality. Returns true if the addresses are equal
  1581. if lhs.family != rhs.family: return false
  1582. if lhs.family == IpAddressFamily.IPv4:
  1583. for i in low(lhs.address_v4) .. high(lhs.address_v4):
  1584. if lhs.address_v4[i] != rhs.address_v4[i]: return false
  1585. else: # IPv6
  1586. for i in low(lhs.address_v6) .. high(lhs.address_v6):
  1587. if lhs.address_v6[i] != rhs.address_v6[i]: return false
  1588. return true
  1589. proc `$`*(address: IpAddress): string =
  1590. ## Converts an IpAddress into the textual representation
  1591. result = ""
  1592. case address.family
  1593. of IpAddressFamily.IPv4:
  1594. for i in 0 .. 3:
  1595. if i != 0:
  1596. result.add('.')
  1597. result.add($address.address_v4[i])
  1598. of IpAddressFamily.IPv6:
  1599. var
  1600. currentZeroStart = -1
  1601. currentZeroCount = 0
  1602. biggestZeroStart = -1
  1603. biggestZeroCount = 0
  1604. # Look for the largest block of zeros
  1605. for i in 0..7:
  1606. var isZero = address.address_v6[i*2] == 0 and address.address_v6[i*2+1] == 0
  1607. if isZero:
  1608. if currentZeroStart == -1:
  1609. currentZeroStart = i
  1610. currentZeroCount = 1
  1611. else:
  1612. currentZeroCount.inc()
  1613. if currentZeroCount > biggestZeroCount:
  1614. biggestZeroCount = currentZeroCount
  1615. biggestZeroStart = currentZeroStart
  1616. else:
  1617. currentZeroStart = -1
  1618. if biggestZeroCount == 8: # Special case ::0
  1619. result.add("::")
  1620. else: # Print address
  1621. var printedLastGroup = false
  1622. for i in 0..7:
  1623. var word: uint16 = (cast[uint16](address.address_v6[i*2])) shl 8
  1624. word = word or cast[uint16](address.address_v6[i*2+1])
  1625. if biggestZeroCount != 0 and # Check if group is in skip group
  1626. (i >= biggestZeroStart and i < (biggestZeroStart + biggestZeroCount)):
  1627. if i == biggestZeroStart: # skip start
  1628. result.add("::")
  1629. printedLastGroup = false
  1630. else:
  1631. if printedLastGroup:
  1632. result.add(':')
  1633. var
  1634. afterLeadingZeros = false
  1635. mask = 0xF000'u16
  1636. for j in 0'u16..3'u16:
  1637. var val = (mask and word) shr (4'u16*(3'u16-j))
  1638. if val != 0 or afterLeadingZeros:
  1639. if val < 0xA:
  1640. result.add(chr(uint16(ord('0'))+val))
  1641. else: # val >= 0xA
  1642. result.add(chr(uint16(ord('a'))+val-0xA))
  1643. afterLeadingZeros = true
  1644. mask = mask shr 4
  1645. printedLastGroup = true
  1646. proc dial*(address: string, port: Port,
  1647. protocol = IPPROTO_TCP, buffered = true): owned(Socket)
  1648. {.tags: [ReadIOEffect, WriteIOEffect].} =
  1649. ## Establishes connection to the specified ``address``:``port`` pair via the
  1650. ## specified protocol. The procedure iterates through possible
  1651. ## resolutions of the ``address`` until it succeeds, meaning that it
  1652. ## seamlessly works with both IPv4 and IPv6.
  1653. ## Returns Socket ready to send or receive data.
  1654. let sockType = protocol.toSockType()
  1655. let aiList = getAddrInfo(address, port, AF_UNSPEC, sockType, protocol)
  1656. var fdPerDomain: array[low(Domain).ord..high(Domain).ord, SocketHandle]
  1657. for i in low(fdPerDomain)..high(fdPerDomain):
  1658. fdPerDomain[i] = osInvalidSocket
  1659. template closeUnusedFds(domainToKeep = -1) {.dirty.} =
  1660. for i, fd in fdPerDomain:
  1661. if fd != osInvalidSocket and i != domainToKeep:
  1662. fd.close()
  1663. var success = false
  1664. var lastError: OSErrorCode
  1665. var it = aiList
  1666. var domain: Domain
  1667. var lastFd: SocketHandle
  1668. while it != nil:
  1669. let domainOpt = it.ai_family.toKnownDomain()
  1670. if domainOpt.isNone:
  1671. it = it.ai_next
  1672. continue
  1673. domain = domainOpt.unsafeGet()
  1674. lastFd = fdPerDomain[ord(domain)]
  1675. if lastFd == osInvalidSocket:
  1676. lastFd = createNativeSocket(domain, sockType, protocol)
  1677. if lastFd == osInvalidSocket:
  1678. # we always raise if socket creation failed, because it means a
  1679. # network system problem (e.g. not enough FDs), and not an unreachable
  1680. # address.
  1681. let err = osLastError()
  1682. freeaddrinfo(aiList)
  1683. closeUnusedFds()
  1684. raiseOSError(err)
  1685. fdPerDomain[ord(domain)] = lastFd
  1686. if connect(lastFd, it.ai_addr, it.ai_addrlen.SockLen) == 0'i32:
  1687. success = true
  1688. break
  1689. lastError = osLastError()
  1690. it = it.ai_next
  1691. freeaddrinfo(aiList)
  1692. closeUnusedFds(ord(domain))
  1693. if success:
  1694. result = newSocket(lastFd, domain, sockType, protocol)
  1695. elif lastError != 0.OSErrorCode:
  1696. raiseOSError(lastError)
  1697. else:
  1698. raise newException(IOError, "Couldn't resolve address: " & address)
  1699. proc connect*(socket: Socket, address: string,
  1700. port = Port(0)) {.tags: [ReadIOEffect].} =
  1701. ## Connects socket to ``address``:``port``. ``Address`` can be an IP address or a
  1702. ## host name. If ``address`` is a host name, this function will try each IP
  1703. ## of that host name. ``htons`` is already performed on ``port`` so you must
  1704. ## not do it.
  1705. ##
  1706. ## If ``socket`` is an SSL socket a handshake will be automatically performed.
  1707. var aiList = getAddrInfo(address, port, socket.domain)
  1708. # try all possibilities:
  1709. var success = false
  1710. var lastError: OSErrorCode
  1711. var it = aiList
  1712. while it != nil:
  1713. if connect(socket.fd, it.ai_addr, it.ai_addrlen.SockLen) == 0'i32:
  1714. success = true
  1715. break
  1716. else: lastError = osLastError()
  1717. it = it.ai_next
  1718. freeaddrinfo(aiList)
  1719. if not success: raiseOSError(lastError)
  1720. when defineSsl:
  1721. if socket.isSsl:
  1722. # RFC3546 for SNI specifies that IP addresses are not allowed.
  1723. if not isIpAddress(address):
  1724. # Discard result in case OpenSSL version doesn't support SNI, or we're
  1725. # not using TLSv1+
  1726. discard SSL_set_tlsext_host_name(socket.sslHandle, address)
  1727. ErrClearError()
  1728. let ret = SSL_connect(socket.sslHandle)
  1729. socketError(socket, ret)
  1730. when not defined(nimDisableCertificateValidation) and not defined(windows):
  1731. if not isIpAddress(address):
  1732. socket.checkCertName(address)
  1733. proc connectAsync(socket: Socket, name: string, port = Port(0),
  1734. af: Domain = AF_INET) {.tags: [ReadIOEffect].} =
  1735. ## A variant of ``connect`` for non-blocking sockets.
  1736. ##
  1737. ## This procedure will immediately return, it will not block until a connection
  1738. ## is made. It is up to the caller to make sure the connection has been established
  1739. ## by checking (using ``select``) whether the socket is writeable.
  1740. ##
  1741. ## **Note**: For SSL sockets, the ``handshake`` procedure must be called
  1742. ## whenever the socket successfully connects to a server.
  1743. var aiList = getAddrInfo(name, port, af)
  1744. # try all possibilities:
  1745. var success = false
  1746. var lastError: OSErrorCode
  1747. var it = aiList
  1748. while it != nil:
  1749. var ret = connect(socket.fd, it.ai_addr, it.ai_addrlen.SockLen)
  1750. if ret == 0'i32:
  1751. success = true
  1752. break
  1753. else:
  1754. lastError = osLastError()
  1755. when useWinVersion:
  1756. # Windows EINTR doesn't behave same as POSIX.
  1757. if lastError.int32 == WSAEWOULDBLOCK:
  1758. success = true
  1759. break
  1760. else:
  1761. if lastError.int32 == EINTR or lastError.int32 == EINPROGRESS:
  1762. success = true
  1763. break
  1764. it = it.ai_next
  1765. freeaddrinfo(aiList)
  1766. if not success: raiseOSError(lastError)
  1767. proc connect*(socket: Socket, address: string, port = Port(0),
  1768. timeout: int) {.tags: [ReadIOEffect, WriteIOEffect].} =
  1769. ## Connects to server as specified by ``address`` on port specified by ``port``.
  1770. ##
  1771. ## The ``timeout`` parameter specifies the time in milliseconds to allow for
  1772. ## the connection to the server to be made.
  1773. ##
  1774. ## **Warning:** This procedure appears to be broken for SSL connections as of
  1775. ## Nim v1.0.2. Consider using the other `connect` procedure. See
  1776. ## https://github.com/nim-lang/Nim/issues/15215 for more info.
  1777. socket.fd.setBlocking(false)
  1778. socket.connectAsync(address, port, socket.domain)
  1779. var s = @[socket.fd]
  1780. if selectWrite(s, timeout) != 1:
  1781. raise newException(TimeoutError, "Call to 'connect' timed out.")
  1782. else:
  1783. let res = getSockOptInt(socket.fd, SOL_SOCKET, SO_ERROR)
  1784. if res != 0:
  1785. raiseOSError(OSErrorCode(res))
  1786. when defineSsl and not defined(nimdoc):
  1787. if socket.isSsl:
  1788. socket.fd.setBlocking(true)
  1789. doAssert socket.gotHandshake()
  1790. socket.fd.setBlocking(true)
  1791. proc getPrimaryIPAddr*(dest = parseIpAddress("8.8.8.8")): IpAddress =
  1792. ## Finds the local IP address, usually assigned to eth0 on LAN or wlan0 on WiFi,
  1793. ## used to reach an external address. Useful to run local services.
  1794. ##
  1795. ## No traffic is sent.
  1796. ##
  1797. ## Supports IPv4 and v6.
  1798. ## Raises OSError if external networking is not set up.
  1799. ##
  1800. ## .. code-block:: Nim
  1801. ## echo $getPrimaryIPAddr() # "192.168.1.2"
  1802. let socket =
  1803. if dest.family == IpAddressFamily.IPv4:
  1804. newSocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP)
  1805. else:
  1806. newSocket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP)
  1807. try:
  1808. socket.connect($dest, 80.Port)
  1809. result = socket.getLocalAddr()[0].parseIpAddress()
  1810. finally:
  1811. socket.close()