blob: 115b43fd6e039e4e70cab4b8be3a9f39143ed6a2 [file] [log] [blame]
Austin Schuh75263e32022-02-22 18:05:32 -08001// Copyright (c) FIRST and other WPILib contributors.
2// Open Source Software; you can modify and/or share it under the terms of
3// the WPILib BSD license file in the root directory of this project.
4
5#ifndef UNICODE
6#define UNICODE
7#endif
8
9#include "wpi/MulticastServiceResolver.h"
10
11#include <string>
12
13#include "DynamicDns.h"
14#include "wpi/ConvertUTF.h"
15#include "wpi/SmallString.h"
16#include "wpi/SmallVector.h"
17#include "wpi/StringExtras.h"
18
19#pragma comment(lib, "dnsapi")
20
21using namespace wpi;
22
23struct MulticastServiceResolver::Impl {
24 wpi::DynamicDns& dynamicDns = wpi::DynamicDns::GetDynamicDns();
25 std::wstring serviceType;
26 DNS_SERVICE_CANCEL serviceCancel{nullptr};
27
28 MulticastServiceResolver* resolver;
29
30 void onFound(ServiceData&& data) {
31 resolver->PushData(std::forward<ServiceData>(data));
32 }
33};
34
35MulticastServiceResolver::MulticastServiceResolver(
36 std::string_view serviceType) {
37 pImpl = std::make_unique<Impl>();
38 pImpl->resolver = this;
39
40 if (!pImpl->dynamicDns.CanDnsResolve) {
41 return;
42 }
43
44 wpi::SmallVector<wchar_t, 128> wideStorage;
45
46 if (wpi::ends_with_lower(serviceType, ".local")) {
47 wpi::sys::windows::UTF8ToUTF16(serviceType, wideStorage);
48 } else {
49 wpi::SmallString<128> storage;
50 storage.append(serviceType);
51 storage.append(".local");
52 wpi::sys::windows::UTF8ToUTF16(storage.str(), wideStorage);
53 }
54 pImpl->serviceType = std::wstring{wideStorage.data(), wideStorage.size()};
55}
56
57MulticastServiceResolver::~MulticastServiceResolver() noexcept {
58 Stop();
59}
60
61bool MulticastServiceResolver::HasImplementation() const {
62 return pImpl->dynamicDns.CanDnsResolve;
63}
64
65static _Function_class_(DNS_QUERY_COMPLETION_ROUTINE) VOID WINAPI
66 DnsCompletion(_In_ PVOID pQueryContext,
67 _Inout_ PDNS_QUERY_RESULT pQueryResults) {
68 MulticastServiceResolver::Impl* impl =
69 reinterpret_cast<MulticastServiceResolver::Impl*>(pQueryContext);
70
71 wpi::SmallVector<DNS_RECORDW*, 4> PtrRecords;
72 wpi::SmallVector<DNS_RECORDW*, 4> SrvRecords;
73 wpi::SmallVector<DNS_RECORDW*, 4> TxtRecords;
74 wpi::SmallVector<DNS_RECORDW*, 4> ARecords;
75
76 {
77 DNS_RECORDW* current = pQueryResults->pQueryRecords;
78 while (current != nullptr) {
79 switch (current->wType) {
80 case DNS_TYPE_PTR:
81 PtrRecords.push_back(current);
82 break;
83 case DNS_TYPE_SRV:
84 SrvRecords.push_back(current);
85 break;
86 case DNS_TYPE_TEXT:
87 TxtRecords.push_back(current);
88 break;
89 case DNS_TYPE_A:
90 ARecords.push_back(current);
91 break;
92 }
93 current = current->pNext;
94 }
95 }
96
97 for (DNS_RECORDW* Ptr : PtrRecords) {
98 if (std::wstring_view{Ptr->pName} != impl->serviceType) {
99 continue;
100 }
101
102 std::wstring_view nameHost = Ptr->Data.Ptr.pNameHost;
103 DNS_RECORDW* foundSrv = nullptr;
104 for (DNS_RECORDW* Srv : SrvRecords) {
105 if (std::wstring_view{Srv->pName} == nameHost) {
106 foundSrv = Srv;
107 break;
108 }
109 }
110
111 if (!foundSrv) {
112 continue;
113 }
114
115 for (DNS_RECORDW* A : ARecords) {
116 if (std::wstring_view{A->pName} ==
117 std::wstring_view{foundSrv->Data.Srv.pNameTarget}) {
118 MulticastServiceResolver::ServiceData data;
119 wpi::SmallString<128> storage;
120 for (DNS_RECORDW* Txt : TxtRecords) {
121 if (std::wstring_view{Txt->pName} == nameHost) {
122 for (DWORD i = 0; i < Txt->Data.Txt.dwStringCount; i++) {
123 std::wstring_view wideView = Txt->Data.TXT.pStringArray[i];
124 size_t splitIndex = wideView.find(L'=');
125 if (splitIndex == wideView.npos) {
126 // Todo make this just do key
127 continue;
128 }
129 storage.clear();
130 wpi::span<const wpi::UTF16> wideStr{
131 reinterpret_cast<const wpi::UTF16*>(wideView.data()),
132 splitIndex};
133 wpi::convertUTF16ToUTF8String(wideStr, storage);
134 auto& pair = data.txt.emplace_back(
135 std::pair<std::string, std::string>{storage.string(), {}});
136 storage.clear();
137 wideStr = wpi::span<const wpi::UTF16>{
138 reinterpret_cast<const wpi::UTF16*>(wideView.data() +
139 splitIndex + 1),
140 wideView.size() - splitIndex - 1};
141 wpi::convertUTF16ToUTF8String(wideStr, storage);
142 pair.second = storage.string();
143 }
144 }
145 }
146
147 storage.clear();
148 wpi::span<const wpi::UTF16> wideHostName{
149 reinterpret_cast<const wpi::UTF16*>(A->pName), wcslen(A->pName)};
150 wpi::convertUTF16ToUTF8String(wideHostName, storage);
151 storage.append(".");
152
153 data.hostName = storage.string();
154 storage.clear();
155
156 int len = nameHost.find(impl->serviceType.c_str());
157 wpi::span<const wpi::UTF16> wideServiceName{
158 reinterpret_cast<const wpi::UTF16*>(nameHost.data()),
159 nameHost.size()};
160 if (len != nameHost.npos) {
161 wideServiceName = wideServiceName.subspan(0, len - 1);
162 }
163 wpi::convertUTF16ToUTF8String(wideServiceName, storage);
164
165 data.serviceName = storage.string();
166 data.port = foundSrv->Data.Srv.wPort;
167 data.ipv4Address = A->Data.A.IpAddress;
168
169 impl->onFound(std::move(data));
170 }
171 }
172 }
173 DnsFree(pQueryResults->pQueryRecords, DNS_FREE_TYPE::DnsFreeRecordList);
174}
175
176void MulticastServiceResolver::Start() {
177 if (pImpl->serviceCancel.reserved != nullptr) {
178 return;
179 }
180
181 if (!pImpl->dynamicDns.CanDnsResolve) {
182 return;
183 }
184
185 DNS_SERVICE_BROWSE_REQUEST request = {};
186 request.InterfaceIndex = 0;
187 request.pQueryContext = pImpl.get();
188 request.QueryName = pImpl->serviceType.c_str();
189 request.Version = 2;
190 request.pBrowseCallbackV2 = DnsCompletion;
191 pImpl->dynamicDns.DnsServiceBrowsePtr(&request, &pImpl->serviceCancel);
192}
193
194void MulticastServiceResolver::Stop() {
195 if (!pImpl->dynamicDns.CanDnsResolve) {
196 return;
197 }
198
199 if (pImpl->serviceCancel.reserved == nullptr) {
200 return;
201 }
202
203 pImpl->dynamicDns.DnsServiceBrowseCancelPtr(&pImpl->serviceCancel);
204 pImpl->serviceCancel.reserved = nullptr;
205}