blob: 3d19738b943349564c62a6b17561326dd25c7c08 [file] [log] [blame]
James Kuszmaulcf324122023-01-14 14:07:17 -08001// Copyright (c) FIRST and other WPILib contributors.
2// Open Source Software; you can modify and/or share it under the terms of
3// the WPILib BSD license file in the root directory of this project.
4
5#include "Downloader.h"
6
7#include <libssh/sftp.h>
8
9#ifdef _WIN32
10#include <fcntl.h>
11#include <io.h>
12#else
13#include <sys/fcntl.h>
14#endif
15
16#include <algorithm>
17#include <filesystem>
18
19#include <fmt/format.h>
20#include <glass/Storage.h>
21#include <imgui.h>
22#include <imgui_stdlib.h>
23#include <portable-file-dialogs.h>
24#include <wpi/StringExtras.h>
25#include <wpi/fs.h>
26
27#include "Sftp.h"
28
29Downloader::Downloader(glass::Storage& storage)
30 : m_serverTeam{storage.GetString("serverTeam")},
31 m_remoteDir{storage.GetString("remoteDir", "/home/lvuser")},
32 m_username{storage.GetString("username", "lvuser")},
33 m_localDir{storage.GetString("localDir")},
34 m_deleteAfter{storage.GetBool("deleteAfter", true)},
35 m_thread{[this] { ThreadMain(); }} {}
36
37Downloader::~Downloader() {
38 {
39 std::scoped_lock lock{m_mutex};
40 m_state = kExit;
41 }
42 m_cv.notify_all();
43 m_thread.join();
44}
45
46void Downloader::DisplayConnect() {
47 // IP or Team Number text box
48 ImGui::SetNextItemWidth(ImGui::GetFontSize() * 12);
49 ImGui::InputText("Team Number / Address", &m_serverTeam);
50
51 // Username/password
52 ImGui::SetNextItemWidth(ImGui::GetFontSize() * 12);
53 ImGui::InputText("Username", &m_username);
54 ImGui::SetNextItemWidth(ImGui::GetFontSize() * 12);
55 ImGui::InputText("Password", &m_password, ImGuiInputTextFlags_Password);
56
57 // Connect button
58 if (ImGui::Button("Connect")) {
59 m_state = kConnecting;
60 m_cv.notify_all();
61 }
62}
63
64void Downloader::DisplayDisconnectButton() {
65 if (ImGui::Button("Disconnect")) {
66 m_state = kDisconnecting;
67 m_cv.notify_all();
68 }
69}
70
71void Downloader::DisplayRemoteDirSelector() {
72 ImGui::SameLine();
73 if (ImGui::Button("Refresh")) {
74 m_state = kGetFiles;
75 m_cv.notify_all();
76 }
77
78 ImGui::SameLine();
79 if (ImGui::Button("Deselect All")) {
80 for (auto&& download : m_downloadList) {
81 download.enabled = false;
82 }
83 }
84
85 ImGui::SameLine();
86 if (ImGui::Button("Select All")) {
87 for (auto&& download : m_downloadList) {
88 download.enabled = true;
89 }
90 }
91
92 // Remote directory text box
93 ImGui::SetNextItemWidth(ImGui::GetFontSize() * 20);
94 if (ImGui::InputText("Remote Dir", &m_remoteDir,
95 ImGuiInputTextFlags_EnterReturnsTrue)) {
96 m_state = kGetFiles;
97 m_cv.notify_all();
98 }
99
100 // List directories
101 for (auto&& dir : m_dirList) {
102 if (ImGui::Selectable(dir.c_str())) {
103 if (dir == "..") {
104 if (wpi::ends_with(m_remoteDir, '/')) {
105 m_remoteDir.resize(m_remoteDir.size() - 1);
106 }
107 m_remoteDir = wpi::rsplit(m_remoteDir, '/').first;
108 if (m_remoteDir.empty()) {
109 m_remoteDir = "/";
110 }
111 } else {
112 if (!wpi::ends_with(m_remoteDir, '/')) {
113 m_remoteDir += '/';
114 }
115 m_remoteDir += dir;
116 }
117 m_state = kGetFiles;
118 m_cv.notify_all();
119 }
120 }
121}
122
123void Downloader::DisplayLocalDirSelector() {
124 // Local directory text / select button
125 if (ImGui::Button("Select Download Folder...")) {
126 m_localDirSelector =
127 std::make_unique<pfd::select_folder>("Select Download Folder");
128 }
129 ImGui::TextUnformatted(m_localDir.c_str());
130
131 // Delete after download (checkbox)
132 ImGui::Checkbox("Delete after download", &m_deleteAfter);
133
134 // Download button
135 if (!m_localDir.empty()) {
136 if (ImGui::Button("Download")) {
137 m_state = kDownload;
138 m_cv.notify_all();
139 }
140 }
141}
142
143size_t Downloader::DisplayFiles() {
144 // List of files (multi-select) (changes to progress bar for downloading)
145 size_t fileCount = 0;
146 if (ImGui::BeginTable(
147 "files", 3,
148 ImGuiTableFlags_Borders | ImGuiTableFlags_SizingStretchProp)) {
149 ImGui::TableSetupColumn("File");
150 ImGui::TableSetupColumn("Size");
151 ImGui::TableSetupColumn("Download");
152 ImGui::TableHeadersRow();
153 for (auto&& download : m_downloadList) {
154 if ((m_state == kDownload || m_state == kDownloadDone) &&
155 !download.enabled) {
156 continue;
157 }
158
159 ++fileCount;
160
161 ImGui::TableNextRow();
162 ImGui::TableNextColumn();
163 ImGui::TextUnformatted(download.name.c_str());
164 ImGui::TableNextColumn();
165 auto sizeText = fmt::format("{}", download.size);
166 ImGui::TextUnformatted(sizeText.c_str());
167 ImGui::TableNextColumn();
168 if (m_state == kDownload || m_state == kDownloadDone) {
169 if (!download.status.empty()) {
170 ImGui::TextUnformatted(download.status.c_str());
171 } else {
172 ImGui::ProgressBar(download.complete);
173 }
174 } else {
175 auto checkboxLabel = fmt::format("##{}", download.name);
176 ImGui::Checkbox(checkboxLabel.c_str(), &download.enabled);
177 }
178 }
179 ImGui::EndTable();
180 }
181
182 return fileCount;
183}
184
185void Downloader::Display() {
186 if (m_localDirSelector && m_localDirSelector->ready(0)) {
187 m_localDir = m_localDirSelector->result();
188 m_localDirSelector.reset();
189 }
190
191 std::scoped_lock lock{m_mutex};
192
193 if (!m_error.empty()) {
194 ImGui::TextUnformatted(m_error.c_str());
195 }
196
197 switch (m_state) {
198 case kDisconnected:
199 DisplayConnect();
200 break;
201 case kConnecting:
202 DisplayDisconnectButton();
203 ImGui::Text("Connecting to %s...", m_serverTeam.c_str());
204 break;
205 case kDisconnecting:
206 ImGui::TextUnformatted("Disconnecting...");
207 break;
208 case kConnected:
209 case kGetFiles:
210 DisplayDisconnectButton();
211 DisplayRemoteDirSelector();
212 if (DisplayFiles() > 0) {
213 DisplayLocalDirSelector();
214 }
215 break;
216 case kDownload:
217 case kDownloadDone:
218 DisplayDisconnectButton();
219 DisplayFiles();
220 if (m_state == kDownloadDone) {
221 if (ImGui::Button("Download complete!")) {
222 m_state = kGetFiles;
223 m_cv.notify_all();
224 }
225 }
226 break;
227 default:
228 break;
229 }
230}
231
232void Downloader::ThreadMain() {
233 std::unique_ptr<sftp::Session> session;
234
235 static constexpr size_t kBufSize = 32 * 1024;
236 std::unique_ptr<uint8_t[]> copyBuf = std::make_unique<uint8_t[]>(kBufSize);
237
238 std::unique_lock lock{m_mutex};
239 while (m_state != kExit) {
240 State prev = m_state;
241 m_cv.wait(lock, [&] { return m_state != prev; });
242 m_error.clear();
243 try {
244 switch (m_state) {
245 case kConnecting:
246 if (auto team = wpi::parse_integer<unsigned int>(m_serverTeam, 10)) {
247 // team number
248 session = std::make_unique<sftp::Session>(
249 fmt::format("roborio-{}-frc.local", team.value()), 22,
250 m_username, m_password);
251 } else {
252 session = std::make_unique<sftp::Session>(m_serverTeam, 22,
253 m_username, m_password);
254 }
255 lock.unlock();
256 try {
257 session->Connect();
258 } catch (...) {
259 lock.lock();
260 throw;
261 }
262 lock.lock();
263 // FALLTHROUGH
264 case kGetFiles: {
265 std::string dir = m_remoteDir;
266 std::vector<sftp::Attributes> fileList;
267 lock.unlock();
268 try {
269 fileList = session->ReadDir(dir);
270 } catch (sftp::Exception& ex) {
271 lock.lock();
272 if (ex.err == SSH_FX_OK || ex.err == SSH_FX_CONNECTION_LOST) {
273 throw;
274 }
275 m_error = ex.what();
276 m_dirList.clear();
277 m_downloadList.clear();
278 m_state = kConnected;
279 break;
280 }
281 std::sort(
282 fileList.begin(), fileList.end(),
283 [](const auto& l, const auto& r) { return l.name < r.name; });
284 lock.lock();
285
286 m_dirList.clear();
287 m_downloadList.clear();
288 for (auto&& attr : fileList) {
289 if (attr.type == SSH_FILEXFER_TYPE_DIRECTORY) {
290 if (attr.name != ".") {
291 m_dirList.emplace_back(attr.name);
292 }
293 } else if (attr.type == SSH_FILEXFER_TYPE_REGULAR &&
294 (attr.flags & SSH_FILEXFER_ATTR_SIZE) != 0 &&
295 wpi::ends_with(attr.name, ".wpilog")) {
296 m_downloadList.emplace_back(attr.name, attr.size);
297 }
298 }
299
300 m_state = kConnected;
301 break;
302 }
303 case kDisconnecting:
304 session.reset();
305 m_state = kDisconnected;
306 break;
307 case kDownload: {
308 for (auto&& download : m_downloadList) {
309 if (m_state != kDownload) {
310 // user aborted
311 break;
312 }
313 if (!download.enabled) {
314 continue;
315 }
316
317 auto remoteFilename = fmt::format(
318 "{}{}{}", m_remoteDir,
319 wpi::ends_with(m_remoteDir, '/') ? "" : "/", download.name);
320 auto localFilename = fs::path{m_localDir} / download.name;
321 uint64_t fileSize = download.size;
322
323 lock.unlock();
324
325 // open local file
326 std::error_code ec;
327 fs::file_t of = fs::OpenFileForWrite(localFilename, ec,
328 fs::CD_CreateNew, fs::OF_None);
329 if (ec) {
330 // failed to open
331 lock.lock();
332 download.status = ec.message();
333 continue;
334 }
335 int ofd = fs::FileToFd(of, ec, fs::OF_None);
336 if (ofd == -1 || ec) {
337 // failed to convert to fd
338 lock.lock();
339 download.status = ec.message();
340 continue;
341 }
342
343 try {
344 // open remote file
345 sftp::File f = session->Open(remoteFilename, O_RDONLY, 0);
346
347 // copy in chunks
348 uint64_t total = 0;
349 while (total < fileSize) {
350 uint64_t toCopy = (std::min)(fileSize - total,
351 static_cast<uint64_t>(kBufSize));
352 auto copied = f.Read(copyBuf.get(), toCopy);
353 if (write(ofd, copyBuf.get(), copied) !=
354 static_cast<int64_t>(copied)) {
355 // error writing
356 close(ofd);
357 fs::remove(localFilename, ec);
358 lock.lock();
359 download.status = "error writing local file";
360 goto err;
361 }
362 total += copied;
363 lock.lock();
364 download.complete = static_cast<float>(total) / fileSize;
365 lock.unlock();
366 }
367
368 // close local file
369 close(ofd);
370 ofd = -1;
371
372 // delete remote file (if enabled)
373 if (m_deleteAfter) {
374 f = sftp::File{};
375 session->Unlink(remoteFilename);
376 }
377 } catch (sftp::Exception& ex) {
378 if (ofd != -1) {
379 // close local file and delete it (due to failure)
380 close(ofd);
381 fs::remove(localFilename, ec);
382 }
383 lock.lock();
384 download.status = ex.what();
385 if (ex.err == SSH_FX_OK || ex.err == SSH_FX_CONNECTION_LOST) {
386 throw;
387 }
388 continue;
389 }
390 lock.lock();
391 err : {}
392 }
393 if (m_state == kDownload) {
394 m_state = kDownloadDone;
395 }
396 break;
397 }
398 default:
399 break;
400 }
401 } catch (sftp::Exception& ex) {
402 m_error = ex.what();
403 session.reset();
404 m_state = kDisconnected;
405 }
406 }
407}