blob: 264b8f594dbbd4940f36ec2e044c3b10b4ec0a7c [file] [log] [blame]
James Kuszmaulb13e13f2023-11-22 20:44:04 -08001// Copyright (c) FIRST and other WPILib contributors.
2// Open Source Software; you can modify and/or share it under the terms of
3// the WPILib BSD license file in the root directory of this project.
4
5#pragma once
6
7#include <functional>
8#include <memory>
9#include <utility>
10
11#include <wpi/SmallVector.h>
12#include <wpi/SpanExtras.h>
13
14#include "WebSocketDebug.h"
15#include "wpinet/WebSocket.h"
16#include "wpinet/uv/Buffer.h"
17
18namespace wpi::detail {
19
20class SerializedFrames {
21 public:
22 SerializedFrames() = default;
23 SerializedFrames(const SerializedFrames&) = delete;
24 SerializedFrames& operator=(const SerializedFrames&) = delete;
25 ~SerializedFrames() { ReleaseBufs(); }
26
27 size_t AddFrame(const WebSocket::Frame& frame, bool server) {
28 if (server) {
29 return AddServerFrame(frame);
30 } else {
31 return AddClientFrame(frame);
32 }
33 }
34
35 size_t AddClientFrame(const WebSocket::Frame& frame);
36 size_t AddServerFrame(const WebSocket::Frame& frame);
37
38 void ReleaseBufs() {
39 for (auto&& buf : m_allocBufs) {
40 buf.Deallocate();
41 }
42 m_allocBufs.clear();
43 }
44
45 SmallVector<uv::Buffer, 4> m_allocBufs;
46 SmallVector<uv::Buffer, 4> m_bufs;
47 size_t m_allocBufPos = 0;
48};
49
50class WebSocketWriteReqBase {
51 public:
52 template <typename Stream, typename Req>
53 int Continue(Stream& stream, std::shared_ptr<Req> req);
54
55 SmallVector<uv::Buffer, 4> m_userBufs;
56 SerializedFrames m_frames;
57 SmallVector<int, 0> m_continueFrameOffs;
58 size_t m_continueBufPos = 0;
59 size_t m_continueFramePos = 0;
60};
61
62template <typename Stream, typename Req>
63int WebSocketWriteReqBase::Continue(Stream& stream, std::shared_ptr<Req> req) {
64 if (m_continueBufPos >= m_frames.m_bufs.size()) {
65 return 0; // nothing more to send
66 }
67
68 // try writing everything remaining
69 std::span bufs = std::span{m_frames.m_bufs}.subspan(m_continueBufPos);
70 int numBytes = 0;
71 for (auto&& buf : bufs) {
72 numBytes += buf.len;
73 }
74
75 int sentBytes = stream.TryWrite(bufs);
76 WS_DEBUG("TryWrite({}) -> {} (expected {})\n", bufs.size(), sentBytes,
77 numBytes);
78 if (sentBytes < 0) {
79 return sentBytes; // error
80 }
81
82 if (sentBytes == numBytes) {
83 m_continueBufPos = m_frames.m_bufs.size();
84 return 0; // nothing more to send
85 }
86
87 // we didn't send everything; deal with the leftovers
88
89 // figure out what the last (partially) frame sent actually was
90 auto offIt = m_continueFrameOffs.begin() + m_continueFramePos;
91 auto offEnd = m_continueFrameOffs.end();
92 while (offIt != offEnd && *offIt < sentBytes) {
93 ++offIt;
94 }
95 assert(offIt != offEnd);
96
97 // build a list of buffers to send as a normal write:
98 SmallVector<uv::Buffer, 4> writeBufs;
99 auto bufIt = bufs.begin();
100 auto bufEnd = bufs.end();
101
102 // start with the remaining portion of the last buffer actually sent
103 int pos = 0;
104 while (bufIt != bufEnd && pos < sentBytes) {
105 pos += (bufIt++)->len;
106 }
107 if (bufIt != bufs.begin() && pos != sentBytes) {
108 writeBufs.emplace_back(
109 wpi::take_back((bufIt - 1)->bytes(), pos - sentBytes));
110 }
111
112 // continue through the last buffer of the last partial frame
113 while (bufIt != bufEnd && offIt != offEnd && pos < *offIt) {
114 pos += bufIt->len;
115 writeBufs.emplace_back(*bufIt++);
116 }
117 if (offIt != offEnd) {
118 ++offIt;
119 }
120
121 // if writeBufs is still empty, write all of the next frame
122 if (writeBufs.empty()) {
123 while (bufIt != bufEnd && offIt != offEnd && pos < *offIt) {
124 pos += bufIt->len;
125 writeBufs.emplace_back(*bufIt++);
126 }
127 if (offIt != offEnd) {
128 ++offIt;
129 }
130 }
131
132 m_continueFramePos = offIt - m_continueFrameOffs.begin();
133 m_continueBufPos += bufIt - bufs.begin();
134
135 if (writeBufs.empty()) {
136 WS_DEBUG("Write Done\n");
137 return 0;
138 }
139 WS_DEBUG("Write({})\n", writeBufs.size());
140 stream.Write(writeBufs, req);
141 return 1;
142}
143
144template <typename MakeReq, typename Stream>
145std::span<const WebSocket::Frame> TrySendFrames(
146 bool server, Stream& stream, std::span<const WebSocket::Frame> frames,
147 MakeReq&& makeReq,
148 std::function<void(std::span<uv::Buffer>, uv::Error)> callback) {
149 WS_DEBUG("TrySendFrames({})\n", frames.size());
150 auto frameIt = frames.begin();
151 auto frameEnd = frames.end();
152 while (frameIt != frameEnd) {
153 auto frameStart = frameIt;
154
155 // build buffers to send
156 SerializedFrames sendFrames;
157 SmallVector<int, 32> frameOffs;
158 int numBytes = 0;
159 while (frameIt != frameEnd) {
160 frameOffs.emplace_back(numBytes);
161 numBytes += sendFrames.AddFrame(*frameIt++, server);
162 if ((server && (numBytes >= 65536 || frameOffs.size() > 32)) ||
163 (!server && numBytes >= 8192)) {
164 // don't waste too much memory or effort on header generation or masking
165 break;
166 }
167 }
168
169 // try to send
170 int sentBytes = stream.TryWrite(sendFrames.m_bufs);
171 WS_DEBUG("TryWrite({}) -> {} (expected {})\n", sendFrames.m_bufs.size(),
172 sentBytes, numBytes);
173
174 if (sentBytes == 0) {
175 // we haven't started a frame yet; clean up any bufs that have actually
176 // sent, and return unsent frames
177 SmallVector<uv::Buffer, 4> bufs;
178 for (auto it = frames.begin(); it != frameStart; ++it) {
179 bufs.append(it->data.begin(), it->data.end());
180 }
181 callback(bufs, {});
182#ifdef __clang__
183 // work around clang bug
184 return {frames.data() + (frameStart - frames.begin()),
185 frames.data() + (frameEnd - frames.begin())};
186#else
187 return {frameStart, frameEnd};
188#endif
189 } else if (sentBytes < 0) {
190 // error
191 SmallVector<uv::Buffer, 4> bufs;
192 for (auto&& frame : frames) {
193 bufs.append(frame.data.begin(), frame.data.end());
194 }
195 callback(bufs, uv::Error{sentBytes});
196 return frames;
197 } else if (sentBytes != numBytes) {
198 // we didn't send everything; deal with the leftovers
199
200 // figure out what the last (partially) frame sent actually was
201 auto offIt = frameOffs.begin();
202 auto offEnd = frameOffs.end();
203 bool isFin = true;
204 while (offIt != offEnd && *offIt < sentBytes) {
205 ++offIt;
206 isFin = (frameStart->opcode & WebSocket::kFlagFin) != 0;
207 ++frameStart;
208 }
209
210 if (offIt != offEnd && *offIt == sentBytes && isFin) {
211 // we finished at a normal FIN frame boundary; no need for a Write()
212 SmallVector<uv::Buffer, 4> bufs;
213 for (auto it = frames.begin(); it != frameStart; ++it) {
214 bufs.append(it->data.begin(), it->data.end());
215 }
216 callback(bufs, {});
217#ifdef __clang__
218 // work around clang bug
219 return {frames.data() + (frameStart - frames.begin()),
220 frames.data() + (frameEnd - frames.begin())};
221#else
222 return {frameStart, frameEnd};
223#endif
224 }
225
226 // build a list of buffers to send as a normal write:
227 SmallVector<uv::Buffer, 4> writeBufs;
228 auto bufIt = sendFrames.m_bufs.begin();
229 auto bufEnd = sendFrames.m_bufs.end();
230
231 // start with the remaining portion of the last buffer actually sent
232 int pos = 0;
233 while (bufIt != bufEnd && pos < sentBytes) {
234 pos += (bufIt++)->len;
235 }
236 if (bufIt != sendFrames.m_bufs.begin() && pos != sentBytes) {
237 writeBufs.emplace_back(
238 wpi::take_back((bufIt - 1)->bytes(), pos - sentBytes));
239 }
240
241 // continue through the last buffer of the last partial frame
242 while (bufIt != bufEnd && offIt != offEnd && pos < *offIt) {
243 pos += bufIt->len;
244 writeBufs.emplace_back(*bufIt++);
245 }
246 if (offIt != offEnd) {
247 ++offIt;
248 }
249
250 // move allocated buffers into request
251 auto req = makeReq(std::move(callback));
252 req->m_frames.m_allocBufs = std::move(sendFrames.m_allocBufs);
253 req->m_frames.m_allocBufPos = sendFrames.m_allocBufPos;
254
255 // if partial frame was non-FIN, put any additional non-FIN frames into
256 // continuation (so the caller isn't responsible for doing this)
257 size_t continuePos = 0;
258 while (frameStart != frameEnd && !isFin) {
259 if (offIt != offEnd) {
260 // we already generated the wire buffers for this frame, use them
261 while (pos < *offIt && bufIt != bufEnd) {
262 pos += bufIt->len;
263 continuePos += bufIt->len;
264 req->m_frames.m_bufs.emplace_back(*bufIt++);
265 }
266 ++offIt;
267 } else {
268 // WS_DEBUG("generating frame for continuation {} {}\n",
269 // frameStart->opcode, frameStart->data.size());
270 // need to generate and add this frame
271 continuePos += req->m_frames.AddFrame(*frameStart, server);
272 }
273 req->m_continueFrameOffs.emplace_back(continuePos);
274 isFin = (frameStart->opcode & WebSocket::kFlagFin) != 0;
275 ++frameStart;
276 }
277
278 // only the non-returned user buffers are added to the request
279 for (auto it = frames.begin(); it != frameStart; ++it) {
280 req->m_userBufs.append(it->data.begin(), it->data.end());
281 }
282
283 WS_DEBUG("Write({})\n", writeBufs.size());
284 stream.Write(writeBufs, req);
285#ifdef __clang__
286 // work around clang bug
287 return {frames.data() + (frameStart - frames.begin()),
288 frames.data() + (frameEnd - frames.begin())};
289#else
290 return {frameStart, frameEnd};
291#endif
292 }
293 }
294
295 // nothing left to send
296 SmallVector<uv::Buffer, 4> bufs;
297 for (auto&& frame : frames) {
298 bufs.append(frame.data.begin(), frame.data.end());
299 }
300 callback(bufs, {});
301 return {};
302}
303
304} // namespace wpi::detail