mirror of
https://github.com/0xJacky/nginx-ui.git
synced 2026-06-19 07:36:59 +00:00
feat(cert): persist draft on issuance entry, status transitions on completion
This commit is contained in:
+156
-46
@@ -1,17 +1,18 @@
|
||||
package certificate
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/0xJacky/Nginx-UI/internal/cert"
|
||||
"github.com/0xJacky/Nginx-UI/internal/helper"
|
||||
"github.com/0xJacky/Nginx-UI/internal/middleware"
|
||||
"github.com/0xJacky/Nginx-UI/internal/translation"
|
||||
"github.com/0xJacky/Nginx-UI/model"
|
||||
"github.com/0xJacky/Nginx-UI/query"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-acme/lego/v5/certcrypto"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/uozi-tech/cosy/logger"
|
||||
"gorm.io/gen/field"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -34,35 +35,48 @@ func IssueCert(c *gin.Context) {
|
||||
CheckOrigin: middleware.CheckWebSocketOrigin,
|
||||
}
|
||||
|
||||
// upgrade http to websocket
|
||||
ws, err := upGrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
defer ws.Close()
|
||||
|
||||
wsWriter := helper.NewSafeWebSocketWriter(ws)
|
||||
|
||||
// read
|
||||
payload := &cert.ConfigPayload{}
|
||||
|
||||
err = ws.ReadJSON(payload)
|
||||
if err != nil {
|
||||
if err := ws.ReadJSON(payload); err != nil {
|
||||
logger.Error(err)
|
||||
return
|
||||
}
|
||||
payload.KeyType = payload.GetKeyType()
|
||||
|
||||
certModel, err := model.FirstOrInit(name, payload.GetKeyType())
|
||||
certModel, err := persistCertDraft(name, payload)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
_ = wsWriter.WriteJSON(IssueCertResponse{Status: Error, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
payload.CertID = certModel.ID
|
||||
|
||||
// Defer guard: if the function returns while still pending (panic / unexpected path),
|
||||
// the record would otherwise be orphaned. Convert to failure with a generic message.
|
||||
defer func() {
|
||||
var current model.Cert
|
||||
db := model.UseDB()
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
if e := db.Where("id = ?", certModel.ID).First(¤t).Error; e != nil {
|
||||
return
|
||||
}
|
||||
if current.Status == model.CertStatusPending {
|
||||
markCertFailure(certModel.ID, "Issuance interrupted before completion.")
|
||||
}
|
||||
}()
|
||||
|
||||
// Hydrate payload.Resource from the existing cert (for renewal path).
|
||||
if certModel.SSLCertificatePath != "" {
|
||||
certInfo, _ := cert.GetCertInfo(certModel.SSLCertificatePath)
|
||||
if certInfo != nil {
|
||||
@@ -72,58 +86,154 @@ func IssueCert(c *gin.Context) {
|
||||
}
|
||||
|
||||
log := cert.NewLogger()
|
||||
log.SetCertModel(&certModel)
|
||||
log.SetCertModel(certModel)
|
||||
log.SetWebSocket(wsWriter)
|
||||
defer log.Close()
|
||||
|
||||
err = cert.IssueCert(payload, log)
|
||||
if err != nil {
|
||||
if err := cert.IssueCert(payload, log); err != nil {
|
||||
log.Error(err)
|
||||
_ = wsWriter.WriteJSON(IssueCertResponse{
|
||||
Status: Error,
|
||||
Message: err.Error(),
|
||||
})
|
||||
markCertFailure(certModel.ID, shortError(err))
|
||||
_ = wsWriter.WriteJSON(IssueCertResponse{Status: Error, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
cert := query.Cert
|
||||
markCertSuccess(certModel.ID, payload.GetCertificatePath(), payload.GetCertificateKeyPath(), payload.Resource)
|
||||
|
||||
_, err = cert.Where(cert.Name.Eq(name), cert.Filename.Eq(name),
|
||||
cert.KeyType.In(helper.GetKeyTypeAliasStrings(payload.KeyType)...)).
|
||||
Assign(field.Attrs(&model.Cert{
|
||||
KeyType: payload.KeyType,
|
||||
Domains: payload.ServerName,
|
||||
SSLCertificatePath: payload.GetCertificatePath(),
|
||||
SSLCertificateKeyPath: payload.GetCertificateKeyPath(),
|
||||
AutoCert: model.AutoCertEnabled,
|
||||
ChallengeMethod: payload.ChallengeMethod,
|
||||
DnsCredentialID: payload.DNSCredentialID,
|
||||
ACMEUserID: payload.ACMEUserID,
|
||||
Resource: payload.Resource,
|
||||
MustStaple: payload.MustStaple,
|
||||
LegoDisableCNAMESupport: payload.LegoDisableCNAMESupport,
|
||||
Log: log.ToString(),
|
||||
RevokeOld: payload.RevokeOld,
|
||||
})).FirstOrCreate()
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
_ = wsWriter.WriteJSON(IssueCertResponse{
|
||||
Status: Error,
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
err = wsWriter.WriteJSON(IssueCertResponse{
|
||||
if err := wsWriter.WriteJSON(IssueCertResponse{
|
||||
Status: Success,
|
||||
Message: translation.C("[Nginx UI] Issued certificate successfully").ToString(),
|
||||
SSLCertificate: payload.GetCertificatePath(),
|
||||
SSLCertificateKey: payload.GetCertificateKeyPath(),
|
||||
KeyType: payload.GetKeyType(),
|
||||
})
|
||||
if err != nil {
|
||||
}); err != nil {
|
||||
if helper.IsUnexpectedWebsocketError(err) {
|
||||
logger.Error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// persistCertDraft inserts or updates a Cert row representing an in-flight issuance.
|
||||
// The row is keyed by (name, filename, key_type). All user-submitted config is captured
|
||||
// up-front so a failure preserves enough state for a one-click retry.
|
||||
func persistCertDraft(name string, payload *cert.ConfigPayload) (*model.Cert, error) {
|
||||
db := model.UseDB()
|
||||
normalizedKeyType := helper.GetKeyType(payload.GetKeyType())
|
||||
keyTypeAliases := helper.GetKeyTypeAliasStrings(normalizedKeyType)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
seed := &model.Cert{
|
||||
Name: name,
|
||||
Filename: name,
|
||||
KeyType: normalizedKeyType,
|
||||
Domains: payload.ServerName,
|
||||
ChallengeMethod: payload.ChallengeMethod,
|
||||
DnsCredentialID: payload.DNSCredentialID,
|
||||
ACMEUserID: payload.ACMEUserID,
|
||||
AutoCert: model.AutoCertEnabled,
|
||||
MustStaple: payload.MustStaple,
|
||||
LegoDisableCNAMESupport: payload.LegoDisableCNAMESupport,
|
||||
RevokeOld: payload.RevokeOld,
|
||||
Status: model.CertStatusPending,
|
||||
LastError: "",
|
||||
LastAttemptAt: &now,
|
||||
}
|
||||
|
||||
// FirstOrCreate by (name, filename, key_type). When the row exists,
|
||||
// `seed` is hydrated with the existing record (preserving SSLCertificatePath,
|
||||
// Resource, etc.) so we can read those fields on the renewal path below.
|
||||
if err := db.Where("name = ? AND filename = ? AND key_type IN ?", name, name, keyTypeAliases).
|
||||
FirstOrCreate(seed).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Refresh all user-submitted config and reset issuance state to pending.
|
||||
// Use struct + Select so GORM applies the `serializer:json` tag for Domains
|
||||
// AND writes the zero-valued LastError ("") instead of skipping it.
|
||||
updates := &model.Cert{
|
||||
Domains: payload.ServerName,
|
||||
ChallengeMethod: payload.ChallengeMethod,
|
||||
DnsCredentialID: payload.DNSCredentialID,
|
||||
ACMEUserID: payload.ACMEUserID,
|
||||
AutoCert: model.AutoCertEnabled,
|
||||
MustStaple: payload.MustStaple,
|
||||
LegoDisableCNAMESupport: payload.LegoDisableCNAMESupport,
|
||||
RevokeOld: payload.RevokeOld,
|
||||
Status: model.CertStatusPending,
|
||||
LastError: "",
|
||||
LastAttemptAt: &now,
|
||||
}
|
||||
if err := db.Model(&model.Cert{}).Where("id = ?", seed.ID).
|
||||
Select(
|
||||
"domains", "challenge_method", "dns_credential_id", "acme_user_id",
|
||||
"auto_cert", "must_staple", "lego_disable_cname_support",
|
||||
"revoke_old", "status", "last_error", "last_attempt_at",
|
||||
).
|
||||
Updates(updates).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Re-read so the caller has the fully-populated struct (Resource, paths, etc.).
|
||||
var fresh model.Cert
|
||||
if err := db.Where("id = ?", seed.ID).First(&fresh).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &fresh, nil
|
||||
}
|
||||
|
||||
// markCertFailure updates only the failure-related columns. It explicitly
|
||||
// avoids touching SSLCertificatePath / SSLCertificateKeyPath / Resource so
|
||||
// a renew failure does not destroy the previously-issued certificate.
|
||||
// Map-based Updates is safe here because neither column has a serializer tag.
|
||||
func markCertFailure(id uint64, lastError string) {
|
||||
db := model.UseDB()
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
if err := db.Model(&model.Cert{}).Where("id = ?", id).Updates(map[string]any{
|
||||
"status": model.CertStatusFailure,
|
||||
"last_error": lastError,
|
||||
}).Error; err != nil {
|
||||
logger.Errorf("markCertFailure: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// markCertSuccess updates the cert with the freshly-issued paths and Resource,
|
||||
// flips status to success, and clears any prior last_error. Uses struct + Select
|
||||
// so GORM applies the `serializer:json[aes]` tag for Resource AND writes the
|
||||
// zero-valued LastError ("").
|
||||
func markCertSuccess(id uint64, sslCertificatePath, sslCertificateKeyPath string, resource *model.CertificateResource) {
|
||||
db := model.UseDB()
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
updates := &model.Cert{
|
||||
SSLCertificatePath: sslCertificatePath,
|
||||
SSLCertificateKeyPath: sslCertificateKeyPath,
|
||||
Resource: resource,
|
||||
Status: model.CertStatusSuccess,
|
||||
LastError: "",
|
||||
}
|
||||
cols := []string{"ssl_certificate_path", "ssl_certificate_key_path", "status", "last_error"}
|
||||
if resource != nil {
|
||||
cols = append(cols, "resource")
|
||||
}
|
||||
if err := db.Model(&model.Cert{}).Where("id = ?", id).
|
||||
Select(cols).Updates(updates).Error; err != nil {
|
||||
logger.Errorf("markCertSuccess: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// shortError trims and truncates an error for UI display in last_error.
|
||||
// Returns "" for nil so a successful retry can clear the prior error.
|
||||
func shortError(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
msg := strings.TrimSpace(err.Error())
|
||||
const max = 500
|
||||
if len(msg) > max {
|
||||
msg = msg[:max] + "…"
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
@@ -0,0 +1,156 @@
|
||||
package certificate
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/0xJacky/Nginx-UI/internal/cert"
|
||||
"github.com/0xJacky/Nginx-UI/model"
|
||||
"github.com/go-acme/lego/v5/certcrypto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func setupCertTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
// Use a per-test private in-memory DB. The literal ":memory:" (no shared cache)
|
||||
// gives each gorm.Open a fresh isolated database, preventing cross-test pollution.
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.AutoMigrate(&model.Cert{}))
|
||||
model.Use(db)
|
||||
t.Cleanup(func() { model.Use(nil) })
|
||||
return db
|
||||
}
|
||||
|
||||
func TestPersistCertDraftCreatesPendingRecord(t *testing.T) {
|
||||
db := setupCertTestDB(t)
|
||||
|
||||
payload := &cert.ConfigPayload{
|
||||
ServerName: []string{"example.com", "*.example.com"},
|
||||
ChallengeMethod: "dns01",
|
||||
DNSCredentialID: 42,
|
||||
ACMEUserID: 7,
|
||||
KeyType: certcrypto.RSA2048,
|
||||
MustStaple: true,
|
||||
LegoDisableCNAMESupport: true,
|
||||
RevokeOld: true,
|
||||
}
|
||||
|
||||
got, err := persistCertDraft("example.com", payload)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, got.ID)
|
||||
assert.Equal(t, model.CertStatusPending, got.Status)
|
||||
assert.Equal(t, "", got.LastError)
|
||||
assert.NotNil(t, got.LastAttemptAt)
|
||||
assert.WithinDuration(t, time.Now(), *got.LastAttemptAt, 5*time.Second)
|
||||
|
||||
var fromDB model.Cert
|
||||
require.NoError(t, db.First(&fromDB, got.ID).Error)
|
||||
assert.Equal(t, []string{"example.com", "*.example.com"}, fromDB.Domains)
|
||||
assert.Equal(t, "dns01", fromDB.ChallengeMethod)
|
||||
assert.Equal(t, uint64(42), fromDB.DnsCredentialID)
|
||||
assert.Equal(t, uint64(7), fromDB.ACMEUserID)
|
||||
assert.True(t, fromDB.MustStaple)
|
||||
assert.True(t, fromDB.LegoDisableCNAMESupport)
|
||||
assert.True(t, fromDB.RevokeOld)
|
||||
assert.Equal(t, model.AutoCertEnabled, fromDB.AutoCert)
|
||||
}
|
||||
|
||||
func TestPersistCertDraftReusesExistingRow(t *testing.T) {
|
||||
db := setupCertTestDB(t)
|
||||
existing := model.Cert{
|
||||
Name: "example.com",
|
||||
Filename: "example.com",
|
||||
KeyType: certcrypto.RSA2048,
|
||||
Status: model.CertStatusFailure,
|
||||
LastError: "prior failure",
|
||||
}
|
||||
require.NoError(t, db.Create(&existing).Error)
|
||||
|
||||
payload := &cert.ConfigPayload{
|
||||
ServerName: []string{"example.com"},
|
||||
ChallengeMethod: "http01",
|
||||
KeyType: certcrypto.RSA2048,
|
||||
}
|
||||
|
||||
got, err := persistCertDraft("example.com", payload)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, existing.ID, got.ID)
|
||||
assert.Equal(t, model.CertStatusPending, got.Status)
|
||||
assert.Equal(t, "", got.LastError)
|
||||
|
||||
var count int64
|
||||
require.NoError(t, db.Model(&model.Cert{}).Where("name = ?", "example.com").Count(&count).Error)
|
||||
assert.Equal(t, int64(1), count, "should reuse, not duplicate")
|
||||
}
|
||||
|
||||
func TestMarkCertFailureSetsStatusAndError(t *testing.T) {
|
||||
db := setupCertTestDB(t)
|
||||
c := model.Cert{Name: "example.com", Filename: "example.com", Status: model.CertStatusPending}
|
||||
require.NoError(t, db.Create(&c).Error)
|
||||
|
||||
markCertFailure(c.ID, "DNS challenge timed out after 60s")
|
||||
|
||||
var got model.Cert
|
||||
require.NoError(t, db.First(&got, c.ID).Error)
|
||||
assert.Equal(t, model.CertStatusFailure, got.Status)
|
||||
assert.Equal(t, "DNS challenge timed out after 60s", got.LastError)
|
||||
}
|
||||
|
||||
func TestMarkCertFailureDoesNotClobberResourceOrPaths(t *testing.T) {
|
||||
db := setupCertTestDB(t)
|
||||
c := model.Cert{
|
||||
Name: "example.com",
|
||||
Filename: "example.com",
|
||||
Status: model.CertStatusPending,
|
||||
SSLCertificatePath: "/etc/nginx/ssl/example.com/fullchain.cer",
|
||||
SSLCertificateKeyPath: "/etc/nginx/ssl/example.com/private.key",
|
||||
}
|
||||
require.NoError(t, db.Create(&c).Error)
|
||||
|
||||
markCertFailure(c.ID, "renewal failed")
|
||||
|
||||
var got model.Cert
|
||||
require.NoError(t, db.First(&got, c.ID).Error)
|
||||
assert.Equal(t, "/etc/nginx/ssl/example.com/fullchain.cer", got.SSLCertificatePath, "must not erase paths")
|
||||
assert.Equal(t, "/etc/nginx/ssl/example.com/private.key", got.SSLCertificateKeyPath)
|
||||
}
|
||||
|
||||
func TestMarkCertSuccessClearsLastError(t *testing.T) {
|
||||
db := setupCertTestDB(t)
|
||||
c := model.Cert{
|
||||
Name: "example.com",
|
||||
Filename: "example.com",
|
||||
Status: model.CertStatusPending,
|
||||
LastError: "stale error",
|
||||
}
|
||||
require.NoError(t, db.Create(&c).Error)
|
||||
|
||||
markCertSuccess(c.ID, "/etc/nginx/ssl/example.com/fullchain.cer", "/etc/nginx/ssl/example.com/private.key", nil)
|
||||
|
||||
var got model.Cert
|
||||
require.NoError(t, db.First(&got, c.ID).Error)
|
||||
assert.Equal(t, model.CertStatusSuccess, got.Status)
|
||||
assert.Equal(t, "", got.LastError)
|
||||
assert.Equal(t, "/etc/nginx/ssl/example.com/fullchain.cer", got.SSLCertificatePath)
|
||||
assert.Equal(t, "/etc/nginx/ssl/example.com/private.key", got.SSLCertificateKeyPath)
|
||||
}
|
||||
|
||||
func TestShortError(t *testing.T) {
|
||||
assert.Equal(t, "", shortError(nil))
|
||||
assert.Equal(t, "hello", shortError(errString(" hello ")))
|
||||
long := make([]byte, 600)
|
||||
for i := range long {
|
||||
long[i] = 'a'
|
||||
}
|
||||
got := shortError(errString(string(long)))
|
||||
assert.Equal(t, 500+len("…"), len(got))
|
||||
assert.Equal(t, "…", got[len(got)-len("…"):])
|
||||
}
|
||||
|
||||
type errString string
|
||||
|
||||
func (e errString) Error() string { return string(e) }
|
||||
Reference in New Issue
Block a user