♻️ Optimize code
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-git/go-git/v5"
|
||||
@@ -36,6 +37,8 @@ type BackupService struct {
|
||||
isInitialized bool
|
||||
autoBackupTicker *time.Ticker
|
||||
autoBackupStop chan bool
|
||||
autoBackupWg sync.WaitGroup // 等待自动备份goroutine完成
|
||||
mu sync.Mutex // 推送操作互斥锁
|
||||
|
||||
// 配置观察者取消函数
|
||||
cancelObserver CancelFunc
|
||||
@@ -86,6 +89,11 @@ func (s *BackupService) Initialize() error {
|
||||
return fmt.Errorf("initializing repository: %w", err)
|
||||
}
|
||||
|
||||
// 验证远程仓库连接
|
||||
if err := s.verifyRemoteConnection(config); err != nil {
|
||||
return fmt.Errorf("verifying remote connection: %w", err)
|
||||
}
|
||||
|
||||
// 启动自动备份
|
||||
if config.AutoBackup && config.BackupInterval > 0 {
|
||||
s.StartAutoBackup()
|
||||
@@ -161,6 +169,22 @@ func (s *BackupService) initializeRepository(config *models.GitBackupConfig, rep
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyRemoteConnection 验证远程仓库连接
|
||||
func (s *BackupService) verifyRemoteConnection(config *models.GitBackupConfig) error {
|
||||
auth, err := s.getAuthMethod(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
remote, err := s.repository.Remote("origin")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = remote.List(&git.ListOptions{Auth: auth})
|
||||
return err
|
||||
}
|
||||
|
||||
// getAuthMethod 根据配置获取认证方法
|
||||
func (s *BackupService) getAuthMethod(config *models.GitBackupConfig) (transport.AuthMethod, error) {
|
||||
switch config.AuthMethod {
|
||||
@@ -203,31 +227,15 @@ func (s *BackupService) serializeDatabase(repoPath string) error {
|
||||
return errors.New("database service not available")
|
||||
}
|
||||
|
||||
// 获取数据库路径
|
||||
dbPath, err := s.dbService.getDatabasePath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting database path: %w", err)
|
||||
}
|
||||
|
||||
// 关闭数据库连接以确保所有更改都写入磁盘
|
||||
if err := s.dbService.ServiceShutdown(); err != nil {
|
||||
s.logger.Error("Failed to close database connection", "error", err)
|
||||
}
|
||||
|
||||
// 直接复制数据库文件到序列化文件
|
||||
dbData, err := os.ReadFile(dbPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading database file: %w", err)
|
||||
}
|
||||
|
||||
binFilePath := filepath.Join(repoPath, dbSerializeFile)
|
||||
if err := os.WriteFile(binFilePath, dbData, 0644); err != nil {
|
||||
return fmt.Errorf("writing serialized database to file: %w", err)
|
||||
}
|
||||
|
||||
// 重新初始化数据库服务
|
||||
if err := s.dbService.initDatabase(); err != nil {
|
||||
return fmt.Errorf("reinitializing database: %w", err)
|
||||
// 使用 VACUUM INTO 创建数据库副本,不影响现有连接
|
||||
s.dbService.mu.RLock()
|
||||
_, err := s.dbService.db.Exec(fmt.Sprintf("VACUUM INTO '%s'", binFilePath))
|
||||
s.dbService.mu.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating database backup: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -235,6 +243,10 @@ func (s *BackupService) serializeDatabase(repoPath string) error {
|
||||
|
||||
// PushToRemote 推送本地更改到远程仓库
|
||||
func (s *BackupService) PushToRemote() error {
|
||||
// 互斥锁防止并发推送
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.isInitialized {
|
||||
return errors.New("backup service not initialized")
|
||||
}
|
||||
@@ -248,56 +260,62 @@ func (s *BackupService) PushToRemote() error {
|
||||
return errors.New("backup is disabled")
|
||||
}
|
||||
|
||||
// 数据库序列化文件的路径
|
||||
// 检查是否有未推送的commit
|
||||
hasUnpushed, err := s.hasUnpushedCommits()
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking unpushed commits: %w", err)
|
||||
}
|
||||
|
||||
binFilePath := filepath.Join(repoPath, dbSerializeFile)
|
||||
|
||||
// 函数返回前都删除临时文件
|
||||
defer func() {
|
||||
if _, err := os.Stat(binFilePath); err == nil {
|
||||
os.Remove(binFilePath)
|
||||
// 只有在没有未推送commit时才创建新commit
|
||||
if !hasUnpushed {
|
||||
// 序列化数据库
|
||||
if err := s.serializeDatabase(repoPath); err != nil {
|
||||
return fmt.Errorf("serializing database: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 序列化数据库
|
||||
if err := s.serializeDatabase(repoPath); err != nil {
|
||||
return fmt.Errorf("serializing database: %w", err)
|
||||
}
|
||||
// 获取工作树
|
||||
w, err := s.repository.Worktree()
|
||||
if err != nil {
|
||||
os.Remove(binFilePath)
|
||||
return fmt.Errorf("getting worktree: %w", err)
|
||||
}
|
||||
|
||||
// 获取工作树
|
||||
w, err := s.repository.Worktree()
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting worktree: %w", err)
|
||||
}
|
||||
// 添加序列化的数据库文件
|
||||
if _, err := w.Add(dbSerializeFile); err != nil {
|
||||
os.Remove(binFilePath)
|
||||
return fmt.Errorf("adding serialized database file: %w", err)
|
||||
}
|
||||
|
||||
// 添加序列化的数据库文件
|
||||
if _, err := w.Add(dbSerializeFile); err != nil {
|
||||
return fmt.Errorf("adding serialized database file: %w", err)
|
||||
}
|
||||
// 检查是否有变化需要提交
|
||||
status, err := w.Status()
|
||||
if err != nil {
|
||||
os.Remove(binFilePath)
|
||||
return fmt.Errorf("getting worktree status: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否有变化需要提交
|
||||
status, err := w.Status()
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting worktree status: %w", err)
|
||||
}
|
||||
|
||||
// 如果没有变化,直接返回
|
||||
if status.IsClean() {
|
||||
return errors.New("no changes to backup")
|
||||
}
|
||||
|
||||
// 创建提交
|
||||
_, err = w.Commit(fmt.Sprintf("Backup %s", time.Now().Format("2006-01-02 15:04:05")), &git.CommitOptions{
|
||||
Author: &object.Signature{
|
||||
Name: "voidraft",
|
||||
Email: "backup@voidraft.app",
|
||||
When: time.Now(),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "cannot create empty commit") {
|
||||
// 如果没有变化,删除文件并返回
|
||||
if status.IsClean() {
|
||||
os.Remove(binFilePath)
|
||||
return errors.New("no changes to backup")
|
||||
}
|
||||
return fmt.Errorf("creating commit: %w", err)
|
||||
|
||||
// 创建提交
|
||||
_, err = w.Commit(fmt.Sprintf("Backup %s", time.Now().Format("2006-01-02 15:04:05")), &git.CommitOptions{
|
||||
Author: &object.Signature{
|
||||
Name: "voidraft",
|
||||
Email: "backup@voidraft.app",
|
||||
When: time.Now(),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
os.Remove(binFilePath)
|
||||
if strings.Contains(err.Error(), "cannot create empty commit") {
|
||||
return errors.New("no changes to backup")
|
||||
}
|
||||
return fmt.Errorf("creating commit: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取认证方法并推送到远程
|
||||
@@ -306,25 +324,57 @@ func (s *BackupService) PushToRemote() error {
|
||||
return fmt.Errorf("getting auth method: %w", err)
|
||||
}
|
||||
|
||||
// 推送到远程仓库
|
||||
// 推送到远程仓库(包括之前失败的commit)
|
||||
if err := s.repository.Push(&git.PushOptions{
|
||||
RemoteName: "origin",
|
||||
Auth: auth,
|
||||
}); err != nil && !errors.Is(err, git.NoErrAlreadyUpToDate) {
|
||||
// 忽略一些常见的非错误情况
|
||||
if strings.Contains(err.Error(), "clean working tree") ||
|
||||
strings.Contains(err.Error(), "already up-to-date") ||
|
||||
strings.Contains(err.Error(), " clean working tree") ||
|
||||
strings.Contains(err.Error(), "reference not found") {
|
||||
// 更新最后推送时间
|
||||
return errors.New("no changes to backup")
|
||||
}
|
||||
return fmt.Errorf("push failed: %w", err)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 只在推送成功后删除临时文件
|
||||
os.Remove(binFilePath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasUnpushedCommits 检查是否有未推送的commit
|
||||
func (s *BackupService) hasUnpushedCommits() (bool, error) {
|
||||
localRef, err := s.repository.Head()
|
||||
if err != nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
config, _, err := s.getConfigAndPath()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
auth, err := s.getAuthMethod(config)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
remote, err := s.repository.Remote("origin")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
refs, err := remote.List(&git.ListOptions{Auth: auth})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
localHash := localRef.Hash()
|
||||
|
||||
for _, ref := range refs {
|
||||
if ref.Name() == localRef.Name() {
|
||||
return localHash != ref.Hash(), nil
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// StartAutoBackup 启动自动备份定时器
|
||||
func (s *BackupService) StartAutoBackup() error {
|
||||
config, _, err := s.getConfigAndPath()
|
||||
@@ -342,14 +392,13 @@ func (s *BackupService) StartAutoBackup() error {
|
||||
s.autoBackupTicker = time.NewTicker(time.Duration(config.BackupInterval) * time.Minute)
|
||||
s.autoBackupStop = make(chan bool)
|
||||
|
||||
s.autoBackupWg.Add(1)
|
||||
go func() {
|
||||
defer s.autoBackupWg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-s.autoBackupTicker.C:
|
||||
// 执行推送操作
|
||||
if err := s.PushToRemote(); err != nil {
|
||||
s.logger.Error("Auto backup failed", "error", err)
|
||||
}
|
||||
s.PushToRemote()
|
||||
case <-s.autoBackupStop:
|
||||
return
|
||||
}
|
||||
@@ -369,16 +418,18 @@ func (s *BackupService) StopAutoBackup() {
|
||||
if s.autoBackupStop != nil {
|
||||
close(s.autoBackupStop)
|
||||
s.autoBackupStop = nil
|
||||
s.autoBackupWg.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
// Reinitialize 重新初始化备份服务,用于响应配置变更
|
||||
func (s *BackupService) Reinitialize() error {
|
||||
// 停止自动备份
|
||||
// 先停止自动备份,等待goroutine完成
|
||||
s.StopAutoBackup()
|
||||
|
||||
// 重新设置标志
|
||||
s.mu.Lock()
|
||||
s.isInitialized = false
|
||||
s.mu.Unlock()
|
||||
|
||||
// 重新初始化
|
||||
return s.Initialize()
|
||||
@@ -386,11 +437,12 @@ func (s *BackupService) Reinitialize() error {
|
||||
|
||||
// HandleConfigChange 处理备份配置变更
|
||||
func (s *BackupService) HandleConfigChange(config *models.GitBackupConfig) error {
|
||||
|
||||
// 如果备份功能禁用,只需停止自动备份
|
||||
if !config.Enabled {
|
||||
s.StopAutoBackup()
|
||||
s.mu.Lock()
|
||||
s.isInitialized = false
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -11,9 +11,6 @@ import (
|
||||
)
|
||||
|
||||
// ObserverCallback 观察者回调函数
|
||||
// 参数:
|
||||
// - oldValue: 配置变更前的值
|
||||
// - newValue: 配置变更后的值
|
||||
type ObserverCallback func(oldValue, newValue interface{})
|
||||
|
||||
// CancelFunc 取消订阅函数
|
||||
@@ -28,7 +25,6 @@ type observer struct {
|
||||
}
|
||||
|
||||
// ConfigObserver 配置观察者系统
|
||||
// 提供轻量级的配置变更监听机制
|
||||
type ConfigObserver struct {
|
||||
observers map[string][]*observer // 路径 -> 观察者列表
|
||||
observerMu sync.RWMutex // 观察者锁
|
||||
@@ -53,19 +49,6 @@ func NewConfigObserver(logger *log.LogService) *ConfigObserver {
|
||||
}
|
||||
|
||||
// Watch 注册配置变更监听器
|
||||
// 参数:
|
||||
// - path: 配置路径,如 "general.enableGlobalHotkey"
|
||||
// - callback: 变更回调函数,接收旧值和新值
|
||||
//
|
||||
// 返回:
|
||||
// - CancelFunc: 取消监听的函数,务必在不需要时调用以避免内存泄漏
|
||||
//
|
||||
// 示例:
|
||||
//
|
||||
// cancel := observer.Watch("general.hotkey", func(old, new interface{}) {
|
||||
// fmt.Printf("配置从 %v 变更为 %v\n", old, new)
|
||||
// })
|
||||
// defer cancel() // 确保清理
|
||||
func (co *ConfigObserver) Watch(path string, callback ObserverCallback) CancelFunc {
|
||||
// 生成唯一ID
|
||||
id := fmt.Sprintf("obs_%d", co.nextObserverID.Add(1))
|
||||
@@ -88,17 +71,6 @@ func (co *ConfigObserver) Watch(path string, callback ObserverCallback) CancelFu
|
||||
}
|
||||
|
||||
// WatchWithContext 使用 Context 注册监听器,Context 取消时自动清理
|
||||
// 参数:
|
||||
// - ctx: Context,取消时自动移除观察者
|
||||
// - path: 配置路径
|
||||
// - callback: 变更回调函数
|
||||
//
|
||||
// 示例:
|
||||
//
|
||||
// ctx, cancel := context.WithCancel(context.Background())
|
||||
// defer cancel()
|
||||
// observer.WatchWithContext(ctx, "general.hotkey", callback)
|
||||
// // Context 取消时自动清理
|
||||
func (co *ConfigObserver) WatchWithContext(ctx context.Context, path string, callback ObserverCallback) {
|
||||
cancel := co.Watch(path, callback)
|
||||
go func() {
|
||||
@@ -132,12 +104,6 @@ func (co *ConfigObserver) removeObserver(path, id string) {
|
||||
}
|
||||
|
||||
// Notify 通知指定路径的所有观察者
|
||||
// 参数:
|
||||
// - path: 配置路径
|
||||
// - oldValue: 旧值
|
||||
// - newValue: 新值
|
||||
//
|
||||
// 注意:此方法会在独立的 goroutine 中异步执行回调,不会阻塞调用者
|
||||
func (co *ConfigObserver) Notify(path string, oldValue, newValue interface{}) {
|
||||
// 获取该路径的所有观察者(拷贝以避免并发问题)
|
||||
co.observerMu.RLock()
|
||||
@@ -222,7 +188,6 @@ func (co *ConfigObserver) Clear() {
|
||||
}
|
||||
|
||||
// Shutdown 关闭观察者系统
|
||||
// 等待所有正在执行的回调完成
|
||||
func (co *ConfigObserver) Shutdown() {
|
||||
// 取消 context
|
||||
co.cancel()
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/wailsapp/wails/v3/pkg/application"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -45,25 +46,29 @@ func NewConfigService(logger *log.LogService) *ConfigService {
|
||||
configDir := filepath.Join(homeDir, ".voidraft", "config")
|
||||
settingsPath := filepath.Join(configDir, "settings.json")
|
||||
|
||||
cs := &ConfigService{
|
||||
logger: logger,
|
||||
configDir: configDir,
|
||||
settingsPath: settingsPath,
|
||||
koanf: koanf.New("."),
|
||||
observerService := NewConfigObserver(logger)
|
||||
|
||||
configMigrator := NewConfigMigrator(logger, configDir, "settings", settingsPath)
|
||||
|
||||
return &ConfigService{
|
||||
logger: logger,
|
||||
configDir: configDir,
|
||||
settingsPath: settingsPath,
|
||||
koanf: koanf.New("."),
|
||||
observer: observerService,
|
||||
configMigrator: configMigrator,
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化配置观察者系统
|
||||
cs.observer = NewConfigObserver(logger)
|
||||
|
||||
// 初始化配置迁移器
|
||||
cs.configMigrator = NewConfigMigrator(logger, configDir, "settings", settingsPath)
|
||||
|
||||
cs.initConfig()
|
||||
|
||||
// ServiceStartup initializes the service when the application starts
|
||||
func (cs *ConfigService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error {
|
||||
err := cs.initConfig()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// 启动配置文件监听
|
||||
cs.startWatching()
|
||||
|
||||
return cs
|
||||
return nil
|
||||
}
|
||||
|
||||
// setDefaults 设置默认配置
|
||||
@@ -103,19 +108,10 @@ func (cs *ConfigService) MigrateConfig() error {
|
||||
}
|
||||
|
||||
defaultConfig := models.NewDefaultAppConfig()
|
||||
result, err := cs.configMigrator.AutoMigrate(defaultConfig, cs.koanf)
|
||||
|
||||
_, err := cs.configMigrator.AutoMigrate(defaultConfig, cs.koanf)
|
||||
if err != nil {
|
||||
cs.logger.Error("Failed to check config migration", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if result != nil && result.Migrated {
|
||||
cs.logger.Info("Config migration performed",
|
||||
"fields", result.MissingFields,
|
||||
"backup", result.BackupPath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -156,10 +152,11 @@ func (cs *ConfigService) startWatching() {
|
||||
cs.mu.Lock()
|
||||
oldSnapshot := cs.createConfigSnapshot()
|
||||
cs.koanf.Load(cs.fileProvider, jsonparser.Parser())
|
||||
newSnapshot := cs.createConfigSnapshot()
|
||||
cs.mu.Unlock()
|
||||
|
||||
// 检测配置变更并通知观察者
|
||||
cs.detectAndNotifyChanges(oldSnapshot)
|
||||
cs.notifyChanges(oldSnapshot, newSnapshot)
|
||||
})
|
||||
|
||||
}
|
||||
@@ -188,23 +185,32 @@ func (cs *ConfigService) GetConfig() (*models.AppConfig, error) {
|
||||
func (cs *ConfigService) Set(key string, value interface{}) error {
|
||||
cs.mu.Lock()
|
||||
|
||||
// 获取旧值
|
||||
// 获取旧值用于回滚
|
||||
oldValue := cs.koanf.Get(key)
|
||||
|
||||
// 设置值到koanf
|
||||
cs.koanf.Set(key, value)
|
||||
|
||||
// 更新时间戳
|
||||
cs.koanf.Set("metadata.lastUpdated", time.Now().Format(time.RFC3339))
|
||||
newTimestamp := time.Now().Format(time.RFC3339)
|
||||
cs.koanf.Set("metadata.lastUpdated", newTimestamp)
|
||||
|
||||
// 将配置写回文件
|
||||
err := cs.writeConfigToFile()
|
||||
cs.mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
// 写文件失败,回滚内存状态
|
||||
if oldValue != nil {
|
||||
cs.koanf.Set(key, oldValue)
|
||||
} else {
|
||||
cs.koanf.Delete(key)
|
||||
}
|
||||
cs.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
cs.mu.Unlock()
|
||||
|
||||
if cs.observer != nil {
|
||||
cs.observer.Notify(key, oldValue, value)
|
||||
}
|
||||
@@ -262,13 +268,14 @@ func (cs *ConfigService) ResetConfig() error {
|
||||
return err
|
||||
}
|
||||
|
||||
newSnapshot := cs.createConfigSnapshot()
|
||||
cs.mu.Unlock()
|
||||
|
||||
// 重新启动文件监听
|
||||
cs.startWatching()
|
||||
|
||||
// 检测配置变更并通知观察者
|
||||
cs.detectAndNotifyChanges(oldSnapshot)
|
||||
cs.notifyChanges(oldSnapshot, newSnapshot)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -297,14 +304,10 @@ func (cs *ConfigService) WatchWithContext(ctx context.Context, path string, call
|
||||
cs.observer.WatchWithContext(ctx, path, callback)
|
||||
}
|
||||
|
||||
// createConfigSnapshot 创建当前配置的快照
|
||||
// createConfigSnapshot 创建当前配置的快照(调用者需确保已持有锁)
|
||||
func (cs *ConfigService) createConfigSnapshot() map[string]interface{} {
|
||||
cs.mu.RLock()
|
||||
defer cs.mu.RUnlock()
|
||||
snapshot := make(map[string]interface{})
|
||||
allKeys := cs.koanf.All()
|
||||
|
||||
// 递归展平配置
|
||||
flattenMap("", allKeys, snapshot)
|
||||
return snapshot
|
||||
}
|
||||
@@ -331,11 +334,8 @@ func flattenMap(prefix string, data map[string]interface{}, result map[string]in
|
||||
}
|
||||
}
|
||||
|
||||
// detectAndNotifyChanges 检测配置变更并通知观察者
|
||||
func (cs *ConfigService) detectAndNotifyChanges(oldSnapshot map[string]interface{}) {
|
||||
// 创建新快照
|
||||
newSnapshot := cs.createConfigSnapshot()
|
||||
|
||||
// notifyChanges 检测配置变更并通知观察者
|
||||
func (cs *ConfigService) notifyChanges(oldSnapshot, newSnapshot map[string]interface{}) {
|
||||
// 检测变更
|
||||
changes := make(map[string]struct {
|
||||
OldValue interface{}
|
||||
|
||||
@@ -1,254 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/wailsapp/wails/v3/pkg/services/log"
|
||||
)
|
||||
|
||||
// StoreOption 存储服务配置选项
|
||||
type StoreOption struct {
|
||||
FilePath string
|
||||
AutoSave bool
|
||||
Logger *log.LogService
|
||||
}
|
||||
|
||||
// Store 泛型存储服务
|
||||
type Store[T any] struct {
|
||||
option StoreOption
|
||||
data atomic.Value // stores T
|
||||
dataMap sync.Map // thread-safe map
|
||||
unsaved atomic.Bool
|
||||
initOnce sync.Once
|
||||
logger *log.LogService
|
||||
}
|
||||
|
||||
// NewStore 存储服务
|
||||
func NewStore[T any](option StoreOption) *Store[T] {
|
||||
logger := option.Logger
|
||||
if logger == nil {
|
||||
logger = log.New()
|
||||
}
|
||||
|
||||
store := &Store[T]{
|
||||
option: option,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// 异步初始化
|
||||
store.initOnce.Do(func() {
|
||||
store.initialize()
|
||||
})
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
// initialize 初始化存储
|
||||
func (s *Store[T]) initialize() {
|
||||
// 确保目录存在
|
||||
if s.option.FilePath != "" {
|
||||
if err := os.MkdirAll(filepath.Dir(s.option.FilePath), 0755); err != nil {
|
||||
s.logger.Error("store: failed to create directory", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 加载数据
|
||||
s.load()
|
||||
}
|
||||
|
||||
// load 加载数据
|
||||
func (s *Store[T]) load() {
|
||||
if s.option.FilePath == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查文件是否存在
|
||||
if _, err := os.Stat(s.option.FilePath); os.IsNotExist(err) {
|
||||
return
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(s.option.FilePath)
|
||||
if err != nil {
|
||||
s.logger.Error("store: failed to read file", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var value T
|
||||
if err := json.Unmarshal(data, &value); err != nil {
|
||||
// 尝试加载为map格式
|
||||
var mapData map[string]any
|
||||
if err := json.Unmarshal(data, &mapData); err != nil {
|
||||
s.logger.Error("store: failed to parse data", "error", err)
|
||||
return
|
||||
}
|
||||
// 将map数据存储到sync.Map中
|
||||
for k, v := range mapData {
|
||||
s.dataMap.Store(k, v)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
s.data.Store(value)
|
||||
}
|
||||
|
||||
// Save 保存数据
|
||||
func (s *Store[T]) Save() error {
|
||||
if !s.unsaved.Load() {
|
||||
return nil // 没有未保存的更改
|
||||
}
|
||||
|
||||
if err := s.saveInternal(); err != nil {
|
||||
return fmt.Errorf("store: failed to save: %w", err)
|
||||
}
|
||||
|
||||
s.unsaved.Store(false)
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveInternal 内部保存实现
|
||||
func (s *Store[T]) saveInternal() error {
|
||||
if s.option.FilePath == "" {
|
||||
return fmt.Errorf("store: filepath not set")
|
||||
}
|
||||
|
||||
// 获取要保存的数据
|
||||
var data []byte
|
||||
var err error
|
||||
|
||||
if value := s.data.Load(); value != nil {
|
||||
data, err = json.MarshalIndent(value, "", " ")
|
||||
} else {
|
||||
// 如果没有结构化数据,保存map数据
|
||||
mapData := make(map[string]any)
|
||||
s.dataMap.Range(func(key, value any) bool {
|
||||
if k, ok := key.(string); ok {
|
||||
mapData[k] = value
|
||||
}
|
||||
return true
|
||||
})
|
||||
data, err = json.MarshalIndent(mapData, "", " ")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to serialize data: %w", err)
|
||||
}
|
||||
|
||||
// 原子写入
|
||||
return s.atomicWrite(data)
|
||||
}
|
||||
|
||||
// atomicWrite 原子写入文件
|
||||
func (s *Store[T]) atomicWrite(data []byte) error {
|
||||
dir := filepath.Dir(s.option.FilePath)
|
||||
|
||||
// 创建临时文件
|
||||
tempFile, err := os.CreateTemp(dir, "store-*.tmp")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
|
||||
tempPath := tempFile.Name()
|
||||
defer func() {
|
||||
tempFile.Close()
|
||||
if err != nil {
|
||||
os.Remove(tempPath)
|
||||
}
|
||||
}()
|
||||
|
||||
// 写入数据并同步
|
||||
if _, err = tempFile.Write(data); err != nil {
|
||||
return fmt.Errorf("failed to write data: %w", err)
|
||||
}
|
||||
|
||||
if err = tempFile.Sync(); err != nil {
|
||||
return fmt.Errorf("failed to sync file: %w", err)
|
||||
}
|
||||
|
||||
if err = tempFile.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close temp file: %w", err)
|
||||
}
|
||||
|
||||
// 原子替换
|
||||
if err = os.Rename(tempPath, s.option.FilePath); err != nil {
|
||||
return fmt.Errorf("failed to rename file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get 获取数据
|
||||
func (s *Store[T]) Get() T {
|
||||
if value := s.data.Load(); value != nil {
|
||||
return value.(T)
|
||||
}
|
||||
var zero T
|
||||
return zero
|
||||
}
|
||||
|
||||
// GetProperty 获取指定属性
|
||||
func (s *Store[T]) GetProperty(key string) any {
|
||||
if key == "" {
|
||||
// 返回所有map数据
|
||||
result := make(map[string]any)
|
||||
s.dataMap.Range(func(k, v any) bool {
|
||||
if str, ok := k.(string); ok {
|
||||
result[str] = v
|
||||
}
|
||||
return true
|
||||
})
|
||||
return result
|
||||
}
|
||||
|
||||
if value, ok := s.dataMap.Load(key); ok {
|
||||
return value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set 设置数据
|
||||
func (s *Store[T]) Set(data T) error {
|
||||
s.data.Store(data)
|
||||
s.unsaved.Store(true)
|
||||
|
||||
if s.option.AutoSave {
|
||||
return s.saveInternal()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetProperty 设置指定属性
|
||||
func (s *Store[T]) SetProperty(key string, value any) error {
|
||||
s.dataMap.Store(key, value)
|
||||
s.unsaved.Store(true)
|
||||
|
||||
if s.option.AutoSave {
|
||||
return s.saveInternal()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除指定属性
|
||||
func (s *Store[T]) Delete(key string) error {
|
||||
s.dataMap.Delete(key)
|
||||
s.unsaved.Store(true)
|
||||
|
||||
if s.option.AutoSave {
|
||||
return s.saveInternal()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasUnsavedChanges 是否有未保存的更改
|
||||
func (s *Store[T]) HasUnsavedChanges() bool {
|
||||
return s.unsaved.Load()
|
||||
}
|
||||
Reference in New Issue
Block a user