James Kuszmaul | b13e13f | 2023-11-22 20:44:04 -0800 | [diff] [blame^] | 1 | // Copyright (c) FIRST and other WPILib contributors. |
| 2 | // Open Source Software; you can modify and/or share it under the terms of |
| 3 | // the WPILib BSD license file in the root directory of this project. |
| 4 | |
| 5 | #pragma once |
| 6 | |
| 7 | #include <functional> |
| 8 | #include <memory> |
| 9 | #include <utility> |
| 10 | |
| 11 | #include <wpi/SmallVector.h> |
| 12 | #include <wpi/SpanExtras.h> |
| 13 | |
| 14 | #include "WebSocketDebug.h" |
| 15 | #include "wpinet/WebSocket.h" |
| 16 | #include "wpinet/uv/Buffer.h" |
| 17 | |
| 18 | namespace wpi::detail { |
| 19 | |
| 20 | class SerializedFrames { |
| 21 | public: |
| 22 | SerializedFrames() = default; |
| 23 | SerializedFrames(const SerializedFrames&) = delete; |
| 24 | SerializedFrames& operator=(const SerializedFrames&) = delete; |
| 25 | ~SerializedFrames() { ReleaseBufs(); } |
| 26 | |
| 27 | size_t AddFrame(const WebSocket::Frame& frame, bool server) { |
| 28 | if (server) { |
| 29 | return AddServerFrame(frame); |
| 30 | } else { |
| 31 | return AddClientFrame(frame); |
| 32 | } |
| 33 | } |
| 34 | |
| 35 | size_t AddClientFrame(const WebSocket::Frame& frame); |
| 36 | size_t AddServerFrame(const WebSocket::Frame& frame); |
| 37 | |
| 38 | void ReleaseBufs() { |
| 39 | for (auto&& buf : m_allocBufs) { |
| 40 | buf.Deallocate(); |
| 41 | } |
| 42 | m_allocBufs.clear(); |
| 43 | } |
| 44 | |
| 45 | SmallVector<uv::Buffer, 4> m_allocBufs; |
| 46 | SmallVector<uv::Buffer, 4> m_bufs; |
| 47 | size_t m_allocBufPos = 0; |
| 48 | }; |
| 49 | |
| 50 | class WebSocketWriteReqBase { |
| 51 | public: |
| 52 | template <typename Stream, typename Req> |
| 53 | int Continue(Stream& stream, std::shared_ptr<Req> req); |
| 54 | |
| 55 | SmallVector<uv::Buffer, 4> m_userBufs; |
| 56 | SerializedFrames m_frames; |
| 57 | SmallVector<int, 0> m_continueFrameOffs; |
| 58 | size_t m_continueBufPos = 0; |
| 59 | size_t m_continueFramePos = 0; |
| 60 | }; |
| 61 | |
| 62 | template <typename Stream, typename Req> |
| 63 | int WebSocketWriteReqBase::Continue(Stream& stream, std::shared_ptr<Req> req) { |
| 64 | if (m_continueBufPos >= m_frames.m_bufs.size()) { |
| 65 | return 0; // nothing more to send |
| 66 | } |
| 67 | |
| 68 | // try writing everything remaining |
| 69 | std::span bufs = std::span{m_frames.m_bufs}.subspan(m_continueBufPos); |
| 70 | int numBytes = 0; |
| 71 | for (auto&& buf : bufs) { |
| 72 | numBytes += buf.len; |
| 73 | } |
| 74 | |
| 75 | int sentBytes = stream.TryWrite(bufs); |
| 76 | WS_DEBUG("TryWrite({}) -> {} (expected {})\n", bufs.size(), sentBytes, |
| 77 | numBytes); |
| 78 | if (sentBytes < 0) { |
| 79 | return sentBytes; // error |
| 80 | } |
| 81 | |
| 82 | if (sentBytes == numBytes) { |
| 83 | m_continueBufPos = m_frames.m_bufs.size(); |
| 84 | return 0; // nothing more to send |
| 85 | } |
| 86 | |
| 87 | // we didn't send everything; deal with the leftovers |
| 88 | |
| 89 | // figure out what the last (partially) frame sent actually was |
| 90 | auto offIt = m_continueFrameOffs.begin() + m_continueFramePos; |
| 91 | auto offEnd = m_continueFrameOffs.end(); |
| 92 | while (offIt != offEnd && *offIt < sentBytes) { |
| 93 | ++offIt; |
| 94 | } |
| 95 | assert(offIt != offEnd); |
| 96 | |
| 97 | // build a list of buffers to send as a normal write: |
| 98 | SmallVector<uv::Buffer, 4> writeBufs; |
| 99 | auto bufIt = bufs.begin(); |
| 100 | auto bufEnd = bufs.end(); |
| 101 | |
| 102 | // start with the remaining portion of the last buffer actually sent |
| 103 | int pos = 0; |
| 104 | while (bufIt != bufEnd && pos < sentBytes) { |
| 105 | pos += (bufIt++)->len; |
| 106 | } |
| 107 | if (bufIt != bufs.begin() && pos != sentBytes) { |
| 108 | writeBufs.emplace_back( |
| 109 | wpi::take_back((bufIt - 1)->bytes(), pos - sentBytes)); |
| 110 | } |
| 111 | |
| 112 | // continue through the last buffer of the last partial frame |
| 113 | while (bufIt != bufEnd && offIt != offEnd && pos < *offIt) { |
| 114 | pos += bufIt->len; |
| 115 | writeBufs.emplace_back(*bufIt++); |
| 116 | } |
| 117 | if (offIt != offEnd) { |
| 118 | ++offIt; |
| 119 | } |
| 120 | |
| 121 | // if writeBufs is still empty, write all of the next frame |
| 122 | if (writeBufs.empty()) { |
| 123 | while (bufIt != bufEnd && offIt != offEnd && pos < *offIt) { |
| 124 | pos += bufIt->len; |
| 125 | writeBufs.emplace_back(*bufIt++); |
| 126 | } |
| 127 | if (offIt != offEnd) { |
| 128 | ++offIt; |
| 129 | } |
| 130 | } |
| 131 | |
| 132 | m_continueFramePos = offIt - m_continueFrameOffs.begin(); |
| 133 | m_continueBufPos += bufIt - bufs.begin(); |
| 134 | |
| 135 | if (writeBufs.empty()) { |
| 136 | WS_DEBUG("Write Done\n"); |
| 137 | return 0; |
| 138 | } |
| 139 | WS_DEBUG("Write({})\n", writeBufs.size()); |
| 140 | stream.Write(writeBufs, req); |
| 141 | return 1; |
| 142 | } |
| 143 | |
| 144 | template <typename MakeReq, typename Stream> |
| 145 | std::span<const WebSocket::Frame> TrySendFrames( |
| 146 | bool server, Stream& stream, std::span<const WebSocket::Frame> frames, |
| 147 | MakeReq&& makeReq, |
| 148 | std::function<void(std::span<uv::Buffer>, uv::Error)> callback) { |
| 149 | WS_DEBUG("TrySendFrames({})\n", frames.size()); |
| 150 | auto frameIt = frames.begin(); |
| 151 | auto frameEnd = frames.end(); |
| 152 | while (frameIt != frameEnd) { |
| 153 | auto frameStart = frameIt; |
| 154 | |
| 155 | // build buffers to send |
| 156 | SerializedFrames sendFrames; |
| 157 | SmallVector<int, 32> frameOffs; |
| 158 | int numBytes = 0; |
| 159 | while (frameIt != frameEnd) { |
| 160 | frameOffs.emplace_back(numBytes); |
| 161 | numBytes += sendFrames.AddFrame(*frameIt++, server); |
| 162 | if ((server && (numBytes >= 65536 || frameOffs.size() > 32)) || |
| 163 | (!server && numBytes >= 8192)) { |
| 164 | // don't waste too much memory or effort on header generation or masking |
| 165 | break; |
| 166 | } |
| 167 | } |
| 168 | |
| 169 | // try to send |
| 170 | int sentBytes = stream.TryWrite(sendFrames.m_bufs); |
| 171 | WS_DEBUG("TryWrite({}) -> {} (expected {})\n", sendFrames.m_bufs.size(), |
| 172 | sentBytes, numBytes); |
| 173 | |
| 174 | if (sentBytes == 0) { |
| 175 | // we haven't started a frame yet; clean up any bufs that have actually |
| 176 | // sent, and return unsent frames |
| 177 | SmallVector<uv::Buffer, 4> bufs; |
| 178 | for (auto it = frames.begin(); it != frameStart; ++it) { |
| 179 | bufs.append(it->data.begin(), it->data.end()); |
| 180 | } |
| 181 | callback(bufs, {}); |
| 182 | #ifdef __clang__ |
| 183 | // work around clang bug |
| 184 | return {frames.data() + (frameStart - frames.begin()), |
| 185 | frames.data() + (frameEnd - frames.begin())}; |
| 186 | #else |
| 187 | return {frameStart, frameEnd}; |
| 188 | #endif |
| 189 | } else if (sentBytes < 0) { |
| 190 | // error |
| 191 | SmallVector<uv::Buffer, 4> bufs; |
| 192 | for (auto&& frame : frames) { |
| 193 | bufs.append(frame.data.begin(), frame.data.end()); |
| 194 | } |
| 195 | callback(bufs, uv::Error{sentBytes}); |
| 196 | return frames; |
| 197 | } else if (sentBytes != numBytes) { |
| 198 | // we didn't send everything; deal with the leftovers |
| 199 | |
| 200 | // figure out what the last (partially) frame sent actually was |
| 201 | auto offIt = frameOffs.begin(); |
| 202 | auto offEnd = frameOffs.end(); |
| 203 | bool isFin = true; |
| 204 | while (offIt != offEnd && *offIt < sentBytes) { |
| 205 | ++offIt; |
| 206 | isFin = (frameStart->opcode & WebSocket::kFlagFin) != 0; |
| 207 | ++frameStart; |
| 208 | } |
| 209 | |
| 210 | if (offIt != offEnd && *offIt == sentBytes && isFin) { |
| 211 | // we finished at a normal FIN frame boundary; no need for a Write() |
| 212 | SmallVector<uv::Buffer, 4> bufs; |
| 213 | for (auto it = frames.begin(); it != frameStart; ++it) { |
| 214 | bufs.append(it->data.begin(), it->data.end()); |
| 215 | } |
| 216 | callback(bufs, {}); |
| 217 | #ifdef __clang__ |
| 218 | // work around clang bug |
| 219 | return {frames.data() + (frameStart - frames.begin()), |
| 220 | frames.data() + (frameEnd - frames.begin())}; |
| 221 | #else |
| 222 | return {frameStart, frameEnd}; |
| 223 | #endif |
| 224 | } |
| 225 | |
| 226 | // build a list of buffers to send as a normal write: |
| 227 | SmallVector<uv::Buffer, 4> writeBufs; |
| 228 | auto bufIt = sendFrames.m_bufs.begin(); |
| 229 | auto bufEnd = sendFrames.m_bufs.end(); |
| 230 | |
| 231 | // start with the remaining portion of the last buffer actually sent |
| 232 | int pos = 0; |
| 233 | while (bufIt != bufEnd && pos < sentBytes) { |
| 234 | pos += (bufIt++)->len; |
| 235 | } |
| 236 | if (bufIt != sendFrames.m_bufs.begin() && pos != sentBytes) { |
| 237 | writeBufs.emplace_back( |
| 238 | wpi::take_back((bufIt - 1)->bytes(), pos - sentBytes)); |
| 239 | } |
| 240 | |
| 241 | // continue through the last buffer of the last partial frame |
| 242 | while (bufIt != bufEnd && offIt != offEnd && pos < *offIt) { |
| 243 | pos += bufIt->len; |
| 244 | writeBufs.emplace_back(*bufIt++); |
| 245 | } |
| 246 | if (offIt != offEnd) { |
| 247 | ++offIt; |
| 248 | } |
| 249 | |
| 250 | // move allocated buffers into request |
| 251 | auto req = makeReq(std::move(callback)); |
| 252 | req->m_frames.m_allocBufs = std::move(sendFrames.m_allocBufs); |
| 253 | req->m_frames.m_allocBufPos = sendFrames.m_allocBufPos; |
| 254 | |
| 255 | // if partial frame was non-FIN, put any additional non-FIN frames into |
| 256 | // continuation (so the caller isn't responsible for doing this) |
| 257 | size_t continuePos = 0; |
| 258 | while (frameStart != frameEnd && !isFin) { |
| 259 | if (offIt != offEnd) { |
| 260 | // we already generated the wire buffers for this frame, use them |
| 261 | while (pos < *offIt && bufIt != bufEnd) { |
| 262 | pos += bufIt->len; |
| 263 | continuePos += bufIt->len; |
| 264 | req->m_frames.m_bufs.emplace_back(*bufIt++); |
| 265 | } |
| 266 | ++offIt; |
| 267 | } else { |
| 268 | // WS_DEBUG("generating frame for continuation {} {}\n", |
| 269 | // frameStart->opcode, frameStart->data.size()); |
| 270 | // need to generate and add this frame |
| 271 | continuePos += req->m_frames.AddFrame(*frameStart, server); |
| 272 | } |
| 273 | req->m_continueFrameOffs.emplace_back(continuePos); |
| 274 | isFin = (frameStart->opcode & WebSocket::kFlagFin) != 0; |
| 275 | ++frameStart; |
| 276 | } |
| 277 | |
| 278 | // only the non-returned user buffers are added to the request |
| 279 | for (auto it = frames.begin(); it != frameStart; ++it) { |
| 280 | req->m_userBufs.append(it->data.begin(), it->data.end()); |
| 281 | } |
| 282 | |
| 283 | WS_DEBUG("Write({})\n", writeBufs.size()); |
| 284 | stream.Write(writeBufs, req); |
| 285 | #ifdef __clang__ |
| 286 | // work around clang bug |
| 287 | return {frames.data() + (frameStart - frames.begin()), |
| 288 | frames.data() + (frameEnd - frames.begin())}; |
| 289 | #else |
| 290 | return {frameStart, frameEnd}; |
| 291 | #endif |
| 292 | } |
| 293 | } |
| 294 | |
| 295 | // nothing left to send |
| 296 | SmallVector<uv::Buffer, 4> bufs; |
| 297 | for (auto&& frame : frames) { |
| 298 | bufs.append(frame.data.begin(), frame.data.end()); |
| 299 | } |
| 300 | callback(bufs, {}); |
| 301 | return {}; |
| 302 | } |
| 303 | |
| 304 | } // namespace wpi::detail |