virtio_transport_common.c 24 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000
  1. /*
  2. * common code for virtio vsock
  3. *
  4. * Copyright (C) 2013-2015 Red Hat, Inc.
  5. * Author: Asias He <asias@redhat.com>
  6. * Stefan Hajnoczi <stefanha@redhat.com>
  7. *
  8. * This work is licensed under the terms of the GNU GPL, version 2.
  9. */
  10. #include <linux/spinlock.h>
  11. #include <linux/module.h>
  12. #include <linux/ctype.h>
  13. #include <linux/list.h>
  14. #include <linux/virtio.h>
  15. #include <linux/virtio_ids.h>
  16. #include <linux/virtio_config.h>
  17. #include <linux/virtio_vsock.h>
  18. #include <net/sock.h>
  19. #include <net/af_vsock.h>
  20. #define CREATE_TRACE_POINTS
  21. #include <trace/events/vsock_virtio_transport_common.h>
  22. /* How long to wait for graceful shutdown of a connection */
  23. #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
  24. static const struct virtio_transport *virtio_transport_get_ops(void)
  25. {
  26. const struct vsock_transport *t = vsock_core_get_transport();
  27. return container_of(t, struct virtio_transport, transport);
  28. }
  29. struct virtio_vsock_pkt *
  30. virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
  31. size_t len,
  32. u32 src_cid,
  33. u32 src_port,
  34. u32 dst_cid,
  35. u32 dst_port)
  36. {
  37. struct virtio_vsock_pkt *pkt;
  38. int err;
  39. pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
  40. if (!pkt)
  41. return NULL;
  42. pkt->hdr.type = cpu_to_le16(info->type);
  43. pkt->hdr.op = cpu_to_le16(info->op);
  44. pkt->hdr.src_cid = cpu_to_le64(src_cid);
  45. pkt->hdr.dst_cid = cpu_to_le64(dst_cid);
  46. pkt->hdr.src_port = cpu_to_le32(src_port);
  47. pkt->hdr.dst_port = cpu_to_le32(dst_port);
  48. pkt->hdr.flags = cpu_to_le32(info->flags);
  49. pkt->len = len;
  50. pkt->hdr.len = cpu_to_le32(len);
  51. pkt->reply = info->reply;
  52. pkt->vsk = info->vsk;
  53. if (info->msg && len > 0) {
  54. pkt->buf = kmalloc(len, GFP_KERNEL);
  55. if (!pkt->buf)
  56. goto out_pkt;
  57. err = memcpy_from_msg(pkt->buf, info->msg, len);
  58. if (err)
  59. goto out;
  60. }
  61. trace_virtio_transport_alloc_pkt(src_cid, src_port,
  62. dst_cid, dst_port,
  63. len,
  64. info->type,
  65. info->op,
  66. info->flags);
  67. return pkt;
  68. out:
  69. kfree(pkt->buf);
  70. out_pkt:
  71. kfree(pkt);
  72. return NULL;
  73. }
  74. EXPORT_SYMBOL_GPL(virtio_transport_alloc_pkt);
  75. static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
  76. struct virtio_vsock_pkt_info *info)
  77. {
  78. u32 src_cid, src_port, dst_cid, dst_port;
  79. struct virtio_vsock_sock *vvs;
  80. struct virtio_vsock_pkt *pkt;
  81. u32 pkt_len = info->pkt_len;
  82. src_cid = vm_sockets_get_local_cid();
  83. src_port = vsk->local_addr.svm_port;
  84. if (!info->remote_cid) {
  85. dst_cid = vsk->remote_addr.svm_cid;
  86. dst_port = vsk->remote_addr.svm_port;
  87. } else {
  88. dst_cid = info->remote_cid;
  89. dst_port = info->remote_port;
  90. }
  91. vvs = vsk->trans;
  92. /* we can send less than pkt_len bytes */
  93. if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
  94. pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
  95. /* virtio_transport_get_credit might return less than pkt_len credit */
  96. pkt_len = virtio_transport_get_credit(vvs, pkt_len);
  97. /* Do not send zero length OP_RW pkt */
  98. if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
  99. return pkt_len;
  100. pkt = virtio_transport_alloc_pkt(info, pkt_len,
  101. src_cid, src_port,
  102. dst_cid, dst_port);
  103. if (!pkt) {
  104. virtio_transport_put_credit(vvs, pkt_len);
  105. return -ENOMEM;
  106. }
  107. virtio_transport_inc_tx_pkt(vvs, pkt);
  108. return virtio_transport_get_ops()->send_pkt(pkt);
  109. }
  110. static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
  111. struct virtio_vsock_pkt *pkt)
  112. {
  113. vvs->rx_bytes += pkt->len;
  114. }
  115. static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
  116. struct virtio_vsock_pkt *pkt)
  117. {
  118. vvs->rx_bytes -= pkt->len;
  119. vvs->fwd_cnt += pkt->len;
  120. }
  121. void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
  122. {
  123. spin_lock_bh(&vvs->tx_lock);
  124. pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
  125. pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
  126. spin_unlock_bh(&vvs->tx_lock);
  127. }
  128. EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
  129. u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
  130. {
  131. u32 ret;
  132. spin_lock_bh(&vvs->tx_lock);
  133. ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
  134. if (ret > credit)
  135. ret = credit;
  136. vvs->tx_cnt += ret;
  137. spin_unlock_bh(&vvs->tx_lock);
  138. return ret;
  139. }
  140. EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
  141. void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
  142. {
  143. spin_lock_bh(&vvs->tx_lock);
  144. vvs->tx_cnt -= credit;
  145. spin_unlock_bh(&vvs->tx_lock);
  146. }
  147. EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
  148. static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
  149. int type,
  150. struct virtio_vsock_hdr *hdr)
  151. {
  152. struct virtio_vsock_pkt_info info = {
  153. .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
  154. .type = type,
  155. .vsk = vsk,
  156. };
  157. return virtio_transport_send_pkt_info(vsk, &info);
  158. }
  159. static ssize_t
  160. virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
  161. struct msghdr *msg,
  162. size_t len)
  163. {
  164. struct virtio_vsock_sock *vvs = vsk->trans;
  165. struct virtio_vsock_pkt *pkt;
  166. size_t bytes, total = 0;
  167. int err = -EFAULT;
  168. spin_lock_bh(&vvs->rx_lock);
  169. while (total < len && !list_empty(&vvs->rx_queue)) {
  170. pkt = list_first_entry(&vvs->rx_queue,
  171. struct virtio_vsock_pkt, list);
  172. bytes = len - total;
  173. if (bytes > pkt->len - pkt->off)
  174. bytes = pkt->len - pkt->off;
  175. /* sk_lock is held by caller so no one else can dequeue.
  176. * Unlock rx_lock since memcpy_to_msg() may sleep.
  177. */
  178. spin_unlock_bh(&vvs->rx_lock);
  179. err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
  180. if (err)
  181. goto out;
  182. spin_lock_bh(&vvs->rx_lock);
  183. total += bytes;
  184. pkt->off += bytes;
  185. if (pkt->off == pkt->len) {
  186. virtio_transport_dec_rx_pkt(vvs, pkt);
  187. list_del(&pkt->list);
  188. virtio_transport_free_pkt(pkt);
  189. }
  190. }
  191. spin_unlock_bh(&vvs->rx_lock);
  192. /* Send a credit pkt to peer */
  193. virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
  194. NULL);
  195. return total;
  196. out:
  197. if (total)
  198. err = total;
  199. return err;
  200. }
  201. ssize_t
  202. virtio_transport_stream_dequeue(struct vsock_sock *vsk,
  203. struct msghdr *msg,
  204. size_t len, int flags)
  205. {
  206. if (flags & MSG_PEEK)
  207. return -EOPNOTSUPP;
  208. return virtio_transport_stream_do_dequeue(vsk, msg, len);
  209. }
  210. EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
  211. int
  212. virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
  213. struct msghdr *msg,
  214. size_t len, int flags)
  215. {
  216. return -EOPNOTSUPP;
  217. }
  218. EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
  219. s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
  220. {
  221. struct virtio_vsock_sock *vvs = vsk->trans;
  222. s64 bytes;
  223. spin_lock_bh(&vvs->rx_lock);
  224. bytes = vvs->rx_bytes;
  225. spin_unlock_bh(&vvs->rx_lock);
  226. return bytes;
  227. }
  228. EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
  229. static s64 virtio_transport_has_space(struct vsock_sock *vsk)
  230. {
  231. struct virtio_vsock_sock *vvs = vsk->trans;
  232. s64 bytes;
  233. bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
  234. if (bytes < 0)
  235. bytes = 0;
  236. return bytes;
  237. }
  238. s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
  239. {
  240. struct virtio_vsock_sock *vvs = vsk->trans;
  241. s64 bytes;
  242. spin_lock_bh(&vvs->tx_lock);
  243. bytes = virtio_transport_has_space(vsk);
  244. spin_unlock_bh(&vvs->tx_lock);
  245. return bytes;
  246. }
  247. EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
  248. int virtio_transport_do_socket_init(struct vsock_sock *vsk,
  249. struct vsock_sock *psk)
  250. {
  251. struct virtio_vsock_sock *vvs;
  252. vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
  253. if (!vvs)
  254. return -ENOMEM;
  255. vsk->trans = vvs;
  256. vvs->vsk = vsk;
  257. if (psk) {
  258. struct virtio_vsock_sock *ptrans = psk->trans;
  259. vvs->buf_size = ptrans->buf_size;
  260. vvs->buf_size_min = ptrans->buf_size_min;
  261. vvs->buf_size_max = ptrans->buf_size_max;
  262. vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
  263. } else {
  264. vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
  265. vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
  266. vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
  267. }
  268. vvs->buf_alloc = vvs->buf_size;
  269. spin_lock_init(&vvs->rx_lock);
  270. spin_lock_init(&vvs->tx_lock);
  271. INIT_LIST_HEAD(&vvs->rx_queue);
  272. return 0;
  273. }
  274. EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
  275. u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
  276. {
  277. struct virtio_vsock_sock *vvs = vsk->trans;
  278. return vvs->buf_size;
  279. }
  280. EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
  281. u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
  282. {
  283. struct virtio_vsock_sock *vvs = vsk->trans;
  284. return vvs->buf_size_min;
  285. }
  286. EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
  287. u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
  288. {
  289. struct virtio_vsock_sock *vvs = vsk->trans;
  290. return vvs->buf_size_max;
  291. }
  292. EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
  293. void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
  294. {
  295. struct virtio_vsock_sock *vvs = vsk->trans;
  296. if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
  297. val = VIRTIO_VSOCK_MAX_BUF_SIZE;
  298. if (val < vvs->buf_size_min)
  299. vvs->buf_size_min = val;
  300. if (val > vvs->buf_size_max)
  301. vvs->buf_size_max = val;
  302. vvs->buf_size = val;
  303. vvs->buf_alloc = val;
  304. }
  305. EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
  306. void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
  307. {
  308. struct virtio_vsock_sock *vvs = vsk->trans;
  309. if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
  310. val = VIRTIO_VSOCK_MAX_BUF_SIZE;
  311. if (val > vvs->buf_size)
  312. vvs->buf_size = val;
  313. vvs->buf_size_min = val;
  314. }
  315. EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
  316. void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
  317. {
  318. struct virtio_vsock_sock *vvs = vsk->trans;
  319. if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
  320. val = VIRTIO_VSOCK_MAX_BUF_SIZE;
  321. if (val < vvs->buf_size)
  322. vvs->buf_size = val;
  323. vvs->buf_size_max = val;
  324. }
  325. EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
  326. int
  327. virtio_transport_notify_poll_in(struct vsock_sock *vsk,
  328. size_t target,
  329. bool *data_ready_now)
  330. {
  331. if (vsock_stream_has_data(vsk))
  332. *data_ready_now = true;
  333. else
  334. *data_ready_now = false;
  335. return 0;
  336. }
  337. EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
  338. int
  339. virtio_transport_notify_poll_out(struct vsock_sock *vsk,
  340. size_t target,
  341. bool *space_avail_now)
  342. {
  343. s64 free_space;
  344. free_space = vsock_stream_has_space(vsk);
  345. if (free_space > 0)
  346. *space_avail_now = true;
  347. else if (free_space == 0)
  348. *space_avail_now = false;
  349. return 0;
  350. }
  351. EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
  352. int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
  353. size_t target, struct vsock_transport_recv_notify_data *data)
  354. {
  355. return 0;
  356. }
  357. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
  358. int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
  359. size_t target, struct vsock_transport_recv_notify_data *data)
  360. {
  361. return 0;
  362. }
  363. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
  364. int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
  365. size_t target, struct vsock_transport_recv_notify_data *data)
  366. {
  367. return 0;
  368. }
  369. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
  370. int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
  371. size_t target, ssize_t copied, bool data_read,
  372. struct vsock_transport_recv_notify_data *data)
  373. {
  374. return 0;
  375. }
  376. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
  377. int virtio_transport_notify_send_init(struct vsock_sock *vsk,
  378. struct vsock_transport_send_notify_data *data)
  379. {
  380. return 0;
  381. }
  382. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
  383. int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
  384. struct vsock_transport_send_notify_data *data)
  385. {
  386. return 0;
  387. }
  388. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
  389. int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
  390. struct vsock_transport_send_notify_data *data)
  391. {
  392. return 0;
  393. }
  394. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
  395. int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
  396. ssize_t written, struct vsock_transport_send_notify_data *data)
  397. {
  398. return 0;
  399. }
  400. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
  401. u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
  402. {
  403. struct virtio_vsock_sock *vvs = vsk->trans;
  404. return vvs->buf_size;
  405. }
  406. EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
  407. bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
  408. {
  409. return true;
  410. }
  411. EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
  412. bool virtio_transport_stream_allow(u32 cid, u32 port)
  413. {
  414. return true;
  415. }
  416. EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
  417. int virtio_transport_dgram_bind(struct vsock_sock *vsk,
  418. struct sockaddr_vm *addr)
  419. {
  420. return -EOPNOTSUPP;
  421. }
  422. EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
  423. bool virtio_transport_dgram_allow(u32 cid, u32 port)
  424. {
  425. return false;
  426. }
  427. EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
  428. int virtio_transport_connect(struct vsock_sock *vsk)
  429. {
  430. struct virtio_vsock_pkt_info info = {
  431. .op = VIRTIO_VSOCK_OP_REQUEST,
  432. .type = VIRTIO_VSOCK_TYPE_STREAM,
  433. .vsk = vsk,
  434. };
  435. return virtio_transport_send_pkt_info(vsk, &info);
  436. }
  437. EXPORT_SYMBOL_GPL(virtio_transport_connect);
  438. int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
  439. {
  440. struct virtio_vsock_pkt_info info = {
  441. .op = VIRTIO_VSOCK_OP_SHUTDOWN,
  442. .type = VIRTIO_VSOCK_TYPE_STREAM,
  443. .flags = (mode & RCV_SHUTDOWN ?
  444. VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
  445. (mode & SEND_SHUTDOWN ?
  446. VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
  447. .vsk = vsk,
  448. };
  449. return virtio_transport_send_pkt_info(vsk, &info);
  450. }
  451. EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
  452. int
  453. virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
  454. struct sockaddr_vm *remote_addr,
  455. struct msghdr *msg,
  456. size_t dgram_len)
  457. {
  458. return -EOPNOTSUPP;
  459. }
  460. EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
  461. ssize_t
  462. virtio_transport_stream_enqueue(struct vsock_sock *vsk,
  463. struct msghdr *msg,
  464. size_t len)
  465. {
  466. struct virtio_vsock_pkt_info info = {
  467. .op = VIRTIO_VSOCK_OP_RW,
  468. .type = VIRTIO_VSOCK_TYPE_STREAM,
  469. .msg = msg,
  470. .pkt_len = len,
  471. .vsk = vsk,
  472. };
  473. return virtio_transport_send_pkt_info(vsk, &info);
  474. }
  475. EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
  476. void virtio_transport_destruct(struct vsock_sock *vsk)
  477. {
  478. struct virtio_vsock_sock *vvs = vsk->trans;
  479. kfree(vvs);
  480. }
  481. EXPORT_SYMBOL_GPL(virtio_transport_destruct);
  482. static int virtio_transport_reset(struct vsock_sock *vsk,
  483. struct virtio_vsock_pkt *pkt)
  484. {
  485. struct virtio_vsock_pkt_info info = {
  486. .op = VIRTIO_VSOCK_OP_RST,
  487. .type = VIRTIO_VSOCK_TYPE_STREAM,
  488. .reply = !!pkt,
  489. .vsk = vsk,
  490. };
  491. /* Send RST only if the original pkt is not a RST pkt */
  492. if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
  493. return 0;
  494. return virtio_transport_send_pkt_info(vsk, &info);
  495. }
  496. /* Normally packets are associated with a socket. There may be no socket if an
  497. * attempt was made to connect to a socket that does not exist.
  498. */
  499. static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
  500. {
  501. struct virtio_vsock_pkt_info info = {
  502. .op = VIRTIO_VSOCK_OP_RST,
  503. .type = le16_to_cpu(pkt->hdr.type),
  504. .reply = true,
  505. };
  506. /* Send RST only if the original pkt is not a RST pkt */
  507. if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
  508. return 0;
  509. pkt = virtio_transport_alloc_pkt(&info, 0,
  510. le64_to_cpu(pkt->hdr.dst_cid),
  511. le32_to_cpu(pkt->hdr.dst_port),
  512. le64_to_cpu(pkt->hdr.src_cid),
  513. le32_to_cpu(pkt->hdr.src_port));
  514. if (!pkt)
  515. return -ENOMEM;
  516. return virtio_transport_get_ops()->send_pkt(pkt);
  517. }
  518. static void virtio_transport_wait_close(struct sock *sk, long timeout)
  519. {
  520. if (timeout) {
  521. DEFINE_WAIT(wait);
  522. do {
  523. prepare_to_wait(sk_sleep(sk), &wait,
  524. TASK_INTERRUPTIBLE);
  525. if (sk_wait_event(sk, &timeout,
  526. sock_flag(sk, SOCK_DONE)))
  527. break;
  528. } while (!signal_pending(current) && timeout);
  529. finish_wait(sk_sleep(sk), &wait);
  530. }
  531. }
  532. static void virtio_transport_do_close(struct vsock_sock *vsk,
  533. bool cancel_timeout)
  534. {
  535. struct sock *sk = sk_vsock(vsk);
  536. sock_set_flag(sk, SOCK_DONE);
  537. vsk->peer_shutdown = SHUTDOWN_MASK;
  538. if (vsock_stream_has_data(vsk) <= 0)
  539. sk->sk_state = SS_DISCONNECTING;
  540. sk->sk_state_change(sk);
  541. if (vsk->close_work_scheduled &&
  542. (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
  543. vsk->close_work_scheduled = false;
  544. vsock_remove_sock(vsk);
  545. /* Release refcnt obtained when we scheduled the timeout */
  546. sock_put(sk);
  547. }
  548. }
  549. static void virtio_transport_close_timeout(struct work_struct *work)
  550. {
  551. struct vsock_sock *vsk =
  552. container_of(work, struct vsock_sock, close_work.work);
  553. struct sock *sk = sk_vsock(vsk);
  554. sock_hold(sk);
  555. lock_sock(sk);
  556. if (!sock_flag(sk, SOCK_DONE)) {
  557. (void)virtio_transport_reset(vsk, NULL);
  558. virtio_transport_do_close(vsk, false);
  559. }
  560. vsk->close_work_scheduled = false;
  561. release_sock(sk);
  562. sock_put(sk);
  563. }
  564. /* User context, vsk->sk is locked */
  565. static bool virtio_transport_close(struct vsock_sock *vsk)
  566. {
  567. struct sock *sk = &vsk->sk;
  568. if (!(sk->sk_state == SS_CONNECTED ||
  569. sk->sk_state == SS_DISCONNECTING))
  570. return true;
  571. /* Already received SHUTDOWN from peer, reply with RST */
  572. if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
  573. (void)virtio_transport_reset(vsk, NULL);
  574. return true;
  575. }
  576. if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
  577. (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
  578. if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
  579. virtio_transport_wait_close(sk, sk->sk_lingertime);
  580. if (sock_flag(sk, SOCK_DONE)) {
  581. return true;
  582. }
  583. sock_hold(sk);
  584. INIT_DELAYED_WORK(&vsk->close_work,
  585. virtio_transport_close_timeout);
  586. vsk->close_work_scheduled = true;
  587. schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
  588. return false;
  589. }
  590. void virtio_transport_release(struct vsock_sock *vsk)
  591. {
  592. struct sock *sk = &vsk->sk;
  593. bool remove_sock = true;
  594. lock_sock(sk);
  595. if (sk->sk_type == SOCK_STREAM)
  596. remove_sock = virtio_transport_close(vsk);
  597. release_sock(sk);
  598. if (remove_sock)
  599. vsock_remove_sock(vsk);
  600. }
  601. EXPORT_SYMBOL_GPL(virtio_transport_release);
  602. static int
  603. virtio_transport_recv_connecting(struct sock *sk,
  604. struct virtio_vsock_pkt *pkt)
  605. {
  606. struct vsock_sock *vsk = vsock_sk(sk);
  607. int err;
  608. int skerr;
  609. switch (le16_to_cpu(pkt->hdr.op)) {
  610. case VIRTIO_VSOCK_OP_RESPONSE:
  611. sk->sk_state = SS_CONNECTED;
  612. sk->sk_socket->state = SS_CONNECTED;
  613. vsock_insert_connected(vsk);
  614. sk->sk_state_change(sk);
  615. break;
  616. case VIRTIO_VSOCK_OP_INVALID:
  617. break;
  618. case VIRTIO_VSOCK_OP_RST:
  619. skerr = ECONNRESET;
  620. err = 0;
  621. goto destroy;
  622. default:
  623. skerr = EPROTO;
  624. err = -EINVAL;
  625. goto destroy;
  626. }
  627. return 0;
  628. destroy:
  629. virtio_transport_reset(vsk, pkt);
  630. sk->sk_state = SS_UNCONNECTED;
  631. sk->sk_err = skerr;
  632. sk->sk_error_report(sk);
  633. return err;
  634. }
  635. static int
  636. virtio_transport_recv_connected(struct sock *sk,
  637. struct virtio_vsock_pkt *pkt)
  638. {
  639. struct vsock_sock *vsk = vsock_sk(sk);
  640. struct virtio_vsock_sock *vvs = vsk->trans;
  641. int err = 0;
  642. switch (le16_to_cpu(pkt->hdr.op)) {
  643. case VIRTIO_VSOCK_OP_RW:
  644. pkt->len = le32_to_cpu(pkt->hdr.len);
  645. pkt->off = 0;
  646. spin_lock_bh(&vvs->rx_lock);
  647. virtio_transport_inc_rx_pkt(vvs, pkt);
  648. list_add_tail(&pkt->list, &vvs->rx_queue);
  649. spin_unlock_bh(&vvs->rx_lock);
  650. sk->sk_data_ready(sk);
  651. return err;
  652. case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
  653. sk->sk_write_space(sk);
  654. break;
  655. case VIRTIO_VSOCK_OP_SHUTDOWN:
  656. if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
  657. vsk->peer_shutdown |= RCV_SHUTDOWN;
  658. if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
  659. vsk->peer_shutdown |= SEND_SHUTDOWN;
  660. if (vsk->peer_shutdown == SHUTDOWN_MASK &&
  661. vsock_stream_has_data(vsk) <= 0)
  662. sk->sk_state = SS_DISCONNECTING;
  663. if (le32_to_cpu(pkt->hdr.flags))
  664. sk->sk_state_change(sk);
  665. break;
  666. case VIRTIO_VSOCK_OP_RST:
  667. virtio_transport_do_close(vsk, true);
  668. break;
  669. default:
  670. err = -EINVAL;
  671. break;
  672. }
  673. virtio_transport_free_pkt(pkt);
  674. return err;
  675. }
  676. static void
  677. virtio_transport_recv_disconnecting(struct sock *sk,
  678. struct virtio_vsock_pkt *pkt)
  679. {
  680. struct vsock_sock *vsk = vsock_sk(sk);
  681. if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
  682. virtio_transport_do_close(vsk, true);
  683. }
  684. static int
  685. virtio_transport_send_response(struct vsock_sock *vsk,
  686. struct virtio_vsock_pkt *pkt)
  687. {
  688. struct virtio_vsock_pkt_info info = {
  689. .op = VIRTIO_VSOCK_OP_RESPONSE,
  690. .type = VIRTIO_VSOCK_TYPE_STREAM,
  691. .remote_cid = le64_to_cpu(pkt->hdr.src_cid),
  692. .remote_port = le32_to_cpu(pkt->hdr.src_port),
  693. .reply = true,
  694. .vsk = vsk,
  695. };
  696. return virtio_transport_send_pkt_info(vsk, &info);
  697. }
  698. /* Handle server socket */
  699. static int
  700. virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
  701. {
  702. struct vsock_sock *vsk = vsock_sk(sk);
  703. struct vsock_sock *vchild;
  704. struct sock *child;
  705. if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
  706. virtio_transport_reset(vsk, pkt);
  707. return -EINVAL;
  708. }
  709. if (sk_acceptq_is_full(sk)) {
  710. virtio_transport_reset(vsk, pkt);
  711. return -ENOMEM;
  712. }
  713. child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
  714. sk->sk_type, 0);
  715. if (!child) {
  716. virtio_transport_reset(vsk, pkt);
  717. return -ENOMEM;
  718. }
  719. sk->sk_ack_backlog++;
  720. lock_sock_nested(child, SINGLE_DEPTH_NESTING);
  721. child->sk_state = SS_CONNECTED;
  722. vchild = vsock_sk(child);
  723. vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
  724. le32_to_cpu(pkt->hdr.dst_port));
  725. vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
  726. le32_to_cpu(pkt->hdr.src_port));
  727. vsock_insert_connected(vchild);
  728. vsock_enqueue_accept(sk, child);
  729. virtio_transport_send_response(vchild, pkt);
  730. release_sock(child);
  731. sk->sk_data_ready(sk);
  732. return 0;
  733. }
  734. static bool virtio_transport_space_update(struct sock *sk,
  735. struct virtio_vsock_pkt *pkt)
  736. {
  737. struct vsock_sock *vsk = vsock_sk(sk);
  738. struct virtio_vsock_sock *vvs = vsk->trans;
  739. bool space_available;
  740. /* buf_alloc and fwd_cnt is always included in the hdr */
  741. spin_lock_bh(&vvs->tx_lock);
  742. vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
  743. vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
  744. space_available = virtio_transport_has_space(vsk);
  745. spin_unlock_bh(&vvs->tx_lock);
  746. return space_available;
  747. }
  748. /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
  749. * lock.
  750. */
  751. void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
  752. {
  753. struct sockaddr_vm src, dst;
  754. struct vsock_sock *vsk;
  755. struct sock *sk;
  756. bool space_available;
  757. vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
  758. le32_to_cpu(pkt->hdr.src_port));
  759. vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
  760. le32_to_cpu(pkt->hdr.dst_port));
  761. trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
  762. dst.svm_cid, dst.svm_port,
  763. le32_to_cpu(pkt->hdr.len),
  764. le16_to_cpu(pkt->hdr.type),
  765. le16_to_cpu(pkt->hdr.op),
  766. le32_to_cpu(pkt->hdr.flags),
  767. le32_to_cpu(pkt->hdr.buf_alloc),
  768. le32_to_cpu(pkt->hdr.fwd_cnt));
  769. if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
  770. (void)virtio_transport_reset_no_sock(pkt);
  771. goto free_pkt;
  772. }
  773. /* The socket must be in connected or bound table
  774. * otherwise send reset back
  775. */
  776. sk = vsock_find_connected_socket(&src, &dst);
  777. if (!sk) {
  778. sk = vsock_find_bound_socket(&dst);
  779. if (!sk) {
  780. (void)virtio_transport_reset_no_sock(pkt);
  781. goto free_pkt;
  782. }
  783. }
  784. vsk = vsock_sk(sk);
  785. space_available = virtio_transport_space_update(sk, pkt);
  786. lock_sock(sk);
  787. /* Update CID in case it has changed after a transport reset event */
  788. vsk->local_addr.svm_cid = dst.svm_cid;
  789. if (space_available)
  790. sk->sk_write_space(sk);
  791. switch (sk->sk_state) {
  792. case VSOCK_SS_LISTEN:
  793. virtio_transport_recv_listen(sk, pkt);
  794. virtio_transport_free_pkt(pkt);
  795. break;
  796. case SS_CONNECTING:
  797. virtio_transport_recv_connecting(sk, pkt);
  798. virtio_transport_free_pkt(pkt);
  799. break;
  800. case SS_CONNECTED:
  801. virtio_transport_recv_connected(sk, pkt);
  802. break;
  803. case SS_DISCONNECTING:
  804. virtio_transport_recv_disconnecting(sk, pkt);
  805. virtio_transport_free_pkt(pkt);
  806. break;
  807. default:
  808. virtio_transport_free_pkt(pkt);
  809. break;
  810. }
  811. release_sock(sk);
  812. /* Release refcnt obtained when we fetched this socket out of the
  813. * bound or connected list.
  814. */
  815. sock_put(sk);
  816. return;
  817. free_pkt:
  818. virtio_transport_free_pkt(pkt);
  819. }
  820. EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
  821. void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
  822. {
  823. kfree(pkt->buf);
  824. kfree(pkt);
  825. }
  826. EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
  827. MODULE_LICENSE("GPL v2");
  828. MODULE_AUTHOR("Asias He");
  829. MODULE_DESCRIPTION("common code for virtio vsock");