src/util/socket.h (view raw)
1/* Copyright (c) 2013-2014 Jeffrey Pfau
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6#ifndef SOCKET_H
7#define SOCKET_H
8
9#include "util/common.h"
10
11#ifdef __cplusplus
12#define restrict __restrict__
13#endif
14
15#ifdef _WIN32
16#include <winsock2.h>
17#include <ws2tcpip.h>
18
19#define SOCKET_FAILED(s) ((s) == INVALID_SOCKET)
20typedef SOCKET Socket;
21#else
22#include <fcntl.h>
23#include <netinet/in.h>
24#include <netinet/tcp.h>
25#include <sys/socket.h>
26
27#define INVALID_SOCKET (-1)
28#define SOCKET_FAILED(s) ((s) < 0)
29typedef int Socket;
30#endif
31
32enum IP {
33 IPV4,
34 IPV6
35};
36
37struct Address {
38 enum IP version;
39 union {
40 uint32_t ipv4;
41 uint8_t ipv6[16];
42 };
43};
44
45static inline void SocketSubsystemInitialize() {
46#ifdef _WIN32
47 WSAStartup(MAKEWORD(2, 2), 0);
48#endif
49}
50
51static inline ssize_t SocketSend(Socket socket, const void* buffer, size_t size) {
52 return write(socket, buffer, size);
53}
54
55static inline ssize_t SocketRecv(Socket socket, void* buffer, size_t size) {
56 return read(socket, buffer, size);
57}
58
59static inline Socket SocketOpenTCP(int port, const struct Address* bindAddress) {
60 Socket sock = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
61 if (SOCKET_FAILED(sock)) {
62 return sock;
63 }
64
65 int err;
66 if (!bindAddress) {
67 struct sockaddr_in bindInfo;
68 memset(&bindInfo, 0, sizeof(bindInfo));
69 bindInfo.sin_family = AF_INET;
70 bindInfo.sin_port = htons(port);
71 err = bind(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
72 } else if (bindAddress->version == IPV4) {
73 struct sockaddr_in bindInfo;
74 memset(&bindInfo, 0, sizeof(bindInfo));
75 bindInfo.sin_family = AF_INET;
76 bindInfo.sin_port = htons(port);
77 bindInfo.sin_addr.s_addr = bindAddress->ipv4;
78 err = bind(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
79 } else {
80 struct sockaddr_in6 bindInfo;
81 memset(&bindInfo, 0, sizeof(bindInfo));
82 bindInfo.sin6_family = AF_INET6;
83 bindInfo.sin6_port = htons(port);
84 memcpy(bindInfo.sin6_addr.s6_addr, bindAddress->ipv6, sizeof(bindInfo.sin6_addr.s6_addr));
85 err = bind(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
86
87 }
88 if (err) {
89 close(sock);
90 return -1;
91 }
92 return sock;
93}
94
95static inline Socket SocketConnectTCP(int port, const struct Address* destinationAddress) {
96 Socket sock = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
97 if (SOCKET_FAILED(sock)) {
98 return sock;
99 }
100
101 int err;
102 if (!destinationAddress) {
103 struct sockaddr_in bindInfo;
104 memset(&bindInfo, 0, sizeof(bindInfo));
105 bindInfo.sin_family = AF_INET;
106 bindInfo.sin_port = htons(port);
107 err = connect(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
108 } else if (destinationAddress->version == IPV4) {
109 struct sockaddr_in bindInfo;
110 memset(&bindInfo, 0, sizeof(bindInfo));
111 bindInfo.sin_family = AF_INET;
112 bindInfo.sin_port = htons(port);
113 bindInfo.sin_addr.s_addr = destinationAddress->ipv4;
114 err = connect(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
115 } else {
116 struct sockaddr_in6 bindInfo;
117 memset(&bindInfo, 0, sizeof(bindInfo));
118 bindInfo.sin6_family = AF_INET6;
119 bindInfo.sin6_port = htons(port);
120 memcpy(bindInfo.sin6_addr.s6_addr, destinationAddress->ipv6, sizeof(bindInfo.sin6_addr.s6_addr));
121 err = connect(sock, (const struct sockaddr*) &bindInfo, sizeof(bindInfo));
122 }
123
124 if (err) {
125 close(sock);
126 return -1;
127 }
128 return sock;
129}
130
131static inline Socket SocketListen(Socket socket, int queueLength) {
132 return listen(socket, queueLength);
133}
134
135static inline Socket SocketAccept(Socket socket, struct Address* address) {
136 if (!address) {
137 return accept(socket, 0, 0);
138 }
139 if (address->version == IPV4) {
140 struct sockaddr_in addrInfo;
141 memset(&addrInfo, 0, sizeof(addrInfo));
142 addrInfo.sin_family = AF_INET;
143 addrInfo.sin_addr.s_addr = address->ipv4;
144 socklen_t len = sizeof(addrInfo);
145 return accept(socket, (struct sockaddr*) &addrInfo, &len);
146 } else {
147 struct sockaddr_in6 addrInfo;
148 memset(&addrInfo, 0, sizeof(addrInfo));
149 addrInfo.sin6_family = AF_INET6;
150 memcpy(addrInfo.sin6_addr.s6_addr, address->ipv6, sizeof(addrInfo.sin6_addr.s6_addr));
151 socklen_t len = sizeof(addrInfo);
152 return accept(socket, (struct sockaddr*) &addrInfo, &len);
153 }
154}
155
156static inline int SocketClose(Socket socket) {
157 return close(socket) >= 0;
158}
159
160static inline int SocketSetBlocking(Socket socket, bool blocking) {
161#ifdef _WIN32
162 u_long unblocking = !blocking;
163 return ioctlsocket(socket, FIONBIO, &unblocking) == NO_ERROR;
164#else
165 int flags = fcntl(socket, F_GETFL);
166 if (flags == -1) {
167 return 0;
168 }
169 if (blocking) {
170 flags &= ~O_NONBLOCK;
171 } else {
172 flags |= O_NONBLOCK;
173 }
174 return fcntl(socket, F_SETFL, flags) >= 0;
175#endif
176}
177
178static inline int SocketSetTCPPush(Socket socket, int push) {
179 return setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (char*) &push, sizeof(int)) >= 0;
180}
181
182static inline int SocketPoll(size_t nSockets, Socket* reads, Socket* writes, Socket* errors, int64_t timeoutMillis) {
183 fd_set rset;
184 fd_set wset;
185 fd_set eset;
186 FD_ZERO(&rset);
187 FD_ZERO(&wset);
188 FD_ZERO(&eset);
189 size_t i;
190 Socket maxFd = 0;
191 if (reads) {
192 for (i = 0; i < nSockets; ++i) {
193 if (SOCKET_FAILED(reads[i])) {
194 break;
195 }
196 if (reads[i] > maxFd) {
197 maxFd = reads[i];
198 }
199 FD_SET(reads[i], &rset);
200 reads[i] = INVALID_SOCKET;
201 }
202 }
203 if (writes) {
204 for (i = 0; i < nSockets; ++i) {
205 if (SOCKET_FAILED(writes[i])) {
206 break;
207 }
208 if (writes[i] > maxFd) {
209 maxFd = writes[i];
210 }
211 FD_SET(writes[i], &wset);
212 writes[i] = INVALID_SOCKET;
213 }
214 }
215 if (errors) {
216 for (i = 0; i < nSockets; ++i) {
217 if (SOCKET_FAILED(errors[i])) {
218 break;
219 }
220 if (errors[i] > maxFd) {
221 maxFd = errors[i];
222 }
223 FD_SET(errors[i], &eset);
224 errors[i] = INVALID_SOCKET;
225 }
226 }
227 struct timeval tv;
228 tv.tv_sec = timeoutMillis / 1000;
229 tv.tv_usec = (timeoutMillis % 1000) * 1000;
230 int result = select(maxFd, &rset, &wset, &eset, timeoutMillis < 0 ? 0 : &tv);
231 int r = 0;
232 int w = 0;
233 int e = 0;
234 Socket j;
235 for (j = 0; j < maxFd; ++j) {
236 if (reads && FD_ISSET(j, &rset)) {
237 reads[r] = j;
238 ++r;
239 }
240 if (writes && FD_ISSET(j, &wset)) {
241 writes[w] = j;
242 ++w;
243 }
244 if (errors && FD_ISSET(j, &eset)) {
245 errors[e] = j;
246 ++e;
247 }
248 }
249 return result;
250}
251
252#endif