blob: eecf819814fd8d2a3fd08a34f04e5e6e30d9f51b [file] [log] [blame]
James Kuszmaulcf324122023-01-14 14:07:17 -08001// Copyright (c) FIRST and other WPILib contributors.
2// Open Source Software; you can modify and/or share it under the terms of
3// the WPILib BSD license file in the root directory of this project.
4
5#if defined(__APPLE__)
6
7#include "wpinet/MulticastServiceResolver.h"
8
9#include <netinet/in.h>
10#include <poll.h>
11
12#include <atomic>
13#include <thread>
14#include <vector>
15
16#include <wpi/SmallVector.h>
17
18#include "ResolverThread.h"
19#include "dns_sd.h"
20
21using namespace wpi;
22
23struct DnsResolveState {
24 DnsResolveState(MulticastServiceResolver::Impl* impl,
25 std::string_view serviceNameView)
26 : pImpl{impl} {
27 data.serviceName = serviceNameView;
28 }
29
30 DNSServiceRef ResolveRef = nullptr;
31 MulticastServiceResolver::Impl* pImpl;
32
33 MulticastServiceResolver::ServiceData data;
34};
35
36struct MulticastServiceResolver::Impl {
37 std::string serviceType;
38 MulticastServiceResolver* resolver;
39 std::shared_ptr<ResolverThread> thread = ResolverThread::Get();
40 std::vector<std::unique_ptr<DnsResolveState>> ResolveStates;
41 DNSServiceRef serviceRef = nullptr;
42
43 void onFound(ServiceData&& data) {
44 resolver->PushData(std::forward<ServiceData>(data));
45 }
46};
47
48MulticastServiceResolver::MulticastServiceResolver(
49 std::string_view serviceType) {
50 pImpl = std::make_unique<Impl>();
51 pImpl->serviceType = serviceType;
52 pImpl->resolver = this;
53}
54
55MulticastServiceResolver::~MulticastServiceResolver() noexcept {
56 Stop();
57}
58
59void ServiceGetAddrInfoReply(DNSServiceRef sdRef, DNSServiceFlags flags,
60 uint32_t interfaceIndex,
61 DNSServiceErrorType errorCode,
62 const char* hostname,
63 const struct sockaddr* address, uint32_t ttl,
64 void* context) {
65 if (errorCode != kDNSServiceErr_NoError) {
66 return;
67 }
68
69 DnsResolveState* resolveState = static_cast<DnsResolveState*>(context);
70
71 resolveState->data.hostName = hostname;
72 resolveState->data.ipv4Address =
73 reinterpret_cast<const struct sockaddr_in*>(address)->sin_addr.s_addr;
74
75 resolveState->pImpl->onFound(std::move(resolveState->data));
76
77 resolveState->pImpl->thread->RemoveServiceRefInThread(
78 resolveState->ResolveRef);
79
80 resolveState->pImpl->ResolveStates.erase(std::find_if(
81 resolveState->pImpl->ResolveStates.begin(),
82 resolveState->pImpl->ResolveStates.end(),
83 [resolveState](auto& a) { return a.get() == resolveState; }));
84}
85
86void ServiceResolveReply(DNSServiceRef sdRef, DNSServiceFlags flags,
87 uint32_t interfaceIndex, DNSServiceErrorType errorCode,
88 const char* fullname, const char* hosttarget,
89 uint16_t port, /* In network byte order */
90 uint16_t txtLen, const unsigned char* txtRecord,
91 void* context) {
92 if (errorCode != kDNSServiceErr_NoError) {
93 return;
94 }
95
96 DnsResolveState* resolveState = static_cast<DnsResolveState*>(context);
97 resolveState->pImpl->thread->RemoveServiceRefInThread(
98 resolveState->ResolveRef);
99 DNSServiceRefDeallocate(resolveState->ResolveRef);
100 resolveState->ResolveRef = nullptr;
101 resolveState->data.port = ntohs(port);
102
103 int txtCount = TXTRecordGetCount(txtLen, txtRecord);
104 char keyBuf[256];
105 uint8_t valueLen;
106 const void* value;
107
108 for (int i = 0; i < txtCount; i++) {
109 errorCode = TXTRecordGetItemAtIndex(txtLen, txtRecord, i, sizeof(keyBuf),
110 keyBuf, &valueLen, &value);
111 if (errorCode == kDNSServiceErr_NoError) {
112 if (valueLen == 0) {
113 // No value
114 resolveState->data.txt.emplace_back(
115 std::pair<std::string, std::string>{std::string{keyBuf}, {}});
116 } else {
117 resolveState->data.txt.emplace_back(std::pair<std::string, std::string>{
118 std::string{keyBuf},
119 std::string{reinterpret_cast<const char*>(value), valueLen}});
120 }
121 }
122 }
123
124 errorCode = DNSServiceGetAddrInfo(
125 &resolveState->ResolveRef, flags, interfaceIndex,
126 kDNSServiceProtocol_IPv4, hosttarget, ServiceGetAddrInfoReply, context);
127
128 if (errorCode == kDNSServiceErr_NoError) {
129 dnssd_sock_t socket = DNSServiceRefSockFD(resolveState->ResolveRef);
130 resolveState->pImpl->thread->AddServiceRef(resolveState->ResolveRef,
131 socket);
132 } else {
133 resolveState->pImpl->thread->RemoveServiceRefInThread(
134 resolveState->ResolveRef);
135 resolveState->pImpl->ResolveStates.erase(std::find_if(
136 resolveState->pImpl->ResolveStates.begin(),
137 resolveState->pImpl->ResolveStates.end(),
138 [resolveState](auto& a) { return a.get() == resolveState; }));
139 }
140}
141
142static void DnsCompletion(DNSServiceRef sdRef, DNSServiceFlags flags,
143 uint32_t interfaceIndex,
144 DNSServiceErrorType errorCode,
145 const char* serviceName, const char* regtype,
146 const char* replyDomain, void* context) {
147 if (errorCode != kDNSServiceErr_NoError) {
148 return;
149 }
150 if (!(flags & kDNSServiceFlagsAdd)) {
151 return;
152 }
153
154 MulticastServiceResolver::Impl* impl =
155 static_cast<MulticastServiceResolver::Impl*>(context);
156 auto& resolveState = impl->ResolveStates.emplace_back(
157 std::make_unique<DnsResolveState>(impl, serviceName));
158
159 errorCode = DNSServiceResolve(&resolveState->ResolveRef, 0, interfaceIndex,
160 serviceName, regtype, replyDomain,
161 ServiceResolveReply, resolveState.get());
162
163 if (errorCode == kDNSServiceErr_NoError) {
164 dnssd_sock_t socket = DNSServiceRefSockFD(resolveState->ResolveRef);
165 resolveState->pImpl->thread->AddServiceRef(resolveState->ResolveRef,
166 socket);
167 } else {
168 resolveState->pImpl->ResolveStates.erase(std::find_if(
169 resolveState->pImpl->ResolveStates.begin(),
170 resolveState->pImpl->ResolveStates.end(),
171 [r = resolveState.get()](auto& a) { return a.get() == r; }));
172 }
173}
174
175bool MulticastServiceResolver::HasImplementation() const {
176 return true;
177}
178
179void MulticastServiceResolver::Start() {
180 if (pImpl->serviceRef) {
181 return;
182 }
183
184 DNSServiceErrorType status =
185 DNSServiceBrowse(&pImpl->serviceRef, 0, 0, pImpl->serviceType.c_str(),
186 "local", DnsCompletion, pImpl.get());
187 if (status == kDNSServiceErr_NoError) {
188 dnssd_sock_t socket = DNSServiceRefSockFD(pImpl->serviceRef);
189 pImpl->thread->AddServiceRef(pImpl->serviceRef, socket);
190 }
191}
192
193void MulticastServiceResolver::Stop() {
194 if (!pImpl->serviceRef) {
195 return;
196 }
197 wpi::SmallVector<WPI_EventHandle, 8> cleanupEvents;
198 for (auto&& i : pImpl->ResolveStates) {
199 cleanupEvents.push_back(
200 pImpl->thread->RemoveServiceRefOutsideThread(i->ResolveRef));
201 }
202 cleanupEvents.push_back(
203 pImpl->thread->RemoveServiceRefOutsideThread(pImpl->serviceRef));
204 wpi::SmallVector<WPI_Handle, 8> signaledBuf;
205 signaledBuf.resize(cleanupEvents.size());
206 while (!cleanupEvents.empty()) {
207 auto signaled = wpi::WaitForObjects(cleanupEvents, signaledBuf);
208 for (auto&& s : signaled) {
209 cleanupEvents.erase(
210 std::find(cleanupEvents.begin(), cleanupEvents.end(), s));
211 }
212 }
213
214 pImpl->ResolveStates.clear();
215 pImpl->serviceRef = nullptr;
216}
217
218#endif // defined(__APPLE__)