src/util/socket.h (view raw)
1/* Copyright (c) 2013-2014 Jeffrey Pfau
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6#ifndef SOCKET_H
7#define SOCKET_H
8
9#include "util/common.h"
10
11#if defined(__cplusplus) && !defined(restrict)
12#define restrict __restrict__
13#endif
14
15#ifdef _WIN32
16#include <ws2tcpip.h>
17
18#define SOCKET_FAILED(s) ((s) == INVALID_SOCKET)
19typedef SOCKET Socket;
20#else
21#include <arpa/inet.h>
22#include <errno.h>
23#include <fcntl.h>
24#include <netinet/in.h>
25#include <netinet/tcp.h>
26#include <sys/select.h>
27#include <sys/socket.h>
28
29#define INVALID_SOCKET (-1)
30#define SOCKET_FAILED(s) ((s) < 0)
31typedef int Socket;
32#endif
33
34enum IP {
35 IPV4,
36 IPV6
37};
38
39struct Address {
40 enum IP version;
41 union {
42 uint32_t ipv4;
43 uint8_t ipv6[16];
44 };
45};
46
47#ifdef _3DS
48#include <3ds.h>
49#include <malloc.h>
50
51#define SOCU_ALIGN 0x1000
52#define SOCU_BUFFERSIZE 0x100000
53
54extern u32* SOCUBuffer;
55#endif
56
57static inline void SocketSubsystemInit() {
58#ifdef _WIN32
59 WSADATA data;
60 WSAStartup(MAKEWORD(2, 2), &data);
61#elif defined(_3DS)
62 if (!SOCUBuffer) {
63 SOCUBuffer = memalign(SOCU_ALIGN, SOCU_BUFFERSIZE);
64 socInit(SOCUBuffer, SOCU_BUFFERSIZE);
65 }
66#endif
67}
68
69static inline void SocketSubsystemDeinit() {
70#ifdef _WIN32
71 WSACleanup();
72#elif defined(_3DS)
73 socExit();
74 free(SOCUBuffer);
75 SOCUBuffer = NULL;
76#endif
77}
78
79static inline int SocketError() {
80#ifdef _WIN32
81 return WSAGetLastError();
82#else
83 return errno;
84#endif
85}
86
87static inline bool SocketWouldBlock() {
88#ifdef _WIN32
89 return SocketError() == WSAEWOULDBLOCK;
90#else
91 return SocketError() == EWOULDBLOCK || SocketError() == EAGAIN;
92#endif
93}
94
95static inline ssize_t SocketSend(Socket socket, const void* buffer, size_t size) {
96#ifdef _WIN32
97 return send(socket, (const char*) buffer, size, 0);
98#else
99 return write(socket, buffer, size);
100#endif
101}
102
103static inline ssize_t SocketRecv(Socket socket, void* buffer, size_t size) {
104#ifdef _WIN32
105 return recv(socket, (char*) buffer, size, 0);
106#else
107 return read(socket, buffer, size);
108#endif
109}
110
111static inline int SocketClose(Socket socket) {
112#ifdef _WIN32
113 return closesocket(socket) == 0;
114#else
115 return close(socket) >= 0;
116#endif
117}
118
119static inline Socket SocketOpenTCP(int port, const struct Address* bindAddress) {
120 Socket sock = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
121 if (SOCKET_FAILED(sock)) {
122 return sock;
123 }
124
125 int err;
126 if (!bindAddress) {
127 struct sockaddr_in bindInfo;
128 memset(&bindInfo, 0, sizeof(bindInfo));
129 bindInfo.sin_family = AF_INET;
130 bindInfo.sin_port = htons(port);
131#ifndef _3DS
132 bindInfo.sin_addr.s_addr = INADDR_ANY;
133#else
134 bindInfo.sin_addr.s_addr = gethostid();
135#endif
136 err = bind(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
137 } else if (bindAddress->version == IPV4) {
138 struct sockaddr_in bindInfo;
139 memset(&bindInfo, 0, sizeof(bindInfo));
140 bindInfo.sin_family = AF_INET;
141 bindInfo.sin_port = htons(port);
142 bindInfo.sin_addr.s_addr = bindAddress->ipv4;
143 err = bind(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
144#ifndef _3DS
145 } else {
146 struct sockaddr_in6 bindInfo;
147 memset(&bindInfo, 0, sizeof(bindInfo));
148 bindInfo.sin6_family = AF_INET6;
149 bindInfo.sin6_port = htons(port);
150 memcpy(bindInfo.sin6_addr.s6_addr, bindAddress->ipv6, sizeof(bindInfo.sin6_addr.s6_addr));
151 err = bind(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
152#endif
153 }
154 if (err) {
155 SocketClose(sock);
156 return INVALID_SOCKET;
157 }
158 return sock;
159}
160
161static inline Socket SocketConnectTCP(int port, const struct Address* destinationAddress) {
162 Socket sock = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
163 if (SOCKET_FAILED(sock)) {
164 return sock;
165 }
166
167 int err;
168 if (!destinationAddress) {
169 struct sockaddr_in bindInfo;
170 memset(&bindInfo, 0, sizeof(bindInfo));
171 bindInfo.sin_family = AF_INET;
172 bindInfo.sin_port = htons(port);
173 err = connect(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
174 } else if (destinationAddress->version == IPV4) {
175 struct sockaddr_in bindInfo;
176 memset(&bindInfo, 0, sizeof(bindInfo));
177 bindInfo.sin_family = AF_INET;
178 bindInfo.sin_port = htons(port);
179 bindInfo.sin_addr.s_addr = destinationAddress->ipv4;
180 err = connect(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
181#ifndef _3DS
182 } else {
183 struct sockaddr_in6 bindInfo;
184 memset(&bindInfo, 0, sizeof(bindInfo));
185 bindInfo.sin6_family = AF_INET6;
186 bindInfo.sin6_port = htons(port);
187 memcpy(bindInfo.sin6_addr.s6_addr, destinationAddress->ipv6, sizeof(bindInfo.sin6_addr.s6_addr));
188 err = connect(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
189#endif
190 }
191
192 if (err) {
193 SocketClose(sock);
194 return INVALID_SOCKET;
195 }
196 return sock;
197}
198
199static inline Socket SocketListen(Socket socket, int queueLength) {
200 return listen(socket, queueLength);
201}
202
203static inline Socket SocketAccept(Socket socket, struct Address* address) {
204 if (!address) {
205 return accept(socket, 0, 0);
206 }
207 if (address->version == IPV4) {
208 struct sockaddr_in addrInfo;
209 memset(&addrInfo, 0, sizeof(addrInfo));
210 addrInfo.sin_family = AF_INET;
211 addrInfo.sin_addr.s_addr = address->ipv4;
212 socklen_t len = sizeof(addrInfo);
213 return accept(socket, (struct sockaddr*) &addrInfo, &len);
214#ifndef _3DS
215 } else {
216 struct sockaddr_in6 addrInfo;
217 memset(&addrInfo, 0, sizeof(addrInfo));
218 addrInfo.sin6_family = AF_INET6;
219 memcpy(addrInfo.sin6_addr.s6_addr, address->ipv6, sizeof(addrInfo.sin6_addr.s6_addr));
220 socklen_t len = sizeof(addrInfo);
221 return accept(socket, (struct sockaddr*) &addrInfo, &len);
222#endif
223 }
224 return INVALID_SOCKET;
225}
226
227static inline int SocketSetBlocking(Socket socket, bool blocking) {
228#ifdef _WIN32
229 u_long unblocking = !blocking;
230 return ioctlsocket(socket, FIONBIO, &unblocking) == NO_ERROR;
231#else
232 int flags = fcntl(socket, F_GETFL);
233 if (flags == -1) {
234 return 0;
235 }
236 if (blocking) {
237 flags &= ~O_NONBLOCK;
238 } else {
239 flags |= O_NONBLOCK;
240 }
241 return fcntl(socket, F_SETFL, flags) >= 0;
242#endif
243}
244
245static inline int SocketSetTCPPush(Socket socket, int push) {
246 return setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (char*) &push, sizeof(int)) >= 0;
247}
248
249static inline int SocketPoll(size_t nSockets, Socket* reads, Socket* writes, Socket* errors, int64_t timeoutMillis) {
250 fd_set rset;
251 fd_set wset;
252 fd_set eset;
253 FD_ZERO(&rset);
254 FD_ZERO(&wset);
255 FD_ZERO(&eset);
256 size_t i;
257 Socket maxFd = 0;
258 if (reads) {
259 for (i = 0; i < nSockets; ++i) {
260 if (SOCKET_FAILED(reads[i])) {
261 break;
262 }
263 if (reads[i] > maxFd) {
264 maxFd = reads[i];
265 }
266 FD_SET(reads[i], &rset);
267 reads[i] = INVALID_SOCKET;
268 }
269 }
270 if (writes) {
271 for (i = 0; i < nSockets; ++i) {
272 if (SOCKET_FAILED(writes[i])) {
273 break;
274 }
275 if (writes[i] > maxFd) {
276 maxFd = writes[i];
277 }
278 FD_SET(writes[i], &wset);
279 writes[i] = INVALID_SOCKET;
280 }
281 }
282 if (errors) {
283 for (i = 0; i < nSockets; ++i) {
284 if (SOCKET_FAILED(errors[i])) {
285 break;
286 }
287 if (errors[i] > maxFd) {
288 maxFd = errors[i];
289 }
290 FD_SET(errors[i], &eset);
291 errors[i] = INVALID_SOCKET;
292 }
293 }
294 struct timeval tv;
295 tv.tv_sec = timeoutMillis / 1000;
296 tv.tv_usec = (timeoutMillis % 1000) * 1000;
297 int result = select(maxFd + 1, &rset, &wset, &eset, timeoutMillis < 0 ? 0 : &tv);
298 int r = 0;
299 int w = 0;
300 int e = 0;
301 Socket j;
302 for (j = 0; j < maxFd; ++j) {
303 if (reads && FD_ISSET(j, &rset)) {
304 reads[r] = j;
305 ++r;
306 }
307 if (writes && FD_ISSET(j, &wset)) {
308 writes[w] = j;
309 ++w;
310 }
311 if (errors && FD_ISSET(j, &eset)) {
312 errors[e] = j;
313 ++e;
314 }
315 }
316 return result;
317}
318
319#endif