src/util/socket.h (view raw)
1/* Copyright (c) 2013-2014 Jeffrey Pfau
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6#ifndef SOCKET_H
7#define SOCKET_H
8
9#include "util/common.h"
10
11#ifdef __cplusplus
12#define restrict __restrict__
13#endif
14
15#ifdef _WIN32
16#include <winsock2.h>
17#include <ws2tcpip.h>
18
19#define SOCKET_FAILED(s) ((s) == INVALID_SOCKET)
20typedef SOCKET Socket;
21#else
22#include <errno.h>
23#include <fcntl.h>
24#include <netinet/in.h>
25#include <netinet/tcp.h>
26#include <sys/select.h>
27#include <sys/socket.h>
28
29#define INVALID_SOCKET (-1)
30#define SOCKET_FAILED(s) ((s) < 0)
31typedef int Socket;
32#endif
33
34enum IP {
35 IPV4,
36 IPV6
37};
38
39struct Address {
40 enum IP version;
41 union {
42 uint32_t ipv4;
43 uint8_t ipv6[16];
44 };
45};
46
47static inline void SocketSubsystemInit() {
48#ifdef _WIN32
49 WSADATA data;
50 WSAStartup(MAKEWORD(2, 2), &data);
51#endif
52}
53
54static inline int SocketError() {
55#ifdef _WIN32
56 return WSAGetLastError();
57#else
58 return errno;
59#endif
60}
61
62static inline bool SocketWouldBlock() {
63#ifdef _WIN32
64 return SocketError() == WSAEWOULDBLOCK;
65#else
66 return SocketError() == EWOULDBLOCK || SocketError() == EAGAIN;
67#endif
68}
69
70static inline ssize_t SocketSend(Socket socket, const void* buffer, size_t size) {
71#ifdef _WIN32
72 return send(socket, (const char*) buffer, size, 0);
73#else
74 return write(socket, buffer, size);
75#endif
76}
77
78static inline ssize_t SocketRecv(Socket socket, void* buffer, size_t size) {
79#ifdef _WIN32
80 return recv(socket, (char*) buffer, size, 0);
81#else
82 return read(socket, buffer, size);
83#endif
84}
85
86static inline int SocketClose(Socket socket) {
87#ifdef _WIN32
88 return closesocket(socket) == 0;
89#else
90 return close(socket) >= 0;
91#endif
92}
93
94static inline Socket SocketOpenTCP(int port, const struct Address* bindAddress) {
95 Socket sock = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
96 if (SOCKET_FAILED(sock)) {
97 return sock;
98 }
99
100 int err;
101 if (!bindAddress) {
102 struct sockaddr_in bindInfo;
103 memset(&bindInfo, 0, sizeof(bindInfo));
104 bindInfo.sin_family = AF_INET;
105 bindInfo.sin_port = htons(port);
106 err = bind(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
107 } else if (bindAddress->version == IPV4) {
108 struct sockaddr_in bindInfo;
109 memset(&bindInfo, 0, sizeof(bindInfo));
110 bindInfo.sin_family = AF_INET;
111 bindInfo.sin_port = htons(port);
112 bindInfo.sin_addr.s_addr = bindAddress->ipv4;
113 err = bind(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
114 } else {
115 struct sockaddr_in6 bindInfo;
116 memset(&bindInfo, 0, sizeof(bindInfo));
117 bindInfo.sin6_family = AF_INET6;
118 bindInfo.sin6_port = htons(port);
119 memcpy(bindInfo.sin6_addr.s6_addr, bindAddress->ipv6, sizeof(bindInfo.sin6_addr.s6_addr));
120 err = bind(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
121 }
122 if (err) {
123 SocketClose(sock);
124 return INVALID_SOCKET;
125 }
126 return sock;
127}
128
129static inline Socket SocketConnectTCP(int port, const struct Address* destinationAddress) {
130 Socket sock = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
131 if (SOCKET_FAILED(sock)) {
132 return sock;
133 }
134
135 int err;
136 if (!destinationAddress) {
137 struct sockaddr_in bindInfo;
138 memset(&bindInfo, 0, sizeof(bindInfo));
139 bindInfo.sin_family = AF_INET;
140 bindInfo.sin_port = htons(port);
141 err = connect(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
142 } else if (destinationAddress->version == IPV4) {
143 struct sockaddr_in bindInfo;
144 memset(&bindInfo, 0, sizeof(bindInfo));
145 bindInfo.sin_family = AF_INET;
146 bindInfo.sin_port = htons(port);
147 bindInfo.sin_addr.s_addr = destinationAddress->ipv4;
148 err = connect(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
149 } else {
150 struct sockaddr_in6 bindInfo;
151 memset(&bindInfo, 0, sizeof(bindInfo));
152 bindInfo.sin6_family = AF_INET6;
153 bindInfo.sin6_port = htons(port);
154 memcpy(bindInfo.sin6_addr.s6_addr, destinationAddress->ipv6, sizeof(bindInfo.sin6_addr.s6_addr));
155 err = connect(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
156 }
157
158 if (err) {
159 SocketClose(sock);
160 return INVALID_SOCKET;
161 }
162 return sock;
163}
164
165static inline Socket SocketListen(Socket socket, int queueLength) {
166 return listen(socket, queueLength);
167}
168
169static inline Socket SocketAccept(Socket socket, struct Address* address) {
170 if (!address) {
171 return accept(socket, 0, 0);
172 }
173 if (address->version == IPV4) {
174 struct sockaddr_in addrInfo;
175 memset(&addrInfo, 0, sizeof(addrInfo));
176 addrInfo.sin_family = AF_INET;
177 addrInfo.sin_addr.s_addr = address->ipv4;
178 socklen_t len = sizeof(addrInfo);
179 return accept(socket, (struct sockaddr*) &addrInfo, &len);
180 } else {
181 struct sockaddr_in6 addrInfo;
182 memset(&addrInfo, 0, sizeof(addrInfo));
183 addrInfo.sin6_family = AF_INET6;
184 memcpy(addrInfo.sin6_addr.s6_addr, address->ipv6, sizeof(addrInfo.sin6_addr.s6_addr));
185 socklen_t len = sizeof(addrInfo);
186 return accept(socket, (struct sockaddr*) &addrInfo, &len);
187 }
188}
189
190static inline int SocketSetBlocking(Socket socket, bool blocking) {
191#ifdef _WIN32
192 u_long unblocking = !blocking;
193 return ioctlsocket(socket, FIONBIO, &unblocking) == NO_ERROR;
194#else
195 int flags = fcntl(socket, F_GETFL);
196 if (flags == -1) {
197 return 0;
198 }
199 if (blocking) {
200 flags &= ~O_NONBLOCK;
201 } else {
202 flags |= O_NONBLOCK;
203 }
204 return fcntl(socket, F_SETFL, flags) >= 0;
205#endif
206}
207
208static inline int SocketSetTCPPush(Socket socket, int push) {
209 return setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (char*) &push, sizeof(int)) >= 0;
210}
211
212static inline int SocketPoll(size_t nSockets, Socket* reads, Socket* writes, Socket* errors, int64_t timeoutMillis) {
213 fd_set rset;
214 fd_set wset;
215 fd_set eset;
216 FD_ZERO(&rset);
217 FD_ZERO(&wset);
218 FD_ZERO(&eset);
219 size_t i;
220 Socket maxFd = 0;
221 if (reads) {
222 for (i = 0; i < nSockets; ++i) {
223 if (SOCKET_FAILED(reads[i])) {
224 break;
225 }
226 if (reads[i] > maxFd) {
227 maxFd = reads[i];
228 }
229 FD_SET(reads[i], &rset);
230 reads[i] = INVALID_SOCKET;
231 }
232 }
233 if (writes) {
234 for (i = 0; i < nSockets; ++i) {
235 if (SOCKET_FAILED(writes[i])) {
236 break;
237 }
238 if (writes[i] > maxFd) {
239 maxFd = writes[i];
240 }
241 FD_SET(writes[i], &wset);
242 writes[i] = INVALID_SOCKET;
243 }
244 }
245 if (errors) {
246 for (i = 0; i < nSockets; ++i) {
247 if (SOCKET_FAILED(errors[i])) {
248 break;
249 }
250 if (errors[i] > maxFd) {
251 maxFd = errors[i];
252 }
253 FD_SET(errors[i], &eset);
254 errors[i] = INVALID_SOCKET;
255 }
256 }
257 struct timeval tv;
258 tv.tv_sec = timeoutMillis / 1000;
259 tv.tv_usec = (timeoutMillis % 1000) * 1000;
260 int result = select(maxFd + 1, &rset, &wset, &eset, timeoutMillis < 0 ? 0 : &tv);
261 int r = 0;
262 int w = 0;
263 int e = 0;
264 Socket j;
265 for (j = 0; j < maxFd; ++j) {
266 if (reads && FD_ISSET(j, &rset)) {
267 reads[r] = j;
268 ++r;
269 }
270 if (writes && FD_ISSET(j, &wset)) {
271 writes[w] = j;
272 ++w;
273 }
274 if (errors && FD_ISSET(j, &eset)) {
275 errors[e] = j;
276 ++e;
277 }
278 }
279 return result;
280}
281
282#endif