blob: 5ad1c3041a98f4ad704da63f125a53a74f70e8a1 [file] [log] [blame]
Austin Schuh24adb6b2015-09-06 17:37:40 -07001// Copyright (c) 2013, Matt Godbolt
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without
5// modification, are permitted provided that the following conditions are met:
6//
7// Redistributions of source code must retain the above copyright notice, this
8// list of conditions and the following disclaimer.
9//
10// Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13//
14// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
18// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
19// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
20// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
21// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
22// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
23// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
24// POSSIBILITY OF SUCH DAMAGE.
25
26#include "internal/LogStream.h"
27
28#include "seasocks/Connection.h"
29#include "seasocks/Logger.h"
30#include "seasocks/Server.h"
31#include "seasocks/PageHandler.h"
32#include "seasocks/StringUtil.h"
33#include "seasocks/util/Json.h"
34
35#include <netinet/in.h>
36#include <netinet/tcp.h>
37
38#include <sys/epoll.h>
39#include <sys/eventfd.h>
40#include <sys/ioctl.h>
41#include <sys/socket.h>
42#include <sys/syscall.h>
43
44#include <memory>
45#include <stdexcept>
46#include <string.h>
47#include <unistd.h>
48
49namespace {
50
51struct EventBits {
52 uint32_t bits;
53 explicit EventBits(uint32_t bits) : bits(bits) {}
54};
55
56std::ostream& operator <<(std::ostream& o, const EventBits& b) {
57 uint32_t bits = b.bits;
58#define DO_BIT(NAME) \
59 do { if (bits & (NAME)) { if (bits != b.bits) {o << ", "; } o << #NAME; bits &= ~(NAME); } } while (0)
60 DO_BIT(EPOLLIN);
61 DO_BIT(EPOLLPRI);
62 DO_BIT(EPOLLOUT);
63 DO_BIT(EPOLLRDNORM);
64 DO_BIT(EPOLLRDBAND);
65 DO_BIT(EPOLLWRNORM);
66 DO_BIT(EPOLLWRBAND);
67 DO_BIT(EPOLLMSG);
68 DO_BIT(EPOLLERR);
69 DO_BIT(EPOLLHUP);
70#ifdef EPOLLRDHUP
71 DO_BIT(EPOLLRDHUP);
72#endif
73 DO_BIT(EPOLLONESHOT);
74 DO_BIT(EPOLLET);
75#undef DO_BIT
76 return o;
77}
78
79const int EpollTimeoutMillis = 500; // Twice a second is ample.
80const int DefaultLameConnectionTimeoutSeconds = 10;
81int gettid() {
82 return syscall(SYS_gettid);
83}
84
85}
86
87namespace seasocks {
88
89Server::Server(std::shared_ptr<Logger> logger)
90: _logger(logger), _listenSock(-1), _epollFd(-1), _eventFd(-1),
91 _maxKeepAliveDrops(0),
92 _lameConnectionTimeoutSeconds(DefaultLameConnectionTimeoutSeconds),
93 _nextDeadConnectionCheck(0), _threadId(0), _terminate(false),
94 _expectedTerminate(false) {
95
96 _epollFd = epoll_create(10);
97 if (_epollFd == -1) {
98 LS_ERROR(_logger, "Unable to create epoll: " << getLastError());
99 return;
100 }
101
102 _eventFd = eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC);
103 if (_eventFd == -1) {
104 LS_ERROR(_logger, "Unable to create event FD: " << getLastError());
105 return;
106 }
107
108 epoll_event eventWake = { EPOLLIN, { &_eventFd } };
109 if (epoll_ctl(_epollFd, EPOLL_CTL_ADD, _eventFd, &eventWake) == -1) {
110 LS_ERROR(_logger, "Unable to add wake socket to epoll: " << getLastError());
111 return;
112 }
113}
114
115Server::~Server() {
116 LS_INFO(_logger, "Server destruction");
117 shutdown();
118 // Only shut the eventfd and epoll at the very end
119 if (_eventFd != -1) {
120 close(_eventFd);
121 }
122 if (_epollFd != -1) {
123 close(_epollFd);
124 }
125}
126
127void Server::shutdown() {
128 // Stop listening to any further incoming connections.
129 if (_listenSock != -1) {
130 close(_listenSock);
131 _listenSock = -1;
132 }
133 // Disconnect and close any current connections.
134 while (!_connections.empty()) {
135 // Deleting the connection closes it and removes it from 'this'.
136 Connection* toBeClosed = _connections.begin()->first;
137 toBeClosed->setLinger();
138 delete toBeClosed;
139 }
140}
141
142bool Server::makeNonBlocking(int fd) const {
143 int yesPlease = 1;
144 if (ioctl(fd, FIONBIO, &yesPlease) != 0) {
145 LS_ERROR(_logger, "Unable to make FD non-blocking: " << getLastError());
146 return false;
147 }
148 return true;
149}
150
151bool Server::configureSocket(int fd) const {
152 if (!makeNonBlocking(fd)) {
153 return false;
154 }
155 const int yesPlease = 1;
156 if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &yesPlease, sizeof(yesPlease)) == -1) {
157 LS_ERROR(_logger, "Unable to set reuse socket option: " << getLastError());
158 return false;
159 }
160 if (_maxKeepAliveDrops > 0) {
161 if (setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &yesPlease, sizeof(yesPlease)) == -1) {
162 LS_ERROR(_logger, "Unable to enable keepalive: " << getLastError());
163 return false;
164 }
165 const int oneSecond = 1;
166 if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPIDLE, &oneSecond, sizeof(oneSecond)) == -1) {
167 LS_ERROR(_logger, "Unable to set idle probe: " << getLastError());
168 return false;
169 }
170 if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPINTVL, &oneSecond, sizeof(oneSecond)) == -1) {
171 LS_ERROR(_logger, "Unable to set idle interval: " << getLastError());
172 return false;
173 }
174 if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPCNT, &_maxKeepAliveDrops, sizeof(_maxKeepAliveDrops)) == -1) {
175 LS_ERROR(_logger, "Unable to set keep alive count: " << getLastError());
176 return false;
177 }
178 }
179 return true;
180}
181
182void Server::terminate() {
183 _expectedTerminate = true;
184 _terminate = true;
185 uint64_t one = 1;
186 if (_eventFd != -1 && ::write(_eventFd, &one, sizeof(one)) == -1) {
187 LS_ERROR(_logger, "Unable to post a wake event: " << getLastError());
188 }
189}
190
191bool Server::startListening(int port) {
192 return startListening(INADDR_ANY, port);
193}
194
195bool Server::startListening(uint32_t hostAddr, int port) {
196 if (_epollFd == -1 || _eventFd == -1) {
197 LS_ERROR(_logger, "Unable to serve, did not initialize properly.");
198 return false;
199 }
200
201 _listenSock = socket(AF_INET, SOCK_STREAM, 0);
202 if (_listenSock == -1) {
203 LS_ERROR(_logger, "Unable to create listen socket: " << getLastError());
204 return false;
205 }
206 if (!configureSocket(_listenSock)) {
207 return false;
208 }
209 sockaddr_in sock;
210 memset(&sock, 0, sizeof(sock));
211 sock.sin_port = htons(port);
212 sock.sin_addr.s_addr = htonl(hostAddr);
213 sock.sin_family = AF_INET;
214 if (bind(_listenSock, reinterpret_cast<const sockaddr*>(&sock), sizeof(sock)) == -1) {
215 LS_ERROR(_logger, "Unable to bind socket: " << getLastError());
216 return false;
217 }
218 if (listen(_listenSock, 5) == -1) {
219 LS_ERROR(_logger, "Unable to listen on socket: " << getLastError());
220 return false;
221 }
222 epoll_event event = { EPOLLIN, { this } };
223 if (epoll_ctl(_epollFd, EPOLL_CTL_ADD, _listenSock, &event) == -1) {
224 LS_ERROR(_logger, "Unable to add listen socket to epoll: " << getLastError());
225 return false;
226 }
227
228 char buf[1024];
229 ::gethostname(buf, sizeof(buf));
230 LS_INFO(_logger, "Listening on http://" << buf << ":" << port << "/");
231
232 return true;
233}
234
235void Server::handlePipe() {
236 uint64_t dummy;
237 while (::read(_eventFd, &dummy, sizeof(dummy)) != -1) {
238 // Spin, draining the pipe until it returns EWOULDBLOCK or similar.
239 }
240 if (errno != EAGAIN || errno != EWOULDBLOCK) {
241 LS_ERROR(_logger, "Error from wakeFd read: " << getLastError());
242 _terminate = true;
243 }
244 // It's a "wake up" event; this will just cause the epoll loop to wake up.
245}
246
247Server::NewState Server::handleConnectionEvents(Connection* connection, uint32_t events) {
248 if (events & ~(EPOLLIN|EPOLLOUT|EPOLLHUP|EPOLLERR)) {
249 LS_WARNING(_logger, "Got unhandled epoll event (" << EventBits(events) << ") on connection: "
250 << formatAddress(connection->getRemoteAddress()));
251 return Close;
252 } else if (events & EPOLLERR) {
253 LS_INFO(_logger, "Error on socket (" << EventBits(events) << "): "
254 << formatAddress(connection->getRemoteAddress()));
255 return Close;
256 } else if (events & EPOLLHUP) {
257 LS_DEBUG(_logger, "Graceful hang-up (" << EventBits(events) << ") of socket: "
258 << formatAddress(connection->getRemoteAddress()));
259 return Close;
260 } else {
261 if (events & EPOLLOUT) {
262 connection->handleDataReadyForWrite();
263 }
264 if (events & EPOLLIN) {
265 connection->handleDataReadyForRead();
266 }
267 }
268 return KeepOpen;
269}
270
271void Server::checkAndDispatchEpoll(int epollMillis) {
272 const int maxEvents = 256;
273 epoll_event events[maxEvents];
274
275 std::list<Connection*> toBeDeleted;
276 int numEvents = epoll_wait(_epollFd, events, maxEvents, epollMillis);
277 if (numEvents == -1) {
278 if (errno != EINTR) {
279 LS_ERROR(_logger, "Error from epoll_wait: " << getLastError());
280 }
281 return;
282 }
283 if (numEvents == maxEvents) {
284 static time_t lastWarnTime = 0;
285 time_t now = time(NULL);
286 if (now - lastWarnTime >= 60) {
287 LS_WARNING(_logger, "Full event queue; may start starving connections. "
288 "Will warn at most once a minute");
289 lastWarnTime = now;
290 }
291 }
292 for (int i = 0; i < numEvents; ++i) {
293 if (events[i].data.ptr == this) {
294 if (events[i].events & ~EPOLLIN) {
295 LS_SEVERE(_logger, "Got unexpected event on listening socket ("
296 << EventBits(events[i].events) << ") - terminating");
297 _terminate = true;
298 break;
299 }
300 handleAccept();
301 } else if (events[i].data.ptr == &_eventFd) {
302 if (events[i].events & ~EPOLLIN) {
303 LS_SEVERE(_logger, "Got unexpected event on management pipe ("
304 << EventBits(events[i].events) << ") - terminating");
305 _terminate = true;
306 break;
307 }
308 handlePipe();
309 } else {
310 auto connection = reinterpret_cast<Connection*>(events[i].data.ptr);
311 if (handleConnectionEvents(connection, events[i].events) == Close) {
312 toBeDeleted.push_back(connection);
313 }
314 }
315 }
316 // The connections are all deleted at the end so we've processed any other subject's
317 // closes etc before we call onDisconnect().
318 for (auto it = toBeDeleted.begin(); it != toBeDeleted.end(); ++it) {
319 auto connection = *it;
320 if (_connections.find(connection) == _connections.end()) {
321 LS_SEVERE(_logger, "Attempt to delete connection we didn't know about: " << (void*)connection
322 << formatAddress(connection->getRemoteAddress()));
323 _terminate = true;
324 break;
325 }
326 LS_DEBUG(_logger, "Deleting connection: " << formatAddress(connection->getRemoteAddress()));
327 delete connection;
328 }
329}
330
331void Server::setStaticPath(const char* staticPath) {
332 LS_INFO(_logger, "Serving content from " << staticPath);
333 _staticPath = staticPath;
334}
335
336bool Server::serve(const char* staticPath, int port) {
337 setStaticPath(staticPath);
338 if (!startListening(port)) {
339 return false;
340 }
341
342 return loop();
343}
344
345bool Server::loop() {
346 if (_listenSock == -1) {
347 LS_ERROR(_logger, "Server not initialised");
348 return false;
349 }
350
351 // Stash away "the" server thread id.
352 _threadId = gettid();
353
354 while (!_terminate) {
355 // Always process events first to catch start up events.
356 processEventQueue();
357 checkAndDispatchEpoll(EpollTimeoutMillis);
358 }
359 // Reasonable effort to ensure anything enqueued during terminate has a chance to run.
360 processEventQueue();
361 LS_INFO(_logger, "Server terminating");
362 shutdown();
363 return _expectedTerminate;
364}
365
366Server::PollResult Server::poll(int millis) {
367 // Grab the thread ID on the first poll.
368 if (_threadId == 0) _threadId = gettid();
369 if (_threadId != gettid()) {
370 LS_ERROR(_logger, "poll() called from the wrong thread");
371 return PollResult::Error;
372 }
373 if (_listenSock == -1) {
374 LS_ERROR(_logger, "Server not initialised");
375 return PollResult::Error;
376 }
377 processEventQueue();
378 checkAndDispatchEpoll(millis);
379 if (!_terminate) return PollResult::Continue;
380
381 // Reasonable effort to ensure anything enqueued during terminate has a chance to run.
382 processEventQueue();
383 LS_INFO(_logger, "Server terminating");
384 shutdown();
385
386 return _expectedTerminate ? PollResult::Terminated : PollResult::Error;
387}
388
389void Server::processEventQueue() {
390 for (;;) {
391 std::shared_ptr<Runnable> runnable = popNextRunnable();
392 if (!runnable) break;
393 runnable->run();
394 }
395 time_t now = time(NULL);
396 if (now >= _nextDeadConnectionCheck) {
397 std::list<Connection*> toRemove;
398 for (auto it = _connections.cbegin(); it != _connections.cend(); ++it) {
399 time_t numSecondsSinceConnection = now - it->second;
400 auto connection = it->first;
401 if (connection->bytesReceived() == 0 && numSecondsSinceConnection >= _lameConnectionTimeoutSeconds) {
402 LS_INFO(_logger, formatAddress(connection->getRemoteAddress())
403 << " : Killing lame connection - no bytes received after " << numSecondsSinceConnection << "s");
404 toRemove.push_back(connection);
405 }
406 }
407 for (auto it = toRemove.begin(); it != toRemove.end(); ++it) {
408 delete *it;
409 }
410 }
411}
412
413void Server::handleAccept() {
414 sockaddr_in address;
415 socklen_t addrLen = sizeof(address);
416 int fd = ::accept(_listenSock,
417 reinterpret_cast<sockaddr*>(&address),
418 &addrLen);
419 if (fd == -1) {
420 LS_ERROR(_logger, "Unable to accept: " << getLastError());
421 return;
422 }
423 if (!configureSocket(fd)) {
424 ::close(fd);
425 return;
426 }
427 LS_INFO(_logger, formatAddress(address) << " : Accepted on descriptor " << fd);
428 Connection* newConnection = new Connection(_logger, *this, fd, address);
429 epoll_event event = { EPOLLIN, { newConnection } };
430 if (epoll_ctl(_epollFd, EPOLL_CTL_ADD, fd, &event) == -1) {
431 LS_ERROR(_logger, "Unable to add socket to epoll: " << getLastError());
432 delete newConnection;
433 ::close(fd);
434 return;
435 }
436 _connections.insert(std::make_pair(newConnection, time(NULL)));
437}
438
439void Server::remove(Connection* connection) {
440 checkThread();
441 epoll_event event = { 0, { connection } };
442 if (epoll_ctl(_epollFd, EPOLL_CTL_DEL, connection->getFd(), &event) == -1) {
443 LS_ERROR(_logger, "Unable to remove from epoll: " << getLastError());
444 }
445 _connections.erase(connection);
446}
447
448bool Server::subscribeToWriteEvents(Connection* connection) {
449 epoll_event event = { EPOLLIN | EPOLLOUT, { connection } };
450 if (epoll_ctl(_epollFd, EPOLL_CTL_MOD, connection->getFd(), &event) == -1) {
451 LS_ERROR(_logger, "Unable to subscribe to write events: " << getLastError());
452 return false;
453 }
454 return true;
455}
456
457bool Server::unsubscribeFromWriteEvents(Connection* connection) {
458 epoll_event event = { EPOLLIN, { connection } };
459 if (epoll_ctl(_epollFd, EPOLL_CTL_MOD, connection->getFd(), &event) == -1) {
460 LS_ERROR(_logger, "Unable to unsubscribe from write events: " << getLastError());
461 return false;
462 }
463 return true;
464}
465
466void Server::addWebSocketHandler(const char* endpoint, std::shared_ptr<WebSocket::Handler> handler,
467 bool allowCrossOriginRequests) {
468 _webSocketHandlerMap[endpoint] = { handler, allowCrossOriginRequests };
469}
470
471void Server::addPageHandler(std::shared_ptr<PageHandler> handler) {
472 _pageHandlers.emplace_back(handler);
473}
474
475bool Server::isCrossOriginAllowed(const std::string &endpoint) const {
476 auto splits = split(endpoint, '?');
477 auto iter = _webSocketHandlerMap.find(splits[0]);
478 if (iter == _webSocketHandlerMap.end()) {
479 return false;
480 }
481 return iter->second.allowCrossOrigin;
482}
483
484std::shared_ptr<WebSocket::Handler> Server::getWebSocketHandler(const char* endpoint) const {
485 auto splits = split(endpoint, '?');
486 auto iter = _webSocketHandlerMap.find(splits[0]);
487 if (iter == _webSocketHandlerMap.end()) {
488 return std::shared_ptr<WebSocket::Handler>();
489 }
490 return iter->second.handler;
491}
492
493void Server::execute(std::shared_ptr<Runnable> runnable) {
494 std::unique_lock<decltype(_pendingRunnableMutex)> lock(_pendingRunnableMutex);
495 _pendingRunnables.push_back(runnable);
496 lock.unlock();
497
498 uint64_t one = 1;
499 if (_eventFd != -1 && ::write(_eventFd, &one, sizeof(one)) == -1) {
500 if (errno != EAGAIN && errno != EWOULDBLOCK) {
501 LS_ERROR(_logger, "Unable to post a wake event: " << getLastError());
502 }
503 }
504}
505
506std::shared_ptr<Server::Runnable> Server::popNextRunnable() {
507 std::lock_guard<decltype(_pendingRunnableMutex)> lock(_pendingRunnableMutex);
508 std::shared_ptr<Runnable> runnable;
509 if (!_pendingRunnables.empty()) {
510 runnable = _pendingRunnables.front();
511 _pendingRunnables.pop_front();
512 }
513 return runnable;
514}
515
516std::string Server::getStatsDocument() const {
517 std::ostringstream doc;
518 doc << "clear();" << std::endl;
519 for (auto it = _connections.begin(); it != _connections.end(); ++it) {
520 doc << "connection({";
521 auto connection = it->first;
522 jsonKeyPairToStream(doc,
523 "since", EpochTimeAsLocal(it->second),
524 "fd", connection->getFd(),
525 "id", reinterpret_cast<uint64_t>(connection),
526 "uri", connection->getRequestUri(),
527 "addr", formatAddress(connection->getRemoteAddress()),
528 "user", connection->credentials() ?
529 connection->credentials()->username : "(not authed)",
530 "input", connection->inputBufferSize(),
531 "read", connection->bytesReceived(),
532 "output", connection->outputBufferSize(),
533 "written", connection->bytesSent()
534 );
535 doc << "});" << std::endl;
536 }
537 return doc.str();
538}
539
540void Server::setLameConnectionTimeoutSeconds(int seconds) {
541 LS_INFO(_logger, "Setting lame connection timeout to " << seconds);
542 _lameConnectionTimeoutSeconds = seconds;
543}
544
545void Server::setMaxKeepAliveDrops(int maxKeepAliveDrops) {
546 LS_INFO(_logger, "Setting max keep alive drops to " << maxKeepAliveDrops);
547 _maxKeepAliveDrops = maxKeepAliveDrops;
548}
549
550void Server::checkThread() const {
551 auto thisTid = gettid();
552 if (thisTid != _threadId) {
553 std::ostringstream o;
554 o << "seasocks called on wrong thread : " << thisTid << " instead of " << _threadId;
555 LS_SEVERE(_logger, o.str());
556 throw std::runtime_error(o.str());
557 }
558}
559
560std::shared_ptr<Response> Server::handle(const Request &request) {
561 for (auto handler : _pageHandlers) {
562 auto result = handler->handle(request);
563 if (result != Response::unhandled()) return result;
564 }
565 return Response::unhandled();
566}
567
568} // namespace seasocks