diff --git a/api/certificate/issue.go b/api/certificate/issue.go index c22ea78b..c4cf2015 100644 --- a/api/certificate/issue.go +++ b/api/certificate/issue.go @@ -43,6 +43,8 @@ func IssueCert(c *gin.Context) { defer ws.Close() + wsWriter := helper.NewSafeWebSocketWriter(ws) + // read payload := &cert.ConfigPayload{} @@ -70,13 +72,13 @@ func IssueCert(c *gin.Context) { log := cert.NewLogger() log.SetCertModel(&certModel) - log.SetWebSocket(ws) + log.SetWebSocket(wsWriter) defer log.Close() err = cert.IssueCert(payload, log) if err != nil { log.Error(err) - _ = ws.WriteJSON(IssueCertResponse{ + _ = wsWriter.WriteJSON(IssueCertResponse{ Status: Error, Message: err.Error(), }) @@ -102,13 +104,13 @@ func IssueCert(c *gin.Context) { })).FirstOrCreate() if err != nil { logger.Error(err) - _ = ws.WriteJSON(IssueCertResponse{ + _ = wsWriter.WriteJSON(IssueCertResponse{ Status: Error, Message: err.Error(), }) return } - err = ws.WriteJSON(IssueCertResponse{ + err = wsWriter.WriteJSON(IssueCertResponse{ Status: Success, Message: translation.C("[Nginx UI] Issued certificate successfully").ToString(), SSLCertificate: payload.GetCertificatePath(), diff --git a/api/certificate/revoke.go b/api/certificate/revoke.go index 3d354e91..c7b07744 100644 --- a/api/certificate/revoke.go +++ b/api/certificate/revoke.go @@ -2,6 +2,7 @@ package certificate import ( "github.com/0xJacky/Nginx-UI/internal/cert" + "github.com/0xJacky/Nginx-UI/internal/helper" "github.com/0xJacky/Nginx-UI/internal/middleware" "github.com/0xJacky/Nginx-UI/internal/translation" "github.com/0xJacky/Nginx-UI/query" @@ -16,7 +17,7 @@ type RevokeCertResponse struct { *translation.Container } -func handleRevokeCertLogChan(conn *websocket.Conn, logChan chan string) { +func handleRevokeCertLogChan(writer *helper.SafeWebSocketWriter, logChan chan string) { defer func() { if err := recover(); err != nil { logger.Error(err) @@ -24,7 +25,7 @@ func handleRevokeCertLogChan(conn *websocket.Conn, logChan chan string) { }() for logString := range logChan { - err := conn.WriteJSON(RevokeCertResponse{ + err := writer.WriteJSON(RevokeCertResponse{ Status: Info, Container: translation.C(logString), }) @@ -54,12 +55,14 @@ func RevokeCert(c *gin.Context) { _ = ws.Close() }(ws) + wsWriter := helper.NewSafeWebSocketWriter(ws) + // Get certificate from database certQuery := query.Cert certModel, err := certQuery.FirstByID(id) if err != nil { logger.Error(err) - _ = ws.WriteJSON(RevokeCertResponse{ + _ = wsWriter.WriteJSON(RevokeCertResponse{ Status: Error, Container: translation.C("Certificate not found: %{error}", map[string]any{ "error": err.Error(), @@ -83,17 +86,17 @@ func RevokeCert(c *gin.Context) { errChan := make(chan error, 1) certLogger := cert.NewLogger() - certLogger.SetWebSocket(ws) + certLogger.SetWebSocket(wsWriter) defer certLogger.Close() go cert.RevokeCert(payload, certLogger, logChan, errChan) - go handleRevokeCertLogChan(ws, logChan) + go handleRevokeCertLogChan(wsWriter, logChan) // block, until errChan closes for err = range errChan { logger.Error(err) - err = ws.WriteJSON(RevokeCertResponse{ + err = wsWriter.WriteJSON(RevokeCertResponse{ Status: Error, Container: translation.C("Failed to revoke certificate: %{error}", map[string]any{ "error": err.Error(), @@ -109,7 +112,7 @@ func RevokeCert(c *gin.Context) { err = certModel.Remove() if err != nil { logger.Error(err) - _ = ws.WriteJSON(RevokeCertResponse{ + _ = wsWriter.WriteJSON(RevokeCertResponse{ Status: Error, Container: translation.C("Failed to delete certificate from database: %{error}", map[string]any{ "error": err.Error(), @@ -118,7 +121,7 @@ func RevokeCert(c *gin.Context) { return } - err = ws.WriteJSON(RevokeCertResponse{ + err = wsWriter.WriteJSON(RevokeCertResponse{ Status: Success, Container: translation.C("Certificate revoked successfully"), }) diff --git a/api/nginx_log/websocket.go b/api/nginx_log/websocket.go index 9e8d1df7..20b5bdb5 100644 --- a/api/nginx_log/websocket.go +++ b/api/nginx_log/websocket.go @@ -64,7 +64,7 @@ func getLogPath(control *controlStruct) (logPath string, err error) { } // tailNginxLog tails the specified log file and sends each line to the websocket -func tailNginxLog(ws *websocket.Conn, controlChan chan controlStruct, errChan chan error) { +func tailNginxLog(writer *helper.SafeWebSocketWriter, controlChan chan controlStruct, errChan chan error) { defer func() { if err := recover(); err != nil { buf := make([]byte, 1024) @@ -117,7 +117,7 @@ func tailNginxLog(ws *websocket.Conn, controlChan chan controlStruct, errChan ch continue } - err = ws.WriteMessage(websocket.TextMessage, []byte(line.Text)) + err = writer.WriteMessage(websocket.TextMessage, []byte(line.Text)) if err != nil { if helper.IsUnexpectedWebsocketError(err) { errChan <- errors.Wrap(err, "error tailNginxLog write message") @@ -182,15 +182,17 @@ func Log(c *gin.Context) { defer ws.Close() + wsWriter := helper.NewSafeWebSocketWriter(ws) + errChan := make(chan error, 1) controlChan := make(chan controlStruct, 1) - go tailNginxLog(ws, controlChan, errChan) + go tailNginxLog(wsWriter, controlChan, errChan) go handleLogControl(ws, controlChan, errChan) if err = <-errChan; err != nil { logger.Error(err) - _ = ws.WriteMessage(websocket.TextMessage, []byte(err.Error())) + _ = wsWriter.WriteMessage(websocket.TextMessage, []byte(err.Error())) return } } diff --git a/api/sites/websocket.go b/api/sites/websocket.go index d32e8250..77f02983 100644 --- a/api/sites/websocket.go +++ b/api/sites/websocket.go @@ -1,7 +1,9 @@ package sites import ( + "errors" "sync" + "time" "github.com/0xJacky/Nginx-UI/internal/helper" "github.com/0xJacky/Nginx-UI/internal/middleware" @@ -42,40 +44,114 @@ var upgrader = websocket.Upgrader{ // WSManager WebSocket connection manager type WSManager struct { - connections map[*websocket.Conn]bool + connections map[*websocket.Conn]*WSClient mutex sync.RWMutex } +var errClientUnavailable = errors.New("websocket client unavailable") + +// WSClient wraps a websocket connection and handles serialized writes. +type WSClient struct { + conn *websocket.Conn + send chan interface{} + mutex sync.RWMutex + closed bool +} + +func (c *WSClient) trySend(v interface{}) bool { + c.mutex.RLock() + if c.closed { + c.mutex.RUnlock() + return false + } + + select { + case c.send <- v: + c.mutex.RUnlock() + return true + default: + c.mutex.RUnlock() + return false + } +} + +func (c *WSClient) closeSendChannel() { + c.mutex.Lock() + if c.closed { + c.mutex.Unlock() + return + } + + close(c.send) + c.closed = true + c.mutex.Unlock() +} + +func (c *WSClient) writePump() { + for message := range c.send { + c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if err := c.conn.WriteJSON(message); err != nil { + logger.Error("Failed to write site websocket message:", err) + return + } + } +} + var wsManager = &WSManager{ - connections: make(map[*websocket.Conn]bool), + connections: make(map[*websocket.Conn]*WSClient), } // AddConnection adds a WebSocket connection to the manager -func (wm *WSManager) AddConnection(conn *websocket.Conn) { +func (wm *WSManager) AddConnection(conn *websocket.Conn) *WSClient { wm.mutex.Lock() defer wm.mutex.Unlock() - wm.connections[conn] = true + client := &WSClient{ + conn: conn, + send: make(chan interface{}, 16), + } + wm.connections[conn] = client + return client } // RemoveConnection removes a WebSocket connection from the manager func (wm *WSManager) RemoveConnection(conn *websocket.Conn) { wm.mutex.Lock() - defer wm.mutex.Unlock() - delete(wm.connections, conn) + client, ok := wm.connections[conn] + if ok { + delete(wm.connections, conn) + } + wm.mutex.Unlock() + + if ok { + client.closeSendChannel() + } +} + +func (wm *WSManager) activeClients() []*WSClient { + wm.mutex.RLock() + if len(wm.connections) == 0 { + wm.mutex.RUnlock() + return nil + } + + clients := make([]*WSClient, 0, len(wm.connections)) + for _, client := range wm.connections { + clients = append(clients, client) + } + wm.mutex.RUnlock() + + return clients } // BroadcastUpdate sends updates to all connected WebSocket clients func (wm *WSManager) BroadcastUpdate(sites []*sitecheck.SiteInfo) { - wm.mutex.RLock() - defer wm.mutex.RUnlock() + for _, client := range wm.activeClients() { + if err := sendSiteData(client, MessageTypeUpdate, sites); err == nil { + continue + } - for conn := range wm.connections { - go func(c *websocket.Conn) { - if err := sendSiteData(c, MessageTypeUpdate, sites); err != nil { - wm.RemoveConnection(c) - c.Close() - } - }(conn) + wm.RemoveConnection(client.conn) + client.conn.Close() } } @@ -94,13 +170,13 @@ func InitWebSocketNotifications() { // SiteNavigationWebSocket handles WebSocket connections for real-time site status updates func SiteNavigationWebSocket(c *gin.Context) { - ctx := c.Request.Context() - conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { logger.Error("WebSocket upgrade failed:", err) return } + + client := wsManager.AddConnection(conn) defer func() { wsManager.RemoveConnection(conn) conn.Close() @@ -108,38 +184,39 @@ func SiteNavigationWebSocket(c *gin.Context) { logger.Info("Site navigation WebSocket connection established") - // Register connection with manager - wsManager.AddConnection(conn) - service := sitecheck.GetService() + go client.writePump() + // Send initial data - if err := sendSiteData(conn, MessageTypeInitial, service.GetSites()); err != nil { - logger.Error("Failed to send initial data:", err) + if err := sendSiteData(client, MessageTypeInitial, service.GetSites()); err != nil { + logger.Error("Failed to queue initial site data:", err) return } - // Handle incoming messages from client - go handleClientMessages(conn, service) - - <-ctx.Done() - logger.Info("Request context cancelled, closing WebSocket") + handleClientMessages(client, service) + logger.Info("Site navigation WebSocket connection closed") } // sendSiteData sends site data via WebSocket -func sendSiteData(conn *websocket.Conn, msgType string, sites []*sitecheck.SiteInfo) error { +func sendSiteData(client *WSClient, msgType string, sites []*sitecheck.SiteInfo) error { message := ServerMessage{ Type: msgType, Data: sites, } - return conn.WriteJSON(message) + + if !client.trySend(message) { + return errClientUnavailable + } + + return nil } // handleClientMessages handles incoming WebSocket messages -func handleClientMessages(conn *websocket.Conn, service *sitecheck.Service) { +func handleClientMessages(client *WSClient, service *sitecheck.Service) { for { var msg ClientMessage - if err := conn.ReadJSON(&msg); err != nil { + if err := client.conn.ReadJSON(&msg); err != nil { if helper.IsUnexpectedWebsocketError(err) { logger.Error("WebSocket read error:", err) } @@ -152,8 +229,8 @@ func handleClientMessages(conn *websocket.Conn, service *sitecheck.Service) { service.RefreshSites() case MessageTypePing: pongMsg := PongMessage{Type: MessageTypePong} - if err := conn.WriteJSON(pongMsg); err != nil { - logger.Error("Failed to send pong:", err) + if !client.trySend(pongMsg) { + logger.Error("Failed to queue pong response:", errClientUnavailable) return } } diff --git a/api/system/upgrade.go b/api/system/upgrade.go index 4c6df5e6..86ae182c 100644 --- a/api/system/upgrade.go +++ b/api/system/upgrade.go @@ -61,6 +61,8 @@ func PerformCoreUpgrade(c *gin.Context) { } defer ws.Close() + wsWriter := helper.NewSafeWebSocketWriter(ws) + var control upgrader.Control err = ws.ReadJSON(&control) @@ -70,8 +72,8 @@ func PerformCoreUpgrade(c *gin.Context) { return } if helper.InNginxUIOfficialDocker() && helper.DockerSocketExists() { - upgrader.DockerUpgrade(ws, &control) + upgrader.DockerUpgrade(wsWriter, &control) } else { - upgrader.BinaryUpgrade(ws, &control) + upgrader.BinaryUpgrade(wsWriter, &control) } } diff --git a/internal/cert/logger.go b/internal/cert/logger.go index 23be71b0..a4a03c3c 100644 --- a/internal/cert/logger.go +++ b/internal/cert/logger.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/0xJacky/Nginx-UI/internal/helper" "github.com/0xJacky/Nginx-UI/internal/translation" "github.com/0xJacky/Nginx-UI/model" "github.com/gorilla/websocket" @@ -15,7 +16,7 @@ import ( type Logger struct { buffer []string cert *model.Cert - ws *websocket.Conn + ws *helper.SafeWebSocketWriter trans *translation.Container mu sync.Mutex msgCh chan []byte @@ -52,7 +53,7 @@ func (t *Logger) SetCertModel(cert *model.Cert) { t.cert = cert } -func (t *Logger) SetWebSocket(ws *websocket.Conn) { +func (t *Logger) SetWebSocket(ws *helper.SafeWebSocketWriter) { t.mu.Lock() defer t.mu.Unlock() t.ws = ws diff --git a/internal/helper/websocket_writer.go b/internal/helper/websocket_writer.go new file mode 100644 index 00000000..80dc3452 --- /dev/null +++ b/internal/helper/websocket_writer.go @@ -0,0 +1,37 @@ +package helper + +import ( + "sync" + "time" + + "github.com/gorilla/websocket" +) + +// SafeWebSocketWriter serializes writes for a websocket connection. +type SafeWebSocketWriter struct { + conn *websocket.Conn + mutex sync.Mutex +} + +// NewSafeWebSocketWriter creates a serialized writer for a websocket connection. +func NewSafeWebSocketWriter(conn *websocket.Conn) *SafeWebSocketWriter { + return &SafeWebSocketWriter{conn: conn} +} + +// WriteJSON writes JSON data with serialized access to the websocket connection. +func (w *SafeWebSocketWriter) WriteJSON(v interface{}) error { + w.mutex.Lock() + defer w.mutex.Unlock() + + w.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + return w.conn.WriteJSON(v) +} + +// WriteMessage writes a websocket message with serialized access to the connection. +func (w *SafeWebSocketWriter) WriteMessage(messageType int, data []byte) error { + w.mutex.Lock() + defer w.mutex.Unlock() + + w.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + return w.conn.WriteMessage(messageType, data) +} diff --git a/internal/upgrader/binary.go b/internal/upgrader/binary.go index 525c817d..787d33f8 100644 --- a/internal/upgrader/binary.go +++ b/internal/upgrader/binary.go @@ -3,8 +3,8 @@ package upgrader import ( "os" + "github.com/0xJacky/Nginx-UI/internal/helper" "github.com/0xJacky/Nginx-UI/settings" - "github.com/gorilla/websocket" "github.com/uozi-tech/cosy/logger" ) @@ -15,7 +15,7 @@ type Control struct { } // BinaryUpgrade Upgrade the binary -func BinaryUpgrade(ws *websocket.Conn, control *Control) { +func BinaryUpgrade(ws *helper.SafeWebSocketWriter, control *Control) { _ = ws.WriteJSON(CoreUpgradeResp{ Status: UpgradeStatusInfo, Message: "Initialing core upgrader", diff --git a/internal/upgrader/docker.go b/internal/upgrader/docker.go index 6eb78f48..56cafcb6 100644 --- a/internal/upgrader/docker.go +++ b/internal/upgrader/docker.go @@ -2,12 +2,12 @@ package upgrader import ( "github.com/0xJacky/Nginx-UI/internal/docker" - "github.com/gorilla/websocket" + "github.com/0xJacky/Nginx-UI/internal/helper" "github.com/uozi-tech/cosy/logger" ) // DockerUpgrade Upgrade the Docker container -func DockerUpgrade(ws *websocket.Conn, control *Control) { +func DockerUpgrade(ws *helper.SafeWebSocketWriter, control *Control) { progressChan := make(chan float64) // Start a goroutine to listen for progress updates and send them via WebSocket