// SPDX-License-Identifier: GPL-2.0-only
/*
* vhost transport for vsock
*
* Copyright (C) 2013-2015 Red Hat, Inc.
* Author: Asias He <asias@redhat.com>
* Stefan Hajnoczi <stefanha@redhat.com>
*/
#include <linux/miscdevice.h>
#include <linux/atomic.h>
#include <linux/module.h>
#include <linux/mutex.h>
#include <linux/vmalloc.h>
#include <net/sock.h>
#include <linux/virtio_vsock.h>
#include <linux/vhost.h>
#include <linux/hashtable.h>
#include <net/af_vsock.h>
#include "vhost.h"
#define VHOST_VSOCK_DEFAULT_HOST_CID 2
/* Max number of bytes transferred before requeueing the job.
* Using this limit prevents one virtqueue from starving others. */
#define VHOST_VSOCK_WEIGHT 0x80000
/* Max number of packets transferred before requeueing the job.
* Using this limit prevents one virtqueue from starving others with
* small pkts.
*/
#define VHOST_VSOCK_PKT_WEIGHT 256
static const int vhost_vsock_bits[] = {
VHOST_FEATURES,
VIRTIO_F_ACCESS_PLATFORM,
VIRTIO_VSOCK_F_SEQPACKET
};
#define VHOST_VSOCK_FEATURES VHOST_FEATURES_U64(vhost_vsock_bits, 0)
enum {
VHOST_VSOCK_BACKEND_FEATURES = (1ULL << VHOST_BACKEND_F_IOTLB_MSG_V2)
};
/* Used to track all the vhost_vsock instances on the system. */
static DEFINE_MUTEX(vhost_vsock_mutex);
static DEFINE_READ_MOSTLY_HASHTABLE(vhost_vsock_hash, 8);
struct vhost_vsock {
struct vhost_dev dev;
struct vhost_virtqueue vqs[2];
/* Link to global vhost_vsock_hash, writes use vhost_vsock_mutex */
struct hlist_node hash;
struct vhost_work send_pkt_work;
struct sk_buff_head send_pkt_queue; /* host->guest pending packets */
atomic_t queued_replies;
u32 guest_cid;
bool seqpacket_allow;
};
static u32 vhost_transport_get_local_cid(void)
{
return VHOST_VSOCK_DEFAULT_HOST_CID;
}
/* Callers must be in an RCU read section or hold the vhost_vsock_mutex.
* The return value can only be dereferenced while within the section.
*/
static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
{
struct vhost_vsock *vsock;
hash_for_each_possible_rcu(vhost_vsock_hash, vsock, hash, guest_cid,
lockdep_is_held(&vhost_vsock_mutex)) {
u32 other_cid = vsock->guest_cid;
/* Skip instances that have no CID yet */
if (other_cid == 0)
continue;
if (other_cid == guest_cid)
return vsock;
}
return NULL;
}
static void
vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
struct vhost_virtqueue *vq)
{
struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
int pkts = 0, total_len = 0;
bool added = false;
bool restart_tx = false;
mutex_lock(&vq->mutex);
if (!vhost_vq_get_backend(vq))
goto out;
if (!vq_meta_prefetch(vq))
goto out;
/* Avoid further vmexits, we're already processing the virtqueue */
vhost_disable_notify(&vsock->dev, vq);
do {
struct virtio_vsock_hdr *hdr;
size_t iov_len, payload_len;
struct iov_iter iov_iter;
u32 flags_to_restore = 0;
struct sk_buff *skb;
unsigned out, in;
size_t nbytes;
u32 offset;
int head;
skb = virtio_vsock_skb_dequeue(&vsock->send_pkt_queue);
if (!skb) {
vhost_enable_notify(&vsock->dev, vq);
break;
}
head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
&out, &in, NULL, NULL);
if (head < 0) {
virtio_vsock_skb_queue_head(&vsock->send_pkt_queue, skb);
break;
}
if (head == vq->num) {
virtio_vsock_skb_queue_head(&vsock->send_pkt_queue, skb);
/* We cannot finish yet if more buffers snuck in while
* re-enabling notify.
*/
if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
vhost_disable_notify(&vsock->dev, vq);
continue;
}
break;
}
if (out) {
kfree_skb(skb);
vq_err(vq, "Expected 0 output buffers, got %u\n", out);
break;
}
iov_len = iov_length(&vq->iov[out], in);
if (iov_len < sizeof(*hdr)) {
kfree_skb(skb);
vq_err(vq, "Buffer len [%zu] too small\n", iov_len);
break;
}
iov_iter_init(&iov_iter, ITER_DEST, &vq->iov[out], in, iov_len);
offset = VIRTIO_VSOCK_SKB_CB(skb)->offset;
payload_len = skb->len - offset;
hdr = virtio_vsock_hdr(skb);
/* If the packet is greater than the space available in the
* buffer, we split it using multiple buffers.
*/
if (payload_len > iov_len - sizeof(*hdr)) {
payload_len = iov_len - sizeof(*hdr);
/* As we are copying pieces of large packet's buffer