/* * Copyright (c) 2021, The OpenThread Authors. * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * 1. Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * 2. Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * 3. Neither the name of the copyright holder nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ /** * @file * The file implements the DNS-SD Discovery Proxy. */ #if OTBR_ENABLE_DNSSD_DISCOVERY_PROXY #define OTBR_LOG_TAG "DPROXY" #include "sdp_proxy/discovery_proxy.hpp" #include #include #include #include #include "common/code_utils.hpp" #include "common/dns_utils.hpp" #include "common/logging.hpp" #include "utils/dns_utils.hpp" #include "utils/string_utils.hpp" namespace otbr { namespace Dnssd { static inline bool DnsLabelsEqual(const std::string &aLabel1, const std::string &aLabel2) { return StringUtils::EqualCaseInsensitive(aLabel1, aLabel2); } DiscoveryProxy::DiscoveryProxy(Ncp::RcpHost &aHost, Mdns::Publisher &aPublisher) : mHost(aHost) , mMdnsPublisher(aPublisher) , mIsEnabled(false) { mHost.RegisterResetHandler([this]() { otDnssdQuerySetCallbacks(mHost.GetInstance(), &DiscoveryProxy::OnDiscoveryProxySubscribe, &DiscoveryProxy::OnDiscoveryProxyUnsubscribe, this); }); } void DiscoveryProxy::SetEnabled(bool aIsEnabled) { VerifyOrExit(IsEnabled() != aIsEnabled); mIsEnabled = aIsEnabled; if (mIsEnabled) { Start(); } else { Stop(); } exit: return; } void DiscoveryProxy::Start(void) { assert(mSubscriberId == 0); otDnssdQuerySetCallbacks(mHost.GetInstance(), &DiscoveryProxy::OnDiscoveryProxySubscribe, &DiscoveryProxy::OnDiscoveryProxyUnsubscribe, this); mSubscriberId = mMdnsPublisher.AddSubscriptionCallbacks( [this](const std::string &aType, const Mdns::Publisher::DiscoveredInstanceInfo &aInstanceInfo) { if (!aInstanceInfo.mRemoved) { OnServiceDiscovered(aType, aInstanceInfo); } }, [this](const std::string &aHostName, const Mdns::Publisher::DiscoveredHostInfo &aHostInfo) { OnHostDiscovered(aHostName, aHostInfo); }); otbrLogInfo("Started"); } void DiscoveryProxy::Stop(void) { otDnssdQuerySetCallbacks(mHost.GetInstance(), nullptr, nullptr, nullptr); if (mSubscriberId > 0) { mMdnsPublisher.RemoveSubscriptionCallbacks(mSubscriberId); mSubscriberId = 0; } otbrLogInfo("Stopped"); } void DiscoveryProxy::OnDiscoveryProxySubscribe(void *aContext, const char *aFullName) { reinterpret_cast(aContext)->OnDiscoveryProxySubscribe(aFullName); } void DiscoveryProxy::OnDiscoveryProxySubscribe(const char *aFullName) { std::string fullName(aFullName); DnsNameInfo nameInfo = SplitFullDnsName(fullName); otbrLogInfo("Subscribe: %s", fullName.c_str()); if (GetServiceSubscriptionCount(nameInfo) == 1) { if (nameInfo.mHostName.empty()) { mMdnsPublisher.SubscribeService(nameInfo.mServiceName, nameInfo.mInstanceName); } else { mMdnsPublisher.SubscribeHost(nameInfo.mHostName); } } } void DiscoveryProxy::OnDiscoveryProxyUnsubscribe(void *aContext, const char *aFullName) { reinterpret_cast(aContext)->OnDiscoveryProxyUnsubscribe(aFullName); } void DiscoveryProxy::OnDiscoveryProxyUnsubscribe(const char *aFullName) { std::string fullName(aFullName); DnsNameInfo nameInfo = SplitFullDnsName(fullName); otbrLogInfo("Unsubscribe: %s", fullName.c_str()); if (GetServiceSubscriptionCount(nameInfo) == 1) { if (nameInfo.mHostName.empty()) { mMdnsPublisher.UnsubscribeService(nameInfo.mServiceName, nameInfo.mInstanceName); } else { mMdnsPublisher.UnsubscribeHost(nameInfo.mHostName); } } } void DiscoveryProxy::OnServiceDiscovered(const std::string &aType, const Mdns::Publisher::DiscoveredInstanceInfo &aInstanceInfo) { otDnssdServiceInstanceInfo instanceInfo; const otDnssdQuery *query = nullptr; std::string unescapedInstanceName = DnsUtils::UnescapeInstanceName(aInstanceInfo.mName); otbrLogInfo("Service discovered: %s, instance %s hostname %s addresses %zu port %d priority %d " "weight %d", aType.c_str(), aInstanceInfo.mName.c_str(), aInstanceInfo.mHostName.c_str(), aInstanceInfo.mAddresses.size(), aInstanceInfo.mPort, aInstanceInfo.mPriority, aInstanceInfo.mWeight); instanceInfo.mAddressNum = aInstanceInfo.mAddresses.size(); if (!aInstanceInfo.mAddresses.empty()) { instanceInfo.mAddresses = reinterpret_cast(&aInstanceInfo.mAddresses[0]); } else { instanceInfo.mAddresses = nullptr; } instanceInfo.mPort = aInstanceInfo.mPort; instanceInfo.mPriority = aInstanceInfo.mPriority; instanceInfo.mWeight = aInstanceInfo.mWeight; instanceInfo.mTxtLength = static_cast(aInstanceInfo.mTxtData.size()); instanceInfo.mTxtData = aInstanceInfo.mTxtData.data(); instanceInfo.mTtl = CapTtl(aInstanceInfo.mTtl); while ((query = otDnssdGetNextQuery(mHost.GetInstance(), query)) != nullptr) { std::string instanceName; std::string serviceName; std::string hostName; std::string domain; char queryName[OT_DNS_MAX_NAME_SIZE]; otDnssdQueryType type = otDnssdGetQueryTypeAndName(query, &queryName); otbrError splitError; switch (type) { case OT_DNSSD_QUERY_TYPE_BROWSE: splitError = SplitFullServiceName(queryName, serviceName, domain); break; case OT_DNSSD_QUERY_TYPE_RESOLVE: splitError = SplitFullServiceInstanceName(queryName, instanceName, serviceName, domain); break; default: splitError = OTBR_ERROR_NOT_FOUND; break; } if (splitError != OTBR_ERROR_NONE) { // Incoming service/instance was not what current query wanted to see, move on. continue; } if (DnsLabelsEqual(serviceName, aType) && (instanceName.empty() || DnsLabelsEqual(instanceName, unescapedInstanceName))) { std::string serviceFullName = aType + "." + domain; std::string translatedHostName = TranslateDomain(aInstanceInfo.mHostName, domain); std::string instanceFullName = unescapedInstanceName + "." + serviceFullName; instanceInfo.mFullName = instanceFullName.c_str(); instanceInfo.mHostName = translatedHostName.c_str(); otDnssdQueryHandleDiscoveredServiceInstance(mHost.GetInstance(), serviceFullName.c_str(), &instanceInfo); } } } void DiscoveryProxy::OnHostDiscovered(const std::string &aHostName, const Mdns::Publisher::DiscoveredHostInfo &aHostInfo) { otDnssdHostInfo hostInfo; const otDnssdQuery *query = nullptr; std::string resolvedHostName = aHostInfo.mHostName; otbrLogInfo("Host discovered: %s hostname %s addresses %zu", aHostName.c_str(), aHostInfo.mHostName.c_str(), aHostInfo.mAddresses.size()); if (resolvedHostName.empty()) { resolvedHostName = aHostName + ".local."; } hostInfo.mAddressNum = aHostInfo.mAddresses.size(); if (!aHostInfo.mAddresses.empty()) { hostInfo.mAddresses = reinterpret_cast(&aHostInfo.mAddresses[0]); } else { hostInfo.mAddresses = nullptr; } hostInfo.mTtl = CapTtl(aHostInfo.mTtl); while ((query = otDnssdGetNextQuery(mHost.GetInstance(), query)) != nullptr) { std::string hostName, domain; char queryName[OT_DNS_MAX_NAME_SIZE]; otDnssdQueryType type = otDnssdGetQueryTypeAndName(query, &queryName); otbrError splitError; OTBR_UNUSED_VARIABLE(splitError); if (type != OT_DNSSD_QUERY_TYPE_RESOLVE_HOST) { continue; } splitError = SplitFullHostName(queryName, hostName, domain); if (splitError != OTBR_ERROR_NONE) { continue; } if (DnsLabelsEqual(hostName, aHostName)) { std::string hostFullName = TranslateDomain(resolvedHostName, domain); otDnssdQueryHandleDiscoveredHost(mHost.GetInstance(), hostFullName.c_str(), &hostInfo); } } } std::string DiscoveryProxy::TranslateDomain(const std::string &aName, const std::string &aTargetDomain) { std::string targetName; std::string hostName; std::string domain; VerifyOrExit(OTBR_ERROR_NONE == SplitFullHostName(aName, hostName, domain), targetName = aName); VerifyOrExit(DnsLabelsEqual(domain, "local."), targetName = aName); targetName = hostName + "." + aTargetDomain; exit: otbrLogDebug("Translate domain: %s => %s", aName.c_str(), targetName.c_str()); return targetName; } int DiscoveryProxy::GetServiceSubscriptionCount(const DnsNameInfo &aNameInfo) const { const otDnssdQuery *query = nullptr; int count = 0; while ((query = otDnssdGetNextQuery(mHost.GetInstance(), query)) != nullptr) { char queryName[OT_DNS_MAX_NAME_SIZE]; DnsNameInfo queryInfo; otDnssdGetQueryTypeAndName(query, &queryName); queryInfo = SplitFullDnsName(queryName); count += (DnsLabelsEqual(aNameInfo.mInstanceName, queryInfo.mInstanceName) && DnsLabelsEqual(aNameInfo.mServiceName, queryInfo.mServiceName) && DnsLabelsEqual(aNameInfo.mHostName, queryInfo.mHostName)); } return count; } uint32_t DiscoveryProxy::CapTtl(uint32_t aTtl) { return std::min(aTtl, static_cast(kServiceTtlCapLimit)); } } // namespace Dnssd } // namespace otbr #endif // OTBR_ENABLE_DNSSD_DISCOVERY_PROXY