mirror of
https://github.com/0xJacky/nginx-ui.git
synced 2026-06-19 07:36:59 +00:00
fix: add IP address family handling and tests for DNS functionality #1572
This commit is contained in:
+42
-11
@@ -40,6 +40,14 @@ var (
|
||||
ipRegex = regexp.MustCompile(`(?i)(?:[0-9a-f]{0,4}:){2,7}[0-9a-f]{0,4}|(?:\d{1,3}\.){3}\d{1,3}`)
|
||||
)
|
||||
|
||||
type ipFamily int
|
||||
|
||||
const (
|
||||
ipFamilyAny ipFamily = iota
|
||||
ipFamilyV4
|
||||
ipFamilyV6
|
||||
)
|
||||
|
||||
// DefaultDDNSInterval returns the default polling interval in seconds.
|
||||
func DefaultDDNSInterval() int {
|
||||
return defaultDDNSIntervalSeconds
|
||||
@@ -338,13 +346,13 @@ func resolvePublicIPs(ctx context.Context) (ipSnapshot, error) {
|
||||
ipCtx, cancel := context.WithTimeout(ctx, ipDetectTimeout)
|
||||
defer cancel()
|
||||
|
||||
if ip, err := fetchAnyIP(ipCtx, ipv4Endpoints); err == nil {
|
||||
if ip, err := fetchAnyIP(ipCtx, ipv4Endpoints, ipFamilyV4); err == nil {
|
||||
snapshot.IPv4 = ip
|
||||
} else {
|
||||
errs = append(errs, fmt.Sprintf("ipv4: %v", err))
|
||||
}
|
||||
|
||||
if ip, err := fetchAnyIP(ipCtx, ipv6Endpoints); err == nil {
|
||||
if ip, err := fetchAnyIP(ipCtx, ipv6Endpoints, ipFamilyV6); err == nil {
|
||||
snapshot.IPv6 = ip
|
||||
} else {
|
||||
errs = append(errs, fmt.Sprintf("ipv6: %v", err))
|
||||
@@ -357,7 +365,7 @@ func resolvePublicIPs(ctx context.Context) (ipSnapshot, error) {
|
||||
return snapshot, nil
|
||||
}
|
||||
|
||||
func fetchIP(ctx context.Context, endpoint string) (string, error) {
|
||||
func fetchIP(ctx context.Context, endpoint string, family ipFamily) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -378,45 +386,45 @@ func fetchIP(ctx context.Context, endpoint string) (string, error) {
|
||||
}
|
||||
|
||||
ipStr := strings.TrimSpace(string(body))
|
||||
parsed, err := parseIPString(ipStr)
|
||||
parsed, err := parseIPString(ipStr, family)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid ip from %s: %v", endpoint, err)
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func fetchAnyIP(ctx context.Context, endpoints []string) (string, error) {
|
||||
func fetchAnyIP(ctx context.Context, endpoints []string, family ipFamily) (string, error) {
|
||||
var errs []string
|
||||
for _, ep := range endpoints {
|
||||
if ip, err := fetchIP(ctx, ep); err == nil {
|
||||
if ip, err := fetchIP(ctx, ep, family); err == nil {
|
||||
return ip, nil
|
||||
} else {
|
||||
errs = append(errs, fmt.Sprintf("%s: %v", ep, err))
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf(strings.Join(errs, "; "))
|
||||
return "", errors.New(strings.Join(errs, "; "))
|
||||
}
|
||||
|
||||
func parseIPString(val string) (string, error) {
|
||||
func parseIPString(val string, family ipFamily) (string, error) {
|
||||
trimmed := strings.TrimSpace(val)
|
||||
if trimmed == "" {
|
||||
return "", fmt.Errorf("empty ip string")
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(trimmed); ip != nil {
|
||||
if ip := parseExpectedIP(trimmed, family); ip != nil {
|
||||
return ip.String(), nil
|
||||
}
|
||||
|
||||
for _, token := range strings.Fields(trimmed) {
|
||||
cleaned := strings.Trim(token, " ,;[](){}<>")
|
||||
if ip := net.ParseIP(cleaned); ip != nil {
|
||||
if ip := parseExpectedIP(cleaned, family); ip != nil {
|
||||
return ip.String(), nil
|
||||
}
|
||||
}
|
||||
|
||||
if matches := ipRegex.FindAllString(trimmed, -1); len(matches) > 0 {
|
||||
for _, candidate := range matches {
|
||||
if ip := net.ParseIP(candidate); ip != nil {
|
||||
if ip := parseExpectedIP(candidate, family); ip != nil {
|
||||
return ip.String(), nil
|
||||
}
|
||||
}
|
||||
@@ -425,6 +433,29 @@ func parseIPString(val string) (string, error) {
|
||||
return "", fmt.Errorf("no valid ip found")
|
||||
}
|
||||
|
||||
func parseExpectedIP(val string, family ipFamily) net.IP {
|
||||
ip := net.ParseIP(val)
|
||||
if ip == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch family {
|
||||
case ipFamilyV4:
|
||||
if ip.To4() == nil {
|
||||
return nil
|
||||
}
|
||||
case ipFamilyV6:
|
||||
if ip.To4() != nil {
|
||||
return nil
|
||||
}
|
||||
case ipFamilyAny:
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
return ip
|
||||
}
|
||||
|
||||
func sanitizeInterval(value int) int {
|
||||
if value < minDDNSIntervalSeconds {
|
||||
return minDDNSIntervalSeconds
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseIPStringRespectsAddressFamily(t *testing.T) {
|
||||
t.Run("rejects ipv4 for ipv6", func(t *testing.T) {
|
||||
_, err := parseIPString("203.0.113.10", ipFamilyV6)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("accepts embedded ipv6 text", func(t *testing.T) {
|
||||
ip, err := parseIPString("Current IP: 2001:db8::10", ipFamilyV6)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "2001:db8::10", ip)
|
||||
})
|
||||
|
||||
t.Run("rejects ipv6 for ipv4", func(t *testing.T) {
|
||||
_, err := parseIPString("2001:db8::10", ipFamilyV4)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFetchAnyIPSkipsMismatchedAddressFamily(t *testing.T) {
|
||||
ipv4Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte("198.51.100.12"))
|
||||
}))
|
||||
defer ipv4Server.Close()
|
||||
|
||||
ipv6Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte("2001:db8::20"))
|
||||
}))
|
||||
defer ipv6Server.Close()
|
||||
|
||||
ip, err := fetchAnyIP(context.Background(), []string{ipv4Server.URL, ipv6Server.URL}, ipFamilyV6)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "2001:db8::20", ip)
|
||||
}
|
||||
Reference in New Issue
Block a user