diff --git a/api/upstream/router.go b/api/upstream/router.go index 0960ca31..08fc1e73 100644 --- a/api/upstream/router.go +++ b/api/upstream/router.go @@ -2,9 +2,12 @@ package upstream import "github.com/gin-gonic/gin" -func InitRouter(r *gin.RouterGroup) { +func InitHTTPRouter(r *gin.RouterGroup) { r.GET("/upstream/availability", GetAvailability) - r.GET("/upstream/availability_ws", AvailabilityWebSocket) r.GET("/upstream/sockets", GetSocketList) r.PUT("/upstream/socket/:socket", UpdateSocketConfig) } + +func InitWebSocketRouter(r *gin.RouterGroup) { + r.GET("/upstream/availability_ws", AvailabilityWebSocket) +} diff --git a/api/upstream/router_test.go b/api/upstream/router_test.go new file mode 100644 index 00000000..e61048d6 --- /dev/null +++ b/api/upstream/router_test.go @@ -0,0 +1,74 @@ +package upstream + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +func TestRouteRegistrationSeparatesHTTPAndWebSocketEndpoints(t *testing.T) { + router := gin.New() + + httpGroup := router.Group("/", func(c *gin.Context) { + c.Header("X-Upstream-Group", "http") + c.AbortWithStatus(http.StatusNoContent) + }) + InitHTTPRouter(httpGroup) + + wsGroup := router.Group("/", func(c *gin.Context) { + c.Header("X-Upstream-Group", "ws") + c.AbortWithStatus(http.StatusNoContent) + }) + InitWebSocketRouter(wsGroup) + + testCases := []struct { + name string + method string + target string + expectedMark string + }{ + { + name: "availability uses http proxy group", + method: http.MethodGet, + target: "/upstream/availability", + expectedMark: "http", + }, + { + name: "socket list uses http proxy group", + method: http.MethodGet, + target: "/upstream/sockets", + expectedMark: "http", + }, + { + name: "socket update uses http proxy group", + method: http.MethodPut, + target: "/upstream/socket/127.0.0.1%3A8080", + expectedMark: "http", + }, + { + name: "availability websocket uses websocket proxy group", + method: http.MethodGet, + target: "/upstream/availability_ws", + expectedMark: "ws", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(tc.method, tc.target, nil) + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusNoContent, w.Code) + require.Equal(t, tc.expectedMark, w.Header().Get("X-Upstream-Group")) + }) + } +} diff --git a/router/routers.go b/router/routers.go index df5bec6b..92d63cef 100644 --- a/router/routers.go +++ b/router/routers.go @@ -107,6 +107,7 @@ func InitRouter() { external_notify.InitRouter(g) backup.InitAutoBackupRouter(g) nginxLog.InitRouter(g) + upstream.InitHTTPRouter(g) g.GET("/geolite/status", geolite.GetStatus) } @@ -121,7 +122,7 @@ func InitRouter() { terminal.InitRouter(o) } nginxLog.InitWebSocketRouter(w) - upstream.InitRouter(w) + upstream.InitWebSocketRouter(w) system.InitWebSocketRouter(w) nginx.InitWebSocketRouter(w) cluster.InitWebSocketRouter(w)