refactor(ws): implement SafeWebSocketWriter for serialized access

- Introduced SafeWebSocketWriter to ensure thread-safe writes to WebSocket connections.
- Updated WebSocket handling in certificate issuance, revocation, Nginx log tailing, and system upgrades to use the new writer.
- Enhanced WebSocket client management in the site navigation module for improved message handling and connection stability.
This commit is contained in:
0xJacky
2026-04-04 02:01:20 +00:00
parent b9e1951423
commit 9f1b9bbbba
9 changed files with 181 additions and 57 deletions
+6 -4
View File
@@ -43,6 +43,8 @@ func IssueCert(c *gin.Context) {
defer ws.Close()
wsWriter := helper.NewSafeWebSocketWriter(ws)
// read
payload := &cert.ConfigPayload{}
@@ -70,13 +72,13 @@ func IssueCert(c *gin.Context) {
log := cert.NewLogger()
log.SetCertModel(&certModel)
log.SetWebSocket(ws)
log.SetWebSocket(wsWriter)
defer log.Close()
err = cert.IssueCert(payload, log)
if err != nil {
log.Error(err)
_ = ws.WriteJSON(IssueCertResponse{
_ = wsWriter.WriteJSON(IssueCertResponse{
Status: Error,
Message: err.Error(),
})
@@ -102,13 +104,13 @@ func IssueCert(c *gin.Context) {
})).FirstOrCreate()
if err != nil {
logger.Error(err)
_ = ws.WriteJSON(IssueCertResponse{
_ = wsWriter.WriteJSON(IssueCertResponse{
Status: Error,
Message: err.Error(),
})
return
}
err = ws.WriteJSON(IssueCertResponse{
err = wsWriter.WriteJSON(IssueCertResponse{
Status: Success,
Message: translation.C("[Nginx UI] Issued certificate successfully").ToString(),
SSLCertificate: payload.GetCertificatePath(),
+11 -8
View File
@@ -2,6 +2,7 @@ package certificate
import (
"github.com/0xJacky/Nginx-UI/internal/cert"
"github.com/0xJacky/Nginx-UI/internal/helper"
"github.com/0xJacky/Nginx-UI/internal/middleware"
"github.com/0xJacky/Nginx-UI/internal/translation"
"github.com/0xJacky/Nginx-UI/query"
@@ -16,7 +17,7 @@ type RevokeCertResponse struct {
*translation.Container
}
func handleRevokeCertLogChan(conn *websocket.Conn, logChan chan string) {
func handleRevokeCertLogChan(writer *helper.SafeWebSocketWriter, logChan chan string) {
defer func() {
if err := recover(); err != nil {
logger.Error(err)
@@ -24,7 +25,7 @@ func handleRevokeCertLogChan(conn *websocket.Conn, logChan chan string) {
}()
for logString := range logChan {
err := conn.WriteJSON(RevokeCertResponse{
err := writer.WriteJSON(RevokeCertResponse{
Status: Info,
Container: translation.C(logString),
})
@@ -54,12 +55,14 @@ func RevokeCert(c *gin.Context) {
_ = ws.Close()
}(ws)
wsWriter := helper.NewSafeWebSocketWriter(ws)
// Get certificate from database
certQuery := query.Cert
certModel, err := certQuery.FirstByID(id)
if err != nil {
logger.Error(err)
_ = ws.WriteJSON(RevokeCertResponse{
_ = wsWriter.WriteJSON(RevokeCertResponse{
Status: Error,
Container: translation.C("Certificate not found: %{error}", map[string]any{
"error": err.Error(),
@@ -83,17 +86,17 @@ func RevokeCert(c *gin.Context) {
errChan := make(chan error, 1)
certLogger := cert.NewLogger()
certLogger.SetWebSocket(ws)
certLogger.SetWebSocket(wsWriter)
defer certLogger.Close()
go cert.RevokeCert(payload, certLogger, logChan, errChan)
go handleRevokeCertLogChan(ws, logChan)
go handleRevokeCertLogChan(wsWriter, logChan)
// block, until errChan closes
for err = range errChan {
logger.Error(err)
err = ws.WriteJSON(RevokeCertResponse{
err = wsWriter.WriteJSON(RevokeCertResponse{
Status: Error,
Container: translation.C("Failed to revoke certificate: %{error}", map[string]any{
"error": err.Error(),
@@ -109,7 +112,7 @@ func RevokeCert(c *gin.Context) {
err = certModel.Remove()
if err != nil {
logger.Error(err)
_ = ws.WriteJSON(RevokeCertResponse{
_ = wsWriter.WriteJSON(RevokeCertResponse{
Status: Error,
Container: translation.C("Failed to delete certificate from database: %{error}", map[string]any{
"error": err.Error(),
@@ -118,7 +121,7 @@ func RevokeCert(c *gin.Context) {
return
}
err = ws.WriteJSON(RevokeCertResponse{
err = wsWriter.WriteJSON(RevokeCertResponse{
Status: Success,
Container: translation.C("Certificate revoked successfully"),
})
+6 -4
View File
@@ -64,7 +64,7 @@ func getLogPath(control *controlStruct) (logPath string, err error) {
}
// tailNginxLog tails the specified log file and sends each line to the websocket
func tailNginxLog(ws *websocket.Conn, controlChan chan controlStruct, errChan chan error) {
func tailNginxLog(writer *helper.SafeWebSocketWriter, controlChan chan controlStruct, errChan chan error) {
defer func() {
if err := recover(); err != nil {
buf := make([]byte, 1024)
@@ -117,7 +117,7 @@ func tailNginxLog(ws *websocket.Conn, controlChan chan controlStruct, errChan ch
continue
}
err = ws.WriteMessage(websocket.TextMessage, []byte(line.Text))
err = writer.WriteMessage(websocket.TextMessage, []byte(line.Text))
if err != nil {
if helper.IsUnexpectedWebsocketError(err) {
errChan <- errors.Wrap(err, "error tailNginxLog write message")
@@ -182,15 +182,17 @@ func Log(c *gin.Context) {
defer ws.Close()
wsWriter := helper.NewSafeWebSocketWriter(ws)
errChan := make(chan error, 1)
controlChan := make(chan controlStruct, 1)
go tailNginxLog(ws, controlChan, errChan)
go tailNginxLog(wsWriter, controlChan, errChan)
go handleLogControl(ws, controlChan, errChan)
if err = <-errChan; err != nil {
logger.Error(err)
_ = ws.WriteMessage(websocket.TextMessage, []byte(err.Error()))
_ = wsWriter.WriteMessage(websocket.TextMessage, []byte(err.Error()))
return
}
}
+110 -33
View File
@@ -1,7 +1,9 @@
package sites
import (
"errors"
"sync"
"time"
"github.com/0xJacky/Nginx-UI/internal/helper"
"github.com/0xJacky/Nginx-UI/internal/middleware"
@@ -42,40 +44,114 @@ var upgrader = websocket.Upgrader{
// WSManager WebSocket connection manager
type WSManager struct {
connections map[*websocket.Conn]bool
connections map[*websocket.Conn]*WSClient
mutex sync.RWMutex
}
var errClientUnavailable = errors.New("websocket client unavailable")
// WSClient wraps a websocket connection and handles serialized writes.
type WSClient struct {
conn *websocket.Conn
send chan interface{}
mutex sync.RWMutex
closed bool
}
func (c *WSClient) trySend(v interface{}) bool {
c.mutex.RLock()
if c.closed {
c.mutex.RUnlock()
return false
}
select {
case c.send <- v:
c.mutex.RUnlock()
return true
default:
c.mutex.RUnlock()
return false
}
}
func (c *WSClient) closeSendChannel() {
c.mutex.Lock()
if c.closed {
c.mutex.Unlock()
return
}
close(c.send)
c.closed = true
c.mutex.Unlock()
}
func (c *WSClient) writePump() {
for message := range c.send {
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := c.conn.WriteJSON(message); err != nil {
logger.Error("Failed to write site websocket message:", err)
return
}
}
}
var wsManager = &WSManager{
connections: make(map[*websocket.Conn]bool),
connections: make(map[*websocket.Conn]*WSClient),
}
// AddConnection adds a WebSocket connection to the manager
func (wm *WSManager) AddConnection(conn *websocket.Conn) {
func (wm *WSManager) AddConnection(conn *websocket.Conn) *WSClient {
wm.mutex.Lock()
defer wm.mutex.Unlock()
wm.connections[conn] = true
client := &WSClient{
conn: conn,
send: make(chan interface{}, 16),
}
wm.connections[conn] = client
return client
}
// RemoveConnection removes a WebSocket connection from the manager
func (wm *WSManager) RemoveConnection(conn *websocket.Conn) {
wm.mutex.Lock()
defer wm.mutex.Unlock()
delete(wm.connections, conn)
client, ok := wm.connections[conn]
if ok {
delete(wm.connections, conn)
}
wm.mutex.Unlock()
if ok {
client.closeSendChannel()
}
}
func (wm *WSManager) activeClients() []*WSClient {
wm.mutex.RLock()
if len(wm.connections) == 0 {
wm.mutex.RUnlock()
return nil
}
clients := make([]*WSClient, 0, len(wm.connections))
for _, client := range wm.connections {
clients = append(clients, client)
}
wm.mutex.RUnlock()
return clients
}
// BroadcastUpdate sends updates to all connected WebSocket clients
func (wm *WSManager) BroadcastUpdate(sites []*sitecheck.SiteInfo) {
wm.mutex.RLock()
defer wm.mutex.RUnlock()
for _, client := range wm.activeClients() {
if err := sendSiteData(client, MessageTypeUpdate, sites); err == nil {
continue
}
for conn := range wm.connections {
go func(c *websocket.Conn) {
if err := sendSiteData(c, MessageTypeUpdate, sites); err != nil {
wm.RemoveConnection(c)
c.Close()
}
}(conn)
wm.RemoveConnection(client.conn)
client.conn.Close()
}
}
@@ -94,13 +170,13 @@ func InitWebSocketNotifications() {
// SiteNavigationWebSocket handles WebSocket connections for real-time site status updates
func SiteNavigationWebSocket(c *gin.Context) {
ctx := c.Request.Context()
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error("WebSocket upgrade failed:", err)
return
}
client := wsManager.AddConnection(conn)
defer func() {
wsManager.RemoveConnection(conn)
conn.Close()
@@ -108,38 +184,39 @@ func SiteNavigationWebSocket(c *gin.Context) {
logger.Info("Site navigation WebSocket connection established")
// Register connection with manager
wsManager.AddConnection(conn)
service := sitecheck.GetService()
go client.writePump()
// Send initial data
if err := sendSiteData(conn, MessageTypeInitial, service.GetSites()); err != nil {
logger.Error("Failed to send initial data:", err)
if err := sendSiteData(client, MessageTypeInitial, service.GetSites()); err != nil {
logger.Error("Failed to queue initial site data:", err)
return
}
// Handle incoming messages from client
go handleClientMessages(conn, service)
<-ctx.Done()
logger.Info("Request context cancelled, closing WebSocket")
handleClientMessages(client, service)
logger.Info("Site navigation WebSocket connection closed")
}
// sendSiteData sends site data via WebSocket
func sendSiteData(conn *websocket.Conn, msgType string, sites []*sitecheck.SiteInfo) error {
func sendSiteData(client *WSClient, msgType string, sites []*sitecheck.SiteInfo) error {
message := ServerMessage{
Type: msgType,
Data: sites,
}
return conn.WriteJSON(message)
if !client.trySend(message) {
return errClientUnavailable
}
return nil
}
// handleClientMessages handles incoming WebSocket messages
func handleClientMessages(conn *websocket.Conn, service *sitecheck.Service) {
func handleClientMessages(client *WSClient, service *sitecheck.Service) {
for {
var msg ClientMessage
if err := conn.ReadJSON(&msg); err != nil {
if err := client.conn.ReadJSON(&msg); err != nil {
if helper.IsUnexpectedWebsocketError(err) {
logger.Error("WebSocket read error:", err)
}
@@ -152,8 +229,8 @@ func handleClientMessages(conn *websocket.Conn, service *sitecheck.Service) {
service.RefreshSites()
case MessageTypePing:
pongMsg := PongMessage{Type: MessageTypePong}
if err := conn.WriteJSON(pongMsg); err != nil {
logger.Error("Failed to send pong:", err)
if !client.trySend(pongMsg) {
logger.Error("Failed to queue pong response:", errClientUnavailable)
return
}
}
+4 -2
View File
@@ -61,6 +61,8 @@ func PerformCoreUpgrade(c *gin.Context) {
}
defer ws.Close()
wsWriter := helper.NewSafeWebSocketWriter(ws)
var control upgrader.Control
err = ws.ReadJSON(&control)
@@ -70,8 +72,8 @@ func PerformCoreUpgrade(c *gin.Context) {
return
}
if helper.InNginxUIOfficialDocker() && helper.DockerSocketExists() {
upgrader.DockerUpgrade(ws, &control)
upgrader.DockerUpgrade(wsWriter, &control)
} else {
upgrader.BinaryUpgrade(ws, &control)
upgrader.BinaryUpgrade(wsWriter, &control)
}
}
+3 -2
View File
@@ -6,6 +6,7 @@ import (
"sync"
"time"
"github.com/0xJacky/Nginx-UI/internal/helper"
"github.com/0xJacky/Nginx-UI/internal/translation"
"github.com/0xJacky/Nginx-UI/model"
"github.com/gorilla/websocket"
@@ -15,7 +16,7 @@ import (
type Logger struct {
buffer []string
cert *model.Cert
ws *websocket.Conn
ws *helper.SafeWebSocketWriter
trans *translation.Container
mu sync.Mutex
msgCh chan []byte
@@ -52,7 +53,7 @@ func (t *Logger) SetCertModel(cert *model.Cert) {
t.cert = cert
}
func (t *Logger) SetWebSocket(ws *websocket.Conn) {
func (t *Logger) SetWebSocket(ws *helper.SafeWebSocketWriter) {
t.mu.Lock()
defer t.mu.Unlock()
t.ws = ws
+37
View File
@@ -0,0 +1,37 @@
package helper
import (
"sync"
"time"
"github.com/gorilla/websocket"
)
// SafeWebSocketWriter serializes writes for a websocket connection.
type SafeWebSocketWriter struct {
conn *websocket.Conn
mutex sync.Mutex
}
// NewSafeWebSocketWriter creates a serialized writer for a websocket connection.
func NewSafeWebSocketWriter(conn *websocket.Conn) *SafeWebSocketWriter {
return &SafeWebSocketWriter{conn: conn}
}
// WriteJSON writes JSON data with serialized access to the websocket connection.
func (w *SafeWebSocketWriter) WriteJSON(v interface{}) error {
w.mutex.Lock()
defer w.mutex.Unlock()
w.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
return w.conn.WriteJSON(v)
}
// WriteMessage writes a websocket message with serialized access to the connection.
func (w *SafeWebSocketWriter) WriteMessage(messageType int, data []byte) error {
w.mutex.Lock()
defer w.mutex.Unlock()
w.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
return w.conn.WriteMessage(messageType, data)
}
+2 -2
View File
@@ -3,8 +3,8 @@ package upgrader
import (
"os"
"github.com/0xJacky/Nginx-UI/internal/helper"
"github.com/0xJacky/Nginx-UI/settings"
"github.com/gorilla/websocket"
"github.com/uozi-tech/cosy/logger"
)
@@ -15,7 +15,7 @@ type Control struct {
}
// BinaryUpgrade Upgrade the binary
func BinaryUpgrade(ws *websocket.Conn, control *Control) {
func BinaryUpgrade(ws *helper.SafeWebSocketWriter, control *Control) {
_ = ws.WriteJSON(CoreUpgradeResp{
Status: UpgradeStatusInfo,
Message: "Initialing core upgrader",
+2 -2
View File
@@ -2,12 +2,12 @@ package upgrader
import (
"github.com/0xJacky/Nginx-UI/internal/docker"
"github.com/gorilla/websocket"
"github.com/0xJacky/Nginx-UI/internal/helper"
"github.com/uozi-tech/cosy/logger"
)
// DockerUpgrade Upgrade the Docker container
func DockerUpgrade(ws *websocket.Conn, control *Control) {
func DockerUpgrade(ws *helper.SafeWebSocketWriter, control *Control) {
progressChan := make(chan float64)
// Start a goroutine to listen for progress updates and send them via WebSocket