virtio_transport_common.c 26 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090
  1. /*
  2. * common code for virtio vsock
  3. *
  4. * Copyright (C) 2013-2015 Red Hat, Inc.
  5. * Author: Asias He <asias@redhat.com>
  6. * Stefan Hajnoczi <stefanha@redhat.com>
  7. *
  8. * This work is licensed under the terms of the GNU GPL, version 2.
  9. */
  10. #include <linux/spinlock.h>
  11. #include <linux/module.h>
  12. #include <linux/sched/signal.h>
  13. #include <linux/ctype.h>
  14. #include <linux/list.h>
  15. #include <linux/virtio.h>
  16. #include <linux/virtio_ids.h>
  17. #include <linux/virtio_config.h>
  18. #include <linux/virtio_vsock.h>
  19. #include <uapi/linux/vsockmon.h>
  20. #include <net/sock.h>
  21. #include <net/af_vsock.h>
  22. #define CREATE_TRACE_POINTS
  23. #include <trace/events/vsock_virtio_transport_common.h>
  24. /* How long to wait for graceful shutdown of a connection */
  25. #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
  26. uint virtio_transport_max_vsock_pkt_buf_size = 64 * 1024;
  27. module_param(virtio_transport_max_vsock_pkt_buf_size, uint, 0444);
  28. EXPORT_SYMBOL_GPL(virtio_transport_max_vsock_pkt_buf_size);
  29. static const struct virtio_transport *virtio_transport_get_ops(void)
  30. {
  31. const struct vsock_transport *t = vsock_core_get_transport();
  32. return container_of(t, struct virtio_transport, transport);
  33. }
  34. static struct virtio_vsock_pkt *
  35. virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
  36. size_t len,
  37. u32 src_cid,
  38. u32 src_port,
  39. u32 dst_cid,
  40. u32 dst_port)
  41. {
  42. struct virtio_vsock_pkt *pkt;
  43. int err;
  44. pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
  45. if (!pkt)
  46. return NULL;
  47. pkt->hdr.type = cpu_to_le16(info->type);
  48. pkt->hdr.op = cpu_to_le16(info->op);
  49. pkt->hdr.src_cid = cpu_to_le64(src_cid);
  50. pkt->hdr.dst_cid = cpu_to_le64(dst_cid);
  51. pkt->hdr.src_port = cpu_to_le32(src_port);
  52. pkt->hdr.dst_port = cpu_to_le32(dst_port);
  53. pkt->hdr.flags = cpu_to_le32(info->flags);
  54. pkt->len = len;
  55. pkt->hdr.len = cpu_to_le32(len);
  56. pkt->reply = info->reply;
  57. pkt->vsk = info->vsk;
  58. if (info->msg && len > 0) {
  59. pkt->buf = kmalloc(len, GFP_KERNEL);
  60. if (!pkt->buf)
  61. goto out_pkt;
  62. err = memcpy_from_msg(pkt->buf, info->msg, len);
  63. if (err)
  64. goto out;
  65. }
  66. trace_virtio_transport_alloc_pkt(src_cid, src_port,
  67. dst_cid, dst_port,
  68. len,
  69. info->type,
  70. info->op,
  71. info->flags);
  72. return pkt;
  73. out:
  74. kfree(pkt->buf);
  75. out_pkt:
  76. kfree(pkt);
  77. return NULL;
  78. }
  79. /* Packet capture */
  80. static struct sk_buff *virtio_transport_build_skb(void *opaque)
  81. {
  82. struct virtio_vsock_pkt *pkt = opaque;
  83. struct af_vsockmon_hdr *hdr;
  84. struct sk_buff *skb;
  85. size_t payload_len;
  86. void *payload_buf;
  87. /* A packet could be split to fit the RX buffer, so we can retrieve
  88. * the payload length from the header and the buffer pointer taking
  89. * care of the offset in the original packet.
  90. */
  91. payload_len = le32_to_cpu(pkt->hdr.len);
  92. payload_buf = pkt->buf + pkt->off;
  93. skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + payload_len,
  94. GFP_ATOMIC);
  95. if (!skb)
  96. return NULL;
  97. hdr = skb_put(skb, sizeof(*hdr));
  98. /* pkt->hdr is little-endian so no need to byteswap here */
  99. hdr->src_cid = pkt->hdr.src_cid;
  100. hdr->src_port = pkt->hdr.src_port;
  101. hdr->dst_cid = pkt->hdr.dst_cid;
  102. hdr->dst_port = pkt->hdr.dst_port;
  103. hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO);
  104. hdr->len = cpu_to_le16(sizeof(pkt->hdr));
  105. memset(hdr->reserved, 0, sizeof(hdr->reserved));
  106. switch (le16_to_cpu(pkt->hdr.op)) {
  107. case VIRTIO_VSOCK_OP_REQUEST:
  108. case VIRTIO_VSOCK_OP_RESPONSE:
  109. hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT);
  110. break;
  111. case VIRTIO_VSOCK_OP_RST:
  112. case VIRTIO_VSOCK_OP_SHUTDOWN:
  113. hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT);
  114. break;
  115. case VIRTIO_VSOCK_OP_RW:
  116. hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD);
  117. break;
  118. case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
  119. case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
  120. hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL);
  121. break;
  122. default:
  123. hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN);
  124. break;
  125. }
  126. skb_put_data(skb, &pkt->hdr, sizeof(pkt->hdr));
  127. if (payload_len) {
  128. skb_put_data(skb, payload_buf, payload_len);
  129. }
  130. return skb;
  131. }
  132. void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt)
  133. {
  134. vsock_deliver_tap(virtio_transport_build_skb, pkt);
  135. }
  136. EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
  137. static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
  138. struct virtio_vsock_pkt_info *info)
  139. {
  140. u32 src_cid, src_port, dst_cid, dst_port;
  141. struct virtio_vsock_sock *vvs;
  142. struct virtio_vsock_pkt *pkt;
  143. u32 pkt_len = info->pkt_len;
  144. src_cid = vm_sockets_get_local_cid();
  145. src_port = vsk->local_addr.svm_port;
  146. if (!info->remote_cid) {
  147. dst_cid = vsk->remote_addr.svm_cid;
  148. dst_port = vsk->remote_addr.svm_port;
  149. } else {
  150. dst_cid = info->remote_cid;
  151. dst_port = info->remote_port;
  152. }
  153. vvs = vsk->trans;
  154. /* we can send less than pkt_len bytes */
  155. if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
  156. pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
  157. /* virtio_transport_get_credit might return less than pkt_len credit */
  158. pkt_len = virtio_transport_get_credit(vvs, pkt_len);
  159. /* Do not send zero length OP_RW pkt */
  160. if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
  161. return pkt_len;
  162. pkt = virtio_transport_alloc_pkt(info, pkt_len,
  163. src_cid, src_port,
  164. dst_cid, dst_port);
  165. if (!pkt) {
  166. virtio_transport_put_credit(vvs, pkt_len);
  167. return -ENOMEM;
  168. }
  169. virtio_transport_inc_tx_pkt(vvs, pkt);
  170. return virtio_transport_get_ops()->send_pkt(pkt);
  171. }
  172. static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
  173. struct virtio_vsock_pkt *pkt)
  174. {
  175. vvs->rx_bytes += pkt->len;
  176. }
  177. static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
  178. struct virtio_vsock_pkt *pkt)
  179. {
  180. vvs->rx_bytes -= pkt->len;
  181. vvs->fwd_cnt += pkt->len;
  182. }
  183. void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
  184. {
  185. spin_lock_bh(&vvs->tx_lock);
  186. pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
  187. pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
  188. spin_unlock_bh(&vvs->tx_lock);
  189. }
  190. EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
  191. u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
  192. {
  193. u32 ret;
  194. spin_lock_bh(&vvs->tx_lock);
  195. ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
  196. if (ret > credit)
  197. ret = credit;
  198. vvs->tx_cnt += ret;
  199. spin_unlock_bh(&vvs->tx_lock);
  200. return ret;
  201. }
  202. EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
  203. void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
  204. {
  205. spin_lock_bh(&vvs->tx_lock);
  206. vvs->tx_cnt -= credit;
  207. spin_unlock_bh(&vvs->tx_lock);
  208. }
  209. EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
  210. static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
  211. int type,
  212. struct virtio_vsock_hdr *hdr)
  213. {
  214. struct virtio_vsock_pkt_info info = {
  215. .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
  216. .type = type,
  217. .vsk = vsk,
  218. };
  219. return virtio_transport_send_pkt_info(vsk, &info);
  220. }
  221. static ssize_t
  222. virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
  223. struct msghdr *msg,
  224. size_t len)
  225. {
  226. struct virtio_vsock_sock *vvs = vsk->trans;
  227. struct virtio_vsock_pkt *pkt;
  228. size_t bytes, total = 0;
  229. int err = -EFAULT;
  230. spin_lock_bh(&vvs->rx_lock);
  231. while (total < len && !list_empty(&vvs->rx_queue)) {
  232. pkt = list_first_entry(&vvs->rx_queue,
  233. struct virtio_vsock_pkt, list);
  234. bytes = len - total;
  235. if (bytes > pkt->len - pkt->off)
  236. bytes = pkt->len - pkt->off;
  237. /* sk_lock is held by caller so no one else can dequeue.
  238. * Unlock rx_lock since memcpy_to_msg() may sleep.
  239. */
  240. spin_unlock_bh(&vvs->rx_lock);
  241. err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
  242. if (err)
  243. goto out;
  244. spin_lock_bh(&vvs->rx_lock);
  245. total += bytes;
  246. pkt->off += bytes;
  247. if (pkt->off == pkt->len) {
  248. virtio_transport_dec_rx_pkt(vvs, pkt);
  249. list_del(&pkt->list);
  250. virtio_transport_free_pkt(pkt);
  251. }
  252. }
  253. spin_unlock_bh(&vvs->rx_lock);
  254. /* Send a credit pkt to peer */
  255. virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
  256. NULL);
  257. return total;
  258. out:
  259. if (total)
  260. err = total;
  261. return err;
  262. }
  263. ssize_t
  264. virtio_transport_stream_dequeue(struct vsock_sock *vsk,
  265. struct msghdr *msg,
  266. size_t len, int flags)
  267. {
  268. if (flags & MSG_PEEK)
  269. return -EOPNOTSUPP;
  270. return virtio_transport_stream_do_dequeue(vsk, msg, len);
  271. }
  272. EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
  273. int
  274. virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
  275. struct msghdr *msg,
  276. size_t len, int flags)
  277. {
  278. return -EOPNOTSUPP;
  279. }
  280. EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
  281. s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
  282. {
  283. struct virtio_vsock_sock *vvs = vsk->trans;
  284. s64 bytes;
  285. spin_lock_bh(&vvs->rx_lock);
  286. bytes = vvs->rx_bytes;
  287. spin_unlock_bh(&vvs->rx_lock);
  288. return bytes;
  289. }
  290. EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
  291. static s64 virtio_transport_has_space(struct vsock_sock *vsk)
  292. {
  293. struct virtio_vsock_sock *vvs = vsk->trans;
  294. s64 bytes;
  295. bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
  296. if (bytes < 0)
  297. bytes = 0;
  298. return bytes;
  299. }
  300. s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
  301. {
  302. struct virtio_vsock_sock *vvs = vsk->trans;
  303. s64 bytes;
  304. spin_lock_bh(&vvs->tx_lock);
  305. bytes = virtio_transport_has_space(vsk);
  306. spin_unlock_bh(&vvs->tx_lock);
  307. return bytes;
  308. }
  309. EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
  310. int virtio_transport_do_socket_init(struct vsock_sock *vsk,
  311. struct vsock_sock *psk)
  312. {
  313. struct virtio_vsock_sock *vvs;
  314. vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
  315. if (!vvs)
  316. return -ENOMEM;
  317. vsk->trans = vvs;
  318. vvs->vsk = vsk;
  319. if (psk) {
  320. struct virtio_vsock_sock *ptrans = psk->trans;
  321. vvs->buf_size = ptrans->buf_size;
  322. vvs->buf_size_min = ptrans->buf_size_min;
  323. vvs->buf_size_max = ptrans->buf_size_max;
  324. vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
  325. } else {
  326. vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
  327. vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
  328. vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
  329. }
  330. vvs->buf_alloc = vvs->buf_size;
  331. spin_lock_init(&vvs->rx_lock);
  332. spin_lock_init(&vvs->tx_lock);
  333. INIT_LIST_HEAD(&vvs->rx_queue);
  334. return 0;
  335. }
  336. EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
  337. u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
  338. {
  339. struct virtio_vsock_sock *vvs = vsk->trans;
  340. return vvs->buf_size;
  341. }
  342. EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
  343. u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
  344. {
  345. struct virtio_vsock_sock *vvs = vsk->trans;
  346. return vvs->buf_size_min;
  347. }
  348. EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
  349. u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
  350. {
  351. struct virtio_vsock_sock *vvs = vsk->trans;
  352. return vvs->buf_size_max;
  353. }
  354. EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
  355. void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
  356. {
  357. struct virtio_vsock_sock *vvs = vsk->trans;
  358. if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
  359. val = VIRTIO_VSOCK_MAX_BUF_SIZE;
  360. if (val < vvs->buf_size_min)
  361. vvs->buf_size_min = val;
  362. if (val > vvs->buf_size_max)
  363. vvs->buf_size_max = val;
  364. vvs->buf_size = val;
  365. vvs->buf_alloc = val;
  366. }
  367. EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
  368. void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
  369. {
  370. struct virtio_vsock_sock *vvs = vsk->trans;
  371. if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
  372. val = VIRTIO_VSOCK_MAX_BUF_SIZE;
  373. if (val > vvs->buf_size)
  374. vvs->buf_size = val;
  375. vvs->buf_size_min = val;
  376. }
  377. EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
  378. void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
  379. {
  380. struct virtio_vsock_sock *vvs = vsk->trans;
  381. if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
  382. val = VIRTIO_VSOCK_MAX_BUF_SIZE;
  383. if (val < vvs->buf_size)
  384. vvs->buf_size = val;
  385. vvs->buf_size_max = val;
  386. }
  387. EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
  388. int
  389. virtio_transport_notify_poll_in(struct vsock_sock *vsk,
  390. size_t target,
  391. bool *data_ready_now)
  392. {
  393. if (vsock_stream_has_data(vsk))
  394. *data_ready_now = true;
  395. else
  396. *data_ready_now = false;
  397. return 0;
  398. }
  399. EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
  400. int
  401. virtio_transport_notify_poll_out(struct vsock_sock *vsk,
  402. size_t target,
  403. bool *space_avail_now)
  404. {
  405. s64 free_space;
  406. free_space = vsock_stream_has_space(vsk);
  407. if (free_space > 0)
  408. *space_avail_now = true;
  409. else if (free_space == 0)
  410. *space_avail_now = false;
  411. return 0;
  412. }
  413. EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
  414. int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
  415. size_t target, struct vsock_transport_recv_notify_data *data)
  416. {
  417. return 0;
  418. }
  419. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
  420. int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
  421. size_t target, struct vsock_transport_recv_notify_data *data)
  422. {
  423. return 0;
  424. }
  425. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
  426. int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
  427. size_t target, struct vsock_transport_recv_notify_data *data)
  428. {
  429. return 0;
  430. }
  431. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
  432. int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
  433. size_t target, ssize_t copied, bool data_read,
  434. struct vsock_transport_recv_notify_data *data)
  435. {
  436. return 0;
  437. }
  438. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
  439. int virtio_transport_notify_send_init(struct vsock_sock *vsk,
  440. struct vsock_transport_send_notify_data *data)
  441. {
  442. return 0;
  443. }
  444. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
  445. int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
  446. struct vsock_transport_send_notify_data *data)
  447. {
  448. return 0;
  449. }
  450. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
  451. int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
  452. struct vsock_transport_send_notify_data *data)
  453. {
  454. return 0;
  455. }
  456. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
  457. int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
  458. ssize_t written, struct vsock_transport_send_notify_data *data)
  459. {
  460. return 0;
  461. }
  462. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
  463. u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
  464. {
  465. struct virtio_vsock_sock *vvs = vsk->trans;
  466. return vvs->buf_size;
  467. }
  468. EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
  469. bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
  470. {
  471. return true;
  472. }
  473. EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
  474. bool virtio_transport_stream_allow(u32 cid, u32 port)
  475. {
  476. return true;
  477. }
  478. EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
  479. int virtio_transport_dgram_bind(struct vsock_sock *vsk,
  480. struct sockaddr_vm *addr)
  481. {
  482. return -EOPNOTSUPP;
  483. }
  484. EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
  485. bool virtio_transport_dgram_allow(u32 cid, u32 port)
  486. {
  487. return false;
  488. }
  489. EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
  490. int virtio_transport_connect(struct vsock_sock *vsk)
  491. {
  492. struct virtio_vsock_pkt_info info = {
  493. .op = VIRTIO_VSOCK_OP_REQUEST,
  494. .type = VIRTIO_VSOCK_TYPE_STREAM,
  495. .vsk = vsk,
  496. };
  497. return virtio_transport_send_pkt_info(vsk, &info);
  498. }
  499. EXPORT_SYMBOL_GPL(virtio_transport_connect);
  500. int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
  501. {
  502. struct virtio_vsock_pkt_info info = {
  503. .op = VIRTIO_VSOCK_OP_SHUTDOWN,
  504. .type = VIRTIO_VSOCK_TYPE_STREAM,
  505. .flags = (mode & RCV_SHUTDOWN ?
  506. VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
  507. (mode & SEND_SHUTDOWN ?
  508. VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
  509. .vsk = vsk,
  510. };
  511. return virtio_transport_send_pkt_info(vsk, &info);
  512. }
  513. EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
  514. int
  515. virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
  516. struct sockaddr_vm *remote_addr,
  517. struct msghdr *msg,
  518. size_t dgram_len)
  519. {
  520. return -EOPNOTSUPP;
  521. }
  522. EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
  523. ssize_t
  524. virtio_transport_stream_enqueue(struct vsock_sock *vsk,
  525. struct msghdr *msg,
  526. size_t len)
  527. {
  528. struct virtio_vsock_pkt_info info = {
  529. .op = VIRTIO_VSOCK_OP_RW,
  530. .type = VIRTIO_VSOCK_TYPE_STREAM,
  531. .msg = msg,
  532. .pkt_len = len,
  533. .vsk = vsk,
  534. };
  535. return virtio_transport_send_pkt_info(vsk, &info);
  536. }
  537. EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
  538. void virtio_transport_destruct(struct vsock_sock *vsk)
  539. {
  540. struct virtio_vsock_sock *vvs = vsk->trans;
  541. kfree(vvs);
  542. }
  543. EXPORT_SYMBOL_GPL(virtio_transport_destruct);
  544. static int virtio_transport_reset(struct vsock_sock *vsk,
  545. struct virtio_vsock_pkt *pkt)
  546. {
  547. struct virtio_vsock_pkt_info info = {
  548. .op = VIRTIO_VSOCK_OP_RST,
  549. .type = VIRTIO_VSOCK_TYPE_STREAM,
  550. .reply = !!pkt,
  551. .vsk = vsk,
  552. };
  553. /* Send RST only if the original pkt is not a RST pkt */
  554. if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
  555. return 0;
  556. return virtio_transport_send_pkt_info(vsk, &info);
  557. }
  558. /* Normally packets are associated with a socket. There may be no socket if an
  559. * attempt was made to connect to a socket that does not exist.
  560. */
  561. static int virtio_transport_reset_no_sock(const struct virtio_transport *t,
  562. struct virtio_vsock_pkt *pkt)
  563. {
  564. struct virtio_vsock_pkt *reply;
  565. struct virtio_vsock_pkt_info info = {
  566. .op = VIRTIO_VSOCK_OP_RST,
  567. .type = le16_to_cpu(pkt->hdr.type),
  568. .reply = true,
  569. };
  570. /* Send RST only if the original pkt is not a RST pkt */
  571. if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
  572. return 0;
  573. reply = virtio_transport_alloc_pkt(&info, 0,
  574. le64_to_cpu(pkt->hdr.dst_cid),
  575. le32_to_cpu(pkt->hdr.dst_port),
  576. le64_to_cpu(pkt->hdr.src_cid),
  577. le32_to_cpu(pkt->hdr.src_port));
  578. if (!reply)
  579. return -ENOMEM;
  580. if (!t) {
  581. virtio_transport_free_pkt(reply);
  582. return -ENOTCONN;
  583. }
  584. return t->send_pkt(reply);
  585. }
  586. static void virtio_transport_wait_close(struct sock *sk, long timeout)
  587. {
  588. if (timeout) {
  589. DEFINE_WAIT_FUNC(wait, woken_wake_function);
  590. add_wait_queue(sk_sleep(sk), &wait);
  591. do {
  592. if (sk_wait_event(sk, &timeout,
  593. sock_flag(sk, SOCK_DONE), &wait))
  594. break;
  595. } while (!signal_pending(current) && timeout);
  596. remove_wait_queue(sk_sleep(sk), &wait);
  597. }
  598. }
  599. static void virtio_transport_do_close(struct vsock_sock *vsk,
  600. bool cancel_timeout)
  601. {
  602. struct sock *sk = sk_vsock(vsk);
  603. sock_set_flag(sk, SOCK_DONE);
  604. vsk->peer_shutdown = SHUTDOWN_MASK;
  605. if (vsock_stream_has_data(vsk) <= 0)
  606. sk->sk_state = TCP_CLOSING;
  607. sk->sk_state_change(sk);
  608. if (vsk->close_work_scheduled &&
  609. (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
  610. vsk->close_work_scheduled = false;
  611. vsock_remove_sock(vsk);
  612. /* Release refcnt obtained when we scheduled the timeout */
  613. sock_put(sk);
  614. }
  615. }
  616. static void virtio_transport_close_timeout(struct work_struct *work)
  617. {
  618. struct vsock_sock *vsk =
  619. container_of(work, struct vsock_sock, close_work.work);
  620. struct sock *sk = sk_vsock(vsk);
  621. sock_hold(sk);
  622. lock_sock(sk);
  623. if (!sock_flag(sk, SOCK_DONE)) {
  624. (void)virtio_transport_reset(vsk, NULL);
  625. virtio_transport_do_close(vsk, false);
  626. }
  627. vsk->close_work_scheduled = false;
  628. release_sock(sk);
  629. sock_put(sk);
  630. }
  631. /* User context, vsk->sk is locked */
  632. static bool virtio_transport_close(struct vsock_sock *vsk)
  633. {
  634. struct sock *sk = &vsk->sk;
  635. if (!(sk->sk_state == TCP_ESTABLISHED ||
  636. sk->sk_state == TCP_CLOSING))
  637. return true;
  638. /* Already received SHUTDOWN from peer, reply with RST */
  639. if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
  640. (void)virtio_transport_reset(vsk, NULL);
  641. return true;
  642. }
  643. if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
  644. (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
  645. if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
  646. virtio_transport_wait_close(sk, sk->sk_lingertime);
  647. if (sock_flag(sk, SOCK_DONE)) {
  648. return true;
  649. }
  650. sock_hold(sk);
  651. INIT_DELAYED_WORK(&vsk->close_work,
  652. virtio_transport_close_timeout);
  653. vsk->close_work_scheduled = true;
  654. schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
  655. return false;
  656. }
  657. void virtio_transport_release(struct vsock_sock *vsk)
  658. {
  659. struct virtio_vsock_sock *vvs = vsk->trans;
  660. struct virtio_vsock_pkt *pkt, *tmp;
  661. struct sock *sk = &vsk->sk;
  662. bool remove_sock = true;
  663. lock_sock_nested(sk, SINGLE_DEPTH_NESTING);
  664. if (sk->sk_type == SOCK_STREAM)
  665. remove_sock = virtio_transport_close(vsk);
  666. list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) {
  667. list_del(&pkt->list);
  668. virtio_transport_free_pkt(pkt);
  669. }
  670. release_sock(sk);
  671. if (remove_sock)
  672. vsock_remove_sock(vsk);
  673. }
  674. EXPORT_SYMBOL_GPL(virtio_transport_release);
  675. static int
  676. virtio_transport_recv_connecting(struct sock *sk,
  677. struct virtio_vsock_pkt *pkt)
  678. {
  679. struct vsock_sock *vsk = vsock_sk(sk);
  680. int err;
  681. int skerr;
  682. switch (le16_to_cpu(pkt->hdr.op)) {
  683. case VIRTIO_VSOCK_OP_RESPONSE:
  684. sk->sk_state = TCP_ESTABLISHED;
  685. sk->sk_socket->state = SS_CONNECTED;
  686. vsock_insert_connected(vsk);
  687. sk->sk_state_change(sk);
  688. break;
  689. case VIRTIO_VSOCK_OP_INVALID:
  690. break;
  691. case VIRTIO_VSOCK_OP_RST:
  692. skerr = ECONNRESET;
  693. err = 0;
  694. goto destroy;
  695. default:
  696. skerr = EPROTO;
  697. err = -EINVAL;
  698. goto destroy;
  699. }
  700. return 0;
  701. destroy:
  702. virtio_transport_reset(vsk, pkt);
  703. sk->sk_state = TCP_CLOSE;
  704. sk->sk_err = skerr;
  705. sk->sk_error_report(sk);
  706. return err;
  707. }
  708. static int
  709. virtio_transport_recv_connected(struct sock *sk,
  710. struct virtio_vsock_pkt *pkt)
  711. {
  712. struct vsock_sock *vsk = vsock_sk(sk);
  713. struct virtio_vsock_sock *vvs = vsk->trans;
  714. int err = 0;
  715. switch (le16_to_cpu(pkt->hdr.op)) {
  716. case VIRTIO_VSOCK_OP_RW:
  717. pkt->len = le32_to_cpu(pkt->hdr.len);
  718. pkt->off = 0;
  719. spin_lock_bh(&vvs->rx_lock);
  720. virtio_transport_inc_rx_pkt(vvs, pkt);
  721. list_add_tail(&pkt->list, &vvs->rx_queue);
  722. spin_unlock_bh(&vvs->rx_lock);
  723. sk->sk_data_ready(sk);
  724. return err;
  725. case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
  726. sk->sk_write_space(sk);
  727. break;
  728. case VIRTIO_VSOCK_OP_SHUTDOWN:
  729. if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
  730. vsk->peer_shutdown |= RCV_SHUTDOWN;
  731. if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
  732. vsk->peer_shutdown |= SEND_SHUTDOWN;
  733. if (vsk->peer_shutdown == SHUTDOWN_MASK &&
  734. vsock_stream_has_data(vsk) <= 0)
  735. sk->sk_state = TCP_CLOSING;
  736. if (le32_to_cpu(pkt->hdr.flags))
  737. sk->sk_state_change(sk);
  738. break;
  739. case VIRTIO_VSOCK_OP_RST:
  740. virtio_transport_do_close(vsk, true);
  741. break;
  742. default:
  743. err = -EINVAL;
  744. break;
  745. }
  746. virtio_transport_free_pkt(pkt);
  747. return err;
  748. }
  749. static void
  750. virtio_transport_recv_disconnecting(struct sock *sk,
  751. struct virtio_vsock_pkt *pkt)
  752. {
  753. struct vsock_sock *vsk = vsock_sk(sk);
  754. if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
  755. virtio_transport_do_close(vsk, true);
  756. }
  757. static int
  758. virtio_transport_send_response(struct vsock_sock *vsk,
  759. struct virtio_vsock_pkt *pkt)
  760. {
  761. struct virtio_vsock_pkt_info info = {
  762. .op = VIRTIO_VSOCK_OP_RESPONSE,
  763. .type = VIRTIO_VSOCK_TYPE_STREAM,
  764. .remote_cid = le64_to_cpu(pkt->hdr.src_cid),
  765. .remote_port = le32_to_cpu(pkt->hdr.src_port),
  766. .reply = true,
  767. .vsk = vsk,
  768. };
  769. return virtio_transport_send_pkt_info(vsk, &info);
  770. }
  771. /* Handle server socket */
  772. static int
  773. virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
  774. {
  775. struct vsock_sock *vsk = vsock_sk(sk);
  776. struct vsock_sock *vchild;
  777. struct sock *child;
  778. if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
  779. virtio_transport_reset(vsk, pkt);
  780. return -EINVAL;
  781. }
  782. if (sk_acceptq_is_full(sk)) {
  783. virtio_transport_reset(vsk, pkt);
  784. return -ENOMEM;
  785. }
  786. child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
  787. sk->sk_type, 0);
  788. if (!child) {
  789. virtio_transport_reset(vsk, pkt);
  790. return -ENOMEM;
  791. }
  792. sk->sk_ack_backlog++;
  793. lock_sock_nested(child, SINGLE_DEPTH_NESTING);
  794. child->sk_state = TCP_ESTABLISHED;
  795. vchild = vsock_sk(child);
  796. vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
  797. le32_to_cpu(pkt->hdr.dst_port));
  798. vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
  799. le32_to_cpu(pkt->hdr.src_port));
  800. vsock_insert_connected(vchild);
  801. vsock_enqueue_accept(sk, child);
  802. virtio_transport_send_response(vchild, pkt);
  803. release_sock(child);
  804. sk->sk_data_ready(sk);
  805. return 0;
  806. }
  807. static bool virtio_transport_space_update(struct sock *sk,
  808. struct virtio_vsock_pkt *pkt)
  809. {
  810. struct vsock_sock *vsk = vsock_sk(sk);
  811. struct virtio_vsock_sock *vvs = vsk->trans;
  812. bool space_available;
  813. /* buf_alloc and fwd_cnt is always included in the hdr */
  814. spin_lock_bh(&vvs->tx_lock);
  815. vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
  816. vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
  817. space_available = virtio_transport_has_space(vsk);
  818. spin_unlock_bh(&vvs->tx_lock);
  819. return space_available;
  820. }
  821. /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
  822. * lock.
  823. */
  824. void virtio_transport_recv_pkt(struct virtio_transport *t,
  825. struct virtio_vsock_pkt *pkt)
  826. {
  827. struct sockaddr_vm src, dst;
  828. struct vsock_sock *vsk;
  829. struct sock *sk;
  830. bool space_available;
  831. vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
  832. le32_to_cpu(pkt->hdr.src_port));
  833. vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
  834. le32_to_cpu(pkt->hdr.dst_port));
  835. trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
  836. dst.svm_cid, dst.svm_port,
  837. le32_to_cpu(pkt->hdr.len),
  838. le16_to_cpu(pkt->hdr.type),
  839. le16_to_cpu(pkt->hdr.op),
  840. le32_to_cpu(pkt->hdr.flags),
  841. le32_to_cpu(pkt->hdr.buf_alloc),
  842. le32_to_cpu(pkt->hdr.fwd_cnt));
  843. if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
  844. (void)virtio_transport_reset_no_sock(t, pkt);
  845. goto free_pkt;
  846. }
  847. /* The socket must be in connected or bound table
  848. * otherwise send reset back
  849. */
  850. sk = vsock_find_connected_socket(&src, &dst);
  851. if (!sk) {
  852. sk = vsock_find_bound_socket(&dst);
  853. if (!sk) {
  854. (void)virtio_transport_reset_no_sock(t, pkt);
  855. goto free_pkt;
  856. }
  857. }
  858. vsk = vsock_sk(sk);
  859. lock_sock(sk);
  860. space_available = virtio_transport_space_update(sk, pkt);
  861. /* Update CID in case it has changed after a transport reset event */
  862. vsk->local_addr.svm_cid = dst.svm_cid;
  863. if (space_available)
  864. sk->sk_write_space(sk);
  865. switch (sk->sk_state) {
  866. case TCP_LISTEN:
  867. virtio_transport_recv_listen(sk, pkt);
  868. virtio_transport_free_pkt(pkt);
  869. break;
  870. case TCP_SYN_SENT:
  871. virtio_transport_recv_connecting(sk, pkt);
  872. virtio_transport_free_pkt(pkt);
  873. break;
  874. case TCP_ESTABLISHED:
  875. virtio_transport_recv_connected(sk, pkt);
  876. break;
  877. case TCP_CLOSING:
  878. virtio_transport_recv_disconnecting(sk, pkt);
  879. virtio_transport_free_pkt(pkt);
  880. break;
  881. default:
  882. (void)virtio_transport_reset_no_sock(t, pkt);
  883. virtio_transport_free_pkt(pkt);
  884. break;
  885. }
  886. release_sock(sk);
  887. /* Release refcnt obtained when we fetched this socket out of the
  888. * bound or connected list.
  889. */
  890. sock_put(sk);
  891. return;
  892. free_pkt:
  893. virtio_transport_free_pkt(pkt);
  894. }
  895. EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
  896. void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
  897. {
  898. kfree(pkt->buf);
  899. kfree(pkt);
  900. }
  901. EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
  902. MODULE_LICENSE("GPL v2");
  903. MODULE_AUTHOR("Asias He");
  904. MODULE_DESCRIPTION("common code for virtio vsock");