// SPDX-License-Identifier: GPL-2.0-only
/*
* vsock test utilities
*
* Copyright (C) 2017 Red Hat, Inc.
*
* Author: Stefan Hajnoczi <stefanha@redhat.com>
*/
#include <errno.h>
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <signal.h>
#include <unistd.h>
#include <assert.h>
#include <sys/epoll.h>
#include <sys/mman.h>
#include "timeout.h"
#include "control.h"
#include "util.h"
/* Install signal handlers */
void init_signals(void)
{
struct sigaction act = {
.sa_handler = sigalrm,
};
sigaction(SIGALRM, &act, NULL);
signal(SIGPIPE, SIG_IGN);
}
static unsigned int parse_uint(const char *str, const char *err_str)
{
char *endptr = NULL;
unsigned long n;
errno = 0;
n = strtoul(str, &endptr, 10);
if (errno || *endptr != '\0') {
fprintf(stderr, "malformed %s \"%s\"\n", err_str, str);
exit(EXIT_FAILURE);
}
return n;
}
/* Parse a CID in string representation */
unsigned int parse_cid(const char *str)
{
return parse_uint(str, "CID");
}
/* Parse a port in string representation */
unsigned int parse_port(const char *str)
{
return parse_uint(str, "port");
}
/* Wait for the remote to close the connection */
void vsock_wait_remote_close(int fd)
{
struct epoll_event ev;
int epollfd, nfds;
epollfd = epoll_create1(0);
if (epollfd == -1) {
perror("epoll_create1");
exit(EXIT_FAILURE);
}
ev.events = EPOLLRDHUP | EPOLLHUP;
ev.data.fd = fd;
if (epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev) == -1) {
perror("epoll_ctl");
exit(EXIT_FAILURE);
}
nfds = epoll_wait(epollfd, &ev, 1, TIMEOUT * 1000);
if (nfds == -1) {
perror("epoll_wait");
exit(EXIT_FAILURE);
}
if (nfds == 0) {
fprintf(stderr, "epoll_wait timed out\n");
exit(EXIT_FAILURE);
}
assert(nfds == 1);
assert(ev.events & (EPOLLRDHUP | EPOLLHUP));
assert(ev.data.fd == fd);
close(epollfd);
}
/* Create socket <type>, bind to <cid, port> and return the file descriptor. */
int vsock_bind(unsigned int cid, unsigned int port, int type)
{
struct sockaddr_vm sa = {
.svm_family = AF_VSOCK,
.svm_cid = cid,
.svm_port = port,
};
int fd;
fd = socket(AF_VSOCK, type, 0);
if (fd < 0) {
perror("socket");
exit(EXIT_FAILURE);
}
if (bind(fd, (struct sockaddr *)&sa, sizeof(sa))) {
perror("bind");
exit(EXIT_FAILURE);
}
return fd;
}
int vsock_connect_fd(int fd, unsigned int cid, unsigned int port)
{
struct sockaddr_vm sa = {
.svm_family = AF_VSOCK,
.svm_cid = cid,
.svm_port = port,
};
int ret;
timeout_begin(TIMEOUT);
do {
ret = connect(fd, (struct sockaddr *)&sa, sizeof(sa));
timeout_check("connect");
} while (ret < 0 && errno == EINTR);
timeout_end();
return ret;
}
/* Bind to <bind_port>, connect to <cid, port> and return the file descriptor. */
int vsock_bind_connect(unsigned int cid, unsigned int port, unsigned int bind_port, int type)
{
int client_fd;
client_fd = vsock_bind(VMADDR_CID_ANY, bind_port, type);
if (vsock_connect_fd(client_fd, cid, port)) {
perror("connect");
exit(EXIT_FAILURE);
}
return client_fd;
}
/* Connect to <cid, port> and return the file descriptor. */
int vsock_connect(unsigned int cid, unsigned int port, int type)
{
int fd;
control_expectln("LISTENING");
fd = socket(AF_VSOCK, type, 0);
if (fd < 0) {
perror("socket");
exit(EXIT_FAILURE);
}
if (vsock_connect_fd(fd, cid, port)) {
int old_errno = errno;
close(fd);
fd = -1;
errno = old_errno;
}
return fd;
}
int vsock_stream_connect(unsigned int cid, unsigned int port)
{
return vsock_connect(cid, port, SOCK_STREAM);
}
int vsock_seqpacket_connect(unsigned int cid, unsigned int port)
{
return vsock_connect(cid, port, SOCK_SEQPACKET);
}
/* Listen on <cid, port> and return the file descriptor. */
static int vsock_listen(unsigned int cid, unsigned int port, int type)
{
int