Move snicheck to ctx instead of simulated routing

This commit is contained in:
Julien Salleyron
2026-05-28 10:30:07 +02:00
committed by GitHub
parent f9d9b72380
commit 5026ca97d0
19 changed files with 362 additions and 597 deletions
@@ -7,6 +7,10 @@
[entryPoints.websecure]
address = ":4443"
[entryPoints.websecure.http3]
[experimental]
http3 = true
[api]
insecure = true
@@ -32,6 +36,35 @@
[http.routers.router3.tls]
options = "mytls"
[http.routers.router4]
rule = "Host(`site4.www.snitest.com`)"
service = "service4"
[http.routers.router4.tls]
[http.routers.router4path]
rule = "Host(`site4.www.snitest.com`) && PathPrefix(`/foo`)"
service = "service4"
[http.routers.router4path.tls]
options = "mytls"
[http.routers.router5]
rule = "Host(`site5.www.snitest.com`)"
service = "service5"
[http.routers.router5.tls]
options = "mytls"
[http.routers.router5path]
rule = "Host(`site5.www.snitest.com`) && PathPrefix(`/bar`)"
service = "service5"
[http.routers.router5path.tls]
options = "mytls"
[http.routers.router6]
rule = "Host(`site6.www.snitest.com.`)"
service = "service6"
[http.routers.router6.tls]
options = "mytls"
[http.services.service1]
[[http.services.service1.loadBalancer.servers]]
url = "http://127.0.0.1:9010"
@@ -44,10 +77,22 @@
[[http.services.service3.loadBalancer.servers]]
url = "http://127.0.0.1:9030"
[http.services.service4]
[[http.services.service4.loadBalancer.servers]]
url = "http://127.0.0.1:9040"
[http.services.service5]
[[http.services.service5.loadBalancer.servers]]
url = "http://127.0.0.1:9050"
[http.services.service6]
[[http.services.service6.loadBalancer.servers]]
url = "http://127.0.0.1:9060"
[[tls.certificates]]
certFile = "fixtures/https/wildcard.www.snitest.com.cert"
keyFile = "fixtures/https/wildcard.www.snitest.com.key"
[tls.options]
[tls.options.mytls]
maxVersion = "VersionTLS12"
maxVersion = "VersionTLS13"
+37 -24
View File
@@ -13,6 +13,7 @@ import (
"time"
"github.com/BurntSushi/toml"
"github.com/quic-go/quic-go/http3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
@@ -254,7 +255,7 @@ func (s *HTTPSSuite) TestWithConflictingTLSOptions() {
assert.ErrorContains(s.T(), err, "tls: no supported versions satisfy MinVersion and MaxVersion")
// with unknown tls option
err = try.GetRequest("http://127.0.0.1:8080/api/rawdata", 1*time.Second, try.BodyContains(fmt.Sprintf("found different TLS options for routers on the same host %v, so using the default TLS options instead", tr4.TLSClientConfig.ServerName)))
err = try.GetRequest("http://127.0.0.1:8080/api/rawdata", 1*time.Second, try.BodyContains("found different TLS options for routers on the same host, so using the default TLS options instead"))
require.NoError(s.T(), err)
}
@@ -995,19 +996,20 @@ func (s *HTTPSSuite) TestWithDomainFronting() {
defer backend2.Close()
backend3 := startTestServer("9030", http.StatusOK, "server3")
defer backend3.Close()
backend5 := startTestServer("9050", http.StatusOK, "server5")
defer backend5.Close()
file := s.adaptFile("fixtures/https/https_domain_fronting.toml", struct{}{})
s.traefikCmd(withConfigFile(file))
// wait for Traefik
err := try.GetRequest("http://127.0.0.1:8080/api/rawdata", 500*time.Millisecond, try.BodyContains("Host(`site1.www.snitest.com`)"))
err := try.GetRequest("http://127.0.0.1:8080/api/rawdata", 1000*time.Millisecond, try.BodyContains("Host(`site1.www.snitest.com`)"))
require.NoError(s.T(), err)
testCases := []struct {
desc string
hostHeader string
serverName string
expectedError bool
expectedContent string
expectedStatusCode int
}{
@@ -1025,14 +1027,6 @@ func (s *HTTPSSuite) TestWithDomainFronting() {
expectedContent: "server3",
expectedStatusCode: http.StatusOK,
},
{
desc: "Spaces after the host header",
hostHeader: "site3.www.snitest.com ",
serverName: "site3.www.snitest.com",
expectedError: true,
expectedContent: "server3",
expectedStatusCode: http.StatusOK,
},
{
desc: "Spaces after the servername",
hostHeader: "site3.www.snitest.com",
@@ -1040,14 +1034,6 @@ func (s *HTTPSSuite) TestWithDomainFronting() {
expectedContent: "server3",
expectedStatusCode: http.StatusOK,
},
{
desc: "Spaces after the servername and host header",
hostHeader: "site3.www.snitest.com ",
serverName: "site3.www.snitest.com ",
expectedError: true,
expectedContent: "server3",
expectedStatusCode: http.StatusOK,
},
{
desc: "Domain Fronting with same tlsOptions should follow header",
hostHeader: "site1.www.snitest.com",
@@ -1083,6 +1069,34 @@ func (s *HTTPSSuite) TestWithDomainFronting() {
expectedContent: "server1",
expectedStatusCode: http.StatusOK,
},
{
desc: "Domain Fronting with ambiguous TLS options should produce a 421",
hostHeader: "site4.www.snitest.com",
serverName: "site3.www.snitest.com",
expectedContent: "",
expectedStatusCode: http.StatusMisdirectedRequest,
},
{
desc: "Domain Fronting with same non-default TLS options should not produce a 421",
hostHeader: "site5.www.snitest.com",
serverName: "site3.www.snitest.com",
expectedContent: "server5",
expectedStatusCode: http.StatusOK,
},
{
desc: "FQDN host header with empty SNI to non-default TLS options route should produce a 421",
hostHeader: "site3.www.snitest.com.",
serverName: "",
expectedContent: "",
expectedStatusCode: http.StatusMisdirectedRequest,
},
{
desc: "Non-FQDN host header with empty SNI matching FQDN route rule should produce a 421",
hostHeader: "site6.www.snitest.com",
serverName: "",
expectedContent: "",
expectedStatusCode: http.StatusMisdirectedRequest,
},
}
for _, test := range testCases {
@@ -1091,11 +1105,10 @@ func (s *HTTPSSuite) TestWithDomainFronting() {
req.Host = test.hostHeader
err = try.RequestWithTransport(req, 500*time.Millisecond, &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true, ServerName: test.serverName}}, try.StatusCodeIs(test.expectedStatusCode), try.BodyContains(test.expectedContent))
if test.expectedError {
assert.Error(s.T(), err)
} else {
require.NoError(s.T(), err)
}
assert.NoError(s.T(), err, "test %s failed with: %v", test.desc, err)
err = try.RequestWithTransport(req, 500*time.Millisecond, &http3.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true, ServerName: test.serverName}}, try.StatusCodeIs(test.expectedStatusCode), try.BodyContains(test.expectedContent))
assert.NoError(s.T(), err, "test %s failed with: %v", test.desc, err)
}
}
+1 -1
View File
@@ -667,7 +667,7 @@ func (s *SimpleSuite) TestRouterConfigErrors() {
s.traefikCmd(withConfigFile(file))
// All errors
err := try.GetRequest("http://127.0.0.1:8080/api/http/routers", 1000*time.Millisecond, try.BodyContains(`["middleware \"unknown@file\" does not exist","found different TLS options for routers on the same host snitest.net, so using the default TLS options instead"]`))
err := try.GetRequest("http://127.0.0.1:8080/api/http/routers", 1000*time.Millisecond, try.BodyContains(`["middleware \"unknown@file\" does not exist","found different TLS options for routers on the same host, so using the default TLS options instead"]`))
require.NoError(s.T(), err)
// router3 has an error because it uses an unknown entrypoint
+3 -3
View File
@@ -76,7 +76,7 @@ func Request(req *http.Request, timeout time.Duration, conditions ...ResponseCon
// the condition on the response.
// ResponseCondition may be nil, in which case only the request against the URL must
// succeed.
func RequestWithTransport(req *http.Request, timeout time.Duration, transport *http.Transport, conditions ...ResponseCondition) error {
func RequestWithTransport(req *http.Request, timeout time.Duration, transport http.RoundTripper, conditions ...ResponseCondition) error {
resp, err := doTryRequest(req, timeout, transport, conditions...)
if resp != nil && resp.Body != nil {
@@ -140,12 +140,12 @@ func doTryRequest(request *http.Request, timeout time.Duration, transport http.R
func doRequest(action timedAction, timeout time.Duration, request *http.Request, transport http.RoundTripper, conditions ...ResponseCondition) (*http.Response, error) {
var resp *http.Response
return resp, action(timeout, func() error {
var err error
client := http.DefaultClient
var client http.Client
if transport != nil {
client.Transport = transport
}
var err error
resp, err = client.Do(request)
if err != nil {
return err
+4 -3
View File
@@ -103,9 +103,10 @@ func (r *RouterDeniedEncodedPathCharacters) Map() map[string]struct{} {
// RouterTLSConfig holds the TLS configuration for a router.
type RouterTLSConfig struct {
Options string `json:"options,omitempty" toml:"options,omitempty" yaml:"options,omitempty" export:"true"`
CertResolver string `json:"certResolver,omitempty" toml:"certResolver,omitempty" yaml:"certResolver,omitempty" export:"true"`
Domains []types.Domain `json:"domains,omitempty" toml:"domains,omitempty" yaml:"domains,omitempty" export:"true"`
Options string `json:"options,omitempty" toml:"options,omitempty" yaml:"options,omitempty" export:"true"`
ResolvedOptions string `json:"-" toml:"-" yaml:"-" label:"-" file:"-" kv:"-" export:"false"`
CertResolver string `json:"certResolver,omitempty" toml:"certResolver,omitempty" yaml:"certResolver,omitempty" export:"true"`
Domains []types.Domain `json:"domains,omitempty" toml:"domains,omitempty" yaml:"domains,omitempty" export:"true"`
}
// +k8s:deepcopy-gen=true
+19 -82
View File
@@ -1,24 +1,26 @@
package snicheck
import (
"net"
"net/http"
"strings"
"github.com/traefik/traefik/v2/pkg/log"
"github.com/traefik/traefik/v2/pkg/middlewares/requestdecorator"
traefiktls "github.com/traefik/traefik/v2/pkg/tls"
"github.com/traefik/traefik/v2/pkg/tcp"
)
// SNICheck is an HTTP handler that checks whether the TLS configuration for the server name is the same as for the host header.
type SNICheck struct {
next http.Handler
tlsOptionsForHost map[string]string
next http.Handler
routerName string
tlsOptionsName string
}
// New creates a new SNICheck.
func New(tlsOptionsForHost map[string]string, next http.Handler) *SNICheck {
return &SNICheck{next: next, tlsOptionsForHost: tlsOptionsForHost}
func New(routerName, tlsOptionsName string, next http.Handler) *SNICheck {
return &SNICheck{
next: next,
routerName: routerName,
tlsOptionsName: tlsOptionsName,
}
}
func (s SNICheck) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
@@ -27,81 +29,16 @@ func (s SNICheck) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
host := getHost(req)
serverName := strings.TrimSpace(req.TLS.ServerName)
// Domain Fronting
if !strings.EqualFold(host, serverName) {
tlsOptionHeader := findTLSOptionName(s.tlsOptionsForHost, host, true)
tlsOptionSNI := findTLSOptionName(s.tlsOptionsForHost, serverName, false)
if tlsOptionHeader != tlsOptionSNI {
log.WithoutContext().
WithField("host", host).
WithField("req.Host", req.Host).
WithField("req.TLS.ServerName", req.TLS.ServerName).
Debugf("TLS options difference: SNI:%s, Header:%s", tlsOptionSNI, tlsOptionHeader)
http.Error(rw, http.StatusText(http.StatusMisdirectedRequest), http.StatusMisdirectedRequest)
return
}
tlsOptionsNameUsed := tcp.GetTLSOptionsName(req.Context())
if s.tlsOptionsName != tlsOptionsNameUsed {
log.WithoutContext().
WithField("routerName", s.routerName).
WithField("req.Host", req.Host).
WithField("req.TLS.ServerName", req.TLS.ServerName).
Debugf("TLS options difference: SNI:%s, Header:%s", tlsOptionsNameUsed, s.tlsOptionsName)
http.Error(rw, http.StatusText(http.StatusMisdirectedRequest), http.StatusMisdirectedRequest)
return
}
s.next.ServeHTTP(rw, req)
}
func getHost(req *http.Request) string {
h := requestdecorator.GetCNAMEFlatten(req.Context())
if h != "" {
return h
}
h = requestdecorator.GetCanonizedHost(req.Context())
if h != "" {
return h
}
host, _, err := net.SplitHostPort(req.Host)
if err != nil {
host = req.Host
}
return strings.TrimSpace(host)
}
func findTLSOptionName(tlsOptionsForHost map[string]string, host string, fqdn bool) string {
name := findTLSOptName(tlsOptionsForHost, host, fqdn)
if name != "" {
return name
}
name = findTLSOptName(tlsOptionsForHost, strings.ToLower(host), fqdn)
if name != "" {
return name
}
return traefiktls.DefaultTLSConfigName
}
func findTLSOptName(tlsOptionsForHost map[string]string, host string, fqdn bool) string {
if tlsOptions, ok := tlsOptionsForHost[host]; ok {
return tlsOptions
}
if !fqdn {
return ""
}
if last := len(host) - 1; last >= 0 && host[last] == '.' {
if tlsOptions, ok := tlsOptionsForHost[host[:last]]; ok {
return tlsOptions
}
return ""
}
if tlsOptions, ok := tlsOptionsForHost[host+"."]; ok {
return tlsOptions
}
return ""
}
-59
View File
@@ -1,59 +0,0 @@
package snicheck
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestSNICheck_ServeHTTP(t *testing.T) {
testCases := []struct {
desc string
tlsOptionsForHost map[string]string
host string
expected int
}{
{
desc: "no TLS options",
expected: http.StatusOK,
},
{
desc: "with TLS options",
tlsOptionsForHost: map[string]string{
"example.com": "foo",
},
expected: http.StatusOK,
},
{
desc: "server name and host doesn't have the same TLS configuration",
tlsOptionsForHost: map[string]string{
"example.com": "foo",
},
host: "example.com",
expected: http.StatusMisdirectedRequest,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})
sniCheck := New(test.tlsOptionsForHost, next)
req := httptest.NewRequest(http.MethodGet, "https://localhost", nil)
if test.host != "" {
req.Host = test.host
}
recorder := httptest.NewRecorder()
sniCheck.ServeHTTP(recorder, req)
assert.Equal(t, test.expected, recorder.Code)
})
}
}
+3 -3
View File
@@ -61,10 +61,10 @@ type ConnData struct {
}
// NewConnData builds a connData struct from the given parameters.
func NewConnData(serverName string, conn tcp.WriteCloser, alpnProtos []string) (ConnData, error) {
remoteIP, _, err := net.SplitHostPort(conn.RemoteAddr().String())
func NewConnData(serverName string, remoteAddr net.Addr, alpnProtos []string) (ConnData, error) {
remoteIP, _, err := net.SplitHostPort(remoteAddr.String())
if err != nil {
return ConnData{}, fmt.Errorf("error while parsing remote address %q: %w", conn.RemoteAddr().String(), err)
return ConnData{}, fmt.Errorf("parsing remote address %q: %w", remoteAddr.String(), err)
}
// as per https://datatracker.ietf.org/doc/html/rfc6066:
+1 -1
View File
@@ -532,7 +532,7 @@ func Test_addTCPRoute(t *testing.T) {
remoteAddr: fakeAddr{addr: addr},
}
connData, err := NewConnData(test.serverName, conn, test.protos)
connData, err := NewConnData(test.serverName, conn.RemoteAddr(), test.protos)
require.NoError(t, err)
matchingHandler, _ := router.Match(connData)
+88 -8
View File
@@ -1,13 +1,16 @@
package server
import (
"context"
"fmt"
"slices"
"github.com/go-acme/lego/v4/challenge/tlsalpn01"
"github.com/traefik/traefik/v2/pkg/config/dynamic"
"github.com/traefik/traefik/v2/pkg/log"
httpmuxer "github.com/traefik/traefik/v2/pkg/muxer/http"
"github.com/traefik/traefik/v2/pkg/server/provider"
"github.com/traefik/traefik/v2/pkg/tls"
traefiktls "github.com/traefik/traefik/v2/pkg/tls"
)
func mergeConfiguration(configurations dynamic.Configurations, defaultEntryPoints []string) dynamic.Configuration {
@@ -31,8 +34,8 @@ func mergeConfiguration(configurations dynamic.Configurations, defaultEntryPoint
Services: make(map[string]*dynamic.UDPService),
},
TLS: &dynamic.TLSConfiguration{
Stores: make(map[string]tls.Store),
Options: make(map[string]tls.Options),
Stores: make(map[string]traefiktls.Store),
Options: make(map[string]traefiktls.Options),
},
}
@@ -101,7 +104,7 @@ func mergeConfiguration(configurations dynamic.Configurations, defaultEntryPoint
}
for key, store := range configuration.TLS.Stores {
if key != tls.DefaultTLSStoreName {
if key != traefiktls.DefaultTLSStoreName {
key = provider.MakeQualifiedName(pvd, key)
} else {
defaultTLSStoreProviders = append(defaultTLSStoreProviders, pvd)
@@ -123,19 +126,96 @@ func mergeConfiguration(configurations dynamic.Configurations, defaultEntryPoint
if len(defaultTLSStoreProviders) > 1 {
log.WithoutContext().Errorf("Default TLS Store defined in multiple providers: %v", defaultTLSStoreProviders)
delete(conf.TLS.Stores, tls.DefaultTLSStoreName)
delete(conf.TLS.Stores, traefiktls.DefaultTLSStoreName)
}
if len(defaultTLSOptionProviders) == 0 {
conf.TLS.Options[tls.DefaultTLSConfigName] = tls.DefaultTLSOptions
conf.TLS.Options[traefiktls.DefaultTLSConfigName] = traefiktls.DefaultTLSOptions
} else if len(defaultTLSOptionProviders) > 1 {
log.WithoutContext().Errorf("Default TLS Options defined in multiple providers %v", defaultTLSOptionProviders)
// We do not set an empty tls.TLS{} as above so that we actually get a "cascading failure" later on,
// i.e. routers depending on this missing TLS option will fail to initialize as well.
delete(conf.TLS.Options, tls.DefaultTLSConfigName)
delete(conf.TLS.Options, traefiktls.DefaultTLSConfigName)
}
return conf
return resolveHTTPTLSOptions(conf)
}
func resolveHTTPTLSOptions(cfg dynamic.Configuration) dynamic.Configuration {
if cfg.HTTP == nil || len(cfg.HTTP.Routers) == 0 {
return cfg
}
rts := make(map[string]*dynamic.Router)
// Keyed by domain, then by options reference.
// The actual source of truth for what TLS options will actually be used for the connection.
// As opposed to tlsOptionsForHost, it keeps track of all the (different) TLS
// options that occur for a given host name, so that later on we can set relevant
// errors and logging for all the routers concerned (i.e. wrongly configured).
tlsOptionsForHostSNI := map[string]map[string][]string{}
for routerHTTPName, routerHTTPConfig := range cfg.HTTP.Routers {
rts[routerHTTPName] = routerHTTPConfig.DeepCopy()
if routerHTTPConfig.TLS == nil {
continue
}
ctxRouter := log.With(provider.AddInContext(context.Background(), routerHTTPName), log.Str(log.RouterName, routerHTTPName))
logger := log.FromContext(ctxRouter)
tlsOptionsName := traefiktls.DefaultTLSConfigName
if len(routerHTTPConfig.TLS.Options) > 0 && routerHTTPConfig.TLS.Options != traefiktls.DefaultTLSConfigName {
tlsOptionsName = provider.GetQualifiedName(ctxRouter, routerHTTPConfig.TLS.Options)
}
domains, err := httpmuxer.ParseDomains(routerHTTPConfig.Rule)
if err != nil {
routerErr := fmt.Errorf("invalid rule %s, error: %w", routerHTTPConfig.Rule, err)
logger.Error(routerErr)
continue
}
if len(domains) == 0 {
rts[routerHTTPName].TLS.ResolvedOptions = "default"
logger.Warnf("No domain found in rule %v, the TLS options applied for this router will depend on the SNI of each request", routerHTTPConfig.Rule)
}
for _, domain := range domains {
// domain is already in lower case thanks to the domain parsing
if tlsOptionsForHostSNI[domain] == nil {
tlsOptionsForHostSNI[domain] = make(map[string][]string)
}
tlsOptionsForHostSNI[domain][tlsOptionsName] = append(tlsOptionsForHostSNI[domain][tlsOptionsName], routerHTTPName)
}
}
for hostSNI, tlsConfigs := range tlsOptionsForHostSNI {
if len(tlsConfigs) == 1 {
for optionsName, v := range tlsConfigs {
log.WithoutContext().Debugf("Adding route for %s with TLS options %s", hostSNI, optionsName)
for _, s := range v {
rts[s].TLS.ResolvedOptions = optionsName
}
}
continue
}
// multiple tlsConfigs
routers := make([]string, 0, len(tlsConfigs))
for _, v := range tlsConfigs {
for _, s := range v {
rts[s].TLS.ResolvedOptions = traefiktls.DefaultTLSConfigName
routers = append(routers, s)
}
}
log.WithoutContext().Warnf("Found different TLS options for routers on the same host %v, so using the default TLS options instead for these routers: %#v", hostSNI, routers)
}
cfg.HTTP.Routers = rts
return cfg
}
func applyModel(cfg dynamic.Configuration) dynamic.Configuration {
+7
View File
@@ -16,6 +16,7 @@ import (
"github.com/traefik/traefik/v2/pkg/middlewares/denyrouterrecursion"
metricsMiddle "github.com/traefik/traefik/v2/pkg/middlewares/metrics"
"github.com/traefik/traefik/v2/pkg/middlewares/recovery"
"github.com/traefik/traefik/v2/pkg/middlewares/snicheck"
"github.com/traefik/traefik/v2/pkg/middlewares/tracing"
httpmuxer "github.com/traefik/traefik/v2/pkg/muxer/http"
"github.com/traefik/traefik/v2/pkg/server/middleware"
@@ -229,6 +230,12 @@ func (m *Manager) buildHTTPHandler(ctx context.Context, router *runtime.RouterIn
})
}
if router.TLS != nil {
chain = chain.Append(func(next http.Handler) (http.Handler, error) {
return snicheck.New(routerName, router.TLS.ResolvedOptions, next), nil
})
}
return chain.Extend(*mHandler).Append(tHandler).Then(sHandler)
}
+24 -83
View File
@@ -2,7 +2,6 @@ package tcp
import (
"context"
"crypto/tls"
"errors"
"fmt"
"math"
@@ -11,7 +10,6 @@ import (
"github.com/traefik/traefik/v2/pkg/config/runtime"
"github.com/traefik/traefik/v2/pkg/log"
"github.com/traefik/traefik/v2/pkg/middlewares/snicheck"
httpmuxer "github.com/traefik/traefik/v2/pkg/muxer/http"
tcpmuxer "github.com/traefik/traefik/v2/pkg/muxer/tcp"
"github.com/traefik/traefik/v2/pkg/server/provider"
@@ -91,11 +89,6 @@ func (m *Manager) getHTTPRouters(ctx context.Context, entryPoints []string, tls
return make(map[string]map[string]*runtime.RouterInfo)
}
type nameAndConfig struct {
routerName string // just so we have it as additional information when logging
TLSConfig *tls.Config
}
func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string]*runtime.TCPRouterInfo, configsHTTP map[string]*runtime.RouterInfo, handlerHTTP, handlerHTTPS http.Handler) (*Router, error) {
// Build a new Router.
router, err := NewRouter()
@@ -113,18 +106,6 @@ func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string
log.FromContext(ctx).Errorf("Error during the build of the default TLS configuration: %v", err)
}
// Keyed by domain. The source of truth for doing SNI checking (domain fronting).
// As soon as there's (at least) two different tlsOptions found for the same domain,
// we set the value to the default TLS conf.
tlsOptionsForHost := map[string]string{}
// Keyed by domain, then by options reference.
// The actual source of truth for what TLS options will actually be used for the connection.
// As opposed to tlsOptionsForHost, it keeps track of all the (different) TLS
// options that occur for a given host name, so that later on we can set relevant
// errors and logging for all the routers concerned (i.e. wrongly configured).
tlsOptionsForHostSNI := map[string]map[string]nameAndConfig{}
for routerHTTPName, routerHTTPConfig := range configsHTTP {
if routerHTTPConfig.TLS == nil {
continue
@@ -133,11 +114,6 @@ func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string
ctxRouter := log.With(provider.AddInContext(ctx, routerHTTPName), log.Str(log.RouterName, routerHTTPName))
logger := log.FromContext(ctxRouter)
tlsOptionsName := traefiktls.DefaultTLSConfigName
if len(routerHTTPConfig.TLS.Options) > 0 && routerHTTPConfig.TLS.Options != traefiktls.DefaultTLSConfigName {
tlsOptionsName = provider.GetQualifiedName(ctxRouter, routerHTTPConfig.TLS.Options)
}
domains, err := httpmuxer.ParseDomains(routerHTTPConfig.Rule)
if err != nil {
routerErr := fmt.Errorf("invalid rule %s, error: %w", routerHTTPConfig.Rule, err)
@@ -152,7 +128,7 @@ func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string
// This is only about choosing the TLS configuration.
// The actual routing will be done further on by the HTTPS handler.
// See examples below.
router.AddHTTPTLSConfig("*", defaultTLSConf)
router.AddHTTPTLSConfig("*", defaultTLSConf, traefiktls.DefaultTLSConfigName)
// The server name (from a Host(SNI) rule) is the only parameter (available in HTTP routing rules) on which we can map a TLS config,
// because it is the only one accessible before decryption (we obtain it during the ClientHello).
@@ -180,79 +156,43 @@ func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string
logger.Warnf("No domain found in rule %v, the TLS options applied for this router will depend on the SNI of each request", routerHTTPConfig.Rule)
}
// Even if the TLS options mismatch between the configured and the resolved one is handled in the aggregator
// we also have to handle it here to be able to mark the router in error.
tlsOptionsName := traefiktls.DefaultTLSConfigName
if len(routerHTTPConfig.TLS.Options) > 0 && routerHTTPConfig.TLS.Options != traefiktls.DefaultTLSConfigName {
tlsOptionsName = provider.GetQualifiedName(ctxRouter, routerHTTPConfig.TLS.Options)
}
if routerHTTPConfig.TLS.ResolvedOptions != tlsOptionsName {
routerHTTPConfig.AddError(errors.New("found different TLS options for routers on the same host, so using the default TLS options instead"), false)
}
// Even though the error is seemingly ignored (aside from logging it),
// we actually rely later on the fact that a tls config is nil (which happens when an error is returned) to take special steps
// when assigning a handler to a route.
tlsConf, tlsConfErr := m.tlsManager.Get(traefiktls.DefaultTLSStoreName, tlsOptionsName)
tlsConf, tlsConfErr := m.tlsManager.Get(traefiktls.DefaultTLSStoreName, routerHTTPConfig.TLS.ResolvedOptions)
if tlsConfErr != nil {
// Note: we do not call AddError here because we already did so when buildRouterHandler errored for the same reason.
logger.Error(tlsConfErr)
}
for _, domain := range domains {
// domain is already in lower case thanks to the domain parsing
if tlsOptionsForHostSNI[domain] == nil {
tlsOptionsForHostSNI[domain] = make(map[string]nameAndConfig)
}
tlsOptionsForHostSNI[domain][tlsOptionsName] = nameAndConfig{
routerName: routerHTTPName,
TLSConfig: tlsConf,
}
if name, ok := tlsOptionsForHost[domain]; ok && name != tlsOptionsName {
// Different tlsOptions on the same domain, so fallback to default
tlsOptionsForHost[domain] = traefiktls.DefaultTLSConfigName
} else {
tlsOptionsForHost[domain] = tlsOptionsName
}
}
}
sniCheck := snicheck.New(tlsOptionsForHost, handlerHTTPS)
// Keep in mind that defaultTLSConf might be nil here.
router.SetHTTPSHandler(sniCheck, defaultTLSConf)
logger := log.FromContext(ctx)
for hostSNI, tlsConfigs := range tlsOptionsForHostSNI {
if len(tlsConfigs) == 1 {
var optionsName string
var config *tls.Config
for k, v := range tlsConfigs {
optionsName = k
config = v.TLSConfig
break
}
if config == nil {
if tlsConf == nil {
// we use nil config as a signal to insert a handler
// that enforces that TLS connection attempts to the corresponding (broken) router should fail.
logger.Debugf("Adding special closing route for %s because broken TLS options %s", hostSNI, optionsName)
router.AddHTTPTLSConfig(hostSNI, nil)
logger.Debugf("Adding special closing route for %s because of a broken TLS options %s", domain, routerHTTPConfig.TLS.ResolvedOptions)
router.AddHTTPTLSConfig(domain, nil, "")
continue
}
logger.Debugf("Adding route for %s with TLS options %s", hostSNI, optionsName)
router.AddHTTPTLSConfig(hostSNI, config)
continue
logger.Debugf("Adding route for %s with TLS options %s", domain, routerHTTPConfig.TLS.ResolvedOptions)
router.AddHTTPTLSConfig(domain, tlsConf, routerHTTPConfig.TLS.ResolvedOptions)
}
// multiple tlsConfigs
routers := make([]string, 0, len(tlsConfigs))
for _, v := range tlsConfigs {
configsHTTP[v.routerName].AddError(fmt.Errorf("found different TLS options for routers on the same host %v, so using the default TLS options instead", hostSNI), false)
routers = append(routers, v.routerName)
}
logger.Warnf("Found different TLS options for routers on the same host %v, so using the default TLS options instead for these routers: %#v", hostSNI, routers)
if defaultTLSConf == nil {
logger.Debugf("Adding special closing route for %s because broken default TLS options", hostSNI)
}
router.AddHTTPTLSConfig(hostSNI, defaultTLSConf)
}
// Keep in mind that defaultTLSConf might be nil here.
router.SetHTTPSHandler(handlerHTTPS, defaultTLSConf)
m.addTCPHandlers(ctx, configs, router)
return router, nil
@@ -385,8 +325,9 @@ func (m *Manager) addTCPHandlers(ctx context.Context, configs map[string]*runtim
}
handler = &tcp.TLSHandler{
Next: handler,
Config: tlsConf,
Next: handler,
Config: tlsConf,
TLSOptionsName: tlsOptionsName,
}
logger.Debugf("Adding TLS route for %q", routerConfig.Rule)
+4 -296
View File
@@ -1,14 +1,10 @@
package tcp
import (
"crypto/tls"
"math"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/traefik/traefik/v2/pkg/config/dynamic"
"github.com/traefik/traefik/v2/pkg/config/runtime"
tcpmiddleware "github.com/traefik/traefik/v2/pkg/server/middleware/tcp"
@@ -129,7 +125,8 @@ func TestRuntimeConfiguration(t *testing.T) {
Service: "foo-service",
Rule: "Host(`bar.foo`)",
TLS: &dynamic.RouterTLSConfig{
Options: "foo",
Options: "foo",
ResolvedOptions: "default",
},
},
},
@@ -139,7 +136,8 @@ func TestRuntimeConfiguration(t *testing.T) {
Service: "foo-service",
Rule: "Host(`bar.foo`) && PathPrefix(`/path`)",
TLS: &dynamic.RouterTLSConfig{
Options: "bar",
Options: "bar",
ResolvedOptions: "default",
},
},
},
@@ -396,293 +394,3 @@ func TestRuntimeConfiguration(t *testing.T) {
})
}
}
func TestDomainFronting(t *testing.T) {
tlsOptionsBase := map[string]traefiktls.Options{
"default": {
MinVersion: "VersionTLS10",
},
"host1@file": {
MinVersion: "VersionTLS12",
},
"host1@crd": {
MinVersion: "VersionTLS12",
},
}
entryPoints := []string{"web"}
tests := []struct {
desc string
routers map[string]*runtime.RouterInfo
tlsOptions map[string]traefiktls.Options
host string
ServerName string
expectedStatus int
}{
{
desc: "Request is misdirected when TLS options are different",
routers: map[string]*runtime.RouterInfo{
"router-1@file": {
Router: &dynamic.Router{
EntryPoints: entryPoints,
Rule: "Host(`host1.local`)",
TLS: &dynamic.RouterTLSConfig{
Options: "host1",
},
},
},
"router-2@file": {
Router: &dynamic.Router{
EntryPoints: entryPoints,
Rule: "Host(`host2.local`)",
TLS: &dynamic.RouterTLSConfig{},
},
},
},
tlsOptions: tlsOptionsBase,
host: "host1.local",
ServerName: "host2.local",
expectedStatus: http.StatusMisdirectedRequest,
},
{
desc: "Request is OK when TLS options are the same",
routers: map[string]*runtime.RouterInfo{
"router-1@file": {
Router: &dynamic.Router{
EntryPoints: entryPoints,
Rule: "Host(`host1.local`)",
TLS: &dynamic.RouterTLSConfig{
Options: "host1",
},
},
},
"router-2@file": {
Router: &dynamic.Router{
EntryPoints: entryPoints,
Rule: "Host(`host2.local`)",
TLS: &dynamic.RouterTLSConfig{
Options: "host1",
},
},
},
},
tlsOptions: tlsOptionsBase,
host: "host1.local",
ServerName: "host2.local",
expectedStatus: http.StatusOK,
},
{
desc: "Default TLS options is used when options are ambiguous for the same host",
routers: map[string]*runtime.RouterInfo{
"router-1@file": {
Router: &dynamic.Router{
EntryPoints: entryPoints,
Rule: "Host(`host1.local`)",
TLS: &dynamic.RouterTLSConfig{
Options: "host1",
},
},
},
"router-2@file": {
Router: &dynamic.Router{
EntryPoints: entryPoints,
Rule: "Host(`host1.local`) && PathPrefix(`/foo`)",
TLS: &dynamic.RouterTLSConfig{
Options: "default",
},
},
},
"router-3@file": {
Router: &dynamic.Router{
EntryPoints: entryPoints,
Rule: "Host(`host2.local`)",
TLS: &dynamic.RouterTLSConfig{
Options: "host1",
},
},
},
},
tlsOptions: tlsOptionsBase,
host: "host1.local",
ServerName: "host2.local",
expectedStatus: http.StatusMisdirectedRequest,
},
{
desc: "Default TLS options should not be used when options are the same for the same host",
routers: map[string]*runtime.RouterInfo{
"router-1@file": {
Router: &dynamic.Router{
EntryPoints: entryPoints,
Rule: "Host(`host1.local`)",
TLS: &dynamic.RouterTLSConfig{
Options: "host1",
},
},
},
"router-2@file": {
Router: &dynamic.Router{
EntryPoints: entryPoints,
Rule: "Host(`host1.local`) && PathPrefix(`/bar`)",
TLS: &dynamic.RouterTLSConfig{
Options: "host1",
},
},
},
"router-3@file": {
Router: &dynamic.Router{
EntryPoints: entryPoints,
Rule: "Host(`host2.local`)",
TLS: &dynamic.RouterTLSConfig{
Options: "host1",
},
},
},
},
tlsOptions: tlsOptionsBase,
host: "host1.local",
ServerName: "host2.local",
expectedStatus: http.StatusOK,
},
{
desc: "Request is misdirected when TLS options have the same name but from different providers",
routers: map[string]*runtime.RouterInfo{
"router-1@file": {
Router: &dynamic.Router{
EntryPoints: entryPoints,
Rule: "Host(`host1.local`)",
TLS: &dynamic.RouterTLSConfig{
Options: "host1",
},
},
},
"router-2@crd": {
Router: &dynamic.Router{
EntryPoints: entryPoints,
Rule: "Host(`host2.local`)",
TLS: &dynamic.RouterTLSConfig{
Options: "host1",
},
},
},
},
tlsOptions: tlsOptionsBase,
host: "host1.local",
ServerName: "host2.local",
expectedStatus: http.StatusMisdirectedRequest,
},
{
desc: "Request is OK when TLS options reference from a different provider is the same",
routers: map[string]*runtime.RouterInfo{
"router-1@file": {
Router: &dynamic.Router{
EntryPoints: entryPoints,
Rule: "Host(`host1.local`)",
TLS: &dynamic.RouterTLSConfig{
Options: "host1@crd",
},
},
},
"router-2@crd": {
Router: &dynamic.Router{
EntryPoints: entryPoints,
Rule: "Host(`host2.local`)",
TLS: &dynamic.RouterTLSConfig{
Options: "host1@crd",
},
},
},
},
tlsOptions: tlsOptionsBase,
host: "host1.local",
ServerName: "host2.local",
expectedStatus: http.StatusOK,
},
{
desc: "Request is misdirected when server name is empty and the host name is an FQDN, but router's rule is not",
routers: map[string]*runtime.RouterInfo{
"router-1@file": {
Router: &dynamic.Router{
EntryPoints: entryPoints,
Rule: "Host(`host1.local`)",
TLS: &dynamic.RouterTLSConfig{
Options: "host1@file",
},
},
},
},
tlsOptions: map[string]traefiktls.Options{
"default": {
MinVersion: "VersionTLS13",
},
"host1@file": {
MinVersion: "VersionTLS12",
},
},
host: "host1.local.",
expectedStatus: http.StatusMisdirectedRequest,
},
{
desc: "Request is misdirected when server name is empty and the host name is not FQDN, but router's rule is",
routers: map[string]*runtime.RouterInfo{
"router-1@file": {
Router: &dynamic.Router{
EntryPoints: entryPoints,
Rule: "Host(`host1.local.`)",
TLS: &dynamic.RouterTLSConfig{
Options: "host1@file",
},
},
},
},
tlsOptions: map[string]traefiktls.Options{
"default": {
MinVersion: "VersionTLS13",
},
"host1@file": {
MinVersion: "VersionTLS12",
},
},
host: "host1.local",
expectedStatus: http.StatusMisdirectedRequest,
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
conf := &runtime.Configuration{
Routers: test.routers,
}
serviceManager := tcp.NewManager(conf)
tlsManager := traefiktls.NewManager()
tlsManager.UpdateConfigs(t.Context(), map[string]traefiktls.Store{}, test.tlsOptions, []*traefiktls.CertAndStores{})
httpsHandler := map[string]http.Handler{
"web": http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {}),
}
middlewaresBuilder := tcpmiddleware.NewBuilder(conf.TCPMiddlewares)
routerManager := NewManager(conf, serviceManager, middlewaresBuilder, nil, httpsHandler, tlsManager)
routers := routerManager.BuildHandlers(t.Context(), entryPoints)
router, ok := routers["web"]
require.True(t, ok)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Host = test.host
req.TLS = &tls.ConnectionState{
ServerName: test.ServerName,
}
rw := httptest.NewRecorder()
router.GetHTTPSHandler().ServeHTTP(rw, req)
assert.Equal(t, test.expectedStatus, rw.Code)
})
}
}
+34 -17
View File
@@ -17,11 +17,17 @@ import (
"github.com/traefik/traefik/v2/pkg/log"
tcpmuxer "github.com/traefik/traefik/v2/pkg/muxer/tcp"
"github.com/traefik/traefik/v2/pkg/tcp"
traefiktls "github.com/traefik/traefik/v2/pkg/tls"
)
// errClientHelloRead is used as a sentinel error to break the TLS handshake once we have read the ClientHello.
var errClientHelloRead = errors.New("client hello successfully read")
type tlsConfigWithOptionsName struct {
cfg *tls.Config
optionsName string
}
// Router is a TCP router.
type Router struct {
acmeTLSPassthrough bool
@@ -48,7 +54,7 @@ type Router struct {
httpsTLSConfig *tls.Config // default TLS config
// hostHTTPTLSConfig contains TLS configs keyed by SNI.
// A nil config is the hint to set up a brokenTLSRouter.
hostHTTPTLSConfig map[string]*tls.Config // TLS configs keyed by SNI
hostHTTPTLSConfig map[string]tlsConfigWithOptionsName // TLS configs keyed by SNI
}
// NewRouter returns a new TCP router.
@@ -75,14 +81,20 @@ func NewRouter() (*Router, error) {
}, nil
}
// GetTLSGetClientInfo is called after a ClientHello is received from a client.
func (r *Router) GetTLSGetClientInfo() func(info *tls.ClientHelloInfo) (*tls.Config, error) {
return func(info *tls.ClientHelloInfo) (*tls.Config, error) {
if tlsConfig, ok := r.hostHTTPTLSConfig[info.ServerName]; ok {
return tlsConfig, nil
// HTTP3TLSConfigMatcherFunc returns a matcher func for HTTP/3 which returns a tls.Config with its corresponding
// TLSOptionName matching the given HostSNI in the connection data, or the default TLS config if there is no match.
func (r *Router) HTTP3TLSConfigMatcherFunc() func(connData tcpmuxer.ConnData) (*tls.Config, string, error) {
return func(connData tcpmuxer.ConnData) (*tls.Config, string, error) {
h, _ := r.muxerHTTPS.Match(connData)
if h == nil {
return r.httpsTLSConfig, traefiktls.DefaultTLSConfigName, nil
}
return r.httpsTLSConfig, nil
if tlsHandler, ok := h.(*tcp.TLSHandler); ok {
return tlsHandler.Config, tlsHandler.TLSOptionsName, nil
}
return nil, "", errors.New("matching handler is not a TLSHandler")
}
}
@@ -94,7 +106,7 @@ func (r *Router) ServeTCP(conn tcp.WriteCloser) {
// we would block forever on clientHelloInfo,
// which is why we want to detect and handle that case first and foremost.
if r.muxerTCP.HasRoutes() && !r.muxerTCPTLS.HasRoutes() && !r.muxerHTTPS.HasRoutes() {
connData, err := tcpmuxer.NewConnData("", conn, nil)
connData, err := tcpmuxer.NewConnData("", conn.RemoteAddr(), nil)
if err != nil {
log.WithoutContext().Errorf("Error while reading TCP connection data: %v", err)
conn.Close()
@@ -136,7 +148,7 @@ func (r *Router) ServeTCP(conn tcp.WriteCloser) {
log.WithoutContext().Errorf("Error while setting deadline: %v", err)
}
connData, err := tcpmuxer.NewConnData(hello.serverName, conn, hello.protos)
connData, err := tcpmuxer.NewConnData(hello.serverName, conn.RemoteAddr(), hello.protos)
if err != nil {
log.WithoutContext().Errorf("Error while reading TCP connection data: %v", err)
conn.Close()
@@ -212,12 +224,15 @@ func (r *Router) AddRoute(rule string, priority int, target tcp.Handler) error {
}
// AddHTTPTLSConfig defines a handler for a given sniHost and sets the matching tlsConfig.
func (r *Router) AddHTTPTLSConfig(sniHost string, config *tls.Config) {
func (r *Router) AddHTTPTLSConfig(sniHost string, config *tls.Config, optionsName string) {
if r.hostHTTPTLSConfig == nil {
r.hostHTTPTLSConfig = map[string]*tls.Config{}
r.hostHTTPTLSConfig = map[string]tlsConfigWithOptionsName{}
}
r.hostHTTPTLSConfig[sniHost] = config
r.hostHTTPTLSConfig[sniHost] = tlsConfigWithOptionsName{
cfg: config,
optionsName: optionsName,
}
}
// GetConn creates a connection proxy with a peeked string.
@@ -262,12 +277,13 @@ func (t *brokenTLSRouter) ServeTCP(conn tcp.WriteCloser) {
func (r *Router) SetHTTPSForwarder(handler tcp.Handler) {
for sniHost, tlsConf := range r.hostHTTPTLSConfig {
var tcpHandler tcp.Handler
if tlsConf == nil {
if tlsConf.cfg == nil {
tcpHandler = &brokenTLSRouter{}
} else {
tcpHandler = &tcp.TLSHandler{
Next: handler,
Config: tlsConf,
Next: handler,
Config: tlsConf.cfg,
TLSOptionsName: tlsConf.optionsName,
}
}
@@ -285,8 +301,9 @@ func (r *Router) SetHTTPSForwarder(handler tcp.Handler) {
}
r.httpsForwarder = &tcp.TLSHandler{
Next: handler,
Config: r.httpsTLSConfig,
Next: handler,
Config: r.httpsTLSConfig,
TLSOptionsName: "default",
}
}
+17 -4
View File
@@ -2,6 +2,7 @@ package tcp
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
@@ -20,7 +21,7 @@ import (
"github.com/traefik/traefik/v2/pkg/config/runtime"
tcpmiddleware "github.com/traefik/traefik/v2/pkg/server/middleware/tcp"
"github.com/traefik/traefik/v2/pkg/server/service/tcp"
tcp2 "github.com/traefik/traefik/v2/pkg/tcp"
traefiktcp "github.com/traefik/traefik/v2/pkg/tcp"
traefiktls "github.com/traefik/traefik/v2/pkg/tls"
"github.com/traefik/traefik/v2/pkg/tls/generate"
)
@@ -52,7 +53,7 @@ func (h *httpForwarder) Close() error {
}
// ServeTCP uses the connection to serve it later in "Accept".
func (h *httpForwarder) ServeTCP(conn tcp2.WriteCloser) {
func (h *httpForwarder) ServeTCP(conn traefiktcp.WriteCloser) {
h.connChan <- conn
}
@@ -621,6 +622,16 @@ func Test_Routing(t *testing.T) {
_, err = fmt.Fprint(w, "HTTPS")
require.NoError(t, err)
}),
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
if tlsConn, ok := c.(*tls.Conn); ok {
if tlsConnWithOptionsName, ok := tlsConn.NetConn().(traefiktcp.TLSConn); ok {
return traefiktcp.AddTLSOptionsNameInContext(ctx, tlsConnWithOptionsName.TLSOptionsName)
}
}
return ctx
},
}
stoppedHTTPS := make(chan struct{})
@@ -812,7 +823,8 @@ func routerHTTPSPathPrefix(conf *runtime.Configuration) {
Service: "http",
Rule: "PathPrefix(`/`)",
TLS: &dynamic.RouterTLSConfig{
Options: "tls10",
Options: "tls10",
ResolvedOptions: "tls10",
},
},
}
@@ -826,7 +838,8 @@ func routerHTTPS(conf *runtime.Configuration) {
Service: "http",
Rule: "Host(`foo.bar`)",
TLS: &dynamic.RouterTLSConfig{
Options: "tls12",
Options: "tls12",
ResolvedOptions: "tls12",
},
},
}
+10
View File
@@ -2,6 +2,7 @@ package server
import (
"context"
"crypto/tls"
"errors"
"expvar"
"fmt"
@@ -610,6 +611,15 @@ func createHTTPServer(ctx context.Context, ln net.Listener, configuration *stati
HTTP2: &http.HTTP2Config{
MaxConcurrentStreams: int(configuration.HTTP2.MaxConcurrentStreams),
},
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
if tlsConn, ok := c.(*tls.Conn); ok {
if tlsConnWithOptionsName, ok := tlsConn.NetConn().(tcp.TLSConn); ok {
return tcp.AddTLSOptionsNameInContext(ctx, tlsConnWithOptionsName.TLSOptionsName)
}
}
return ctx
},
}
if debugConnection || (configuration.Transport != nil && (configuration.Transport.KeepAliveMaxTime > 0 || configuration.Transport.KeepAliveMaxRequests > 0)) {
serverHTTP.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
+36 -7
View File
@@ -13,7 +13,9 @@ import (
"github.com/quic-go/quic-go/http3"
"github.com/traefik/traefik/v2/pkg/config/static"
"github.com/traefik/traefik/v2/pkg/log"
tcpmuxer "github.com/traefik/traefik/v2/pkg/muxer/tcp"
tcprouter "github.com/traefik/traefik/v2/pkg/server/router/tcp"
"github.com/traefik/traefik/v2/pkg/tcp"
)
type http3server struct {
@@ -22,7 +24,7 @@ type http3server struct {
http3conn net.PacketConn
lock sync.RWMutex
getter func(info *tls.ClientHelloInfo) (*tls.Config, error)
getter func(data tcpmuxer.ConnData) (*tls.Config, string, error)
}
func newHTTP3Server(ctx context.Context, configuration *static.EntryPoint, httpsServer *httpServer) (*http3server, error) {
@@ -41,8 +43,8 @@ func newHTTP3Server(ctx context.Context, configuration *static.EntryPoint, https
h3 := &http3server{
http3conn: conn,
getter: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
return nil, errors.New("no tls config")
getter: func(data tcpmuxer.ConnData) (*tls.Config, string, error) {
return nil, "", errors.New("no TLS config")
},
}
@@ -50,10 +52,18 @@ func newHTTP3Server(ctx context.Context, configuration *static.EntryPoint, https
Addr: configuration.GetAddress(),
Port: configuration.HTTP3.AdvertisedPort,
Handler: httpsServer.Server.(*http.Server).Handler,
TLSConfig: &tls.Config{GetConfigForClient: h3.getGetConfigForClient},
TLSConfig: &tls.Config{GetConfigForClient: h3.getTLSConfigForClient},
QUICConfig: &quic.Config{
Allow0RTT: false,
},
ConnContext: func(ctx context.Context, c *quic.Conn) context.Context {
tlsOptionsName, err := h3.getTLSOptionsName(c)
if err != nil {
log.WithoutContext().Errorf("Error getting TLS options name for client: %v", err)
return ctx
}
return tcp.AddTLSOptionsNameInContext(ctx, tlsOptionsName)
},
}
previousHandler := httpsServer.Server.(*http.Server).Handler
@@ -77,7 +87,7 @@ func (e *http3server) Switch(rt *tcprouter.Router) {
e.lock.Lock()
defer e.lock.Unlock()
e.getter = rt.GetTLSGetClientInfo()
e.getter = rt.HTTP3TLSConfigMatcherFunc()
}
func (e *http3server) Shutdown(_ context.Context) error {
@@ -85,9 +95,28 @@ func (e *http3server) Shutdown(_ context.Context) error {
return e.Server.Close()
}
func (e *http3server) getGetConfigForClient(info *tls.ClientHelloInfo) (*tls.Config, error) {
func (e *http3server) getTLSConfigForClient(info *tls.ClientHelloInfo) (*tls.Config, error) {
e.lock.RLock()
defer e.lock.RUnlock()
return e.getter(info)
connData, err := tcpmuxer.NewConnData(info.ServerName, info.Conn.RemoteAddr(), info.SupportedProtos)
if err != nil {
return nil, fmt.Errorf("creating ConnData from client hello: %w", err)
}
conf, _, err := e.getter(connData)
return conf, err
}
func (e *http3server) getTLSOptionsName(c *quic.Conn) (string, error) {
e.lock.RLock()
defer e.lock.RUnlock()
connData, err := tcpmuxer.NewConnData(c.ConnectionState().TLS.ServerName, c.RemoteAddr(), []string{c.ConnectionState().TLS.NegotiatedProtocol})
if err != nil {
return "", fmt.Errorf("creating ConnData from quic Conn: %w", err)
}
_, name, err := e.getter(connData)
return name, err
}
@@ -102,7 +102,7 @@ func TestHTTP3AdvertisedPort(t *testing.T) {
router.AddHTTPTLSConfig("*", &tls.Config{
Certificates: []tls.Certificate{tlsCert},
})
}, traefiktls.DefaultTLSConfigName)
router.SetHTTPSHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
}), nil)
@@ -164,7 +164,7 @@ func TestHTTP30RTT(t *testing.T) {
router.AddHTTPTLSConfig("example.com", &tls.Config{
Certificates: []tls.Certificate{tlsCert},
})
}, traefiktls.DefaultTLSConfigName)
router.SetHTTPSHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
}), nil)
+26 -3
View File
@@ -1,16 +1,39 @@
package tcp
import (
"context"
"crypto/tls"
)
// TLSConn is a TLS connection that also carries the name of the TLS config used.
type TLSConn struct {
WriteCloser
TLSOptionsName string
}
// TLSHandler handles TLS connections.
type TLSHandler struct {
Next Handler
Config *tls.Config
Next Handler
Config *tls.Config
TLSOptionsName string
}
// ServeTCP terminates the TLS connection.
func (t *TLSHandler) ServeTCP(conn WriteCloser) {
t.Next.ServeTCP(tls.Server(conn, t.Config))
t.Next.ServeTCP(tls.Server(TLSConn{WriteCloser: conn, TLSOptionsName: t.TLSOptionsName}, t.Config))
}
type tlsOptionsNameKey struct{}
func AddTLSOptionsNameInContext(ctx context.Context, name string) context.Context {
return context.WithValue(ctx, tlsOptionsNameKey{}, name)
}
func GetTLSOptionsName(ctx context.Context) string {
if name, ok := ctx.Value(tlsOptionsNameKey{}).(string); ok {
return name
}
return ""
}