feat(cert): persist draft on issuance entry, status transitions on completion

This commit is contained in:
Hintay
2026-05-23 05:03:17 +09:00
parent 43d83f49a9
commit 8acffbf078
2 changed files with 312 additions and 46 deletions
+156 -46
View File
@@ -1,17 +1,18 @@
package certificate
import (
"strings"
"time"
"github.com/0xJacky/Nginx-UI/internal/cert"
"github.com/0xJacky/Nginx-UI/internal/helper"
"github.com/0xJacky/Nginx-UI/internal/middleware"
"github.com/0xJacky/Nginx-UI/internal/translation"
"github.com/0xJacky/Nginx-UI/model"
"github.com/0xJacky/Nginx-UI/query"
"github.com/gin-gonic/gin"
"github.com/go-acme/lego/v5/certcrypto"
"github.com/gorilla/websocket"
"github.com/uozi-tech/cosy/logger"
"gorm.io/gen/field"
)
const (
@@ -34,35 +35,48 @@ func IssueCert(c *gin.Context) {
CheckOrigin: middleware.CheckWebSocketOrigin,
}
// upgrade http to websocket
ws, err := upGrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
return
}
defer ws.Close()
wsWriter := helper.NewSafeWebSocketWriter(ws)
// read
payload := &cert.ConfigPayload{}
err = ws.ReadJSON(payload)
if err != nil {
if err := ws.ReadJSON(payload); err != nil {
logger.Error(err)
return
}
payload.KeyType = payload.GetKeyType()
certModel, err := model.FirstOrInit(name, payload.GetKeyType())
certModel, err := persistCertDraft(name, payload)
if err != nil {
logger.Error(err)
_ = wsWriter.WriteJSON(IssueCertResponse{Status: Error, Message: err.Error()})
return
}
payload.CertID = certModel.ID
// Defer guard: if the function returns while still pending (panic / unexpected path),
// the record would otherwise be orphaned. Convert to failure with a generic message.
defer func() {
var current model.Cert
db := model.UseDB()
if db == nil {
return
}
if e := db.Where("id = ?", certModel.ID).First(&current).Error; e != nil {
return
}
if current.Status == model.CertStatusPending {
markCertFailure(certModel.ID, "Issuance interrupted before completion.")
}
}()
// Hydrate payload.Resource from the existing cert (for renewal path).
if certModel.SSLCertificatePath != "" {
certInfo, _ := cert.GetCertInfo(certModel.SSLCertificatePath)
if certInfo != nil {
@@ -72,58 +86,154 @@ func IssueCert(c *gin.Context) {
}
log := cert.NewLogger()
log.SetCertModel(&certModel)
log.SetCertModel(certModel)
log.SetWebSocket(wsWriter)
defer log.Close()
err = cert.IssueCert(payload, log)
if err != nil {
if err := cert.IssueCert(payload, log); err != nil {
log.Error(err)
_ = wsWriter.WriteJSON(IssueCertResponse{
Status: Error,
Message: err.Error(),
})
markCertFailure(certModel.ID, shortError(err))
_ = wsWriter.WriteJSON(IssueCertResponse{Status: Error, Message: err.Error()})
return
}
cert := query.Cert
markCertSuccess(certModel.ID, payload.GetCertificatePath(), payload.GetCertificateKeyPath(), payload.Resource)
_, err = cert.Where(cert.Name.Eq(name), cert.Filename.Eq(name),
cert.KeyType.In(helper.GetKeyTypeAliasStrings(payload.KeyType)...)).
Assign(field.Attrs(&model.Cert{
KeyType: payload.KeyType,
Domains: payload.ServerName,
SSLCertificatePath: payload.GetCertificatePath(),
SSLCertificateKeyPath: payload.GetCertificateKeyPath(),
AutoCert: model.AutoCertEnabled,
ChallengeMethod: payload.ChallengeMethod,
DnsCredentialID: payload.DNSCredentialID,
ACMEUserID: payload.ACMEUserID,
Resource: payload.Resource,
MustStaple: payload.MustStaple,
LegoDisableCNAMESupport: payload.LegoDisableCNAMESupport,
Log: log.ToString(),
RevokeOld: payload.RevokeOld,
})).FirstOrCreate()
if err != nil {
logger.Error(err)
_ = wsWriter.WriteJSON(IssueCertResponse{
Status: Error,
Message: err.Error(),
})
return
}
err = wsWriter.WriteJSON(IssueCertResponse{
if err := wsWriter.WriteJSON(IssueCertResponse{
Status: Success,
Message: translation.C("[Nginx UI] Issued certificate successfully").ToString(),
SSLCertificate: payload.GetCertificatePath(),
SSLCertificateKey: payload.GetCertificateKeyPath(),
KeyType: payload.GetKeyType(),
})
if err != nil {
}); err != nil {
if helper.IsUnexpectedWebsocketError(err) {
logger.Error(err)
}
return
}
}
// persistCertDraft inserts or updates a Cert row representing an in-flight issuance.
// The row is keyed by (name, filename, key_type). All user-submitted config is captured
// up-front so a failure preserves enough state for a one-click retry.
func persistCertDraft(name string, payload *cert.ConfigPayload) (*model.Cert, error) {
db := model.UseDB()
normalizedKeyType := helper.GetKeyType(payload.GetKeyType())
keyTypeAliases := helper.GetKeyTypeAliasStrings(normalizedKeyType)
now := time.Now()
seed := &model.Cert{
Name: name,
Filename: name,
KeyType: normalizedKeyType,
Domains: payload.ServerName,
ChallengeMethod: payload.ChallengeMethod,
DnsCredentialID: payload.DNSCredentialID,
ACMEUserID: payload.ACMEUserID,
AutoCert: model.AutoCertEnabled,
MustStaple: payload.MustStaple,
LegoDisableCNAMESupport: payload.LegoDisableCNAMESupport,
RevokeOld: payload.RevokeOld,
Status: model.CertStatusPending,
LastError: "",
LastAttemptAt: &now,
}
// FirstOrCreate by (name, filename, key_type). When the row exists,
// `seed` is hydrated with the existing record (preserving SSLCertificatePath,
// Resource, etc.) so we can read those fields on the renewal path below.
if err := db.Where("name = ? AND filename = ? AND key_type IN ?", name, name, keyTypeAliases).
FirstOrCreate(seed).Error; err != nil {
return nil, err
}
// Refresh all user-submitted config and reset issuance state to pending.
// Use struct + Select so GORM applies the `serializer:json` tag for Domains
// AND writes the zero-valued LastError ("") instead of skipping it.
updates := &model.Cert{
Domains: payload.ServerName,
ChallengeMethod: payload.ChallengeMethod,
DnsCredentialID: payload.DNSCredentialID,
ACMEUserID: payload.ACMEUserID,
AutoCert: model.AutoCertEnabled,
MustStaple: payload.MustStaple,
LegoDisableCNAMESupport: payload.LegoDisableCNAMESupport,
RevokeOld: payload.RevokeOld,
Status: model.CertStatusPending,
LastError: "",
LastAttemptAt: &now,
}
if err := db.Model(&model.Cert{}).Where("id = ?", seed.ID).
Select(
"domains", "challenge_method", "dns_credential_id", "acme_user_id",
"auto_cert", "must_staple", "lego_disable_cname_support",
"revoke_old", "status", "last_error", "last_attempt_at",
).
Updates(updates).Error; err != nil {
return nil, err
}
// Re-read so the caller has the fully-populated struct (Resource, paths, etc.).
var fresh model.Cert
if err := db.Where("id = ?", seed.ID).First(&fresh).Error; err != nil {
return nil, err
}
return &fresh, nil
}
// markCertFailure updates only the failure-related columns. It explicitly
// avoids touching SSLCertificatePath / SSLCertificateKeyPath / Resource so
// a renew failure does not destroy the previously-issued certificate.
// Map-based Updates is safe here because neither column has a serializer tag.
func markCertFailure(id uint64, lastError string) {
db := model.UseDB()
if db == nil {
return
}
if err := db.Model(&model.Cert{}).Where("id = ?", id).Updates(map[string]any{
"status": model.CertStatusFailure,
"last_error": lastError,
}).Error; err != nil {
logger.Errorf("markCertFailure: %v", err)
}
}
// markCertSuccess updates the cert with the freshly-issued paths and Resource,
// flips status to success, and clears any prior last_error. Uses struct + Select
// so GORM applies the `serializer:json[aes]` tag for Resource AND writes the
// zero-valued LastError ("").
func markCertSuccess(id uint64, sslCertificatePath, sslCertificateKeyPath string, resource *model.CertificateResource) {
db := model.UseDB()
if db == nil {
return
}
updates := &model.Cert{
SSLCertificatePath: sslCertificatePath,
SSLCertificateKeyPath: sslCertificateKeyPath,
Resource: resource,
Status: model.CertStatusSuccess,
LastError: "",
}
cols := []string{"ssl_certificate_path", "ssl_certificate_key_path", "status", "last_error"}
if resource != nil {
cols = append(cols, "resource")
}
if err := db.Model(&model.Cert{}).Where("id = ?", id).
Select(cols).Updates(updates).Error; err != nil {
logger.Errorf("markCertSuccess: %v", err)
}
}
// shortError trims and truncates an error for UI display in last_error.
// Returns "" for nil so a successful retry can clear the prior error.
func shortError(err error) string {
if err == nil {
return ""
}
msg := strings.TrimSpace(err.Error())
const max = 500
if len(msg) > max {
msg = msg[:max] + "…"
}
return msg
}
+156
View File
@@ -0,0 +1,156 @@
package certificate
import (
"testing"
"time"
"github.com/0xJacky/Nginx-UI/internal/cert"
"github.com/0xJacky/Nginx-UI/model"
"github.com/go-acme/lego/v5/certcrypto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
func setupCertTestDB(t *testing.T) *gorm.DB {
t.Helper()
// Use a per-test private in-memory DB. The literal ":memory:" (no shared cache)
// gives each gorm.Open a fresh isolated database, preventing cross-test pollution.
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
require.NoError(t, err)
require.NoError(t, db.AutoMigrate(&model.Cert{}))
model.Use(db)
t.Cleanup(func() { model.Use(nil) })
return db
}
func TestPersistCertDraftCreatesPendingRecord(t *testing.T) {
db := setupCertTestDB(t)
payload := &cert.ConfigPayload{
ServerName: []string{"example.com", "*.example.com"},
ChallengeMethod: "dns01",
DNSCredentialID: 42,
ACMEUserID: 7,
KeyType: certcrypto.RSA2048,
MustStaple: true,
LegoDisableCNAMESupport: true,
RevokeOld: true,
}
got, err := persistCertDraft("example.com", payload)
require.NoError(t, err)
assert.NotZero(t, got.ID)
assert.Equal(t, model.CertStatusPending, got.Status)
assert.Equal(t, "", got.LastError)
assert.NotNil(t, got.LastAttemptAt)
assert.WithinDuration(t, time.Now(), *got.LastAttemptAt, 5*time.Second)
var fromDB model.Cert
require.NoError(t, db.First(&fromDB, got.ID).Error)
assert.Equal(t, []string{"example.com", "*.example.com"}, fromDB.Domains)
assert.Equal(t, "dns01", fromDB.ChallengeMethod)
assert.Equal(t, uint64(42), fromDB.DnsCredentialID)
assert.Equal(t, uint64(7), fromDB.ACMEUserID)
assert.True(t, fromDB.MustStaple)
assert.True(t, fromDB.LegoDisableCNAMESupport)
assert.True(t, fromDB.RevokeOld)
assert.Equal(t, model.AutoCertEnabled, fromDB.AutoCert)
}
func TestPersistCertDraftReusesExistingRow(t *testing.T) {
db := setupCertTestDB(t)
existing := model.Cert{
Name: "example.com",
Filename: "example.com",
KeyType: certcrypto.RSA2048,
Status: model.CertStatusFailure,
LastError: "prior failure",
}
require.NoError(t, db.Create(&existing).Error)
payload := &cert.ConfigPayload{
ServerName: []string{"example.com"},
ChallengeMethod: "http01",
KeyType: certcrypto.RSA2048,
}
got, err := persistCertDraft("example.com", payload)
require.NoError(t, err)
assert.Equal(t, existing.ID, got.ID)
assert.Equal(t, model.CertStatusPending, got.Status)
assert.Equal(t, "", got.LastError)
var count int64
require.NoError(t, db.Model(&model.Cert{}).Where("name = ?", "example.com").Count(&count).Error)
assert.Equal(t, int64(1), count, "should reuse, not duplicate")
}
func TestMarkCertFailureSetsStatusAndError(t *testing.T) {
db := setupCertTestDB(t)
c := model.Cert{Name: "example.com", Filename: "example.com", Status: model.CertStatusPending}
require.NoError(t, db.Create(&c).Error)
markCertFailure(c.ID, "DNS challenge timed out after 60s")
var got model.Cert
require.NoError(t, db.First(&got, c.ID).Error)
assert.Equal(t, model.CertStatusFailure, got.Status)
assert.Equal(t, "DNS challenge timed out after 60s", got.LastError)
}
func TestMarkCertFailureDoesNotClobberResourceOrPaths(t *testing.T) {
db := setupCertTestDB(t)
c := model.Cert{
Name: "example.com",
Filename: "example.com",
Status: model.CertStatusPending,
SSLCertificatePath: "/etc/nginx/ssl/example.com/fullchain.cer",
SSLCertificateKeyPath: "/etc/nginx/ssl/example.com/private.key",
}
require.NoError(t, db.Create(&c).Error)
markCertFailure(c.ID, "renewal failed")
var got model.Cert
require.NoError(t, db.First(&got, c.ID).Error)
assert.Equal(t, "/etc/nginx/ssl/example.com/fullchain.cer", got.SSLCertificatePath, "must not erase paths")
assert.Equal(t, "/etc/nginx/ssl/example.com/private.key", got.SSLCertificateKeyPath)
}
func TestMarkCertSuccessClearsLastError(t *testing.T) {
db := setupCertTestDB(t)
c := model.Cert{
Name: "example.com",
Filename: "example.com",
Status: model.CertStatusPending,
LastError: "stale error",
}
require.NoError(t, db.Create(&c).Error)
markCertSuccess(c.ID, "/etc/nginx/ssl/example.com/fullchain.cer", "/etc/nginx/ssl/example.com/private.key", nil)
var got model.Cert
require.NoError(t, db.First(&got, c.ID).Error)
assert.Equal(t, model.CertStatusSuccess, got.Status)
assert.Equal(t, "", got.LastError)
assert.Equal(t, "/etc/nginx/ssl/example.com/fullchain.cer", got.SSLCertificatePath)
assert.Equal(t, "/etc/nginx/ssl/example.com/private.key", got.SSLCertificateKeyPath)
}
func TestShortError(t *testing.T) {
assert.Equal(t, "", shortError(nil))
assert.Equal(t, "hello", shortError(errString(" hello ")))
long := make([]byte, 600)
for i := range long {
long[i] = 'a'
}
got := shortError(errString(string(long)))
assert.Equal(t, 500+len("…"), len(got))
assert.Equal(t, "…", got[len(got)-len("…"):])
}
type errString string
func (e errString) Error() string { return string(e) }