From fc10c88395f3fefd3b30f3ef9354c45cef7136a6 Mon Sep 17 00:00:00 2001 From: Nikias Bassen Date: Wed, 6 Mar 2024 03:15:47 +0100 Subject: socket: Make sure errno is always set on error, and always return a meaningful error code --- src/socket.c | 102 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 97 insertions(+), 5 deletions(-) diff --git a/src/socket.c b/src/socket.c index 6f85f2f..5276b1e 100644 --- a/src/socket.c +++ b/src/socket.c @@ -150,6 +150,51 @@ enum poll_status poll_status_error }; +#ifdef WIN32 +static inline __attribute__((always_inline)) int WSAError_to_errno(int wsaerr) +{ + switch (wsaerr) { + case WSAEINVAL: + return EINVAL; + case WSAENOTSOCK: + return ENOTSOCK; + case WSAENOTCONN: + return ENOTCONN; + case WSAESHUTDOWN: + return ENOTCONN; + case WSAECONNRESET: + return ECONNRESET; + case WSAECONNABORTED: + return ECONNABORTED; + case WSAECONNREFUSED: + return ECONNREFUSED; + case WSAENETDOWN: + return ENETDOWN; + case WSAENETRESET: + return ENETRESET; + case WSAEHOSTUNREACH: + return EHOSTUNREACH; + case WSAETIMEDOUT: + return ETIMEDOUT; + case WSAEWOULDBLOCK: + return EWOULDBLOCK; + case WSAEINPROGRESS: + return EINPROGRESS; + case WSAENOBUFS: + return ENOBUFS; + case WSAEINTR: + return EINTR; + case WSAEACCES: + return EACCES; + case WSAEFAULT: + return EFAULT; + default: + break; + } + return wsaerr; +} +#endif + // timeout of -1 means infinity static inline __attribute__((always_inline)) enum poll_status poll_wrapper(int fd, fd_mode mode, int timeout) { @@ -387,8 +432,17 @@ int socket_connect_unix(const char *filename) socklen_t len = sizeof(so_error); getsockopt(sfd, SOL_SOCKET, SO_ERROR, (void*)&so_error, &len); if (so_error == 0) { + errno = 0; break; } + errno = so_error; + } else { + int so_error = 0; + socklen_t len = sizeof(so_error); + getsockopt(sfd, SOL_SOCKET, SO_ERROR, (void*)&so_error, &len); + if (so_error != 0) { + errno = so_error; + } } } socket_close(sfd); @@ -1064,7 +1118,20 @@ int socket_connect_addr(struct sockaddr* addr, uint16_t port) errno = 0; break; } +#ifdef WIN32 + so_error = WSAError_to_errno(so_error); +#endif errno = so_error; + } else { + int so_error = 0; + socklen_t len = sizeof(so_error); + getsockopt(sfd, SOL_SOCKET, SO_ERROR, (void*)&so_error, &len); + if (so_error != 0) { +#ifdef WIN32 + so_error = WSAError_to_errno(so_error); +#endif + errno = so_error; + } } } socket_close(sfd); @@ -1173,8 +1240,23 @@ int socket_connect(const char *addr, uint16_t port) socklen_t len = sizeof(so_error); getsockopt(sfd, SOL_SOCKET, SO_ERROR, (void*)&so_error, &len); if (so_error == 0) { + errno = 0; break; } +#ifdef WIN32 + so_error = WSAError_to_errno(so_error); +#endif + errno = so_error; + } else { + int so_error = 0; + socklen_t len = sizeof(so_error); + getsockopt(sfd, SOL_SOCKET, SO_ERROR, (void*)&so_error, &len); + if (so_error != 0) { +#ifdef WIN32 + so_error = WSAError_to_errno(so_error); +#endif + errno = so_error; + } } } socket_close(sfd); @@ -1208,7 +1290,7 @@ int socket_check_fd(int fd, fd_mode fdm, unsigned int timeout) if (fd < 0) { if (verbose >= 2) fprintf(stderr, "ERROR: invalid fd in check_fd %d\n", fd); - return -1; + return -EINVAL; } int timeout_ms; @@ -1229,10 +1311,10 @@ int socket_check_fd(int fd, fd_mode fdm, unsigned int timeout) default: if (verbose >= 2) fprintf(stderr, "%s: poll_wrapper failed\n", __func__); - return -1; + return -ECONNRESET; } - return -1; + return -ECONNRESET; } int socket_accept(int fd, uint16_t port) @@ -1286,13 +1368,16 @@ int socket_receive_timeout(int fd, void *data, size_t length, int flags, unsigne } // if we get here, there _is_ data available result = recv(fd, data, length, flags); - if (res > 0 && result == 0) { + if (result == 0) { // but this is an error condition if (verbose >= 3) fprintf(stderr, "%s: fd=%d recv returned 0\n", __func__, fd); return -ECONNRESET; } if (result < 0) { +#ifdef WIN32 + errno = WSAError_to_errno(WSAGetLastError()); +#endif return -errno; } return result; @@ -1308,7 +1393,14 @@ int socket_send(int fd, void *data, size_t length) #ifdef MSG_NOSIGNAL flags |= MSG_NOSIGNAL; #endif - return send(fd, data, length, flags); + int s = (int)send(fd, data, length, flags); + if (s < 0) { +#ifdef WIN32 + errno = WSAError_to_errno(WSAGetLastError()); +#endif + return -errno; + } + return s; } int socket_get_socket_port(int fd, uint16_t *port) -- cgit v1.1-32-gdbae