blob: 4af8cadf91454d0710e9b6b81bb2fa105b6de0b3 [file] [log] [blame]
Austin Schuh75263e32022-02-22 18:05:32 -08001// Copyright (c) FIRST and other WPILib contributors.
2// Open Source Software; you can modify and/or share it under the terms of
3// the WPILib BSD license file in the root directory of this project.
4
5#include "wpi/MulticastServiceResolver.h"
6
7#include <netinet/in.h>
8#include <poll.h>
9
10#include <atomic>
11#include <thread>
12#include <vector>
13
14#include "ResolverThread.h"
15#include "dns_sd.h"
16#include "wpi/SmallVector.h"
17
18using namespace wpi;
19
20struct DnsResolveState {
21 DnsResolveState(MulticastServiceResolver::Impl* impl,
22 std::string_view serviceNameView)
23 : pImpl{impl} {
24 data.serviceName = serviceNameView;
25 }
26
27 DNSServiceRef ResolveRef = nullptr;
28 MulticastServiceResolver::Impl* pImpl;
29
30 MulticastServiceResolver::ServiceData data;
31};
32
33struct MulticastServiceResolver::Impl {
34 std::string serviceType;
35 MulticastServiceResolver* resolver;
36 std::shared_ptr<ResolverThread> thread = ResolverThread::Get();
37 std::vector<std::unique_ptr<DnsResolveState>> ResolveStates;
38 DNSServiceRef serviceRef = nullptr;
39
40 void onFound(ServiceData&& data) {
41 resolver->PushData(std::forward<ServiceData>(data));
42 }
43};
44
45MulticastServiceResolver::MulticastServiceResolver(
46 std::string_view serviceType) {
47 pImpl = std::make_unique<Impl>();
48 pImpl->serviceType = serviceType;
49 pImpl->resolver = this;
50}
51
52MulticastServiceResolver::~MulticastServiceResolver() noexcept {
53 Stop();
54}
55
56void ServiceGetAddrInfoReply(DNSServiceRef sdRef, DNSServiceFlags flags,
57 uint32_t interfaceIndex,
58 DNSServiceErrorType errorCode,
59 const char* hostname,
60 const struct sockaddr* address, uint32_t ttl,
61 void* context) {
62 if (errorCode != kDNSServiceErr_NoError) {
63 return;
64 }
65
66 DnsResolveState* resolveState = static_cast<DnsResolveState*>(context);
67
68 resolveState->data.hostName = hostname;
69 resolveState->data.ipv4Address =
70 reinterpret_cast<const struct sockaddr_in*>(address)->sin_addr.s_addr;
71
72 resolveState->pImpl->onFound(std::move(resolveState->data));
73
74 resolveState->pImpl->thread->RemoveServiceRefInThread(
75 resolveState->ResolveRef);
76
77 resolveState->pImpl->ResolveStates.erase(std::find_if(
78 resolveState->pImpl->ResolveStates.begin(),
79 resolveState->pImpl->ResolveStates.end(),
80 [resolveState](auto& a) { return a.get() == resolveState; }));
81}
82
83void ServiceResolveReply(DNSServiceRef sdRef, DNSServiceFlags flags,
84 uint32_t interfaceIndex, DNSServiceErrorType errorCode,
85 const char* fullname, const char* hosttarget,
86 uint16_t port, /* In network byte order */
87 uint16_t txtLen, const unsigned char* txtRecord,
88 void* context) {
89 if (errorCode != kDNSServiceErr_NoError) {
90 return;
91 }
92
93 DnsResolveState* resolveState = static_cast<DnsResolveState*>(context);
94 resolveState->pImpl->thread->RemoveServiceRefInThread(
95 resolveState->ResolveRef);
96 DNSServiceRefDeallocate(resolveState->ResolveRef);
97 resolveState->ResolveRef = nullptr;
98 resolveState->data.port = ntohs(port);
99
100 int txtCount = TXTRecordGetCount(txtLen, txtRecord);
101 char keyBuf[256];
102 uint8_t valueLen;
103 const void* value;
104
105 for (int i = 0; i < txtCount; i++) {
106 errorCode = TXTRecordGetItemAtIndex(txtLen, txtRecord, i, sizeof(keyBuf),
107 keyBuf, &valueLen, &value);
108 if (errorCode == kDNSServiceErr_NoError) {
109 if (valueLen == 0) {
110 // No value
111 resolveState->data.txt.emplace_back(
112 std::pair<std::string, std::string>{std::string{keyBuf}, {}});
113 } else {
114 resolveState->data.txt.emplace_back(std::pair<std::string, std::string>{
115 std::string{keyBuf},
116 std::string{reinterpret_cast<const char*>(value), valueLen}});
117 }
118 }
119 }
120
121 errorCode = DNSServiceGetAddrInfo(
122 &resolveState->ResolveRef, flags, interfaceIndex,
123 kDNSServiceProtocol_IPv4, hosttarget, ServiceGetAddrInfoReply, context);
124
125 if (errorCode == kDNSServiceErr_NoError) {
126 dnssd_sock_t socket = DNSServiceRefSockFD(resolveState->ResolveRef);
127 resolveState->pImpl->thread->AddServiceRef(resolveState->ResolveRef,
128 socket);
129 } else {
130 resolveState->pImpl->thread->RemoveServiceRefInThread(
131 resolveState->ResolveRef);
132 resolveState->pImpl->ResolveStates.erase(std::find_if(
133 resolveState->pImpl->ResolveStates.begin(),
134 resolveState->pImpl->ResolveStates.end(),
135 [resolveState](auto& a) { return a.get() == resolveState; }));
136 }
137}
138
139static void DnsCompletion(DNSServiceRef sdRef, DNSServiceFlags flags,
140 uint32_t interfaceIndex,
141 DNSServiceErrorType errorCode,
142 const char* serviceName, const char* regtype,
143 const char* replyDomain, void* context) {
144 if (errorCode != kDNSServiceErr_NoError) {
145 return;
146 }
147 if (!(flags & kDNSServiceFlagsAdd)) {
148 return;
149 }
150
151 MulticastServiceResolver::Impl* impl =
152 static_cast<MulticastServiceResolver::Impl*>(context);
153 auto& resolveState = impl->ResolveStates.emplace_back(
154 std::make_unique<DnsResolveState>(impl, serviceName));
155
156 errorCode = DNSServiceResolve(&resolveState->ResolveRef, 0, interfaceIndex,
157 serviceName, regtype, replyDomain,
158 ServiceResolveReply, resolveState.get());
159
160 if (errorCode == kDNSServiceErr_NoError) {
161 dnssd_sock_t socket = DNSServiceRefSockFD(resolveState->ResolveRef);
162 resolveState->pImpl->thread->AddServiceRef(resolveState->ResolveRef,
163 socket);
164 } else {
165 resolveState->pImpl->ResolveStates.erase(std::find_if(
166 resolveState->pImpl->ResolveStates.begin(),
167 resolveState->pImpl->ResolveStates.end(),
168 [r = resolveState.get()](auto& a) { return a.get() == r; }));
169 }
170}
171
172bool MulticastServiceResolver::HasImplementation() const {
173 return true;
174}
175
176void MulticastServiceResolver::Start() {
177 if (pImpl->serviceRef) {
178 return;
179 }
180
181 DNSServiceErrorType status =
182 DNSServiceBrowse(&pImpl->serviceRef, 0, 0, pImpl->serviceType.c_str(),
183 "local", DnsCompletion, pImpl.get());
184 if (status == kDNSServiceErr_NoError) {
185 dnssd_sock_t socket = DNSServiceRefSockFD(pImpl->serviceRef);
186 pImpl->thread->AddServiceRef(pImpl->serviceRef, socket);
187 }
188}
189
190void MulticastServiceResolver::Stop() {
191 if (!pImpl->serviceRef) {
192 return;
193 }
194 wpi::SmallVector<WPI_EventHandle, 8> cleanupEvents;
195 for (auto&& i : pImpl->ResolveStates) {
196 cleanupEvents.push_back(
197 pImpl->thread->RemoveServiceRefOutsideThread(i->ResolveRef));
198 }
199 cleanupEvents.push_back(
200 pImpl->thread->RemoveServiceRefOutsideThread(pImpl->serviceRef));
201 wpi::SmallVector<WPI_Handle, 8> signaledBuf;
202 signaledBuf.resize(cleanupEvents.size());
203 while (!cleanupEvents.empty()) {
204 auto signaled = wpi::WaitForObjects(cleanupEvents, signaledBuf);
205 for (auto&& s : signaled) {
206 cleanupEvents.erase(
207 std::find(cleanupEvents.begin(), cleanupEvents.end(), s));
208 }
209 }
210
211 pImpl->ResolveStates.clear();
212 pImpl->serviceRef = nullptr;
213}