sctp_lib: explicitly test for ipv6 support

Change-Id: I662138e6c9ea71cad751e2a057a642b85e6aabc7
Signed-off-by: Austin Schuh <austin.schuh@bluerivertech.com>
diff --git a/aos/network/BUILD b/aos/network/BUILD
index 0ed756c..323a334 100644
--- a/aos/network/BUILD
+++ b/aos/network/BUILD
@@ -553,6 +553,52 @@
     ],
 )
 
+cc_binary(
+    name = "sctp_lib_shim.so",
+    testonly = True,
+    srcs = [
+        "sctp_lib_shim.c",
+    ],
+    linkopts = [
+        "-ldl",
+    ],
+    linkshared = True,
+    target_compatible_with = ["@platforms//os:linux"],
+)
+
+cc_binary(
+    name = "sctp_lib_test_binary",
+    testonly = True,
+    srcs = [
+        "sctp_lib_test.cc",
+    ],
+    linkstatic = False,
+    target_compatible_with = ["@platforms//os:linux"],
+    deps = [
+        ":sctp_lib",
+        "//aos:init",
+    ],
+)
+
+sh_test(
+    name = "sctp_lib_test",
+    srcs = [
+        "sctp_lib_test.sh",
+    ],
+    args = [
+        "$(location :sctp_lib_test_binary)",
+        "$(location :sctp_lib_shim.so)",
+    ],
+    data = [
+        ":sctp_lib_shim.so",
+        ":sctp_lib_test_binary",
+    ],
+    target_compatible_with = ["@platforms//os:linux"],
+    deps = [
+        "@bazel_tools//tools/bash/runfiles",
+    ],
+)
+
 cc_test(
     name = "multinode_timestamp_filter_test",
     srcs = [
diff --git a/aos/network/sctp_client.cc b/aos/network/sctp_client.cc
index 273cff8..1033e31 100644
--- a/aos/network/sctp_client.cc
+++ b/aos/network/sctp_client.cc
@@ -17,9 +17,11 @@
 namespace message_bridge {
 
 SctpClient::SctpClient(std::string_view remote_host, int remote_port,
-                       int streams, std::string_view local_host, int local_port)
-    : sockaddr_remote_(ResolveSocket(remote_host, remote_port)),
-      sockaddr_local_(ResolveSocket(local_host, local_port)) {
+                       int streams, std::string_view local_host,
+                       int local_port) {
+  bool use_ipv6 = Ipv6Enabled();
+  sockaddr_local_ = ResolveSocket(local_host, local_port, use_ipv6);
+  sockaddr_remote_ = ResolveSocket(remote_host, remote_port, use_ipv6);
   sctp_.OpenSocket(sockaddr_local_);
 
   {
diff --git a/aos/network/sctp_lib.cc b/aos/network/sctp_lib.cc
index 7822e8e..99ead37 100644
--- a/aos/network/sctp_lib.cc
+++ b/aos/network/sctp_lib.cc
@@ -13,7 +13,7 @@
 
 #include "aos/util/file.h"
 
-DEFINE_string(interface, "", "ipv6 interface");
+DEFINE_string(interface, "", "network interface");
 DEFINE_bool(disable_ipv6, false, "disable ipv6");
 
 namespace aos {
@@ -31,19 +31,40 @@
 
 }  // namespace
 
-struct sockaddr_storage ResolveSocket(std::string_view host, int port) {
+bool Ipv6Enabled() {
+  if (FLAGS_disable_ipv6) {
+    return false;
+  }
+  int fd = socket(AF_INET6, SOCK_SEQPACKET, IPPROTO_SCTP);
+  if (fd != -1) {
+    close(fd);
+    return true;
+  }
+  switch (errno) {
+    case EAFNOSUPPORT:
+    case EINVAL:
+    case EPROTONOSUPPORT:
+      PLOG(INFO) << "no ipv6";
+      return false;
+    default:
+      PLOG(FATAL) << "Open socket failed";
+      return false;
+  };
+}
+
+struct sockaddr_storage ResolveSocket(std::string_view host, int port,
+                                      bool use_ipv6) {
   struct sockaddr_storage result;
   struct addrinfo *addrinfo_result;
   struct sockaddr_in *t_addr = (struct sockaddr_in *)&result;
   struct sockaddr_in6 *t_addr6 = (struct sockaddr_in6 *)&result;
   struct addrinfo hints;
   memset(&hints, 0, sizeof(hints));
-  if (FLAGS_disable_ipv6) {
+  if (!use_ipv6) {
     hints.ai_family = AF_INET;
   } else {
-    // IPv6 can handle IPv4 through IPv4-mapped IPv6 addresses
-    // but IPv4 can't handle IPv6 connections.
-    // The default, if unspecified, is to use IPv4.
+    // Default to IPv6 as the clearly superior protocol, since it also handles
+    // IPv4.
     hints.ai_family = AF_INET6;
   }
   hints.ai_socktype = SOCK_SEQPACKET;
@@ -55,11 +76,6 @@
   hints.ai_flags = AI_PASSIVE | AI_V4MAPPED | AI_NUMERICSERV;
   int ret = getaddrinfo(host.empty() ? nullptr : std::string(host).c_str(),
                         std::to_string(port).c_str(), &hints, &addrinfo_result);
-  if (ret) {
-    hints.ai_family = AF_INET;
-    ret = getaddrinfo(host.empty() ? nullptr : std::string(host).c_str(),
-                      std::to_string(port).c_str(), &hints, &addrinfo_result);
-  }
   if (ret == EAI_SYSTEM) {
     PLOG(FATAL) << "getaddrinfo failed to look up '" << host << "'";
   } else if (ret != 0) {
diff --git a/aos/network/sctp_lib.h b/aos/network/sctp_lib.h
index 6b40600..e81e6b3 100644
--- a/aos/network/sctp_lib.h
+++ b/aos/network/sctp_lib.h
@@ -17,9 +17,17 @@
 namespace aos {
 namespace message_bridge {
 
+// Check if ipv6 is enabled.
+// If we don't try IPv6, and omit AI_ADDRCONFIG when resolving addresses, the
+// library will happily resolve nodes to IPv6 IPs that can't be used. If we add
+// AI_ADDRCONFIG, the unit tests no longer work because they only have loopback
+// addresses available.
+bool Ipv6Enabled();
+
 // Resolves a socket and returns the address.  This can be either an ipv4 or
 // ipv6 address.
-struct sockaddr_storage ResolveSocket(std::string_view host, int port);
+struct sockaddr_storage ResolveSocket(std::string_view host, int port,
+                                      bool use_ipv6);
 
 // Returns a formatted version of the address.
 std::string Address(const struct sockaddr_storage &sockaddr);
diff --git a/aos/network/sctp_lib_shim.c b/aos/network/sctp_lib_shim.c
new file mode 100644
index 0000000..ab38c6c
--- /dev/null
+++ b/aos/network/sctp_lib_shim.c
@@ -0,0 +1,28 @@
+#define _GNU_SOURCE
+#include <dlfcn.h>
+#include <errno.h>
+#include <stddef.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <sys/socket.h>
+
+int socket(int domain, int type, int protocol) {
+  static int (*libsocket)(int domain, int type, int protocol) = NULL;
+  const char *error;
+  if (!libsocket) {
+    libsocket = dlsym(RTLD_NEXT, "socket");
+    if ((error = dlerror()) != NULL) {
+      fprintf(stderr, "shim socket: %s\n", error);
+      exit(1);
+    }
+  }
+
+  if (getenv("has_ipv6")[0] != 'y' && domain == AF_INET6) {
+    errno = EAFNOSUPPORT;
+    return -1;
+  }
+  // Force AF_INET since we don't actually know whether this system
+  // supports IPv6 and we're just trying to create a socket for the
+  // caller to immediately close again.
+  return libsocket(AF_INET, type, protocol);
+}
diff --git a/aos/network/sctp_lib_test.cc b/aos/network/sctp_lib_test.cc
new file mode 100644
index 0000000..ed3c15e
--- /dev/null
+++ b/aos/network/sctp_lib_test.cc
@@ -0,0 +1,15 @@
+#include "aos/network/sctp_lib.h"
+#include "aos/init.h"
+#include "gflags/gflags.h"
+
+DEFINE_string(host, "", "host to resolve");
+DEFINE_int32(port, 2977, "port to use");
+
+int main(int argc, char **argv) {
+  aos::InitGoogle(&argc, &argv);
+  struct sockaddr_storage sockaddr = aos::message_bridge::ResolveSocket(
+      FLAGS_host, FLAGS_port, aos::message_bridge::Ipv6Enabled());
+  LOG(INFO) << "Family " << aos::message_bridge::Family(sockaddr);
+  LOG(INFO) << "Address " << aos::message_bridge::Address(sockaddr);
+  return 0;
+}
diff --git a/aos/network/sctp_lib_test.sh b/aos/network/sctp_lib_test.sh
new file mode 100755
index 0000000..aa3fdbc
--- /dev/null
+++ b/aos/network/sctp_lib_test.sh
@@ -0,0 +1,45 @@
+#!/bin/bash
+
+set -euo pipefail
+
+result=0
+verify() {
+  local thisresult=0
+  local family
+  local address
+  family=$(echo "$output" | grep -Po '\] Family [^ ]+' | cut -f3 -d' ')
+  address=$(echo "$output" | grep -Po '\] Address [^ ]+' | cut -f3 -d' ')
+  if [[ "${family}" != "${1}" ]]; then
+    echo "Expected family ${1}, got ${family}" >&2
+    thisresult=1
+    result=1
+  fi
+  if [[ ! "${address}" =~ ${2} ]]; then
+    echo "Expected address ${2}, got ${address}" >&2
+    thisresult=1
+    result=1
+  fi
+  return $thisresult
+}
+
+run_test() {
+   local has_ipv6
+   has_ipv6="${1}"
+   export has_ipv6
+   shift
+   LD_PRELOAD="${SHIM}" "${BINARY}" --host=localhost "$@" 2>&1
+}
+
+BINARY="$1"
+SHIM="$2"
+
+output=$(run_test y)
+verify AF_INET6 "(::ffff:127.0.0.1|::)" || echo "IPv6 allowed with no arguments failed" >&2
+
+output=$(run_test n)
+verify AF_INET "127\\.0\\.0\\.1" || echo "IPv6 disallowed with no arguments failed" >&2
+
+output=$(run_test y --disable_ipv6)
+verify AF_INET "127\\.0\\.0\\.1" || echo "IPv6 allowed with --disable_ipv6 failed" >&2
+
+exit $result
diff --git a/aos/network/sctp_server.cc b/aos/network/sctp_server.cc
index 33046d0..93d1e88 100644
--- a/aos/network/sctp_server.cc
+++ b/aos/network/sctp_server.cc
@@ -20,8 +20,10 @@
 namespace aos {
 namespace message_bridge {
 
-SctpServer::SctpServer(int streams, std::string_view local_host, int local_port)
-    : sockaddr_local_(ResolveSocket(local_host, local_port)) {
+SctpServer::SctpServer(int streams, std::string_view local_host,
+                       int local_port) {
+  bool use_ipv6 = Ipv6Enabled();
+  sockaddr_local_ = ResolveSocket(local_host, local_port, use_ipv6);
   while (true) {
     sctp_.OpenSocket(sockaddr_local_);