src/util/socket.h (view raw)
1/* Copyright (c) 2013-2014 Jeffrey Pfau
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6#ifndef SOCKET_H
7#define SOCKET_H
8
9#include "util/common.h"
10
11#ifdef __cplusplus
12#define restrict __restrict__
13#endif
14
15#ifdef _WIN32
16#include <winsock2.h>
17#include <ws2tcpip.h>
18
19#define SOCKET_FAILED(s) ((s) == INVALID_SOCKET)
20typedef SOCKET Socket;
21#else
22#include <arpa/inet.h>
23#include <errno.h>
24#include <fcntl.h>
25#include <netinet/in.h>
26#include <netinet/tcp.h>
27#include <sys/select.h>
28#include <sys/socket.h>
29
30#define INVALID_SOCKET (-1)
31#define SOCKET_FAILED(s) ((s) < 0)
32typedef int Socket;
33#endif
34
35enum IP {
36 IPV4,
37 IPV6
38};
39
40struct Address {
41 enum IP version;
42 union {
43 uint32_t ipv4;
44 uint8_t ipv6[16];
45 };
46};
47
48#ifdef _3DS
49#include <3ds.h>
50#include <malloc.h>
51
52#define SOCU_ALIGN 0x1000
53#define SOCU_BUFFERSIZE 0x100000
54
55extern u32* SOCUBuffer;
56#endif
57
58static inline void SocketSubsystemInit() {
59#ifdef _WIN32
60 WSADATA data;
61 WSAStartup(MAKEWORD(2, 2), &data);
62#elif defined(_3DS)
63 if (!SOCUBuffer) {
64 SOCUBuffer = memalign(SOCU_ALIGN, SOCU_BUFFERSIZE);
65 socInit(SOCUBuffer, SOCU_BUFFERSIZE);
66 }
67#endif
68}
69
70static inline void SocketSubsystemDeinit() {
71#ifdef _WIN32
72 WSACleanup();
73#elif defined(_3DS)
74 socExit();
75 free(SOCUBuffer);
76 SOCUBuffer = NULL;
77#endif
78}
79
80static inline int SocketError() {
81#ifdef _WIN32
82 return WSAGetLastError();
83#else
84 return errno;
85#endif
86}
87
88static inline bool SocketWouldBlock() {
89#ifdef _WIN32
90 return SocketError() == WSAEWOULDBLOCK;
91#else
92 return SocketError() == EWOULDBLOCK || SocketError() == EAGAIN;
93#endif
94}
95
96static inline ssize_t SocketSend(Socket socket, const void* buffer, size_t size) {
97#ifdef _WIN32
98 return send(socket, (const char*) buffer, size, 0);
99#else
100 return write(socket, buffer, size);
101#endif
102}
103
104static inline ssize_t SocketRecv(Socket socket, void* buffer, size_t size) {
105#ifdef _WIN32
106 return recv(socket, (char*) buffer, size, 0);
107#else
108 return read(socket, buffer, size);
109#endif
110}
111
112static inline int SocketClose(Socket socket) {
113#ifdef _WIN32
114 return closesocket(socket) == 0;
115#else
116 return close(socket) >= 0;
117#endif
118}
119
120static inline Socket SocketOpenTCP(int port, const struct Address* bindAddress) {
121 Socket sock = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
122 if (SOCKET_FAILED(sock)) {
123 return sock;
124 }
125
126 int err;
127 if (!bindAddress) {
128 struct sockaddr_in bindInfo;
129 memset(&bindInfo, 0, sizeof(bindInfo));
130 bindInfo.sin_family = AF_INET;
131 bindInfo.sin_port = htons(port);
132#ifndef _3DS
133 bindInfo.sin_addr.s_addr = INADDR_ANY;
134#else
135 bindInfo.sin_addr.s_addr = gethostid();
136#endif
137 err = bind(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
138 } else if (bindAddress->version == IPV4) {
139 struct sockaddr_in bindInfo;
140 memset(&bindInfo, 0, sizeof(bindInfo));
141 bindInfo.sin_family = AF_INET;
142 bindInfo.sin_port = htons(port);
143 bindInfo.sin_addr.s_addr = bindAddress->ipv4;
144 err = bind(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
145#ifndef _3DS
146 } else {
147 struct sockaddr_in6 bindInfo;
148 memset(&bindInfo, 0, sizeof(bindInfo));
149 bindInfo.sin6_family = AF_INET6;
150 bindInfo.sin6_port = htons(port);
151 memcpy(bindInfo.sin6_addr.s6_addr, bindAddress->ipv6, sizeof(bindInfo.sin6_addr.s6_addr));
152 err = bind(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
153#endif
154 }
155 if (err) {
156 SocketClose(sock);
157 return INVALID_SOCKET;
158 }
159 return sock;
160}
161
162static inline Socket SocketConnectTCP(int port, const struct Address* destinationAddress) {
163 Socket sock = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
164 if (SOCKET_FAILED(sock)) {
165 return sock;
166 }
167
168 int err;
169 if (!destinationAddress) {
170 struct sockaddr_in bindInfo;
171 memset(&bindInfo, 0, sizeof(bindInfo));
172 bindInfo.sin_family = AF_INET;
173 bindInfo.sin_port = htons(port);
174 err = connect(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
175 } else if (destinationAddress->version == IPV4) {
176 struct sockaddr_in bindInfo;
177 memset(&bindInfo, 0, sizeof(bindInfo));
178 bindInfo.sin_family = AF_INET;
179 bindInfo.sin_port = htons(port);
180 bindInfo.sin_addr.s_addr = destinationAddress->ipv4;
181 err = connect(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
182#ifndef _3DS
183 } else {
184 struct sockaddr_in6 bindInfo;
185 memset(&bindInfo, 0, sizeof(bindInfo));
186 bindInfo.sin6_family = AF_INET6;
187 bindInfo.sin6_port = htons(port);
188 memcpy(bindInfo.sin6_addr.s6_addr, destinationAddress->ipv6, sizeof(bindInfo.sin6_addr.s6_addr));
189 err = connect(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
190#endif
191 }
192
193 if (err) {
194 SocketClose(sock);
195 return INVALID_SOCKET;
196 }
197 return sock;
198}
199
200static inline Socket SocketListen(Socket socket, int queueLength) {
201 return listen(socket, queueLength);
202}
203
204static inline Socket SocketAccept(Socket socket, struct Address* address) {
205 if (!address) {
206 return accept(socket, 0, 0);
207 }
208 if (address->version == IPV4) {
209 struct sockaddr_in addrInfo;
210 memset(&addrInfo, 0, sizeof(addrInfo));
211 addrInfo.sin_family = AF_INET;
212 addrInfo.sin_addr.s_addr = address->ipv4;
213 socklen_t len = sizeof(addrInfo);
214 return accept(socket, (struct sockaddr*) &addrInfo, &len);
215#ifndef _3DS
216 } else {
217 struct sockaddr_in6 addrInfo;
218 memset(&addrInfo, 0, sizeof(addrInfo));
219 addrInfo.sin6_family = AF_INET6;
220 memcpy(addrInfo.sin6_addr.s6_addr, address->ipv6, sizeof(addrInfo.sin6_addr.s6_addr));
221 socklen_t len = sizeof(addrInfo);
222 return accept(socket, (struct sockaddr*) &addrInfo, &len);
223#endif
224 }
225 return INVALID_SOCKET;
226}
227
228static inline int SocketSetBlocking(Socket socket, bool blocking) {
229#ifdef _WIN32
230 u_long unblocking = !blocking;
231 return ioctlsocket(socket, FIONBIO, &unblocking) == NO_ERROR;
232#else
233 int flags = fcntl(socket, F_GETFL);
234 if (flags == -1) {
235 return 0;
236 }
237 if (blocking) {
238 flags &= ~O_NONBLOCK;
239 } else {
240 flags |= O_NONBLOCK;
241 }
242 return fcntl(socket, F_SETFL, flags) >= 0;
243#endif
244}
245
246static inline int SocketSetTCPPush(Socket socket, int push) {
247 return setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (char*) &push, sizeof(int)) >= 0;
248}
249
250static inline int SocketPoll(size_t nSockets, Socket* reads, Socket* writes, Socket* errors, int64_t timeoutMillis) {
251 fd_set rset;
252 fd_set wset;
253 fd_set eset;
254 FD_ZERO(&rset);
255 FD_ZERO(&wset);
256 FD_ZERO(&eset);
257 size_t i;
258 Socket maxFd = 0;
259 if (reads) {
260 for (i = 0; i < nSockets; ++i) {
261 if (SOCKET_FAILED(reads[i])) {
262 break;
263 }
264 if (reads[i] > maxFd) {
265 maxFd = reads[i];
266 }
267 FD_SET(reads[i], &rset);
268 reads[i] = INVALID_SOCKET;
269 }
270 }
271 if (writes) {
272 for (i = 0; i < nSockets; ++i) {
273 if (SOCKET_FAILED(writes[i])) {
274 break;
275 }
276 if (writes[i] > maxFd) {
277 maxFd = writes[i];
278 }
279 FD_SET(writes[i], &wset);
280 writes[i] = INVALID_SOCKET;
281 }
282 }
283 if (errors) {
284 for (i = 0; i < nSockets; ++i) {
285 if (SOCKET_FAILED(errors[i])) {
286 break;
287 }
288 if (errors[i] > maxFd) {
289 maxFd = errors[i];
290 }
291 FD_SET(errors[i], &eset);
292 errors[i] = INVALID_SOCKET;
293 }
294 }
295 struct timeval tv;
296 tv.tv_sec = timeoutMillis / 1000;
297 tv.tv_usec = (timeoutMillis % 1000) * 1000;
298 int result = select(maxFd + 1, &rset, &wset, &eset, timeoutMillis < 0 ? 0 : &tv);
299 int r = 0;
300 int w = 0;
301 int e = 0;
302 Socket j;
303 for (j = 0; j < maxFd; ++j) {
304 if (reads && FD_ISSET(j, &rset)) {
305 reads[r] = j;
306 ++r;
307 }
308 if (writes && FD_ISSET(j, &wset)) {
309 writes[w] = j;
310 ++w;
311 }
312 if (errors && FD_ISSET(j, &eset)) {
313 errors[e] = j;
314 ++e;
315 }
316 }
317 return result;
318}
319
320#endif