blob: 9c250a0545cb33787b6b99d76a3e6f9699edfd20 [file] [log] [blame]
Austin Schuh24adb6b2015-09-06 17:37:40 -07001// Copyright (c) 2013, Matt Godbolt
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without
5// modification, are permitted provided that the following conditions are met:
6//
7// Redistributions of source code must retain the above copyright notice, this
8// list of conditions and the following disclaimer.
9//
10// Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13//
14// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
18// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
19// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
20// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
21// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
22// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
23// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
24// POSSIBILITY OF SUCH DAMAGE.
25
26#include "internal/Config.h"
27#include "internal/Embedded.h"
28#include "internal/HeaderMap.h"
29#include "internal/HybiAccept.h"
30#include "internal/HybiPacketDecoder.h"
31#include "internal/LogStream.h"
32#include "internal/PageRequest.h"
33#include "internal/Version.h"
34
35#include "md5/md5.h"
36
37#include "seasocks/Connection.h"
38#include "seasocks/Credentials.h"
39#include "seasocks/Logger.h"
40#include "seasocks/PageHandler.h"
41#include "seasocks/Server.h"
42#include "seasocks/StringUtil.h"
43#include "seasocks/ToString.h"
44
45#include <sys/stat.h>
46#include <sys/types.h>
47
48#include <assert.h>
49#include <ctype.h>
50#include <errno.h>
51#include <fcntl.h>
52#include <fstream>
53#include <iostream>
54#include <limits>
55#include <sstream>
56#include <stdio.h>
57#include <string.h>
58#include <unistd.h>
59#include <unordered_map>
60
61namespace {
62
63uint32_t parseWebSocketKey(const std::string& key) {
64 uint32_t keyNumber = 0;
65 uint32_t numSpaces = 0;
66 for (auto c : key) {
67 if (c >= '0' && c <= '9') {
68 keyNumber = keyNumber * 10 + c - '0';
69 } else if (c == ' ') {
70 ++numSpaces;
71 }
72 }
73 return numSpaces > 0 ? keyNumber / numSpaces : 0;
74}
75
76char* extractLine(uint8_t*& first, uint8_t* last, char** colon = NULL) {
77 for (uint8_t* ptr = first; ptr < last - 1; ++ptr) {
78 if (ptr[0] == '\r' && ptr[1] == '\n') {
79 ptr[0] = 0;
80 uint8_t* result = first;
81 first = ptr + 2;
82 return reinterpret_cast<char*> (result);
83 }
84 if (colon && ptr[0] == ':' && *colon == NULL) {
85 *colon = reinterpret_cast<char*> (ptr);
86 }
87 }
88 return NULL;
89}
90
91std::string webtime(time_t time) {
92 struct tm tm;
93 gmtime_r(&time, &tm);
94 char buf[1024];
95 // Wed, 20 Apr 2011 17:31:28 GMT
96 strftime(buf, sizeof(buf)-1, "%a, %d %b %Y %H:%M:%S %Z", &tm);
97 return buf;
98}
99
100std::string now() {
101 return webtime(time(NULL));
102}
103
104class RaiiFd {
105 int _fd;
106public:
107 RaiiFd(const char* filename) {
108 _fd = ::open(filename, O_RDONLY);
109 }
110 RaiiFd(const RaiiFd&) = delete;
111 RaiiFd& operator=(const RaiiFd&) = delete;
112 ~RaiiFd() {
113 if (_fd != -1) {
114 ::close(_fd);
115 }
116 }
117 bool ok() const {
118 return _fd != -1;
119 }
120 operator int() const {
121 return _fd;
122 }
123};
124
125const std::unordered_map<std::string, std::string> contentTypes = {
126 { "txt", "text/plain" },
127 { "css", "text/css" },
128 { "csv", "text/csv" },
129 { "htm", "text/html" },
130 { "html", "text/html" },
131 { "xml", "text/xml" },
132 { "js", "text/javascript" }, // Technically it should be application/javascript (RFC 4329), but IE8 struggles with that
133 { "xhtml", "application/xhtml+xml" },
134 { "json", "application/json" },
135 { "pdf", "application/pdf" },
136 { "zip", "application/zip" },
137 { "tar", "application/x-tar" },
138 { "gif", "image/gif" },
139 { "jpeg", "image/jpeg" },
140 { "jpg", "image/jpeg" },
141 { "tiff", "image/tiff" },
142 { "tif", "image/tiff" },
143 { "png", "image/png" },
144 { "svg", "image/svg+xml" },
145 { "ico", "image/x-icon" },
146 { "swf", "application/x-shockwave-flash" },
147 { "mp3", "audio/mpeg" },
148 { "wav", "audio/x-wav" },
149 { "ttf", "font/ttf" },
150};
151
152std::string getExt(const std::string& path) {
153 auto lastDot = path.find_last_of('.');
154 if (lastDot != path.npos) {
155 return path.substr(lastDot + 1);
156 }
157 return "";
158}
159
160const std::string& getContentType(const std::string& path) {
161 auto it = contentTypes.find(getExt(path));
162 if (it != contentTypes.end()) {
163 return it->second;
164 }
165 static const std::string defaultType("text/html");
166 return defaultType;
167}
168
169// Cacheability is only set for resources that *REQUIRE* caching for browser support reasons.
170// It's off for everything else to save on browser reload headaches during development, at
171// least until we support ETags or If-Modified-Since: type checking, which we may never do.
172bool isCacheable(const std::string& path) {
173 std::string extension = getExt(path);
174 if (extension == "mp3" || extension == "wav") {
175 return true;
176 }
177 return false;
178}
179
180const size_t MaxBufferSize = 16 * 1024 * 1024;
181const size_t ReadWriteBufferSize = 16 * 1024;
182const size_t MaxWebsocketMessageSize = 16384;
183const size_t MaxHeadersSize = 64 * 1024;
184
185class PrefixWrapper : public seasocks::Logger {
186 std::string _prefix;
187 std::shared_ptr<Logger> _logger;
188public:
189 PrefixWrapper(const std::string& prefix, std::shared_ptr<Logger> logger)
190 : _prefix(prefix), _logger(logger) {}
191
192 virtual void log(Level level, const char* message) {
193 _logger->log(level, (_prefix + message).c_str());
194 }
195};
196
197bool hasConnectionType(const std::string &connection, const std::string &type) {
198 for (auto conType : seasocks::split(connection, ',')) {
199 while (!conType.empty() && isspace(conType[0]))
200 conType = conType.substr(1);
201 if (seasocks::caseInsensitiveSame(conType, type))
202 return true;
203 }
204 return false;
205}
206
207} // namespace
208
209namespace seasocks {
210
211Connection::Connection(
212 std::shared_ptr<Logger> logger,
213 ServerImpl& server,
214 int fd,
215 const sockaddr_in& address)
216 : _logger(new PrefixWrapper(formatAddress(address) + " : ", logger)),
217 _server(server),
218 _fd(fd),
219 _shutdown(false),
220 _hadSendError(false),
221 _closeOnEmpty(false),
222 _registeredForWriteEvents(false),
223 _address(address),
224 _bytesSent(0),
225 _bytesReceived(0),
226 _shutdownByUser(false),
227 _state(READING_HEADERS) {
228}
229
230Connection::~Connection() {
231 _server.checkThread();
232 finalise();
233}
234
235void Connection::close() {
236 // This is the user-side close requests ONLY! You should Call closeInternal
237 _shutdownByUser = true;
238 closeInternal();
239}
240
241void Connection::closeWhenEmpty() {
242 if (_outBuf.empty()) {
243 closeInternal();
244 } else {
245 _closeOnEmpty = true;
246 }
247}
248
249void Connection::closeInternal() {
250 // It only actually only calls shutdown on the socket,
251 // leaving the close of the FD and the cleanup until the destructor runs.
252 _server.checkThread();
253 if (_fd != -1 && !_shutdown && ::shutdown(_fd, SHUT_RDWR) == -1) {
254 LS_WARNING(_logger, "Unable to shutdown socket : " << getLastError());
255 }
256 _shutdown = true;
257}
258
259
260void Connection::finalise() {
261 if (_webSocketHandler) {
262 _webSocketHandler->onDisconnect(this);
263 _webSocketHandler.reset();
264 }
265 if (_fd != -1) {
266 _server.remove(this);
267 LS_DEBUG(_logger, "Closing socket");
268 ::close(_fd);
269 }
270 _fd = -1;
271}
272
273int Connection::safeSend(const void* data, size_t size) {
274 if (_fd == -1 || _hadSendError || _shutdown) {
275 // Ignore further writes to the socket, it's already closed or has been shutdown
276 return -1;
277 }
278 int sendResult = ::send(_fd, data, size, MSG_NOSIGNAL);
279 if (sendResult == -1) {
280 if (errno == EAGAIN || errno == EWOULDBLOCK) {
281 // Treat this as if zero bytes were written.
282 return 0;
283 }
284 LS_WARNING(_logger, "Unable to write to socket : " << getLastError() << " - disabling further writes");
285 closeInternal();
286 } else {
287 _bytesSent += sendResult;
288 }
289 return sendResult;
290}
291
292bool Connection::write(const void* data, size_t size, bool flushIt) {
293 if (closed() || _closeOnEmpty) {
294 return false;
295 }
296 if (size) {
297 int bytesSent = 0;
298 if (_outBuf.empty() && flushIt) {
299 // Attempt fast path, send directly.
300 bytesSent = safeSend(data, size);
301 if (bytesSent == static_cast<int>(size)) {
302 // We sent directly.
303 return true;
304 }
305 if (bytesSent == -1) {
306 return false;
307 }
308 }
309 size_t bytesToBuffer = size - bytesSent;
310 size_t endOfBuffer = _outBuf.size();
311 size_t newBufferSize = endOfBuffer + bytesToBuffer;
312 if (newBufferSize >= MaxBufferSize) {
313 LS_WARNING(_logger, "Closing connection: buffer size too large ("
314 << newBufferSize << " >= " << MaxBufferSize << ")");
315 closeInternal();
316 return false;
317 }
318 _outBuf.resize(newBufferSize);
319 memcpy(&_outBuf[endOfBuffer], reinterpret_cast<const uint8_t*>(data) + bytesSent, bytesToBuffer);
320 }
321 if (flushIt) {
322 return flush();
323 }
324 return true;
325}
326
327bool Connection::bufferLine(const char* line) {
328 static const char crlf[] = { '\r', '\n' };
329 if (!write(line, strlen(line), false)) return false;
330 return write(crlf, 2, false);
331}
332
333bool Connection::bufferLine(const std::string& line) {
334 std::string lineAndCrlf = line + "\r\n";
335 return write(lineAndCrlf.c_str(), lineAndCrlf.length(), false);
336}
337
338void Connection::handleDataReadyForRead() {
339 if (closed()) {
340 return;
341 }
342 size_t curSize = _inBuf.size();
343 _inBuf.resize(curSize + ReadWriteBufferSize);
344 int result = ::read(_fd, &_inBuf[curSize], ReadWriteBufferSize);
345 if (result == -1) {
346 LS_WARNING(_logger, "Unable to read from socket : " << getLastError());
347 return;
348 }
349 if (result == 0) {
350 LS_DEBUG(_logger, "Remote end closed connection");
351 closeInternal();
352 return;
353 }
354 _bytesReceived += result;
355 _inBuf.resize(curSize + result);
356 handleNewData();
357}
358
359void Connection::handleDataReadyForWrite() {
360 if (closed()) {
361 return;
362 }
363 flush();
364}
365
366bool Connection::flush() {
367 if (_outBuf.empty()) {
368 return true;
369 }
370 int numSent = safeSend(&_outBuf[0], _outBuf.size());
371 if (numSent == -1) {
372 return false;
373 }
374 _outBuf.erase(_outBuf.begin(), _outBuf.begin() + numSent);
375 if (_outBuf.size() > 0 && !_registeredForWriteEvents) {
376 if (!_server.subscribeToWriteEvents(this)) {
377 return false;
378 }
379 _registeredForWriteEvents = true;
380 } else if (_outBuf.empty() && _registeredForWriteEvents) {
381 if (!_server.unsubscribeFromWriteEvents(this)) {
382 return false;
383 }
384 _registeredForWriteEvents = false;
385 }
386 if (_outBuf.empty() && !closed() && _closeOnEmpty) {
387 LS_DEBUG(_logger, "Ready for close, now empty");
388 closeInternal();
389 }
390 return true;
391}
392
393bool Connection::closed() const {
394 return _fd == -1 || _shutdown;
395}
396
397void Connection::handleNewData() {
398 switch (_state) {
399 case READING_HEADERS:
400 handleHeaders();
401 break;
402 case READING_WEBSOCKET_KEY3:
403 handleWebSocketKey3();
404 break;
405 case HANDLING_HIXIE_WEBSOCKET:
406 handleHixieWebSocket();
407 break;
408 case HANDLING_HYBI_WEBSOCKET:
409 handleHybiWebSocket();
410 break;
411 case BUFFERING_POST_DATA:
412 handleBufferingPostData();
413 break;
414 default:
415 assert(false);
416 break;
417 }
418}
419
420void Connection::handleHeaders() {
421 if (_inBuf.size() < 4) {
422 return;
423 }
424 for (size_t i = 0; i <= _inBuf.size() - 4; ++i) {
425 if (_inBuf[i] == '\r' &&
426 _inBuf[i+1] == '\n' &&
427 _inBuf[i+2] == '\r' &&
428 _inBuf[i+3] == '\n') {
429 if (!processHeaders(&_inBuf[0], &_inBuf[i + 2])) {
430 closeInternal();
431 return;
432 }
433 _inBuf.erase(_inBuf.begin(), _inBuf.begin() + i + 4);
434 handleNewData();
435 return;
436 }
437 }
438 if (_inBuf.size() > MaxHeadersSize) {
439 sendUnsupportedError("Headers too big");
440 }
441}
442
443void Connection::handleWebSocketKey3() {
444 constexpr auto WebSocketKeyLen = 8u;
445 if (_inBuf.size() < WebSocketKeyLen) {
446 return;
447 }
448
449 struct {
450 uint32_t key1;
451 uint32_t key2;
452 char key3[WebSocketKeyLen];
453 } md5Source;
454
455 auto key1 = parseWebSocketKey(_request->getHeader("Sec-WebSocket-Key1"));
456 auto key2 = parseWebSocketKey(_request->getHeader("Sec-WebSocket-Key2"));
457
458 LS_DEBUG(_logger, "Got a hixie websocket with key1=0x" << std::hex << key1 << ", key2=0x" << key2);
459
460 md5Source.key1 = htonl(key1);
461 md5Source.key2 = htonl(key2);
462 memcpy(&md5Source.key3, &_inBuf[0], WebSocketKeyLen);
463
464 uint8_t digest[16];
465 md5_state_t md5state;
466 md5_init(&md5state);
467 md5_append(&md5state, reinterpret_cast<const uint8_t*>(&md5Source), sizeof(md5Source));
468 md5_finish(&md5state, digest);
469
470 LS_DEBUG(_logger, "Attempting websocket upgrade");
471
472 bufferResponseAndCommonHeaders(ResponseCode::WebSocketProtocolHandshake);
473 bufferLine("Upgrade: websocket");
474 bufferLine("Connection: Upgrade");
475 bool allowCrossOrigin = _server.isCrossOriginAllowed(_request->getRequestUri());
476 if (_request->hasHeader("Origin") && allowCrossOrigin) {
477 bufferLine("Sec-WebSocket-Origin: " + _request->getHeader("Origin"));
478 }
479 if (_request->hasHeader("Host")) {
480 auto host = _request->getHeader("Host");
481 if (!allowCrossOrigin) {
482 bufferLine("Sec-WebSocket-Origin: http://" + host);
483 }
484 bufferLine("Sec-WebSocket-Location: ws://" + host + _request->getRequestUri());
485 }
486 bufferLine("");
487
488 write(&digest, 16, true);
489
490 _state = HANDLING_HIXIE_WEBSOCKET;
491 _inBuf.erase(_inBuf.begin(), _inBuf.begin() + 8);
492 if (_webSocketHandler) {
493 _webSocketHandler->onConnect(this);
494 }
495}
496
497void Connection::handleBufferingPostData() {
498 if (_request->consumeContent(_inBuf)) {
499 _state = READING_HEADERS;
500 if (!handlePageRequest()) {
501 closeInternal();
502 }
503 }
504}
505
506void Connection::send(const char* webSocketResponse) {
507 _server.checkThread();
508 if (_shutdown) {
509 if (_shutdownByUser) {
510 LS_ERROR(_logger, "Server wrote to connection after closing it");
511 }
512 return;
513 }
514 auto messageLength = strlen(webSocketResponse);
515 if (_state == HANDLING_HIXIE_WEBSOCKET) {
516 uint8_t zero = 0;
517 if (!write(&zero, 1, false)) return;
518 if (!write(webSocketResponse, messageLength, false)) return;
519 uint8_t effeff = 0xff;
520 write(&effeff, 1, true);
521 return;
522 }
523 sendHybi(HybiPacketDecoder::OPCODE_TEXT, reinterpret_cast<const uint8_t*>(webSocketResponse), messageLength);
524}
525
526void Connection::send(const uint8_t* data, size_t length) {
527 _server.checkThread();
528 if (_shutdown) {
529 if (_shutdownByUser) {
530 LS_ERROR(_logger, "Client wrote to connection after closing it");
531 }
532 return;
533 }
534 if (_state == HANDLING_HIXIE_WEBSOCKET) {
535 LS_ERROR(_logger, "Hixie does not support binary");
536 return;
537 }
538 sendHybi(HybiPacketDecoder::OPCODE_BINARY, data, length);
539}
540
541void Connection::sendHybi(int opcode, const uint8_t* webSocketResponse, size_t messageLength) {
542 uint8_t firstByte = 0x80 | opcode;
543 if (!write(&firstByte, 1, false)) return;
544 if (messageLength < 126) {
545 uint8_t nextByte = messageLength; // No MASK bit set.
546 if (!write(&nextByte, 1, false)) return;
547 } else if (messageLength < 65536) {
548 uint8_t nextByte = 126; // No MASK bit set.
549 if (!write(&nextByte, 1, false)) return;
550 auto lengthBytes = htons(messageLength);
551 if (!write(&lengthBytes, 2, false)) return;
552 } else {
553 uint8_t nextByte = 127; // No MASK bit set.
554 if (!write(&nextByte, 1, false)) return;
555 uint64_t lengthBytes = __bswap_64(messageLength);
556 if (!write(&lengthBytes, 8, false)) return;
557 }
558 write(webSocketResponse, messageLength, true);
559}
560
561std::shared_ptr<Credentials> Connection::credentials() const {
562 _server.checkThread();
563 return _request ? _request->credentials() : std::shared_ptr<Credentials>();
564}
565
566void Connection::handleHixieWebSocket() {
567 if (_inBuf.empty()) {
568 return;
569 }
570 size_t messageStart = 0;
571 while (messageStart < _inBuf.size()) {
572 if (_inBuf[messageStart] != 0) {
573 LS_WARNING(_logger, "Error in WebSocket input stream (got " << (int)_inBuf[messageStart] << ")");
574 closeInternal();
575 return;
576 }
577 // TODO: UTF-8
578 size_t endOfMessage = 0;
579 for (size_t i = messageStart + 1; i < _inBuf.size(); ++i) {
580 if (_inBuf[i] == 0xff) {
581 endOfMessage = i;
582 break;
583 }
584 }
585 if (endOfMessage != 0) {
586 _inBuf[endOfMessage] = 0;
587 handleWebSocketTextMessage(reinterpret_cast<const char*>(&_inBuf[messageStart + 1]));
588 messageStart = endOfMessage + 1;
589 } else {
590 break;
591 }
592 }
593 if (messageStart != 0) {
594 _inBuf.erase(_inBuf.begin(), _inBuf.begin() + messageStart);
595 }
596 if (_inBuf.size() > MaxWebsocketMessageSize) {
597 LS_WARNING(_logger, "WebSocket message too long");
598 closeInternal();
599 }
600}
601
602void Connection::handleHybiWebSocket() {
603 if (_inBuf.empty()) {
604 return;
605 }
606 HybiPacketDecoder decoder(*_logger, _inBuf);
607 bool done = false;
608 while (!done) {
609 std::vector<uint8_t> decodedMessage;
610 switch (decoder.decodeNextMessage(decodedMessage)) {
611 default:
612 closeInternal();
613 LS_WARNING(_logger, "Unknown HybiPacketDecoder state");
614 return;
615 case HybiPacketDecoder::Error:
616 closeInternal();
617 return;
618 case HybiPacketDecoder::TextMessage:
619 decodedMessage.push_back(0); // avoids a copy
620 handleWebSocketTextMessage(reinterpret_cast<const char*>(&decodedMessage[0]));
621 break;
622 case HybiPacketDecoder::BinaryMessage:
623 handleWebSocketBinaryMessage(decodedMessage);
624 break;
625 case HybiPacketDecoder::Ping:
626 sendHybi(HybiPacketDecoder::OPCODE_PONG, &decodedMessage[0], decodedMessage.size());
627 break;
628 case HybiPacketDecoder::NoMessage:
629 done = true;
630 break;
631 case HybiPacketDecoder::Close:
632 LS_DEBUG(_logger, "Received WebSocket close");
633 closeInternal();
634 return;
635 }
636 }
637 if (decoder.numBytesDecoded() != 0) {
638 _inBuf.erase(_inBuf.begin(), _inBuf.begin() + decoder.numBytesDecoded());
639 }
640 if (_inBuf.size() > MaxWebsocketMessageSize) {
641 LS_WARNING(_logger, "WebSocket message too long");
642 closeInternal();
643 }
644}
645
646void Connection::handleWebSocketTextMessage(const char* message) {
647 LS_DEBUG(_logger, "Got text web socket message: '" << message << "'");
648 if (_webSocketHandler) {
649 _webSocketHandler->onData(this, message);
650 }
651}
652
653void Connection::handleWebSocketBinaryMessage(const std::vector<uint8_t>& message) {
654 LS_DEBUG(_logger, "Got binary web socket message (size: " << message.size() << ")");
655 if (_webSocketHandler) {
656 _webSocketHandler->onData(this, &message[0], message.size());
657 }
658}
659
660bool Connection::sendError(ResponseCode errorCode, const std::string& body) {
661 assert(_state != HANDLING_HIXIE_WEBSOCKET);
662 auto errorNumber = static_cast<int>(errorCode);
663 auto message = ::name(errorCode);
664 bufferResponseAndCommonHeaders(errorCode);
665 auto errorContent = findEmbeddedContent("/_error.html");
666 std::string document;
667 if (errorContent) {
668 document.assign(errorContent->data, errorContent->data + errorContent->length);
669 replace(document, "%%ERRORCODE%%", toString(errorNumber));
670 replace(document, "%%MESSAGE%%", message);
671 replace(document, "%%BODY%%", body);
672 } else {
673 std::stringstream documentStr;
674 documentStr << "<html><head><title>" << errorNumber << " - " << message << "</title></head>"
675 << "<body><h1>" << errorNumber << " - " << message << "</h1>"
676 << "<div>" << body << "</div><hr/><div><i>Powered by seasocks</i></div></body></html>";
677 document = documentStr.str();
678 }
679 bufferLine("Content-Length: " + toString(document.length()));
680 bufferLine("Connection: close");
681 bufferLine("");
682 bufferLine(document);
683 if (!flush()) {
684 return false;
685 }
686 closeWhenEmpty();
687 return true;
688}
689
690bool Connection::sendUnsupportedError(const std::string& reason) {
691 return sendError(ResponseCode::NotImplemented, reason);
692}
693
694bool Connection::send404() {
695 auto path = getRequestUri();
696 auto embedded = findEmbeddedContent(path);
697 if (embedded) {
698 return sendData(getContentType(path), embedded->data, embedded->length);
699 } else if (strcmp(path.c_str(), "/_livestats.js") == 0) {
700 auto stats = _server.getStatsDocument();
701 return sendData("text/javascript", stats.c_str(), stats.length());
702 } else {
703 return sendError(ResponseCode::NotFound, "Unable to find resource for: " + path);
704 }
705}
706
707bool Connection::sendBadRequest(const std::string& reason) {
708 return sendError(ResponseCode::BadRequest, reason);
709}
710
711bool Connection::sendISE(const std::string& error) {
712 return sendError(ResponseCode::InternalServerError, error);
713}
714
715bool Connection::processHeaders(uint8_t* first, uint8_t* last) {
716 // Ideally we'd copy off [first, last] now into a header structure here.
717 // Be careful about lifetimes though and multiple requests coming in, should
718 // we ever support HTTP pipelining and/or long-lived requests.
719 char* requestLine = extractLine(first, last);
720 assert(requestLine != NULL);
721
722 LS_ACCESS(_logger, "Request: " << requestLine);
723
724 const char* verbText = shift(requestLine);
725 if (!verbText) {
726 return sendBadRequest("Malformed request line");
727 }
728 auto verb = Request::verb(verbText);
729 if (verb == Request::Verb::Invalid) {
730 return sendBadRequest("Malformed request line");
731 }
732 const char* requestUri = shift(requestLine);
733 if (requestUri == NULL) {
734 return sendBadRequest("Malformed request line");
735 }
736
737 const char* httpVersion = shift(requestLine);
738 if (httpVersion == NULL) {
739 return sendBadRequest("Malformed request line");
740 }
741 if (strcmp(httpVersion, "HTTP/1.1") != 0) {
742 return sendUnsupportedError("Unsupported HTTP version");
743 }
744 if (*requestLine != 0) {
745 return sendBadRequest("Trailing crap after http version");
746 }
747
748 HeaderMap headers(31);
749 while (first < last) {
750 char* colonPos = NULL;
751 char* headerLine = extractLine(first, last, &colonPos);
752 assert(headerLine != NULL);
753 if (colonPos == NULL) {
754 return sendBadRequest("Malformed header");
755 }
756 *colonPos = 0;
757 const char* key = headerLine;
758 const char* value = skipWhitespace(colonPos + 1);
759 LS_DEBUG(_logger, "Key: " << key << " || " << value);
760#if HAVE_UNORDERED_MAP_EMPLACE
761 headers.emplace(key, value);
762#else
763 headers.insert(std::make_pair(key, value));
764#endif
765 }
766
767 if (headers.count("Connection") && headers.count("Upgrade")
768 && hasConnectionType(headers["Connection"], "Upgrade")
769 && caseInsensitiveSame(headers["Upgrade"], "websocket")) {
770 LS_INFO(_logger, "Websocket request for " << requestUri << "'");
771 if (verb != Request::Verb::Get) {
772 return sendBadRequest("Non-GET WebSocket request");
773 }
774 _webSocketHandler = _server.getWebSocketHandler(requestUri);
775 if (!_webSocketHandler) {
776 LS_WARNING(_logger, "Couldn't find WebSocket end point for '" << requestUri << "'");
777 return send404();
778 }
779 verb = Request::Verb::WebSocket;
780 }
781
782 _request.reset(new PageRequest(_address, requestUri, verb, std::move(headers)));
783
784 const EmbeddedContent *embedded = findEmbeddedContent(requestUri);
785 if (verb == Request::Verb::Get && embedded) {
786 // MRG: one day, this could be a request handler.
787 return sendData(getContentType(requestUri), embedded->data, embedded->length);
788 }
789
790 if (_request->contentLength() > MaxBufferSize) {
791 return sendBadRequest("Content length too long");
792 }
793 if (_request->contentLength() == 0) {
794 return handlePageRequest();
795 }
796 _state = BUFFERING_POST_DATA;
797 return true;
798}
799
800bool Connection::handlePageRequest() {
801 std::shared_ptr<Response> response;
802 try {
803 response = _server.handle(*_request);
804 } catch (const std::exception& e) {
805 LS_ERROR(_logger, "page error: " << e.what());
806 return sendISE(e.what());
807 } catch (...) {
808 LS_ERROR(_logger, "page error: (unknown)");
809 return sendISE("(unknown)");
810 }
811 auto uri = _request->getRequestUri();
812 if (!response && _request->verb() == Request::Verb::WebSocket) {
813 _webSocketHandler = _server.getWebSocketHandler(uri.c_str());
814 auto webSocketVersion = atoi(_request->getHeader("Sec-WebSocket-Version").c_str());
815 if (!_webSocketHandler) {
816 LS_WARNING(_logger, "Couldn't find WebSocket end point for '" << uri << "'");
817 return send404();
818 }
819 if (webSocketVersion == 0) {
820 // Hixie
821 _state = READING_WEBSOCKET_KEY3;
822 return true;
823 }
824 auto hybiKey = _request->getHeader("Sec-WebSocket-Key");
825 return handleHybiHandshake(webSocketVersion, hybiKey);
826 }
827 return sendResponse(response);
828}
829
830bool Connection::sendResponse(std::shared_ptr<Response> response) {
831 const auto requestUri = _request->getRequestUri();
832 if (response == Response::unhandled()) {
833 return sendStaticData();
834 }
835 if (response->responseCode() == ResponseCode::NotFound) {
836 // TODO: better here; we use this purely to serve our own embedded content.
837 return send404();
838 } else if (!isOk(response->responseCode())) {
839 return sendError(response->responseCode(), response->payload());
840 }
841
842 bufferResponseAndCommonHeaders(response->responseCode());
843 bufferLine("Content-Length: " + toString(response->payloadSize()));
844 bufferLine("Content-Type: " + response->contentType());
845 if (response->keepConnectionAlive()) {
846 bufferLine("Connection: keep-alive");
847 } else {
848 bufferLine("Connection: close");
849 }
850 bufferLine("Last-Modified: " + now());
851 bufferLine("Cache-Control: no-store");
852 bufferLine("Pragma: no-cache");
853 bufferLine("Expires: " + now());
854 auto headers = response->getAdditionalHeaders();
855 for (auto it = headers.begin(); it != headers.end(); ++it) {
856 bufferLine(it->first + ": " + it->second);
857 }
858 bufferLine("");
859
860 if (!write(response->payload(), response->payloadSize(), true)) {
861 return false;
862 }
863 if (!response->keepConnectionAlive()) {
864 closeWhenEmpty();
865 }
866 return true;
867}
868
869bool Connection::handleHybiHandshake(
870 int webSocketVersion,
871 const std::string& webSocketKey) {
872 if (webSocketVersion != 8 && webSocketVersion != 13) {
873 return sendBadRequest("Invalid websocket version");
874 }
875 LS_DEBUG(_logger, "Got a hybi-8 websocket with key=" << webSocketKey);
876
877 LS_DEBUG(_logger, "Attempting websocket upgrade");
878
879 bufferResponseAndCommonHeaders(ResponseCode::WebSocketProtocolHandshake);
880 bufferLine("Upgrade: websocket");
881 bufferLine("Connection: Upgrade");
882 bufferLine("Sec-WebSocket-Accept: " + getAcceptKey(webSocketKey));
883 bufferLine("");
884 flush();
885
886 if (_webSocketHandler) {
887 _webSocketHandler->onConnect(this);
888 }
889 _state = HANDLING_HYBI_WEBSOCKET;
890 return true;
891}
892
893bool Connection::parseRange(const std::string& rangeStr, Range& range) const {
894 size_t minusPos = rangeStr.find('-');
895 if (minusPos == std::string::npos) {
896 LS_WARNING(_logger, "Bad range: '" << rangeStr << "'");
897 return false;
898 }
899 if (minusPos == 0) {
900 // A range like "-500" means 500 bytes from end of file to end.
901 range.start = atoi(rangeStr.c_str());
902 range.end = std::numeric_limits<long>::max();
903 return true;
904 } else {
905 range.start = atoi(rangeStr.substr(0, minusPos).c_str());
906 if (minusPos == rangeStr.size()-1) {
907 range.end = std::numeric_limits<long>::max();
908 } else {
909 range.end = atoi(rangeStr.substr(minusPos + 1).c_str());
910 }
911 return true;
912 }
913 return false;
914}
915
916bool Connection::parseRanges(const std::string& range, std::list<Range>& ranges) const {
917 static const std::string expectedPrefix = "bytes=";
918 if (range.length() < expectedPrefix.length() || range.substr(0, expectedPrefix.length()) != expectedPrefix) {
919 LS_WARNING(_logger, "Bad range request prefix: '" << range << "'");
920 return false;
921 }
922 auto rangesText = split(range.substr(expectedPrefix.length()), ',');
923 for (auto it = rangesText.begin(); it != rangesText.end(); ++it) {
924 Range r;
925 if (!parseRange(*it, r)) {
926 return false;
927 }
928 ranges.push_back(r);
929 }
930 return !ranges.empty();
931}
932
933// Sends HTTP 200 or 206, content-length, and range info as needed. Returns the actual file ranges
934// needing sending.
935std::list<Connection::Range> Connection::processRangesForStaticData(const std::list<Range>& origRanges, long fileSize) {
936 if (origRanges.empty()) {
937 // Easy case: a non-range request.
938 bufferResponseAndCommonHeaders(ResponseCode::Ok);
939 bufferLine("Content-Length: " + toString(fileSize));
940 return { Range { 0, fileSize - 1 } };
941 }
942
943 // Partial content request.
944 bufferResponseAndCommonHeaders(ResponseCode::PartialContent);
945 int contentLength = 0;
946 std::ostringstream rangeLine;
947 rangeLine << "Content-Range: bytes ";
948 std::list<Range> sendRanges;
949 for (auto rangeIter = origRanges.cbegin(); rangeIter != origRanges.cend(); ++rangeIter) {
950 Range actualRange = *rangeIter;
951 if (actualRange.start < 0) {
952 actualRange.start += fileSize;
953 }
954 if (actualRange.start >= fileSize) {
955 actualRange.start = fileSize - 1;
956 }
957 if (actualRange.end >= fileSize) {
958 actualRange.end = fileSize - 1;
959 }
960 contentLength += actualRange.length();
961 sendRanges.push_back(actualRange);
962 rangeLine << actualRange.start << "-" << actualRange.end;
963 }
964 rangeLine << "/" << fileSize;
965 bufferLine(rangeLine.str());
966 bufferLine("Content-Length: " + toString(contentLength));
967 return sendRanges;
968}
969
970bool Connection::sendStaticData() {
971 // TODO: fold this into the handler way of doing things.
972 std::string path = _server.getStaticPath() + getRequestUri();
973 auto rangeHeader = getHeader("Range");
974 // Trim any trailing queries.
975 size_t queryPos = path.find('?');
976 if (queryPos != path.npos) {
977 path.resize(queryPos);
978 }
979 if (*path.rbegin() == '/') {
980 path += "index.html";
981 }
982 RaiiFd input(path.c_str());
983 struct stat stat;
984 if (!input.ok() || ::fstat(input, &stat) == -1) {
985 return send404();
986 }
987 std::list<Range> ranges;
988 if (!rangeHeader.empty() && !parseRanges(rangeHeader, ranges)) {
989 return sendBadRequest("Bad range header");
990 }
991 ranges = processRangesForStaticData(ranges, stat.st_size);
992 bufferLine("Content-Type: " + getContentType(path));
993 bufferLine("Connection: keep-alive");
994 bufferLine("Accept-Ranges: bytes");
995 bufferLine("Last-Modified: " + webtime(stat.st_mtime));
996 if (!isCacheable(path)) {
997 bufferLine("Cache-Control: no-store");
998 bufferLine("Pragma: no-cache");
999 bufferLine("Expires: " + now());
1000 }
1001 bufferLine("");
1002 if (!flush()) {
1003 return false;
1004 }
1005
1006 for (auto rangeIter = ranges.cbegin(); rangeIter != ranges.cend(); ++rangeIter) {
1007 if (::lseek(input, rangeIter->start, SEEK_SET) == -1) {
1008 // We've (probably) already sent data.
1009 return false;
1010 }
1011 auto bytesLeft = rangeIter->length();
1012 while (bytesLeft) {
1013 char buf[ReadWriteBufferSize];
1014 auto bytesRead = ::read(input, buf, std::min(sizeof(buf), bytesLeft));
1015 if (bytesRead <= 0) {
1016 const static std::string unexpectedEof("Unexpected EOF");
1017 LS_ERROR(_logger, "Error reading file: " << (bytesRead == 0 ? unexpectedEof : getLastError()));
1018 // We can't send an error document as we've sent the header.
1019 return false;
1020 }
1021 bytesLeft -= bytesRead;
1022 if (!write(buf, bytesRead, true)) {
1023 return false;
1024 }
1025 }
1026 }
1027 return true;
1028}
1029
1030bool Connection::sendData(const std::string& type, const char* start, size_t size) {
1031 bufferResponseAndCommonHeaders(ResponseCode::Ok);
1032 bufferLine("Content-Type: " + type);
1033 bufferLine("Content-Length: " + toString(size));
1034 bufferLine("Connection: keep-alive");
1035 bufferLine("");
1036 bool result = write(start, size, true);
1037 return result;
1038}
1039
1040void Connection::bufferResponseAndCommonHeaders(ResponseCode code) {
1041 auto responseCodeInt = static_cast<int>(code);
1042 auto responseCodeName = ::name(code);
1043 auto response = std::string("HTTP/1.1 " + toString(responseCodeInt) + " " + responseCodeName);
1044 LS_ACCESS(_logger, "Response: " << response);
1045 bufferLine(response);
1046 bufferLine("Server: " SEASOCKS_VERSION_STRING);
1047 bufferLine("Date: " + now());
1048 bufferLine("Access-Control-Allow-Origin: *");
1049}
1050
1051void Connection::setLinger() {
1052 if (_fd == -1) {
1053 return;
1054 }
1055 const int secondsToLinger = 1;
1056 struct linger linger = { true, secondsToLinger };
1057 if (::setsockopt(_fd, SOL_SOCKET, SO_LINGER, &linger, sizeof(linger)) == -1) {
1058 LS_INFO(_logger, "Unable to set linger on socket");
1059 }
1060}
1061
1062bool Connection::hasHeader(const std::string& header) const {
1063 return _request ? _request->hasHeader(header) : false;
1064}
1065
1066std::string Connection::getHeader(const std::string& header) const {
1067 return _request ? _request->getHeader(header) : "";
1068}
1069
1070const std::string& Connection::getRequestUri() const {
1071 static const std::string empty;
1072 return _request ? _request->getRequestUri() : empty;
1073}
1074
1075} // seasocks