blob: 79fc7b5ef71edd85e6fd37ccaf36b5067b3bd62b [file] [log] [blame]
James Kuszmaulcf324122023-01-14 14:07:17 -08001// Copyright (c) FIRST and other WPILib contributors.
2// Open Source Software; you can modify and/or share it under the terms of
3// the WPILib BSD license file in the root directory of this project.
4
5#ifndef UNICODE
6#define UNICODE
7#endif
8
9#include "wpinet/MulticastServiceResolver.h"
10
11#include <string>
12
13#include <wpi/ConvertUTF.h>
14#include <wpi/SmallString.h>
15#include <wpi/SmallVector.h>
16#include <wpi/StringExtras.h>
17
18#include "DynamicDns.h"
19
20#pragma comment(lib, "dnsapi")
21
22using namespace wpi;
23
24struct MulticastServiceResolver::Impl {
25 wpi::DynamicDns& dynamicDns = wpi::DynamicDns::GetDynamicDns();
26 std::wstring serviceType;
27 DNS_SERVICE_CANCEL serviceCancel{nullptr};
28
29 MulticastServiceResolver* resolver;
30
31 void onFound(ServiceData&& data) {
32 resolver->PushData(std::forward<ServiceData>(data));
33 }
34};
35
36MulticastServiceResolver::MulticastServiceResolver(
37 std::string_view serviceType) {
38 pImpl = std::make_unique<Impl>();
39 pImpl->resolver = this;
40
41 if (!pImpl->dynamicDns.CanDnsResolve) {
42 return;
43 }
44
45 wpi::SmallVector<wchar_t, 128> wideStorage;
46
47 if (wpi::ends_with_lower(serviceType, ".local")) {
48 wpi::sys::windows::UTF8ToUTF16(serviceType, wideStorage);
49 } else {
50 wpi::SmallString<128> storage;
51 storage.append(serviceType);
52 storage.append(".local");
53 wpi::sys::windows::UTF8ToUTF16(storage.str(), wideStorage);
54 }
55 pImpl->serviceType = std::wstring{wideStorage.data(), wideStorage.size()};
56}
57
58MulticastServiceResolver::~MulticastServiceResolver() noexcept {
59 Stop();
60}
61
62bool MulticastServiceResolver::HasImplementation() const {
63 return pImpl->dynamicDns.CanDnsResolve;
64}
65
66static _Function_class_(DNS_QUERY_COMPLETION_ROUTINE) VOID WINAPI
67 DnsCompletion(_In_ PVOID pQueryContext,
68 _Inout_ PDNS_QUERY_RESULT pQueryResults) {
69 MulticastServiceResolver::Impl* impl =
70 reinterpret_cast<MulticastServiceResolver::Impl*>(pQueryContext);
71
72 wpi::SmallVector<DNS_RECORDW*, 4> PtrRecords;
73 wpi::SmallVector<DNS_RECORDW*, 4> SrvRecords;
74 wpi::SmallVector<DNS_RECORDW*, 4> TxtRecords;
75 wpi::SmallVector<DNS_RECORDW*, 4> ARecords;
76
77 {
78 DNS_RECORDW* current = pQueryResults->pQueryRecords;
79 while (current != nullptr) {
80 switch (current->wType) {
81 case DNS_TYPE_PTR:
82 PtrRecords.push_back(current);
83 break;
84 case DNS_TYPE_SRV:
85 SrvRecords.push_back(current);
86 break;
87 case DNS_TYPE_TEXT:
88 TxtRecords.push_back(current);
89 break;
90 case DNS_TYPE_A:
91 ARecords.push_back(current);
92 break;
93 }
94 current = current->pNext;
95 }
96 }
97
98 for (DNS_RECORDW* Ptr : PtrRecords) {
99 if (std::wstring_view{Ptr->pName} != impl->serviceType) {
100 continue;
101 }
102
103 std::wstring_view nameHost = Ptr->Data.Ptr.pNameHost;
104 DNS_RECORDW* foundSrv = nullptr;
105 for (DNS_RECORDW* Srv : SrvRecords) {
106 if (std::wstring_view{Srv->pName} == nameHost) {
107 foundSrv = Srv;
108 break;
109 }
110 }
111
112 if (!foundSrv) {
113 continue;
114 }
115
116 for (DNS_RECORDW* A : ARecords) {
117 if (std::wstring_view{A->pName} ==
118 std::wstring_view{foundSrv->Data.Srv.pNameTarget}) {
119 MulticastServiceResolver::ServiceData data;
120 wpi::SmallString<128> storage;
121 for (DNS_RECORDW* Txt : TxtRecords) {
122 if (std::wstring_view{Txt->pName} == nameHost) {
123 for (DWORD i = 0; i < Txt->Data.Txt.dwStringCount; i++) {
124 std::wstring_view wideView = Txt->Data.TXT.pStringArray[i];
125 size_t splitIndex = wideView.find(L'=');
126 if (splitIndex == wideView.npos) {
127 // Todo make this just do key
128 continue;
129 }
130 storage.clear();
131 std::span<const wpi::UTF16> wideStr{
132 reinterpret_cast<const wpi::UTF16*>(wideView.data()),
133 splitIndex};
134 wpi::convertUTF16ToUTF8String(wideStr, storage);
135 auto& pair =
136 data.txt.emplace_back(std::pair<std::string, std::string>{
137 std::string{storage}, {}});
138 storage.clear();
139 wideStr = std::span<const wpi::UTF16>{
140 reinterpret_cast<const wpi::UTF16*>(wideView.data() +
141 splitIndex + 1),
142 wideView.size() - splitIndex - 1};
143 wpi::convertUTF16ToUTF8String(wideStr, storage);
144 pair.second = std::string{storage};
145 }
146 }
147 }
148
149 storage.clear();
150 std::span<const wpi::UTF16> wideHostName{
151 reinterpret_cast<const wpi::UTF16*>(A->pName), wcslen(A->pName)};
152 wpi::convertUTF16ToUTF8String(wideHostName, storage);
153 storage.append(".");
154
155 data.hostName = std::string{storage};
156 storage.clear();
157
158 int len = nameHost.find(impl->serviceType.c_str());
159 std::span<const wpi::UTF16> wideServiceName{
160 reinterpret_cast<const wpi::UTF16*>(nameHost.data()),
161 nameHost.size()};
162 if (len != nameHost.npos) {
163 wideServiceName = wideServiceName.subspan(0, len - 1);
164 }
165 wpi::convertUTF16ToUTF8String(wideServiceName, storage);
166
167 data.serviceName = std::string{storage};
168 data.port = foundSrv->Data.Srv.wPort;
169 data.ipv4Address = A->Data.A.IpAddress;
170
171 impl->onFound(std::move(data));
172 }
173 }
174 }
175 DnsFree(pQueryResults->pQueryRecords, DNS_FREE_TYPE::DnsFreeRecordList);
176}
177
178void MulticastServiceResolver::Start() {
179 if (pImpl->serviceCancel.reserved != nullptr) {
180 return;
181 }
182
183 if (!pImpl->dynamicDns.CanDnsResolve) {
184 return;
185 }
186
187 DNS_SERVICE_BROWSE_REQUEST request = {};
188 request.InterfaceIndex = 0;
189 request.pQueryContext = pImpl.get();
190 request.QueryName = pImpl->serviceType.c_str();
191 request.Version = 2;
192 request.pBrowseCallbackV2 = DnsCompletion;
193 pImpl->dynamicDns.DnsServiceBrowsePtr(&request, &pImpl->serviceCancel);
194}
195
196void MulticastServiceResolver::Stop() {
197 if (!pImpl->dynamicDns.CanDnsResolve) {
198 return;
199 }
200
201 if (pImpl->serviceCancel.reserved == nullptr) {
202 return;
203 }
204
205 pImpl->dynamicDns.DnsServiceBrowseCancelPtr(&pImpl->serviceCancel);
206 pImpl->serviceCancel.reserved = nullptr;
207}