fix: add IP address family handling and tests for DNS functionality #1572

This commit is contained in:
0xJacky
2026-03-15 02:28:49 +00:00
parent 7f7e569460
commit a2e1e8e31c
2 changed files with 86 additions and 11 deletions
+42 -11
View File
@@ -40,6 +40,14 @@ var (
ipRegex = regexp.MustCompile(`(?i)(?:[0-9a-f]{0,4}:){2,7}[0-9a-f]{0,4}|(?:\d{1,3}\.){3}\d{1,3}`)
)
type ipFamily int
const (
ipFamilyAny ipFamily = iota
ipFamilyV4
ipFamilyV6
)
// DefaultDDNSInterval returns the default polling interval in seconds.
func DefaultDDNSInterval() int {
return defaultDDNSIntervalSeconds
@@ -338,13 +346,13 @@ func resolvePublicIPs(ctx context.Context) (ipSnapshot, error) {
ipCtx, cancel := context.WithTimeout(ctx, ipDetectTimeout)
defer cancel()
if ip, err := fetchAnyIP(ipCtx, ipv4Endpoints); err == nil {
if ip, err := fetchAnyIP(ipCtx, ipv4Endpoints, ipFamilyV4); err == nil {
snapshot.IPv4 = ip
} else {
errs = append(errs, fmt.Sprintf("ipv4: %v", err))
}
if ip, err := fetchAnyIP(ipCtx, ipv6Endpoints); err == nil {
if ip, err := fetchAnyIP(ipCtx, ipv6Endpoints, ipFamilyV6); err == nil {
snapshot.IPv6 = ip
} else {
errs = append(errs, fmt.Sprintf("ipv6: %v", err))
@@ -357,7 +365,7 @@ func resolvePublicIPs(ctx context.Context) (ipSnapshot, error) {
return snapshot, nil
}
func fetchIP(ctx context.Context, endpoint string) (string, error) {
func fetchIP(ctx context.Context, endpoint string, family ipFamily) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return "", err
@@ -378,45 +386,45 @@ func fetchIP(ctx context.Context, endpoint string) (string, error) {
}
ipStr := strings.TrimSpace(string(body))
parsed, err := parseIPString(ipStr)
parsed, err := parseIPString(ipStr, family)
if err != nil {
return "", fmt.Errorf("invalid ip from %s: %v", endpoint, err)
}
return parsed, nil
}
func fetchAnyIP(ctx context.Context, endpoints []string) (string, error) {
func fetchAnyIP(ctx context.Context, endpoints []string, family ipFamily) (string, error) {
var errs []string
for _, ep := range endpoints {
if ip, err := fetchIP(ctx, ep); err == nil {
if ip, err := fetchIP(ctx, ep, family); err == nil {
return ip, nil
} else {
errs = append(errs, fmt.Sprintf("%s: %v", ep, err))
}
}
return "", fmt.Errorf(strings.Join(errs, "; "))
return "", errors.New(strings.Join(errs, "; "))
}
func parseIPString(val string) (string, error) {
func parseIPString(val string, family ipFamily) (string, error) {
trimmed := strings.TrimSpace(val)
if trimmed == "" {
return "", fmt.Errorf("empty ip string")
}
if ip := net.ParseIP(trimmed); ip != nil {
if ip := parseExpectedIP(trimmed, family); ip != nil {
return ip.String(), nil
}
for _, token := range strings.Fields(trimmed) {
cleaned := strings.Trim(token, " ,;[](){}<>")
if ip := net.ParseIP(cleaned); ip != nil {
if ip := parseExpectedIP(cleaned, family); ip != nil {
return ip.String(), nil
}
}
if matches := ipRegex.FindAllString(trimmed, -1); len(matches) > 0 {
for _, candidate := range matches {
if ip := net.ParseIP(candidate); ip != nil {
if ip := parseExpectedIP(candidate, family); ip != nil {
return ip.String(), nil
}
}
@@ -425,6 +433,29 @@ func parseIPString(val string) (string, error) {
return "", fmt.Errorf("no valid ip found")
}
func parseExpectedIP(val string, family ipFamily) net.IP {
ip := net.ParseIP(val)
if ip == nil {
return nil
}
switch family {
case ipFamilyV4:
if ip.To4() == nil {
return nil
}
case ipFamilyV6:
if ip.To4() != nil {
return nil
}
case ipFamilyAny:
default:
return nil
}
return ip
}
func sanitizeInterval(value int) int {
if value < minDDNSIntervalSeconds {
return minDDNSIntervalSeconds
+44
View File
@@ -0,0 +1,44 @@
package dns
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestParseIPStringRespectsAddressFamily(t *testing.T) {
t.Run("rejects ipv4 for ipv6", func(t *testing.T) {
_, err := parseIPString("203.0.113.10", ipFamilyV6)
require.Error(t, err)
})
t.Run("accepts embedded ipv6 text", func(t *testing.T) {
ip, err := parseIPString("Current IP: 2001:db8::10", ipFamilyV6)
require.NoError(t, err)
require.Equal(t, "2001:db8::10", ip)
})
t.Run("rejects ipv6 for ipv4", func(t *testing.T) {
_, err := parseIPString("2001:db8::10", ipFamilyV4)
require.Error(t, err)
})
}
func TestFetchAnyIPSkipsMismatchedAddressFamily(t *testing.T) {
ipv4Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("198.51.100.12"))
}))
defer ipv4Server.Close()
ipv6Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("2001:db8::20"))
}))
defer ipv6Server.Close()
ip, err := fetchAnyIP(context.Background(), []string{ipv4Server.URL, ipv6Server.URL}, ipFamilyV6)
require.NoError(t, err)
require.Equal(t, "2001:db8::20", ip)
}