#include "net.h" #include "log.h" #include #include #include #include #include #include #define gai_print_errno(ec) print_error("getaddrinfo: %s\n", gai_strerror(ec)) int net_bind(const char *port) { struct addrinfo *result, *it = NULL; struct addrinfo hints; int socket_fd, ret = 0; memset(&hints, 0, sizeof(hints)); hints.ai_family = AF_UNSPEC; hints.ai_socktype = SOCK_STREAM; hints.ai_flags = AI_PASSIVE; ret = getaddrinfo(NULL, port, &hints, &result); if (ret) { gai_print_errno(ret); return -1; } for (it = result; it; it = it->ai_next) { socket_fd = socket(it->ai_family, it->ai_socktype | SOCK_CLOEXEC, it->ai_protocol); if (socket_fd < 0) { print_errno("socket"); continue; } static const int yes = 1; if (setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(yes)) == -1) { print_errno("setsockopt"); goto close_socket; } if (bind(socket_fd, it->ai_addr, it->ai_addrlen) < 0) { print_errno("bind"); goto close_socket; } break; close_socket: close(socket_fd); } freeaddrinfo(result); if (!it) { print_error("Couldn't bind to port %s\n", port); return -1; } ret = listen(socket_fd, 4096); if (ret < 0) { print_errno("listen"); goto fail; } return socket_fd; fail: close(socket_fd); return ret; } int net_accept(int fd) { int ret = 0; ret = accept4(fd, NULL, NULL, SOCK_CLOEXEC); if (ret < 0) { print_errno("accept"); return ret; } return ret; } int net_connect(const char *host, const char *port) { struct addrinfo *result, *it = NULL; struct addrinfo hints; int socket_fd, ret = 0; memset(&hints, 0, sizeof(hints)); hints.ai_family = AF_UNSPEC; hints.ai_socktype = SOCK_STREAM; ret = getaddrinfo(host, port, &hints, &result); if (ret) { gai_print_errno(ret); return -1; } for (it = result; it; it = it->ai_next) { socket_fd = socket(it->ai_family, it->ai_socktype | SOCK_CLOEXEC, it->ai_protocol); if (socket_fd < 0) { print_errno("socket"); continue; } if (connect(socket_fd, it->ai_addr, it->ai_addrlen) < 0) { print_errno("connect"); goto close_socket; } break; close_socket: close(socket_fd); } freeaddrinfo(result); if (!it) { print_error("Couldn't connect to host %s, port %s\n", host, port); return -1; } return socket_fd; } int net_send_all(int fd, const void *buf, size_t len) { size_t sent_total = 0; while (sent_total < len) { ssize_t sent_now = write(fd, (const char *)buf + sent_total, len - sent_total); if (sent_now < 0) { print_errno("write"); return -1; } sent_total += sent_now; } return 0; } ssize_t net_recv_all(int fd, void *buf, size_t len) { ssize_t read_total = 0; while ((size_t)read_total < len) { ssize_t read_now = read(fd, buf, len); if (!read_now) break; if (read_now < 0) { print_errno("read"); return -1; } read_total += read_now; } return read_total; } int net_send_buf(int fd, const void *buf, size_t len) { int ret = 0; ret = net_send_all(fd, &len, sizeof(len)); if (ret < 0) return ret; ret = net_send_all(fd, buf, len); if (ret < 0) return ret; return ret; } int net_recv_buf(int fd, void **buf, size_t *len) { ssize_t nb = 0; nb = net_recv_all(fd, len, sizeof(*len)); if (nb < 0) goto fail; if (nb != sizeof(*len)) { print_error("Couldn't read buffer length\n"); goto fail; } *buf = malloc(*len); if (!*buf) { print_errno("malloc"); goto fail; } nb = net_recv_all(fd, *buf, *len); if (nb < 0) goto free_buf; if ((size_t)nb != *len) { print_error("Couldn't read the entire buffer\n"); goto free_buf; } return 0; free_buf: free(*buf); fail: return -1; } int net_recv_static(int fd, void *buf, size_t len) { void *actual_buf; size_t actual_len; int ret = 0; ret = net_recv_buf(fd, &actual_buf, &actual_len); if (ret < 0) return ret; if (actual_len != len) { print_error("Expected message length: %lu, actual: %lu\n", len, actual_len); ret = -1; goto free_buf; } memcpy(buf, actual_buf, len); free_buf: free(actual_buf); return ret; }