🚧 Refactor basic services
This commit is contained in:
@@ -20,8 +20,6 @@ import (
|
||||
"github.com/wailsapp/wails/v3/pkg/services/log"
|
||||
|
||||
"voidraft/internal/models"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -223,20 +221,20 @@ func (s *BackupService) getAuthMethod(config *models.GitBackupConfig) (transport
|
||||
|
||||
// serializeDatabase 序列化数据库到文件
|
||||
func (s *BackupService) serializeDatabase(repoPath string) error {
|
||||
if s.dbService == nil || s.dbService.db == nil {
|
||||
return errors.New("database service not available")
|
||||
}
|
||||
|
||||
binFilePath := filepath.Join(repoPath, dbSerializeFile)
|
||||
|
||||
// 使用 VACUUM INTO 创建数据库副本,不影响现有连接
|
||||
s.dbService.mu.RLock()
|
||||
_, err := s.dbService.db.Exec(fmt.Sprintf("VACUUM INTO '%s'", binFilePath))
|
||||
s.dbService.mu.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating database backup: %w", err)
|
||||
}
|
||||
//if s.dbService == nil || s.dbService.Engine == nil {
|
||||
// return errors.New("database service not available")
|
||||
//}
|
||||
//
|
||||
//binFilePath := filepath.Join(repoPath, dbSerializeFile)
|
||||
//
|
||||
//// 使用 VACUUM INTO 创建数据库副本,不影响现有连接
|
||||
//s.dbService.mu.RLock()
|
||||
//_, err := s.dbService.db.Exec(fmt.Sprintf("VACUUM INTO '%s'", binFilePath))
|
||||
//s.dbService.mu.RUnlock()
|
||||
//
|
||||
//if err != nil {
|
||||
// return fmt.Errorf("creating database backup: %w", err)
|
||||
//}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -12,32 +12,23 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/wailsapp/wails/v3/pkg/application"
|
||||
"github.com/wailsapp/wails/v3/pkg/services/log"
|
||||
)
|
||||
|
||||
// MigrationStatus 迁移状态
|
||||
type MigrationStatus string
|
||||
|
||||
const (
|
||||
MigrationStatusMigrating MigrationStatus = "migrating"
|
||||
MigrationStatusCompleted MigrationStatus = "completed"
|
||||
MigrationStatusFailed MigrationStatus = "failed"
|
||||
)
|
||||
|
||||
// MigrationProgress 迁移进度信息
|
||||
type MigrationProgress struct {
|
||||
Status MigrationStatus `json:"status"`
|
||||
Progress float64 `json:"progress"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Progress float64 `json:"progress"` // 0-100
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// MigrationService 迁移服务
|
||||
type MigrationService struct {
|
||||
logger *log.LogService
|
||||
dbService *DatabaseService
|
||||
mu sync.RWMutex
|
||||
progress atomic.Value // stores MigrationProgress
|
||||
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
@@ -47,18 +38,11 @@ func NewMigrationService(dbService *DatabaseService, logger *log.LogService) *Mi
|
||||
if logger == nil {
|
||||
logger = log.New()
|
||||
}
|
||||
|
||||
ms := &MigrationService{
|
||||
logger: logger,
|
||||
dbService: dbService,
|
||||
}
|
||||
|
||||
// 初始化进度
|
||||
ms.progress.Store(MigrationProgress{
|
||||
Status: MigrationStatusCompleted,
|
||||
Progress: 0,
|
||||
})
|
||||
|
||||
ms.progress.Store(MigrationProgress{})
|
||||
return ms
|
||||
}
|
||||
|
||||
@@ -67,9 +51,15 @@ func (ms *MigrationService) GetProgress() MigrationProgress {
|
||||
return ms.progress.Load().(MigrationProgress)
|
||||
}
|
||||
|
||||
// updateProgress 更新进度
|
||||
func (ms *MigrationService) updateProgress(progress MigrationProgress) {
|
||||
ms.progress.Store(progress)
|
||||
// setProgress 设置进度
|
||||
func (ms *MigrationService) setProgress(progress float64) {
|
||||
ms.progress.Store(MigrationProgress{Progress: progress})
|
||||
}
|
||||
|
||||
// fail 标记失败并返回错误
|
||||
func (ms *MigrationService) fail(err error) error {
|
||||
ms.progress.Store(MigrationProgress{Error: err.Error()})
|
||||
return err
|
||||
}
|
||||
|
||||
// MigrateDirectory 迁移目录
|
||||
@@ -77,118 +67,84 @@ func (ms *MigrationService) MigrateDirectory(srcPath, dstPath string) error {
|
||||
// 创建可取消的上下文
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ms.mu.Lock()
|
||||
ms.ctx = ctx
|
||||
ms.cancel = cancel
|
||||
ms.ctx, ms.cancel = ctx, cancel
|
||||
ms.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
ms.mu.Lock()
|
||||
ms.cancel = nil
|
||||
ms.ctx = nil
|
||||
ms.ctx, ms.cancel = nil, nil
|
||||
ms.mu.Unlock()
|
||||
}()
|
||||
|
||||
// 初始化进度
|
||||
ms.updateProgress(MigrationProgress{
|
||||
Status: MigrationStatusMigrating,
|
||||
Progress: 0,
|
||||
})
|
||||
ms.setProgress(0)
|
||||
|
||||
// 预检查
|
||||
if err := ms.preCheck(srcPath, dstPath); err != nil {
|
||||
if err == errNoMigrationNeeded {
|
||||
ms.updateProgress(MigrationProgress{
|
||||
Status: MigrationStatusCompleted,
|
||||
Progress: 100,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
return ms.failWithError(err)
|
||||
needMigrate, err := ms.preCheck(srcPath, dstPath)
|
||||
if err != nil {
|
||||
return ms.fail(err)
|
||||
}
|
||||
if !needMigrate {
|
||||
ms.setProgress(100)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 迁移前断开数据库连接
|
||||
ms.updateProgress(MigrationProgress{
|
||||
Status: MigrationStatusMigrating,
|
||||
Progress: 10,
|
||||
})
|
||||
|
||||
ms.setProgress(10)
|
||||
if ms.dbService != nil {
|
||||
if err := ms.dbService.ServiceShutdown(); err != nil {
|
||||
ms.logger.Error("Failed to close database connection", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 确保失败时恢复数据库连接
|
||||
defer func() {
|
||||
if ms.dbService != nil {
|
||||
if err := ms.dbService.ServiceStartup(ctx, application.ServiceOptions{}); err != nil {
|
||||
ms.logger.Error("Failed to reconnect database", "error", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 执行原子迁移
|
||||
if err := ms.atomicMove(ctx, srcPath, dstPath); err != nil {
|
||||
return ms.failWithError(err)
|
||||
return ms.fail(err)
|
||||
}
|
||||
|
||||
// 迁移完成后重新连接数据库
|
||||
ms.updateProgress(MigrationProgress{
|
||||
Status: MigrationStatusMigrating,
|
||||
Progress: 95,
|
||||
})
|
||||
|
||||
if ms.dbService != nil {
|
||||
if err := ms.dbService.initDatabase(); err != nil {
|
||||
return ms.failWithError(fmt.Errorf("failed to reconnect database: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// 迁移完成
|
||||
ms.updateProgress(MigrationProgress{
|
||||
Status: MigrationStatusCompleted,
|
||||
Progress: 100,
|
||||
})
|
||||
|
||||
ms.setProgress(100)
|
||||
return nil
|
||||
}
|
||||
|
||||
var errNoMigrationNeeded = fmt.Errorf("no migration needed")
|
||||
|
||||
// preCheck 预检查
|
||||
func (ms *MigrationService) preCheck(srcPath, dstPath string) error {
|
||||
// 检查源目录是否存在
|
||||
// preCheck 预检查,返回是否需要迁移
|
||||
func (ms *MigrationService) preCheck(srcPath, dstPath string) (bool, error) {
|
||||
// 源目录不存在,无需迁移
|
||||
if _, err := os.Stat(srcPath); os.IsNotExist(err) {
|
||||
return errNoMigrationNeeded
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// 如果路径相同,不需要迁移
|
||||
// 路径相同,无需迁移
|
||||
srcAbs, _ := filepath.Abs(srcPath)
|
||||
dstAbs, _ := filepath.Abs(dstPath)
|
||||
if srcAbs == dstAbs {
|
||||
return errNoMigrationNeeded
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// 检查目标路径是否是源路径的子目录
|
||||
if ms.isSubDirectory(srcAbs, dstAbs) {
|
||||
return fmt.Errorf("target path cannot be a subdirectory of source path")
|
||||
// 目标不能是源的子目录
|
||||
if isSubDir(srcAbs, dstAbs) {
|
||||
return false, fmt.Errorf("target path cannot be a subdirectory of source path")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// failWithError 失败并记录错误
|
||||
func (ms *MigrationService) failWithError(err error) error {
|
||||
ms.updateProgress(MigrationProgress{
|
||||
Status: MigrationStatusFailed,
|
||||
Error: err.Error(),
|
||||
})
|
||||
return err
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// atomicMove 原子移动目录
|
||||
func (ms *MigrationService) atomicMove(ctx context.Context, srcPath, dstPath string) error {
|
||||
// 检查是否取消
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 确保目标目录的父目录存在
|
||||
// 确保目标父目录存在
|
||||
if err := os.MkdirAll(filepath.Dir(dstPath), 0755); err != nil {
|
||||
return fmt.Errorf("failed to create target parent directory: %v", err)
|
||||
return fmt.Errorf("failed to create target parent directory: %w", err)
|
||||
}
|
||||
|
||||
// 检查目标路径
|
||||
@@ -196,132 +152,100 @@ func (ms *MigrationService) atomicMove(ctx context.Context, srcPath, dstPath str
|
||||
return err
|
||||
}
|
||||
|
||||
// 尝试直接重命名
|
||||
ms.updateProgress(MigrationProgress{
|
||||
Status: MigrationStatusMigrating,
|
||||
Progress: 20,
|
||||
})
|
||||
ms.setProgress(20)
|
||||
|
||||
// 尝试直接重命名(同一文件系统时最快)
|
||||
if err := os.Rename(srcPath, dstPath); err == nil {
|
||||
// 重命名成功,更新进度到90%
|
||||
ms.updateProgress(MigrationProgress{
|
||||
Status: MigrationStatusMigrating,
|
||||
Progress: 90,
|
||||
})
|
||||
ms.setProgress(90)
|
||||
ms.logger.Info("Directory migration completed using direct rename", "src", srcPath, "dst", dstPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 重命名失败,使用压缩迁移
|
||||
// 重命名失败(跨文件系统),使用压缩迁移
|
||||
ms.logger.Info("Direct rename failed, using compress migration", "src", srcPath, "dst", dstPath)
|
||||
ms.updateProgress(MigrationProgress{
|
||||
Status: MigrationStatusMigrating,
|
||||
Progress: 30,
|
||||
})
|
||||
ms.setProgress(30)
|
||||
|
||||
return ms.compressMove(ctx, srcPath, dstPath)
|
||||
}
|
||||
|
||||
// checkTargetPath 检查目标路径
|
||||
// checkTargetPath 检查目标路径是否可用
|
||||
func (ms *MigrationService) checkTargetPath(dstPath string) error {
|
||||
stat, err := os.Stat(dstPath)
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check target path: %v", err)
|
||||
return fmt.Errorf("failed to check target path: %w", err)
|
||||
}
|
||||
|
||||
if !stat.IsDir() {
|
||||
return fmt.Errorf("target path exists but is not a directory")
|
||||
}
|
||||
|
||||
isEmpty, err := ms.isDirectoryEmpty(dstPath)
|
||||
isEmpty, err := isDirEmpty(dstPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check target directory: %v", err)
|
||||
return fmt.Errorf("failed to check target directory: %w", err)
|
||||
}
|
||||
if !isEmpty {
|
||||
return fmt.Errorf("target directory is not empty")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// compressMove 压缩迁移
|
||||
func (ms *MigrationService) compressMove(ctx context.Context, srcPath, dstPath string) error {
|
||||
tempZipFile := filepath.Join(os.TempDir(),
|
||||
fmt.Sprintf("voidraft_migration_%d.zip", time.Now().UnixNano()))
|
||||
|
||||
defer os.Remove(tempZipFile)
|
||||
tempZip := filepath.Join(os.TempDir(), fmt.Sprintf("voidraft_migration_%d.zip", time.Now().UnixNano()))
|
||||
defer os.Remove(tempZip)
|
||||
|
||||
// 压缩源目录
|
||||
ms.updateProgress(MigrationProgress{
|
||||
Status: MigrationStatusMigrating,
|
||||
Progress: 40,
|
||||
})
|
||||
|
||||
if err := ms.compressDirectory(ctx, srcPath, tempZipFile); err != nil {
|
||||
return fmt.Errorf("failed to compress source directory: %v", err)
|
||||
ms.setProgress(40)
|
||||
if err := ms.compressDir(ctx, srcPath, tempZip); err != nil {
|
||||
return fmt.Errorf("failed to compress source directory: %w", err)
|
||||
}
|
||||
|
||||
// 解压到目标位置
|
||||
ms.updateProgress(MigrationProgress{
|
||||
Status: MigrationStatusMigrating,
|
||||
Progress: 70,
|
||||
})
|
||||
|
||||
if err := ms.extractToDirectory(ctx, tempZipFile, dstPath); err != nil {
|
||||
return fmt.Errorf("failed to extract to target location: %v", err)
|
||||
ms.setProgress(70)
|
||||
if err := ms.extractZip(ctx, tempZip, dstPath); err != nil {
|
||||
return fmt.Errorf("failed to extract to target location: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否取消
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// 检查取消
|
||||
if err := ctx.Err(); err != nil {
|
||||
os.RemoveAll(dstPath)
|
||||
return ctx.Err()
|
||||
default:
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证迁移是否成功
|
||||
// 验证迁移结果
|
||||
if err := ms.verifyMigration(dstPath); err != nil {
|
||||
// 迁移验证失败,清理目标目录
|
||||
os.RemoveAll(dstPath)
|
||||
return fmt.Errorf("migration verification failed: %v", err)
|
||||
return fmt.Errorf("migration verification failed: %w", err)
|
||||
}
|
||||
|
||||
// 删除源目录
|
||||
ms.updateProgress(MigrationProgress{
|
||||
Status: MigrationStatusMigrating,
|
||||
Progress: 90,
|
||||
})
|
||||
ms.setProgress(90)
|
||||
os.RemoveAll(srcPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// compressDirectory 压缩目录到zip文件
|
||||
func (ms *MigrationService) compressDirectory(ctx context.Context, srcDir, zipFile string) error {
|
||||
zipWriter, err := os.Create(zipFile)
|
||||
// compressDir 压缩目录到zip文件
|
||||
func (ms *MigrationService) compressDir(ctx context.Context, srcDir, zipPath string) error {
|
||||
zipFile, err := os.Create(zipPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer zipWriter.Close()
|
||||
defer zipFile.Close()
|
||||
|
||||
zw := zip.NewWriter(zipWriter)
|
||||
zw := zip.NewWriter(zipFile)
|
||||
defer zw.Close()
|
||||
|
||||
return filepath.Walk(srcDir, func(filePath string, info os.FileInfo, err error) error {
|
||||
return filepath.Walk(srcDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查是否取消
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
relPath, err := filepath.Rel(srcDir, filePath)
|
||||
relPath, err := filepath.Rel(srcDir, path)
|
||||
if err != nil || relPath == "." {
|
||||
return err
|
||||
}
|
||||
@@ -345,27 +269,21 @@ func (ms *MigrationService) compressDirectory(ctx context.Context, srcDir, zipFi
|
||||
}
|
||||
|
||||
if !info.IsDir() {
|
||||
return ms.copyFileToZip(filePath, writer)
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
_, err = io.Copy(writer, file)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// copyFileToZip 复制文件到zip
|
||||
func (ms *MigrationService) copyFileToZip(filePath string, writer io.Writer) error {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
_, err = io.Copy(writer, file)
|
||||
return err
|
||||
}
|
||||
|
||||
// extractToDirectory 从zip文件解压到目录
|
||||
func (ms *MigrationService) extractToDirectory(ctx context.Context, zipFile, dstDir string) error {
|
||||
reader, err := zip.OpenReader(zipFile)
|
||||
// extractZip 解压zip文件到目录
|
||||
func (ms *MigrationService) extractZip(ctx context.Context, zipPath, dstDir string) error {
|
||||
reader, err := zip.OpenReader(zipPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -376,27 +294,22 @@ func (ms *MigrationService) extractToDirectory(ctx context.Context, zipFile, dst
|
||||
}
|
||||
|
||||
for _, file := range reader.File {
|
||||
// 检查是否取消
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ms.extractSingleFile(file, dstDir); err != nil {
|
||||
if err := extractFile(file, dstDir); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractSingleFile 解压单个文件
|
||||
func (ms *MigrationService) extractSingleFile(file *zip.File, dstDir string) error {
|
||||
// extractFile 解压单个文件
|
||||
func extractFile(file *zip.File, dstDir string) error {
|
||||
dstPath := filepath.Join(dstDir, file.Name)
|
||||
|
||||
// 安全检查:防止zip slip攻击
|
||||
if !strings.HasPrefix(dstPath, filepath.Clean(dstDir)+string(os.PathSeparator)) {
|
||||
if !strings.HasPrefix(filepath.Clean(dstPath), filepath.Clean(dstDir)+string(os.PathSeparator)) {
|
||||
return fmt.Errorf("invalid file path in archive: %s", file.Name)
|
||||
}
|
||||
|
||||
@@ -424,45 +337,23 @@ func (ms *MigrationService) extractSingleFile(file *zip.File, dstDir string) err
|
||||
return err
|
||||
}
|
||||
|
||||
// isDirectoryEmpty 检查目录是否为空
|
||||
func (ms *MigrationService) isDirectoryEmpty(dirPath string) (bool, error) {
|
||||
f, err := os.Open(dirPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
_, err = f.Readdir(1)
|
||||
return err == io.EOF, nil
|
||||
}
|
||||
|
||||
// isSubDirectory 检查target是否是parent的子目录
|
||||
func (ms *MigrationService) isSubDirectory(parent, target string) bool {
|
||||
parent = filepath.Clean(parent) + string(filepath.Separator)
|
||||
target = filepath.Clean(target) + string(filepath.Separator)
|
||||
return len(target) > len(parent) && strings.HasPrefix(target, parent)
|
||||
}
|
||||
|
||||
// verifyMigration 验证迁移是否成功
|
||||
// verifyMigration 验证迁移结果
|
||||
func (ms *MigrationService) verifyMigration(dstPath string) error {
|
||||
// 检查目标目录是否存在
|
||||
dstStat, err := os.Stat(dstPath)
|
||||
stat, err := os.Stat(dstPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("target directory does not exist: %v", err)
|
||||
return fmt.Errorf("target directory does not exist: %w", err)
|
||||
}
|
||||
if !dstStat.IsDir() {
|
||||
if !stat.IsDir() {
|
||||
return fmt.Errorf("target path is not a directory")
|
||||
}
|
||||
|
||||
// 简单验证:检查目标目录是否非空
|
||||
isEmpty, err := ms.isDirectoryEmpty(dstPath)
|
||||
isEmpty, err := isDirEmpty(dstPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check target directory: %v", err)
|
||||
return fmt.Errorf("failed to check target directory: %w", err)
|
||||
}
|
||||
if isEmpty {
|
||||
return fmt.Errorf("target directory is empty after migration")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -475,12 +366,30 @@ func (ms *MigrationService) CancelMigration() error {
|
||||
ms.cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("no active migration to cancel")
|
||||
}
|
||||
|
||||
// ServiceShutdown 服务关闭
|
||||
func (ms *MigrationService) ServiceShutdown() error {
|
||||
ms.CancelMigration()
|
||||
_ = ms.CancelMigration()
|
||||
return nil
|
||||
}
|
||||
|
||||
// isDirEmpty 检查目录是否为空
|
||||
func isDirEmpty(path string) (bool, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
_, err = f.Readdir(1)
|
||||
return err == io.EOF, nil
|
||||
}
|
||||
|
||||
// isSubDir 检查target是否是parent的子目录
|
||||
func isSubDir(parent, target string) bool {
|
||||
parent = filepath.Clean(parent) + string(filepath.Separator)
|
||||
target = filepath.Clean(target) + string(filepath.Separator)
|
||||
return len(target) > len(parent) && strings.HasPrefix(target, parent)
|
||||
}
|
||||
|
||||
@@ -6,376 +6,100 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"sync"
|
||||
"voidraft/internal/models"
|
||||
"time"
|
||||
|
||||
"voidraft/internal/models/ent"
|
||||
"voidraft/internal/models/ent/migrate"
|
||||
_ "voidraft/internal/models/ent/runtime"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/wailsapp/wails/v3/pkg/application"
|
||||
"github.com/wailsapp/wails/v3/pkg/services/log"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
const (
|
||||
dbName = "voidraft.db"
|
||||
|
||||
// SQLite performance optimization settings
|
||||
sqlOptimizationSettings = `
|
||||
PRAGMA journal_mode = WAL;
|
||||
PRAGMA synchronous = NORMAL;
|
||||
PRAGMA cache_size = -64000;
|
||||
PRAGMA temp_store = MEMORY;
|
||||
PRAGMA foreign_keys = ON;`
|
||||
|
||||
// Documents table
|
||||
sqlCreateDocumentsTable = `
|
||||
CREATE TABLE IF NOT EXISTS documents (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
title TEXT NOT NULL,
|
||||
content TEXT DEFAULT '∞∞∞text-a',
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
is_deleted INTEGER DEFAULT 0,
|
||||
is_locked INTEGER DEFAULT 0
|
||||
)`
|
||||
|
||||
// Extensions table
|
||||
sqlCreateExtensionsTable = `
|
||||
CREATE TABLE IF NOT EXISTS extensions (
|
||||
id TEXT PRIMARY KEY,
|
||||
enabled INTEGER NOT NULL DEFAULT 1,
|
||||
is_default INTEGER NOT NULL DEFAULT 0,
|
||||
config TEXT DEFAULT '{}',
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)`
|
||||
|
||||
// Key bindings table
|
||||
sqlCreateKeyBindingsTable = `
|
||||
CREATE TABLE IF NOT EXISTS key_bindings (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
command TEXT NOT NULL,
|
||||
extension TEXT NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
enabled INTEGER NOT NULL DEFAULT 1,
|
||||
is_default INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)`
|
||||
|
||||
// Themes table
|
||||
sqlCreateThemesTable = `
|
||||
CREATE TABLE IF NOT EXISTS themes (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
colors TEXT NOT NULL,
|
||||
is_default INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)`
|
||||
dbName = "voidraft.db"
|
||||
maxIdleConns = 10
|
||||
maxOpenConns = 100
|
||||
connMaxLife = time.Hour
|
||||
)
|
||||
|
||||
// ColumnInfo 存储列的信息
|
||||
type ColumnInfo struct {
|
||||
SQLType string
|
||||
DefaultValue string
|
||||
}
|
||||
|
||||
// TableModel 表示表与模型之间的映射关系
|
||||
type TableModel struct {
|
||||
TableName string
|
||||
Model interface{}
|
||||
}
|
||||
|
||||
// DatabaseService provides shared database functionality
|
||||
// DatabaseService 数据库服务
|
||||
type DatabaseService struct {
|
||||
configService *ConfigService
|
||||
logger *log.LogService
|
||||
Client *ent.Client
|
||||
db *sql.DB
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
tableModels []TableModel // 注册的表模型
|
||||
|
||||
// 配置观察者取消函数
|
||||
cancelObserver CancelFunc
|
||||
}
|
||||
|
||||
// NewDatabaseService creates a new database service
|
||||
// NewDatabaseService 创建数据库服务
|
||||
func NewDatabaseService(configService *ConfigService, logger *log.LogService) *DatabaseService {
|
||||
if logger == nil {
|
||||
logger = log.New()
|
||||
}
|
||||
|
||||
ds := &DatabaseService{
|
||||
return &DatabaseService{
|
||||
configService: configService,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// 注册所有模型
|
||||
ds.registerAllModels()
|
||||
|
||||
return ds
|
||||
}
|
||||
|
||||
// registerAllModels 注册所有数据模型,集中管理表-模型映射
|
||||
func (ds *DatabaseService) registerAllModels() {
|
||||
// 文档表
|
||||
ds.RegisterModel("documents", &models.Document{})
|
||||
// 扩展表
|
||||
ds.RegisterModel("extensions", &models.Extension{})
|
||||
// 快捷键表
|
||||
ds.RegisterModel("key_bindings", &models.KeyBinding{})
|
||||
// 主题表
|
||||
ds.RegisterModel("themes", &models.Theme{})
|
||||
}
|
||||
|
||||
// ServiceStartup initializes the service when the application starts
|
||||
func (ds *DatabaseService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error {
|
||||
ds.ctx = ctx
|
||||
return ds.initDatabase()
|
||||
}
|
||||
|
||||
// initDatabase initializes the SQLite database
|
||||
func (ds *DatabaseService) initDatabase() error {
|
||||
dbPath, err := ds.getDatabasePath()
|
||||
// ServiceStartup 服务启动
|
||||
func (s *DatabaseService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error {
|
||||
dbPath, err := s.getBasePath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get database path: %w", err)
|
||||
return fmt.Errorf("get database path error: %w", err)
|
||||
}
|
||||
|
||||
// 确保数据库目录存在
|
||||
dbDir := filepath.Dir(dbPath)
|
||||
if err := os.MkdirAll(dbDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create database directory: %w", err)
|
||||
if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
|
||||
return fmt.Errorf("create database directory error: %w", err)
|
||||
}
|
||||
|
||||
// 打开数据库连接
|
||||
ds.db, err = sql.Open("sqlite", dbPath)
|
||||
// _fk=1 启用外键,_journal_mode=WAL 启用 WAL 模式
|
||||
dsn := fmt.Sprintf("file:%s?_fk=1&_journal_mode=WAL&_synchronous=NORMAL&_busy_timeout=5000", dbPath)
|
||||
s.db, err = sql.Open("sqlite3", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open database: %w", err)
|
||||
return fmt.Errorf("open database error: %w", err)
|
||||
}
|
||||
|
||||
// 测试连接
|
||||
if err := ds.db.Ping(); err != nil {
|
||||
return fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
// 连接池配置
|
||||
s.db.SetMaxIdleConns(maxIdleConns)
|
||||
s.db.SetMaxOpenConns(maxOpenConns)
|
||||
s.db.SetConnMaxLifetime(connMaxLife)
|
||||
|
||||
// 应用性能优化设置
|
||||
if _, err := ds.db.Exec(sqlOptimizationSettings); err != nil {
|
||||
return fmt.Errorf("failed to apply optimization settings: %w", err)
|
||||
}
|
||||
// 创建 ent 客户端
|
||||
drv := entsql.OpenDB(dialect.SQLite, s.db)
|
||||
s.Client = ent.NewClient(ent.Driver(drv))
|
||||
|
||||
// 创建表和索引
|
||||
if err := ds.createTables(); err != nil {
|
||||
return fmt.Errorf("failed to create tables: %w", err)
|
||||
}
|
||||
|
||||
if err := ds.createIndexes(); err != nil {
|
||||
return fmt.Errorf("failed to create indexes: %w", err)
|
||||
}
|
||||
|
||||
// 执行模型与表结构同步
|
||||
if err := ds.syncAllModelTables(); err != nil {
|
||||
return fmt.Errorf("failed to sync model tables: %w", err)
|
||||
// 自动迁移
|
||||
if err := s.Client.Schema.Create(ctx,
|
||||
migrate.WithDropColumn(true),
|
||||
migrate.WithDropIndex(true),
|
||||
); err != nil {
|
||||
return fmt.Errorf("schema migration error: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getDatabasePath gets the database file path
|
||||
func (ds *DatabaseService) getDatabasePath() (string, error) {
|
||||
config, err := ds.configService.GetConfig()
|
||||
func (s *DatabaseService) getBasePath() (string, error) {
|
||||
config, err := s.configService.GetConfig()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(config.General.DataPath, dbName), nil
|
||||
}
|
||||
|
||||
// createTables creates all database tables
|
||||
func (ds *DatabaseService) createTables() error {
|
||||
tables := []string{
|
||||
sqlCreateDocumentsTable,
|
||||
sqlCreateExtensionsTable,
|
||||
sqlCreateKeyBindingsTable,
|
||||
sqlCreateThemesTable,
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
if _, err := ds.db.Exec(table); err != nil {
|
||||
// ServiceShutdown 服务关闭
|
||||
func (s *DatabaseService) ServiceShutdown() error {
|
||||
if s.Client != nil {
|
||||
if err := s.Client.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createIndexes creates database indexes
|
||||
func (ds *DatabaseService) createIndexes() error {
|
||||
indexes := []string{
|
||||
// Documents indexes
|
||||
`CREATE INDEX IF NOT EXISTS idx_documents_updated_at ON documents(updated_at DESC)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_documents_title ON documents(title)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_documents_is_deleted ON documents(is_deleted)`,
|
||||
// Extensions indexes
|
||||
`CREATE INDEX IF NOT EXISTS idx_extensions_enabled ON extensions(enabled)`,
|
||||
// Key bindings indexes
|
||||
`CREATE INDEX IF NOT EXISTS idx_key_bindings_command ON key_bindings(command)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_key_bindings_extension ON key_bindings(extension)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_key_bindings_enabled ON key_bindings(enabled)`,
|
||||
// Themes indexes
|
||||
`CREATE INDEX IF NOT EXISTS idx_themes_type ON themes(type)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_themes_is_default ON themes(is_default)`,
|
||||
}
|
||||
|
||||
for _, index := range indexes {
|
||||
if _, err := ds.db.Exec(index); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterModel 注册模型与表的映射关系
|
||||
func (ds *DatabaseService) RegisterModel(tableName string, model interface{}) {
|
||||
ds.mu.Lock()
|
||||
defer ds.mu.Unlock()
|
||||
|
||||
ds.tableModels = append(ds.tableModels, TableModel{
|
||||
TableName: tableName,
|
||||
Model: model,
|
||||
})
|
||||
}
|
||||
|
||||
// syncAllModelTables 同步所有注册的模型与表结构
|
||||
func (ds *DatabaseService) syncAllModelTables() error {
|
||||
for _, tm := range ds.tableModels {
|
||||
if err := ds.syncModelTable(tm.TableName, tm.Model); err != nil {
|
||||
return fmt.Errorf("failed to sync table %s: %w", tm.TableName, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// syncModelTable 同步模型与表结构
|
||||
func (ds *DatabaseService) syncModelTable(tableName string, model interface{}) error {
|
||||
// 获取表结构元数据
|
||||
columns, err := ds.getTableColumns(tableName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get table columns: %w", err)
|
||||
}
|
||||
|
||||
// 使用反射从模型中提取字段信息
|
||||
expectedColumns, err := ds.getModelColumns(model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get model columns: %w", err)
|
||||
}
|
||||
|
||||
// 检查缺失的列并添加
|
||||
for colName, colInfo := range expectedColumns {
|
||||
if _, exists := columns[colName]; !exists {
|
||||
// 执行添加列的SQL
|
||||
alterSQL := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s DEFAULT %s",
|
||||
tableName, colName, colInfo.SQLType, colInfo.DefaultValue)
|
||||
if _, err := ds.db.Exec(alterSQL); err != nil {
|
||||
return fmt.Errorf("failed to add column %s: %w", colName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getTableColumns 获取表的列信息
|
||||
func (ds *DatabaseService) getTableColumns(table string) (map[string]string, error) {
|
||||
query := fmt.Sprintf("PRAGMA table_info(%s)", table)
|
||||
rows, err := ds.db.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns := make(map[string]string)
|
||||
for rows.Next() {
|
||||
var cid int
|
||||
var name, typeName string
|
||||
var notNull, pk int
|
||||
var dflt_value interface{}
|
||||
|
||||
if err := rows.Scan(&cid, &name, &typeName, ¬Null, &dflt_value, &pk); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns[name] = typeName
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
// getModelColumns 从模型结构体中提取数据库列信息
|
||||
func (ds *DatabaseService) getModelColumns(model interface{}) (map[string]ColumnInfo, error) {
|
||||
columns := make(map[string]ColumnInfo)
|
||||
|
||||
// 使用反射获取结构体的类型信息
|
||||
t := reflect.TypeOf(model)
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
if t.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("model must be a struct or a pointer to struct")
|
||||
}
|
||||
|
||||
// 遍历所有字段
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
|
||||
// 只处理有db标签的字段
|
||||
dbTag := field.Tag.Get("db")
|
||||
if dbTag == "" {
|
||||
// 如果没有db标签,跳过该字段
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取字段类型对应的SQL类型和默认值
|
||||
sqlType, defaultVal := getSQLTypeAndDefault(field.Type)
|
||||
|
||||
columns[dbTag] = ColumnInfo{
|
||||
SQLType: sqlType,
|
||||
DefaultValue: defaultVal,
|
||||
}
|
||||
}
|
||||
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
// getSQLTypeAndDefault 根据Go类型获取对应的SQL类型和默认值
|
||||
func getSQLTypeAndDefault(t reflect.Type) (string, string) {
|
||||
switch t.Kind() {
|
||||
case reflect.Bool:
|
||||
return "INTEGER", "0"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return "INTEGER", "0"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "REAL", "0.0"
|
||||
case reflect.String:
|
||||
return "TEXT", "''"
|
||||
default:
|
||||
return "TEXT", "NULL"
|
||||
}
|
||||
}
|
||||
|
||||
// ServiceShutdown shuts down the service when the application closes
|
||||
func (ds *DatabaseService) ServiceShutdown() error {
|
||||
// 取消配置观察者
|
||||
if ds.cancelObserver != nil {
|
||||
ds.cancelObserver()
|
||||
}
|
||||
|
||||
if ds.db != nil {
|
||||
return ds.db.Close()
|
||||
if s.db != nil {
|
||||
return s.db.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,428 +2,141 @@ package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
"voidraft/internal/models"
|
||||
"voidraft/internal/models/ent/document"
|
||||
|
||||
"github.com/wailsapp/wails/v3/pkg/application"
|
||||
"github.com/wailsapp/wails/v3/pkg/services/log"
|
||||
"voidraft/internal/models/ent"
|
||||
)
|
||||
|
||||
// SQL constants for document operations
|
||||
const (
|
||||
const defaultDocumentTitle = "default"
|
||||
const defaultDocumentContent = "\n∞∞∞text-a\n"
|
||||
|
||||
// Document operations
|
||||
sqlGetDocumentByID = `
|
||||
SELECT id, title, content, created_at, updated_at, is_deleted, is_locked
|
||||
FROM documents
|
||||
WHERE id = ?`
|
||||
|
||||
sqlInsertDocument = `
|
||||
INSERT INTO documents (title, content, created_at, updated_at, is_deleted, is_locked)
|
||||
VALUES (?, ?, ?, ?, 0, 0)`
|
||||
|
||||
sqlUpdateDocumentContent = `
|
||||
UPDATE documents
|
||||
SET content = ?, updated_at = ?
|
||||
WHERE id = ? AND is_deleted = 0`
|
||||
|
||||
sqlUpdateDocumentTitle = `
|
||||
UPDATE documents
|
||||
SET title = ?, updated_at = ?
|
||||
WHERE id = ? AND is_deleted = 0`
|
||||
|
||||
sqlMarkDocumentAsDeleted = `
|
||||
UPDATE documents
|
||||
SET is_deleted = 1, updated_at = ?
|
||||
WHERE id = ? AND is_locked = 0`
|
||||
|
||||
sqlRestoreDocument = `
|
||||
UPDATE documents
|
||||
SET is_deleted = 0, updated_at = ?
|
||||
WHERE id = ?`
|
||||
|
||||
sqlListAllDocumentsMeta = `
|
||||
SELECT id, title, created_at, updated_at, is_locked
|
||||
FROM documents
|
||||
WHERE is_deleted = 0
|
||||
ORDER BY updated_at DESC`
|
||||
|
||||
sqlListDeletedDocumentsMeta = `
|
||||
SELECT id, title, created_at, updated_at, is_locked
|
||||
FROM documents
|
||||
WHERE is_deleted = 1
|
||||
ORDER BY updated_at DESC`
|
||||
|
||||
sqlGetFirstDocumentID = `
|
||||
SELECT id FROM documents WHERE is_deleted = 0 ORDER BY id LIMIT 1`
|
||||
|
||||
sqlCountDocuments = `SELECT COUNT(*) FROM documents WHERE is_deleted = 0`
|
||||
|
||||
sqlSetDocumentLocked = `
|
||||
UPDATE documents
|
||||
SET is_locked = 1, updated_at = ?
|
||||
WHERE id = ?`
|
||||
|
||||
sqlSetDocumentUnlocked = `
|
||||
UPDATE documents
|
||||
SET is_locked = 0, updated_at = ?
|
||||
WHERE id = ?`
|
||||
|
||||
sqlDefaultDocumentID = 1 // 默认文档的ID
|
||||
)
|
||||
|
||||
// DocumentService provides document management functionality
|
||||
// DocumentService 文档服务
|
||||
type DocumentService struct {
|
||||
databaseService *DatabaseService
|
||||
logger *log.LogService
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
db *DatabaseService
|
||||
logger *log.LogService
|
||||
}
|
||||
|
||||
// NewDocumentService creates a new document service
|
||||
func NewDocumentService(databaseService *DatabaseService, logger *log.LogService) *DocumentService {
|
||||
// NewDocumentService 创建文档服务
|
||||
func NewDocumentService(db *DatabaseService, logger *log.LogService) *DocumentService {
|
||||
if logger == nil {
|
||||
logger = log.New()
|
||||
}
|
||||
|
||||
ds := &DocumentService{
|
||||
databaseService: databaseService,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
return ds
|
||||
return &DocumentService{db: db, logger: logger}
|
||||
}
|
||||
|
||||
// ServiceStartup initializes the service when the application starts
|
||||
func (ds *DocumentService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error {
|
||||
ds.ctx = ctx
|
||||
|
||||
// 确保默认文档存在
|
||||
if err := ds.ensureDefaultDocument(); err != nil {
|
||||
return fmt.Errorf("failed to ensure default document: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureDefaultDocument ensures a default document exists
|
||||
func (ds *DocumentService) ensureDefaultDocument() error {
|
||||
if ds.databaseService == nil || ds.databaseService.db == nil {
|
||||
return errors.New("database service not available")
|
||||
}
|
||||
|
||||
// Check if any document exists
|
||||
var count int64
|
||||
err := ds.databaseService.db.QueryRow(sqlCountDocuments).Scan(&count)
|
||||
// ServiceStartup 服务启动
|
||||
func (s *DocumentService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error {
|
||||
exists, err := s.db.Client.Document.Query().Exist(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query document count: %w", err)
|
||||
return fmt.Errorf("check document exists error: %w", err)
|
||||
}
|
||||
|
||||
// If no documents exist, create default document
|
||||
if count == 0 {
|
||||
defaultDoc := models.NewDefaultDocument()
|
||||
_, err := ds.CreateDocument(defaultDoc.Title)
|
||||
return err
|
||||
if !exists {
|
||||
_, err = s.CreateDocument(ctx, defaultDocumentTitle)
|
||||
}
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// GetDocumentByID gets a document by ID
|
||||
func (ds *DocumentService) GetDocumentByID(id int64) (*models.Document, error) {
|
||||
ds.mu.RLock()
|
||||
defer ds.mu.RUnlock()
|
||||
|
||||
if ds.databaseService == nil || ds.databaseService.db == nil {
|
||||
return nil, errors.New("database service not available")
|
||||
}
|
||||
|
||||
doc := &models.Document{}
|
||||
var isDeleted, isLocked int
|
||||
|
||||
err := ds.databaseService.db.QueryRow(sqlGetDocumentByID, id).Scan(
|
||||
&doc.ID,
|
||||
&doc.Title,
|
||||
&doc.Content,
|
||||
&doc.CreatedAt,
|
||||
&doc.UpdatedAt,
|
||||
&isDeleted,
|
||||
&isLocked,
|
||||
)
|
||||
|
||||
// GetDocumentByID 根据ID获取文档
|
||||
func (s *DocumentService) GetDocumentByID(ctx context.Context, id int) (*ent.Document, error) {
|
||||
doc, err := s.db.Client.Document.Get(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get document by ID: %w", err)
|
||||
return nil, fmt.Errorf("get document by id error: %w", err)
|
||||
}
|
||||
|
||||
// 转换布尔字段
|
||||
doc.IsDeleted = isDeleted == 1
|
||||
doc.IsLocked = isLocked == 1
|
||||
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
// CreateDocument creates a new document and returns the created document with ID
|
||||
func (ds *DocumentService) CreateDocument(title string) (*models.Document, error) {
|
||||
ds.mu.Lock()
|
||||
defer ds.mu.Unlock()
|
||||
|
||||
if ds.databaseService == nil || ds.databaseService.db == nil {
|
||||
return nil, errors.New("database service not available")
|
||||
}
|
||||
doc := models.NewDocument(title, "\n∞∞∞text-a\n")
|
||||
// 执行插入操作
|
||||
result, err := ds.databaseService.db.Exec(sqlInsertDocument,
|
||||
doc.Title, doc.Content, doc.CreatedAt, doc.UpdatedAt)
|
||||
// CreateDocument 创建文档
|
||||
func (s *DocumentService) CreateDocument(ctx context.Context, title string) (*ent.Document, error) {
|
||||
doc, err := s.db.Client.Document.Create().
|
||||
SetTitle(title).
|
||||
SetContent(defaultDocumentContent).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create document: %w", err)
|
||||
return nil, fmt.Errorf("create document error: %w", err)
|
||||
}
|
||||
|
||||
// 获取自增ID
|
||||
lastID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get last insert ID: %w", err)
|
||||
}
|
||||
|
||||
// 返回带ID的文档
|
||||
doc.ID = lastID
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
// LockDocument 锁定文档,防止删除
|
||||
func (ds *DocumentService) LockDocument(id int64) error {
|
||||
if ds.databaseService == nil || ds.databaseService.db == nil {
|
||||
return errors.New("database service not available")
|
||||
}
|
||||
// UpdateDocumentContent 更新文档内容
|
||||
func (s *DocumentService) UpdateDocumentContent(ctx context.Context, id int, content string) error {
|
||||
return s.db.Client.Document.UpdateOneID(id).
|
||||
SetContent(content).
|
||||
Exec(ctx)
|
||||
}
|
||||
|
||||
// 先检查文档是否存在且未删除
|
||||
doc, err := ds.GetDocumentByID(id)
|
||||
// UpdateDocumentTitle 更新文档标题
|
||||
func (s *DocumentService) UpdateDocumentTitle(ctx context.Context, id int, title string) error {
|
||||
return s.db.Client.Document.UpdateOneID(id).
|
||||
SetTitle(title).
|
||||
Exec(ctx)
|
||||
}
|
||||
|
||||
// LockDocument 锁定文档
|
||||
func (s *DocumentService) LockDocument(ctx context.Context, id int) error {
|
||||
doc, err := s.GetDocumentByID(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get document: %w", err)
|
||||
return err
|
||||
}
|
||||
if doc == nil {
|
||||
return fmt.Errorf("document not found: %d", id)
|
||||
}
|
||||
if doc.IsDeleted {
|
||||
return fmt.Errorf("cannot lock deleted document: %d", id)
|
||||
}
|
||||
|
||||
// 如果已经锁定,无需操作
|
||||
if doc.IsLocked {
|
||||
if doc.Locked {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 现在加锁执行锁定操作
|
||||
ds.mu.Lock()
|
||||
defer ds.mu.Unlock()
|
||||
|
||||
_, err = ds.databaseService.db.Exec(sqlSetDocumentLocked, time.Now().Format("2006-01-02 15:04:05"), id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to lock document: %w", err)
|
||||
}
|
||||
return nil
|
||||
return s.db.Client.Document.UpdateOneID(id).
|
||||
SetLocked(true).
|
||||
Exec(ctx)
|
||||
}
|
||||
|
||||
// UnlockDocument 解锁文档
|
||||
func (ds *DocumentService) UnlockDocument(id int64) error {
|
||||
if ds.databaseService == nil || ds.databaseService.db == nil {
|
||||
return errors.New("database service not available")
|
||||
}
|
||||
|
||||
// 先检查文档是否存在
|
||||
doc, err := ds.GetDocumentByID(id)
|
||||
func (s *DocumentService) UnlockDocument(ctx context.Context, id int) error {
|
||||
doc, err := s.GetDocumentByID(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get document: %w", err)
|
||||
return err
|
||||
}
|
||||
if doc == nil {
|
||||
return fmt.Errorf("document not found: %d", id)
|
||||
}
|
||||
|
||||
// 如果未锁定,无需操作
|
||||
if !doc.IsLocked {
|
||||
if !doc.Locked {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 现在加锁执行解锁操作
|
||||
ds.mu.Lock()
|
||||
defer ds.mu.Unlock()
|
||||
|
||||
_, err = ds.databaseService.db.Exec(sqlSetDocumentUnlocked, time.Now().Format("2006-01-02 15:04:05"), id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unlock document: %w", err)
|
||||
}
|
||||
return nil
|
||||
return s.db.Client.Document.UpdateOneID(id).
|
||||
SetLocked(false).
|
||||
Exec(ctx)
|
||||
}
|
||||
|
||||
// UpdateDocumentContent updates the content of a document
|
||||
func (ds *DocumentService) UpdateDocumentContent(id int64, content string) error {
|
||||
ds.mu.Lock()
|
||||
defer ds.mu.Unlock()
|
||||
|
||||
if ds.databaseService == nil || ds.databaseService.db == nil {
|
||||
return errors.New("database service not available")
|
||||
}
|
||||
|
||||
_, err := ds.databaseService.db.Exec(sqlUpdateDocumentContent, content, time.Now().Format("2006-01-02 15:04:05"), id)
|
||||
// DeleteDocument 删除文档
|
||||
func (s *DocumentService) DeleteDocument(ctx context.Context, id int) error {
|
||||
doc, err := s.GetDocumentByID(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update document content: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateDocumentTitle updates the title of a document
|
||||
func (ds *DocumentService) UpdateDocumentTitle(id int64, title string) error {
|
||||
ds.mu.Lock()
|
||||
defer ds.mu.Unlock()
|
||||
|
||||
if ds.databaseService == nil || ds.databaseService.db == nil {
|
||||
return errors.New("database service not available")
|
||||
}
|
||||
|
||||
_, err := ds.databaseService.db.Exec(sqlUpdateDocumentTitle, title, time.Now().Format("2006-01-02 15:04:05"), id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update document title: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteDocument marks a document as deleted (default document with ID=1 cannot be deleted)
|
||||
func (ds *DocumentService) DeleteDocument(id int64) error {
|
||||
if ds.databaseService == nil || ds.databaseService.db == nil {
|
||||
ds.logger.Error("database service not available")
|
||||
return errors.New("database service not available")
|
||||
}
|
||||
|
||||
// 不允许删除默认文档
|
||||
if id == sqlDefaultDocumentID {
|
||||
return fmt.Errorf("cannot delete the default document")
|
||||
}
|
||||
|
||||
// 先检查文档是否存在和锁定状态(不加锁避免死锁)
|
||||
doc, err := ds.GetDocumentByID(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get document: %w", err)
|
||||
return err
|
||||
}
|
||||
if doc == nil {
|
||||
return fmt.Errorf("document not found: %d", id)
|
||||
}
|
||||
if doc.IsLocked {
|
||||
if doc.Locked {
|
||||
return fmt.Errorf("cannot delete locked document: %d", id)
|
||||
}
|
||||
|
||||
// 现在加锁执行删除操作
|
||||
ds.mu.Lock()
|
||||
defer ds.mu.Unlock()
|
||||
|
||||
_, err = ds.databaseService.db.Exec(sqlMarkDocumentAsDeleted, time.Now().Format("2006-01-02 15:04:05"), id)
|
||||
count, err := s.db.Client.Document.Query().Count(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to mark document as deleted: %w", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
if count <= 1 {
|
||||
return errors.New("cannot delete the last document")
|
||||
}
|
||||
return s.db.Client.Document.DeleteOneID(id).Exec(ctx)
|
||||
}
|
||||
|
||||
// RestoreDocument restores a deleted document
|
||||
func (ds *DocumentService) RestoreDocument(id int64) error {
|
||||
ds.mu.Lock()
|
||||
defer ds.mu.Unlock()
|
||||
|
||||
if ds.databaseService == nil || ds.databaseService.db == nil {
|
||||
return errors.New("database service not available")
|
||||
}
|
||||
|
||||
_, err := ds.databaseService.db.Exec(sqlRestoreDocument, time.Now().Format("2006-01-02 15:04:05"), id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to restore document: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListAllDocumentsMeta lists all active (non-deleted) document metadata
|
||||
func (ds *DocumentService) ListAllDocumentsMeta() ([]*models.Document, error) {
|
||||
ds.mu.RLock()
|
||||
defer ds.mu.RUnlock()
|
||||
|
||||
if ds.databaseService == nil || ds.databaseService.db == nil {
|
||||
return nil, errors.New("database service not available")
|
||||
}
|
||||
|
||||
rows, err := ds.databaseService.db.Query(sqlListAllDocumentsMeta)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list document meta: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var documents []*models.Document
|
||||
for rows.Next() {
|
||||
doc := &models.Document{IsDeleted: false}
|
||||
var isLocked int
|
||||
|
||||
err := rows.Scan(
|
||||
&doc.ID,
|
||||
&doc.Title,
|
||||
&doc.CreatedAt,
|
||||
&doc.UpdatedAt,
|
||||
&isLocked,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan document row: %w", err)
|
||||
}
|
||||
|
||||
doc.IsLocked = isLocked == 1
|
||||
documents = append(documents, doc)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating document rows: %w", err)
|
||||
}
|
||||
|
||||
return documents, nil
|
||||
}
|
||||
|
||||
// ListDeletedDocumentsMeta lists all deleted document metadata
|
||||
func (ds *DocumentService) ListDeletedDocumentsMeta() ([]*models.Document, error) {
|
||||
ds.mu.RLock()
|
||||
defer ds.mu.RUnlock()
|
||||
|
||||
if ds.databaseService == nil || ds.databaseService.db == nil {
|
||||
return nil, errors.New("database service not available")
|
||||
}
|
||||
|
||||
rows, err := ds.databaseService.db.Query(sqlListDeletedDocumentsMeta)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list deleted document meta: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var documents []*models.Document
|
||||
for rows.Next() {
|
||||
doc := &models.Document{IsDeleted: true}
|
||||
var isLocked int
|
||||
|
||||
err := rows.Scan(
|
||||
&doc.ID,
|
||||
&doc.Title,
|
||||
&doc.CreatedAt,
|
||||
&doc.UpdatedAt,
|
||||
&isLocked,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan document row: %w", err)
|
||||
}
|
||||
|
||||
doc.IsLocked = isLocked == 1
|
||||
documents = append(documents, doc)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating deleted document rows: %w", err)
|
||||
}
|
||||
|
||||
return documents, nil
|
||||
// ListAllDocumentsMeta 列出所有文档
|
||||
func (s *DocumentService) ListAllDocumentsMeta(ctx context.Context) ([]*ent.Document, error) {
|
||||
return s.db.Client.Document.Query().Select(document.FieldID, document.FieldTitle, document.FieldUpdatedAt, document.FieldLocked, document.FieldCreatedAt).
|
||||
Order(ent.Desc(document.FieldUpdatedAt)).
|
||||
All(ctx)
|
||||
}
|
||||
|
||||
@@ -2,380 +2,165 @@ package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
"voidraft/internal/models"
|
||||
|
||||
"voidraft/internal/models/ent"
|
||||
"voidraft/internal/models/ent/extension"
|
||||
"voidraft/internal/models/schema/mixin"
|
||||
|
||||
"github.com/wailsapp/wails/v3/pkg/application"
|
||||
"github.com/wailsapp/wails/v3/pkg/services/log"
|
||||
)
|
||||
|
||||
// SQL constants for extension operations
|
||||
const (
|
||||
// Extension operations
|
||||
sqlGetAllExtensions = `
|
||||
SELECT id, enabled, is_default, config
|
||||
FROM extensions
|
||||
ORDER BY id`
|
||||
|
||||
sqlGetExtensionByID = `
|
||||
SELECT id, enabled, is_default, config
|
||||
FROM extensions
|
||||
WHERE id = ?`
|
||||
|
||||
sqlInsertExtension = `
|
||||
INSERT INTO extensions (id, enabled, is_default, config, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`
|
||||
|
||||
sqlUpdateExtension = `
|
||||
UPDATE extensions
|
||||
SET enabled = ?, config = ?, updated_at = ?
|
||||
WHERE id = ?`
|
||||
|
||||
sqlDeleteAllExtensions = `DELETE FROM extensions`
|
||||
)
|
||||
|
||||
// ExtensionService 扩展管理服务
|
||||
// ExtensionService 扩展服务
|
||||
type ExtensionService struct {
|
||||
databaseService *DatabaseService
|
||||
logger *log.LogService
|
||||
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
initOnce sync.Once
|
||||
db *DatabaseService
|
||||
logger *log.LogService
|
||||
}
|
||||
|
||||
// ExtensionError 扩展错误
|
||||
type ExtensionError struct {
|
||||
Operation string
|
||||
Extension string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *ExtensionError) Error() string {
|
||||
if e.Extension != "" {
|
||||
return fmt.Sprintf("extension %s for %s: %v", e.Operation, e.Extension, e.Err)
|
||||
}
|
||||
return fmt.Sprintf("extension %s: %v", e.Operation, e.Err)
|
||||
}
|
||||
|
||||
func (e *ExtensionError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
func (e *ExtensionError) Is(target error) bool {
|
||||
var extensionError *ExtensionError
|
||||
return errors.As(target, &extensionError)
|
||||
}
|
||||
|
||||
// NewExtensionService 创建扩展服务实例
|
||||
func NewExtensionService(databaseService *DatabaseService, logger *log.LogService) *ExtensionService {
|
||||
// NewExtensionService 创建扩展服务
|
||||
func NewExtensionService(db *DatabaseService, logger *log.LogService) *ExtensionService {
|
||||
if logger == nil {
|
||||
logger = log.New()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
service := &ExtensionService{
|
||||
databaseService: databaseService,
|
||||
logger: logger,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
return service
|
||||
return &ExtensionService{db: db, logger: logger}
|
||||
}
|
||||
|
||||
// initialize 初始化配置
|
||||
func (es *ExtensionService) initialize() {
|
||||
es.initOnce.Do(func() {
|
||||
if err := es.initDatabase(); err != nil {
|
||||
es.logger.Error("failed to initialize extension database", "error", err)
|
||||
}
|
||||
})
|
||||
// ServiceStartup 服务启动
|
||||
func (s *ExtensionService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error {
|
||||
return s.SyncExtensions(ctx)
|
||||
}
|
||||
|
||||
// initDatabase 初始化数据库数据
|
||||
func (es *ExtensionService) initDatabase() error {
|
||||
es.mu.Lock()
|
||||
defer es.mu.Unlock()
|
||||
|
||||
if es.databaseService == nil || es.databaseService.db == nil {
|
||||
return &ExtensionError{"check_db", "", errors.New("database service not available")}
|
||||
// SyncExtensions 同步扩展配置
|
||||
func (s *ExtensionService) SyncExtensions(ctx context.Context) error {
|
||||
defaults := models.NewDefaultExtensions()
|
||||
definedKeys := make(map[models.ExtensionKey]models.Extension)
|
||||
for _, ext := range defaults {
|
||||
definedKeys[ext.Key] = ext
|
||||
}
|
||||
|
||||
// 检查是否已有扩展数据
|
||||
var count int64
|
||||
err := es.databaseService.db.QueryRow("SELECT COUNT(*) FROM extensions").Scan(&count)
|
||||
// 获取数据库中已有的扩展
|
||||
existing, err := s.db.Client.Extension.Query().All(ctx)
|
||||
if err != nil {
|
||||
return &ExtensionError{"check_extensions_count", "", err}
|
||||
return fmt.Errorf("find extensions error: %w", err)
|
||||
}
|
||||
|
||||
// 如果没有数据,插入默认配置
|
||||
if count == 0 {
|
||||
if err := es.insertDefaultExtensions(); err != nil {
|
||||
es.logger.Error("Failed to insert default extensions", "error", err)
|
||||
return err
|
||||
existingKeys := make(map[string]bool)
|
||||
for _, ext := range existing {
|
||||
existingKeys[ext.Key] = true
|
||||
}
|
||||
|
||||
// 批量添加缺失的扩展
|
||||
var builders []*ent.ExtensionCreate
|
||||
for key, ext := range definedKeys {
|
||||
if !existingKeys[string(key)] {
|
||||
builders = append(builders, s.db.Client.Extension.Create().
|
||||
SetKey(string(ext.Key)).
|
||||
SetEnabled(ext.Enabled).
|
||||
SetConfig(ext.Config))
|
||||
}
|
||||
} else {
|
||||
// 检查并补充缺失的扩展
|
||||
if err := es.syncExtensions(); err != nil {
|
||||
es.logger.Error("Failed to ensure all extensions exist", "error", err)
|
||||
return err
|
||||
}
|
||||
if len(builders) > 0 {
|
||||
if _, err := s.db.Client.Extension.CreateBulk(builders...).Save(ctx); err != nil {
|
||||
return fmt.Errorf("bulk insert extensions error: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 批量删除废弃的扩展
|
||||
var deleteIDs []int
|
||||
for _, ext := range existing {
|
||||
if _, ok := definedKeys[models.ExtensionKey(ext.Key)]; !ok {
|
||||
deleteIDs = append(deleteIDs, ext.ID)
|
||||
}
|
||||
}
|
||||
if len(deleteIDs) > 0 {
|
||||
if _, err := s.db.Client.Extension.Delete().
|
||||
Where(extension.IDIn(deleteIDs...)).
|
||||
Exec(mixin.SkipSoftDelete(ctx)); err != nil {
|
||||
return fmt.Errorf("bulk delete extensions error: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// insertDefaultExtensions 插入默认扩展配置
|
||||
func (es *ExtensionService) insertDefaultExtensions() error {
|
||||
defaultSettings := models.NewDefaultExtensionSettings()
|
||||
now := time.Now().Format("2006-01-02 15:04:05")
|
||||
|
||||
for _, ext := range defaultSettings.Extensions {
|
||||
configJSON, err := json.Marshal(ext.Config)
|
||||
if err != nil {
|
||||
return &ExtensionError{"marshal_config", string(ext.ID), err}
|
||||
}
|
||||
|
||||
_, err = es.databaseService.db.Exec(sqlInsertExtension,
|
||||
string(ext.ID),
|
||||
ext.Enabled,
|
||||
ext.IsDefault,
|
||||
string(configJSON),
|
||||
now,
|
||||
now,
|
||||
)
|
||||
if err != nil {
|
||||
return &ExtensionError{"insert_extension", string(ext.ID), err}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
// GetAllExtensions 获取所有扩展
|
||||
func (s *ExtensionService) GetAllExtensions(ctx context.Context) ([]*ent.Extension, error) {
|
||||
return s.db.Client.Extension.Query().All(ctx)
|
||||
}
|
||||
|
||||
// syncExtensions 确保数据库中的扩展与代码定义同步
|
||||
func (es *ExtensionService) syncExtensions() error {
|
||||
defaultSettings := models.NewDefaultExtensionSettings()
|
||||
now := time.Now().Format("2006-01-02 15:04:05")
|
||||
|
||||
// 构建代码中定义的扩展ID集合
|
||||
definedExtensions := make(map[string]bool)
|
||||
for _, ext := range defaultSettings.Extensions {
|
||||
definedExtensions[string(ext.ID)] = true
|
||||
}
|
||||
|
||||
// 1. 添加缺失的扩展
|
||||
for _, ext := range defaultSettings.Extensions {
|
||||
var exists int
|
||||
err := es.databaseService.db.QueryRow("SELECT COUNT(*) FROM extensions WHERE id = ?", string(ext.ID)).Scan(&exists)
|
||||
if err != nil {
|
||||
return &ExtensionError{"check_extension_exists", string(ext.ID), err}
|
||||
}
|
||||
|
||||
if exists == 0 {
|
||||
configJSON, err := json.Marshal(ext.Config)
|
||||
if err != nil {
|
||||
return &ExtensionError{"marshal_config", string(ext.ID), err}
|
||||
}
|
||||
|
||||
_, err = es.databaseService.db.Exec(sqlInsertExtension,
|
||||
string(ext.ID),
|
||||
ext.Enabled,
|
||||
ext.IsDefault,
|
||||
string(configJSON),
|
||||
now,
|
||||
now,
|
||||
)
|
||||
if err != nil {
|
||||
return &ExtensionError{"insert_missing_extension", string(ext.ID), err}
|
||||
}
|
||||
es.logger.Info("Added missing extension to database", "id", ext.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 删除数据库中已不存在于代码定义的扩展
|
||||
rows, err := es.databaseService.db.Query("SELECT id FROM extensions")
|
||||
// GetExtensionByKey 根据Key获取扩展
|
||||
func (s *ExtensionService) GetExtensionByKey(ctx context.Context, key string) (*ent.Extension, error) {
|
||||
ext, err := s.db.Client.Extension.Query().
|
||||
Where(extension.Key(key)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return &ExtensionError{"query_all_extension_ids", "", err}
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var toDelete []string
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return &ExtensionError{"scan_extension_id", "", err}
|
||||
}
|
||||
if !definedExtensions[id] {
|
||||
toDelete = append(toDelete, id)
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("get extension error: %w", err)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return &ExtensionError{"iterate_extension_ids", "", err}
|
||||
}
|
||||
|
||||
// 删除不再定义的扩展
|
||||
for _, id := range toDelete {
|
||||
_, err := es.databaseService.db.Exec("DELETE FROM extensions WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return &ExtensionError{"delete_obsolete_extension", id, err}
|
||||
}
|
||||
es.logger.Info("Removed obsolete extension from database", "id", id)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ServiceStartup 启动时调用
|
||||
func (es *ExtensionService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error {
|
||||
es.ctx = ctx
|
||||
|
||||
// 初始化数据库
|
||||
var initErr error
|
||||
es.initOnce.Do(func() {
|
||||
if err := es.initDatabase(); err != nil {
|
||||
es.logger.Error("failed to initialize extension database", "error", err)
|
||||
initErr = err
|
||||
}
|
||||
})
|
||||
return initErr
|
||||
}
|
||||
|
||||
// GetAllExtensions 获取所有扩展配置
|
||||
func (es *ExtensionService) GetAllExtensions() ([]models.Extension, error) {
|
||||
es.mu.RLock()
|
||||
defer es.mu.RUnlock()
|
||||
|
||||
if es.databaseService == nil || es.databaseService.db == nil {
|
||||
return nil, &ExtensionError{"query_db", "", errors.New("database service not available")}
|
||||
}
|
||||
|
||||
rows, err := es.databaseService.db.Query(sqlGetAllExtensions)
|
||||
if err != nil {
|
||||
return nil, &ExtensionError{"query_extensions", "", err}
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var extensions []models.Extension
|
||||
for rows.Next() {
|
||||
var ext models.Extension
|
||||
var id string
|
||||
var configJSON string
|
||||
var enabled, isDefault int
|
||||
|
||||
err := rows.Scan(
|
||||
&id,
|
||||
&enabled,
|
||||
&isDefault,
|
||||
&configJSON,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, &ExtensionError{"scan_extension", "", err}
|
||||
}
|
||||
|
||||
ext.ID = models.ExtensionID(id)
|
||||
ext.Enabled = enabled == 1
|
||||
ext.IsDefault = isDefault == 1
|
||||
|
||||
var config models.ExtensionConfig
|
||||
if err := json.Unmarshal([]byte(configJSON), &config); err != nil {
|
||||
return nil, &ExtensionError{"unmarshal_config", id, err}
|
||||
}
|
||||
ext.Config = config
|
||||
|
||||
extensions = append(extensions, ext)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, &ExtensionError{"iterate_extensions", "", err}
|
||||
}
|
||||
|
||||
return extensions, nil
|
||||
return ext, nil
|
||||
}
|
||||
|
||||
// UpdateExtensionEnabled 更新扩展启用状态
|
||||
func (es *ExtensionService) UpdateExtensionEnabled(id models.ExtensionID, enabled bool) error {
|
||||
return es.UpdateExtensionState(id, enabled, nil)
|
||||
}
|
||||
|
||||
// UpdateExtensionState 更新扩展状态
|
||||
func (es *ExtensionService) UpdateExtensionState(id models.ExtensionID, enabled bool, config models.ExtensionConfig) error {
|
||||
es.mu.Lock()
|
||||
defer es.mu.Unlock()
|
||||
|
||||
if es.databaseService == nil || es.databaseService.db == nil {
|
||||
return &ExtensionError{"check_db", string(id), errors.New("database service not available")}
|
||||
}
|
||||
|
||||
var configJSON []byte
|
||||
var err error
|
||||
|
||||
if config != nil {
|
||||
configJSON, err = json.Marshal(config)
|
||||
if err != nil {
|
||||
return &ExtensionError{"marshal_config", string(id), err}
|
||||
}
|
||||
} else {
|
||||
// 如果没有提供配置,保持原有配置
|
||||
var currentConfigJSON string
|
||||
err := es.databaseService.db.QueryRow("SELECT config FROM extensions WHERE id = ?", string(id)).Scan(¤tConfigJSON)
|
||||
if err != nil {
|
||||
return &ExtensionError{"query_current_config", string(id), err}
|
||||
}
|
||||
configJSON = []byte(currentConfigJSON)
|
||||
}
|
||||
|
||||
_, err = es.databaseService.db.Exec(sqlUpdateExtension,
|
||||
enabled,
|
||||
string(configJSON),
|
||||
time.Now().Format("2006-01-02 15:04:05"),
|
||||
string(id))
|
||||
func (s *ExtensionService) UpdateExtensionEnabled(ctx context.Context, key string, enabled bool) error {
|
||||
ext, err := s.GetExtensionByKey(ctx, key)
|
||||
if err != nil {
|
||||
return &ExtensionError{"update_extension", string(id), err}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetExtensionToDefault 重置扩展到默认状态
|
||||
func (es *ExtensionService) ResetExtensionToDefault(id models.ExtensionID) error {
|
||||
// 获取默认配置
|
||||
defaultSettings := models.NewDefaultExtensionSettings()
|
||||
defaultExtension := defaultSettings.GetExtensionByID(id)
|
||||
if defaultExtension == nil {
|
||||
return &ExtensionError{"default_extension_not_found", string(id), nil}
|
||||
}
|
||||
|
||||
return es.UpdateExtensionState(id, defaultExtension.Enabled, defaultExtension.Config)
|
||||
}
|
||||
|
||||
// ResetAllExtensionsToDefault 重置所有扩展到默认状态
|
||||
func (es *ExtensionService) ResetAllExtensionsToDefault() error {
|
||||
es.mu.Lock()
|
||||
defer es.mu.Unlock()
|
||||
|
||||
if es.databaseService == nil || es.databaseService.db == nil {
|
||||
return &ExtensionError{"check_db", "", errors.New("database service not available")}
|
||||
}
|
||||
|
||||
// 删除所有现有扩展
|
||||
_, err := es.databaseService.db.Exec(sqlDeleteAllExtensions)
|
||||
if err != nil {
|
||||
return &ExtensionError{"delete_all_extensions", "", err}
|
||||
}
|
||||
|
||||
// 插入默认扩展配置
|
||||
if err := es.insertDefaultExtensions(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
if ext == nil {
|
||||
return fmt.Errorf("extension not found: %s", key)
|
||||
}
|
||||
return s.db.Client.Extension.UpdateOneID(ext.ID).
|
||||
SetEnabled(enabled).
|
||||
Exec(ctx)
|
||||
}
|
||||
|
||||
// UpdateExtensionConfig 更新扩展配置
|
||||
func (s *ExtensionService) UpdateExtensionConfig(ctx context.Context, key string, config map[string]interface{}) error {
|
||||
ext, err := s.GetExtensionByKey(ctx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ext == nil {
|
||||
return fmt.Errorf("extension not found: %s", key)
|
||||
}
|
||||
return s.db.Client.Extension.UpdateOneID(ext.ID).
|
||||
SetConfig(config).
|
||||
Exec(ctx)
|
||||
}
|
||||
|
||||
// ResetExtensionConfig 重置单个扩展到默认状态
|
||||
func (s *ExtensionService) ResetExtensionConfig(ctx context.Context, key string) error {
|
||||
defaults := models.NewDefaultExtensions()
|
||||
var defaultExt *models.Extension
|
||||
for _, ext := range defaults {
|
||||
if string(ext.Key) == key {
|
||||
defaultExt = &ext
|
||||
break
|
||||
}
|
||||
}
|
||||
if defaultExt == nil {
|
||||
return fmt.Errorf("default extension not found: %s", key)
|
||||
}
|
||||
|
||||
ext, err := s.GetExtensionByKey(ctx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ext == nil {
|
||||
return fmt.Errorf("extension not found: %s", key)
|
||||
}
|
||||
|
||||
return s.db.Client.Extension.UpdateOneID(ext.ID).
|
||||
SetEnabled(defaultExt.Enabled).
|
||||
SetConfig(defaultExt.Config).
|
||||
Exec(ctx)
|
||||
}
|
||||
|
||||
// GetDefaultExtensions 获取默认扩展配置(用于前端绑定生成)
|
||||
func (s *ExtensionService) GetDefaultExtensions() []models.Extension {
|
||||
return models.NewDefaultExtensions()
|
||||
}
|
||||
|
||||
@@ -2,212 +2,141 @@ package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
"voidraft/internal/models"
|
||||
|
||||
"voidraft/internal/models/ent"
|
||||
"voidraft/internal/models/ent/keybinding"
|
||||
"voidraft/internal/models/schema/mixin"
|
||||
|
||||
"github.com/wailsapp/wails/v3/pkg/application"
|
||||
"github.com/wailsapp/wails/v3/pkg/services/log"
|
||||
)
|
||||
|
||||
// SQL 查询语句
|
||||
const (
|
||||
// 快捷键操作
|
||||
sqlGetAllKeyBindings = `
|
||||
SELECT command, extension, key, enabled, is_default
|
||||
FROM key_bindings
|
||||
ORDER BY command
|
||||
`
|
||||
|
||||
sqlGetKeyBindingByCommand = `
|
||||
SELECT command, extension, key, enabled, is_default
|
||||
FROM key_bindings
|
||||
WHERE command = ?
|
||||
`
|
||||
|
||||
sqlInsertKeyBinding = `
|
||||
INSERT INTO key_bindings (command, extension, key, enabled, is_default, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
sqlUpdateKeyBinding = `
|
||||
UPDATE key_bindings
|
||||
SET extension = ?, key = ?, enabled = ?, updated_at = ?
|
||||
WHERE command = ?
|
||||
`
|
||||
|
||||
sqlDeleteKeyBinding = `
|
||||
DELETE FROM key_bindings
|
||||
WHERE command = ?
|
||||
`
|
||||
|
||||
sqlDeleteAllKeyBindings = `
|
||||
DELETE FROM key_bindings
|
||||
`
|
||||
)
|
||||
|
||||
// KeyBindingService 快捷键管理服务
|
||||
// KeyBindingService 快捷键服务
|
||||
type KeyBindingService struct {
|
||||
databaseService *DatabaseService
|
||||
logger *log.LogService
|
||||
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
initOnce sync.Once
|
||||
db *DatabaseService
|
||||
logger *log.LogService
|
||||
}
|
||||
|
||||
// KeyBindingError 快捷键错误
|
||||
type KeyBindingError struct {
|
||||
Operation string
|
||||
Command string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *KeyBindingError) Error() string {
|
||||
if e.Command != "" {
|
||||
return fmt.Sprintf("keybinding %s for %s: %v", e.Operation, e.Command, e.Err)
|
||||
}
|
||||
return fmt.Sprintf("keybinding %s: %v", e.Operation, e.Err)
|
||||
}
|
||||
|
||||
func (e *KeyBindingError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
func (e *KeyBindingError) Is(target error) bool {
|
||||
var keyBindingError *KeyBindingError
|
||||
return errors.As(target, &keyBindingError)
|
||||
}
|
||||
|
||||
// NewKeyBindingService 创建快捷键服务实例
|
||||
func NewKeyBindingService(databaseService *DatabaseService, logger *log.LogService) *KeyBindingService {
|
||||
// NewKeyBindingService 创建快捷键服务
|
||||
func NewKeyBindingService(db *DatabaseService, logger *log.LogService) *KeyBindingService {
|
||||
if logger == nil {
|
||||
logger = log.New()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
service := &KeyBindingService{
|
||||
databaseService: databaseService,
|
||||
logger: logger,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
return service
|
||||
return &KeyBindingService{db: db, logger: logger}
|
||||
}
|
||||
|
||||
// initDatabase 初始化数据库数据
|
||||
func (kbs *KeyBindingService) initDatabase() error {
|
||||
kbs.mu.Lock()
|
||||
defer kbs.mu.Unlock()
|
||||
// ServiceStartup 服务启动
|
||||
func (s *KeyBindingService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error {
|
||||
return s.SyncKeyBindings(ctx)
|
||||
}
|
||||
|
||||
if kbs.databaseService == nil || kbs.databaseService.db == nil {
|
||||
return &KeyBindingError{"check_db", "", errors.New("database service not available")}
|
||||
// SyncKeyBindings 同步快捷键配置
|
||||
func (s *KeyBindingService) SyncKeyBindings(ctx context.Context) error {
|
||||
defaults := models.NewDefaultKeyBindings()
|
||||
definedKeys := make(map[models.KeyBindingKey]models.KeyBinding)
|
||||
for _, kb := range defaults {
|
||||
definedKeys[kb.Key] = kb
|
||||
}
|
||||
|
||||
// 检查是否已有快捷键数据
|
||||
var count int64
|
||||
err := kbs.databaseService.db.QueryRow("SELECT COUNT(*) FROM key_bindings").Scan(&count)
|
||||
// 获取数据库中已有的快捷键
|
||||
existing, err := s.db.Client.KeyBinding.Query().All(ctx)
|
||||
if err != nil {
|
||||
return &KeyBindingError{"check_keybindings_count", "", err}
|
||||
return fmt.Errorf("find key bindings error: %w", err)
|
||||
}
|
||||
|
||||
// 如果没有数据,插入默认配置
|
||||
if count == 0 {
|
||||
if err := kbs.insertDefaultKeyBindings(); err != nil {
|
||||
kbs.logger.Error("Failed to insert default key bindings", "error", err)
|
||||
return err
|
||||
existingKeys := make(map[string]bool)
|
||||
for _, kb := range existing {
|
||||
existingKeys[kb.Key] = true
|
||||
}
|
||||
|
||||
// 批量添加缺失的快捷键
|
||||
var builders []*ent.KeyBindingCreate
|
||||
for key, kb := range definedKeys {
|
||||
if !existingKeys[string(key)] {
|
||||
create := s.db.Client.KeyBinding.Create().
|
||||
SetKey(string(kb.Key)).
|
||||
SetCommand(kb.Command).
|
||||
SetEnabled(kb.Enabled)
|
||||
if kb.Extension != "" {
|
||||
create.SetExtension(string(kb.Extension))
|
||||
}
|
||||
builders = append(builders, create)
|
||||
}
|
||||
}
|
||||
if len(builders) > 0 {
|
||||
if _, err := s.db.Client.KeyBinding.CreateBulk(builders...).Save(ctx); err != nil {
|
||||
return fmt.Errorf("bulk insert key bindings error: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 批量删除废弃的快捷键(硬删除)
|
||||
var deleteIDs []int
|
||||
for _, kb := range existing {
|
||||
if _, ok := definedKeys[models.KeyBindingKey(kb.Key)]; !ok {
|
||||
deleteIDs = append(deleteIDs, kb.ID)
|
||||
}
|
||||
}
|
||||
if len(deleteIDs) > 0 {
|
||||
if _, err := s.db.Client.KeyBinding.Delete().
|
||||
Where(keybinding.IDIn(deleteIDs...)).
|
||||
Exec(mixin.SkipSoftDelete(ctx)); err != nil {
|
||||
return fmt.Errorf("bulk delete key bindings error: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// insertDefaultKeyBindings 插入默认快捷键配置
|
||||
func (kbs *KeyBindingService) insertDefaultKeyBindings() error {
|
||||
defaultConfig := models.NewDefaultKeyBindingConfig()
|
||||
now := time.Now().Format("2006-01-02 15:04:05")
|
||||
|
||||
for _, kb := range defaultConfig.KeyBindings {
|
||||
_, err := kbs.databaseService.db.Exec(sqlInsertKeyBinding,
|
||||
string(kb.Command), // 转换为字符串存储
|
||||
string(kb.Extension), // 转换为字符串存储
|
||||
kb.Key,
|
||||
kb.Enabled,
|
||||
kb.IsDefault,
|
||||
now,
|
||||
now,
|
||||
)
|
||||
if err != nil {
|
||||
return &KeyBindingError{"insert_keybinding", string(kb.Command), err}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
// GetAllKeyBindings 获取所有快捷键
|
||||
func (s *KeyBindingService) GetAllKeyBindings(ctx context.Context) ([]*ent.KeyBinding, error) {
|
||||
return s.db.Client.KeyBinding.Query().All(ctx)
|
||||
}
|
||||
|
||||
// GetAllKeyBindings 获取所有快捷键配置
|
||||
func (kbs *KeyBindingService) GetAllKeyBindings() ([]models.KeyBinding, error) {
|
||||
kbs.mu.RLock()
|
||||
defer kbs.mu.RUnlock()
|
||||
|
||||
if kbs.databaseService == nil || kbs.databaseService.db == nil {
|
||||
return nil, &KeyBindingError{"query_db", "", errors.New("database service not available")}
|
||||
}
|
||||
|
||||
rows, err := kbs.databaseService.db.Query(sqlGetAllKeyBindings)
|
||||
// GetKeyBindingByKey 根据Key获取快捷键
|
||||
func (s *KeyBindingService) GetKeyBindingByKey(ctx context.Context, key string) (*ent.KeyBinding, error) {
|
||||
kb, err := s.db.Client.KeyBinding.Query().
|
||||
Where(keybinding.Key(key)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, &KeyBindingError{"query_keybindings", "", err}
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var keyBindings []models.KeyBinding
|
||||
for rows.Next() {
|
||||
var kb models.KeyBinding
|
||||
var command, extension string
|
||||
var enabled, isDefault int
|
||||
|
||||
err := rows.Scan(
|
||||
&command,
|
||||
&extension,
|
||||
&kb.Key,
|
||||
&enabled,
|
||||
&isDefault,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, &KeyBindingError{"scan_keybinding", "", err}
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
kb.Command = models.KeyBindingCommand(command)
|
||||
kb.Extension = models.ExtensionID(extension)
|
||||
kb.Enabled = enabled == 1
|
||||
kb.IsDefault = isDefault == 1
|
||||
|
||||
keyBindings = append(keyBindings, kb)
|
||||
return nil, fmt.Errorf("get key binding error: %w", err)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, &KeyBindingError{"iterate_keybindings", "", err}
|
||||
}
|
||||
|
||||
return keyBindings, nil
|
||||
return kb, nil
|
||||
}
|
||||
|
||||
// ServiceStartup 启动时调用
|
||||
func (kbs *KeyBindingService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error {
|
||||
kbs.ctx = ctx
|
||||
// 初始化数据库
|
||||
var initErr error
|
||||
kbs.initOnce.Do(func() {
|
||||
if err := kbs.initDatabase(); err != nil {
|
||||
kbs.logger.Error("failed to initialize keybinding database", "error", err)
|
||||
initErr = err
|
||||
}
|
||||
})
|
||||
return initErr
|
||||
// UpdateKeyBindingCommand 更新快捷键命令
|
||||
func (s *KeyBindingService) UpdateKeyBindingCommand(ctx context.Context, key string, command string) error {
|
||||
kb, err := s.GetKeyBindingByKey(ctx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if kb == nil {
|
||||
return fmt.Errorf("key binding not found: %s", key)
|
||||
}
|
||||
return s.db.Client.KeyBinding.UpdateOneID(kb.ID).
|
||||
SetCommand(command).
|
||||
Exec(ctx)
|
||||
}
|
||||
|
||||
// UpdateKeyBindingEnabled 更新快捷键启用状态
|
||||
func (s *KeyBindingService) UpdateKeyBindingEnabled(ctx context.Context, key string, enabled bool) error {
|
||||
kb, err := s.GetKeyBindingByKey(ctx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if kb == nil {
|
||||
return fmt.Errorf("key binding not found: %s", key)
|
||||
}
|
||||
return s.db.Client.KeyBinding.UpdateOneID(kb.ID).
|
||||
SetEnabled(enabled).
|
||||
Exec(ctx)
|
||||
}
|
||||
|
||||
// GetDefaultKeyBindings 获取默认快捷键配置
|
||||
func (s *KeyBindingService) GetDefaultKeyBindings() []models.KeyBinding {
|
||||
return models.NewDefaultKeyBindings()
|
||||
}
|
||||
|
||||
@@ -140,7 +140,6 @@ func (sm *ServiceManager) GetServices() []application.Service {
|
||||
application.NewService(sm.systemService),
|
||||
application.NewService(sm.hotkeyService),
|
||||
application.NewService(sm.dialogService),
|
||||
application.NewService(sm.trayService),
|
||||
application.NewService(sm.startupService),
|
||||
application.NewService(sm.selfUpdateService),
|
||||
application.NewService(sm.translationService),
|
||||
|
||||
@@ -2,12 +2,12 @@ package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
"voidraft/internal/models"
|
||||
|
||||
"voidraft/internal/models/ent"
|
||||
"voidraft/internal/models/ent/theme"
|
||||
"voidraft/internal/models/schema/mixin"
|
||||
|
||||
"github.com/wailsapp/wails/v3/pkg/application"
|
||||
"github.com/wailsapp/wails/v3/pkg/services/log"
|
||||
@@ -15,152 +15,90 @@ import (
|
||||
|
||||
// ThemeService 主题服务
|
||||
type ThemeService struct {
|
||||
databaseService *DatabaseService
|
||||
logger *log.LogService
|
||||
ctx context.Context
|
||||
db *DatabaseService
|
||||
logger *log.LogService
|
||||
}
|
||||
|
||||
// NewThemeService 创建新的主题服务
|
||||
func NewThemeService(databaseService *DatabaseService, logger *log.LogService) *ThemeService {
|
||||
// NewThemeService 创建主题服务
|
||||
func NewThemeService(db *DatabaseService, logger *log.LogService) *ThemeService {
|
||||
if logger == nil {
|
||||
logger = log.New()
|
||||
}
|
||||
|
||||
return &ThemeService{
|
||||
databaseService: databaseService,
|
||||
logger: logger,
|
||||
}
|
||||
return &ThemeService{db: db, logger: logger}
|
||||
}
|
||||
|
||||
// ServiceStartup 服务启动
|
||||
func (ts *ThemeService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error {
|
||||
ts.ctx = ctx
|
||||
func (s *ThemeService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// getDB 获取数据库连接
|
||||
func (ts *ThemeService) getDB() *sql.DB {
|
||||
return ts.databaseService.db
|
||||
}
|
||||
|
||||
// GetThemeByName 通过名称获取主题覆盖,若不存在则返回 nil
|
||||
func (ts *ThemeService) GetThemeByName(name string) (*models.Theme, error) {
|
||||
db := ts.getDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("database not available")
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(name)
|
||||
// GetThemeByKey 根据Key获取主题
|
||||
func (s *ThemeService) GetThemeByKey(ctx context.Context, key string) (*ent.Theme, error) {
|
||||
trimmed := strings.TrimSpace(key)
|
||||
if trimmed == "" {
|
||||
return nil, fmt.Errorf("theme name cannot be empty")
|
||||
return nil, fmt.Errorf("theme key cannot be empty")
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT id, name, type, colors, is_default, created_at, updated_at
|
||||
FROM themes
|
||||
WHERE name = ?
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
theme := &models.Theme{}
|
||||
err := db.QueryRow(query, trimmed).Scan(
|
||||
&theme.ID,
|
||||
&theme.Name,
|
||||
&theme.Type,
|
||||
&theme.Colors,
|
||||
&theme.IsDefault,
|
||||
&theme.CreatedAt,
|
||||
&theme.UpdatedAt,
|
||||
)
|
||||
|
||||
t, err := s.db.Client.Theme.Query().
|
||||
Where(theme.Key(trimmed)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to query theme: %w", err)
|
||||
return nil, fmt.Errorf("get theme error: %w", err)
|
||||
}
|
||||
|
||||
return theme, nil
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// UpdateTheme 保存或更新主题覆盖
|
||||
func (ts *ThemeService) UpdateTheme(name string, colors models.ThemeColorConfig) error {
|
||||
db := ts.getDB()
|
||||
if db == nil {
|
||||
return fmt.Errorf("database not available")
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(name)
|
||||
// UpdateTheme 保存或更新主题
|
||||
func (s *ThemeService) UpdateTheme(ctx context.Context, key string, colors map[string]interface{}) error {
|
||||
trimmed := strings.TrimSpace(key)
|
||||
if trimmed == "" {
|
||||
return fmt.Errorf("theme name cannot be empty")
|
||||
return fmt.Errorf("theme key cannot be empty")
|
||||
}
|
||||
|
||||
if colors == nil {
|
||||
colors = models.ThemeColorConfig{}
|
||||
colors = map[string]interface{}{}
|
||||
}
|
||||
colors["themeName"] = trimmed
|
||||
|
||||
themeType := models.ThemeTypeDark
|
||||
// 判断主题类型
|
||||
themeType := theme.TypeDark
|
||||
if raw, ok := colors["dark"].(bool); ok && !raw {
|
||||
themeType = models.ThemeTypeLight
|
||||
themeType = theme.TypeLight
|
||||
}
|
||||
|
||||
now := time.Now().Format("2006-01-02 15:04:05")
|
||||
|
||||
existing, err := ts.GetThemeByName(trimmed)
|
||||
existing, err := s.GetThemeByKey(ctx, trimmed)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
_, err = db.Exec(
|
||||
`INSERT INTO themes (name, type, colors, is_default, created_at, updated_at) VALUES (?, ?, ?, 0, ?, ?)`,
|
||||
trimmed,
|
||||
themeType,
|
||||
colors,
|
||||
now,
|
||||
now,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert theme: %w", err)
|
||||
}
|
||||
return nil
|
||||
// 插入新主题
|
||||
_, err = s.db.Client.Theme.Create().
|
||||
SetKey(trimmed).
|
||||
SetType(themeType).
|
||||
SetColors(colors).
|
||||
Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.Exec(
|
||||
`UPDATE themes SET type = ?, colors = ?, updated_at = ? WHERE name = ?`,
|
||||
themeType,
|
||||
colors,
|
||||
now,
|
||||
trimmed,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update theme: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
// 更新现有主题
|
||||
return s.db.Client.Theme.UpdateOneID(existing.ID).
|
||||
SetType(themeType).
|
||||
SetColors(colors).
|
||||
Exec(ctx)
|
||||
}
|
||||
|
||||
// ResetTheme 删除指定主题的覆盖配置
|
||||
func (ts *ThemeService) ResetTheme(name string) error {
|
||||
db := ts.getDB()
|
||||
if db == nil {
|
||||
return fmt.Errorf("database not available")
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(name)
|
||||
// ResetTheme 删除主题
|
||||
func (s *ThemeService) ResetTheme(ctx context.Context, key string) error {
|
||||
trimmed := strings.TrimSpace(key)
|
||||
if trimmed == "" {
|
||||
return fmt.Errorf("theme name cannot be empty")
|
||||
return fmt.Errorf("theme key cannot be empty")
|
||||
}
|
||||
|
||||
if _, err := db.Exec(`DELETE FROM themes WHERE name = ?`, trimmed); err != nil {
|
||||
return fmt.Errorf("failed to reset theme: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ServiceShutdown 服务关闭
|
||||
func (ts *ThemeService) ServiceShutdown() error {
|
||||
return nil
|
||||
_, err := s.db.Client.Theme.Delete().
|
||||
Where(theme.Key(trimmed)).
|
||||
Exec(mixin.SkipSoftDelete(ctx))
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ func (ws *WindowService) OpenDocumentWindow(documentID int64) error {
|
||||
}
|
||||
|
||||
// 获取文档信息
|
||||
doc, err := ws.documentService.GetDocumentByID(documentID)
|
||||
doc, err := ws.documentService.GetDocumentByID(context.Background(), int(documentID))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get document: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"github.com/wailsapp/wails/v3/pkg/application"
|
||||
"github.com/wailsapp/wails/v3/pkg/events"
|
||||
"github.com/wailsapp/wails/v3/pkg/services/log"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
"voidraft/internal/common/helper"
|
||||
"voidraft/internal/models"
|
||||
|
||||
"github.com/wailsapp/wails/v3/pkg/application"
|
||||
"github.com/wailsapp/wails/v3/pkg/events"
|
||||
"github.com/wailsapp/wails/v3/pkg/services/log"
|
||||
)
|
||||
|
||||
// 防抖和检测常量
|
||||
@@ -22,6 +20,43 @@ const (
|
||||
dragDetectionThreshold = 40 * time.Millisecond
|
||||
)
|
||||
|
||||
// SnapEdge 表示吸附的边缘类型
|
||||
type SnapEdge int
|
||||
|
||||
const (
|
||||
SnapEdgeNone SnapEdge = iota // 未吸附
|
||||
SnapEdgeTop // 吸附到上边缘
|
||||
SnapEdgeRight // 吸附到右边缘
|
||||
SnapEdgeBottom // 吸附到下边缘
|
||||
SnapEdgeLeft // 吸附到左边缘
|
||||
SnapEdgeTopRight // 吸附到右上角
|
||||
SnapEdgeBottomRight // 吸附到右下角
|
||||
SnapEdgeBottomLeft // 吸附到左下角
|
||||
SnapEdgeTopLeft // 吸附到左上角
|
||||
)
|
||||
|
||||
// WindowPosition 窗口位置
|
||||
type WindowPosition struct {
|
||||
X int `json:"x"` // X坐标
|
||||
Y int `json:"y"` // Y坐标
|
||||
}
|
||||
|
||||
// SnapPosition 表示吸附的相对位置
|
||||
type SnapPosition struct {
|
||||
X int `json:"x"` // X轴相对偏移
|
||||
Y int `json:"y"` // Y轴相对偏移
|
||||
}
|
||||
|
||||
// WindowInfo 窗口信息
|
||||
type WindowInfo struct {
|
||||
DocumentID int64 `json:"documentID"` // 文档ID
|
||||
IsSnapped bool `json:"isSnapped"` // 是否处于吸附状态
|
||||
SnapOffset SnapPosition `json:"snapOffset"` // 与主窗口的相对位置偏移
|
||||
SnapEdge SnapEdge `json:"snapEdge"` // 吸附的边缘类型
|
||||
LastPos WindowPosition `json:"lastPos"` // 上一次记录的窗口位置
|
||||
MoveTime time.Time `json:"moveTime"` // 上次移动时间,用于判断移动速度
|
||||
}
|
||||
|
||||
// WindowSnapService 窗口吸附服务
|
||||
type WindowSnapService struct {
|
||||
logger *log.LogService
|
||||
@@ -38,11 +73,11 @@ type WindowSnapService struct {
|
||||
maxThreshold int // 最大阈值(像素)
|
||||
|
||||
// 位置缓存
|
||||
lastMainWindowPos models.WindowPosition // 缓存主窗口位置
|
||||
lastMainWindowSize [2]int // 缓存主窗口尺寸 [width, height]
|
||||
lastMainWindowPos WindowPosition // 缓存主窗口位置
|
||||
lastMainWindowSize [2]int // 缓存主窗口尺寸 [width, height]
|
||||
|
||||
// 管理的窗口
|
||||
managedWindows map[int64]*models.WindowInfo // documentID -> WindowInfo
|
||||
managedWindows map[int64]*WindowInfo // documentID -> WindowInfo
|
||||
windowRefs map[int64]*application.WebviewWindow // documentID -> Window引用
|
||||
|
||||
// 窗口尺寸缓存
|
||||
@@ -81,7 +116,7 @@ func NewWindowSnapService(logger *log.LogService, configService *ConfigService)
|
||||
baseThresholdRatio: 0.025, // 2.5%的主窗口宽度作为基础阈值
|
||||
minThreshold: 8, // 最小8像素(小屏幕保底)
|
||||
maxThreshold: 40, // 最大40像素(大屏幕上限)
|
||||
managedWindows: make(map[int64]*models.WindowInfo),
|
||||
managedWindows: make(map[int64]*WindowInfo),
|
||||
windowRefs: make(map[int64]*application.WebviewWindow),
|
||||
windowSizeCache: make(map[int64][2]int),
|
||||
isUpdatingPosition: make(map[int64]bool),
|
||||
@@ -114,12 +149,12 @@ func (wss *WindowSnapService) RegisterWindow(documentID int64, window *applicati
|
||||
// 获取初始位置
|
||||
x, y := window.Position()
|
||||
|
||||
windowInfo := &models.WindowInfo{
|
||||
windowInfo := &WindowInfo{
|
||||
DocumentID: documentID,
|
||||
IsSnapped: false,
|
||||
SnapOffset: models.SnapPosition{X: 0, Y: 0},
|
||||
SnapEdge: models.SnapEdgeNone,
|
||||
LastPos: models.WindowPosition{X: x, Y: y},
|
||||
SnapOffset: SnapPosition{X: 0, Y: 0},
|
||||
SnapEdge: SnapEdgeNone,
|
||||
LastPos: WindowPosition{X: x, Y: y},
|
||||
MoveTime: time.Now(),
|
||||
}
|
||||
|
||||
@@ -176,7 +211,7 @@ func (wss *WindowSnapService) SetSnapEnabled(enabled bool) {
|
||||
for _, windowInfo := range wss.managedWindows {
|
||||
if windowInfo.IsSnapped {
|
||||
windowInfo.IsSnapped = false
|
||||
windowInfo.SnapEdge = models.SnapEdgeNone
|
||||
windowInfo.SnapEdge = SnapEdgeNone
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -251,7 +286,7 @@ func (wss *WindowSnapService) cleanupMainWindowEvents() {
|
||||
}
|
||||
|
||||
// setupWindowEvents 为子窗口设置事件监听
|
||||
func (wss *WindowSnapService) setupWindowEvents(window *application.WebviewWindow, windowInfo *models.WindowInfo) {
|
||||
func (wss *WindowSnapService) setupWindowEvents(window *application.WebviewWindow, windowInfo *WindowInfo) {
|
||||
// 监听子窗口移动事件,保存清理函数
|
||||
unhook := window.RegisterHook(events.Common.WindowDidMove, func(event *application.WindowEvent) {
|
||||
wss.onChildWindowMoved(window, windowInfo)
|
||||
@@ -274,7 +309,7 @@ func (wss *WindowSnapService) updateMainWindowCacheLocked() {
|
||||
w, h := mainWindow.Size()
|
||||
wss.mu.Lock()
|
||||
|
||||
wss.lastMainWindowPos = models.WindowPosition{X: x, Y: y}
|
||||
wss.lastMainWindowPos = WindowPosition{X: x, Y: y}
|
||||
wss.lastMainWindowSize = [2]int{w, h}
|
||||
}
|
||||
|
||||
@@ -335,7 +370,7 @@ func (wss *WindowSnapService) onMainWindowMoved() {
|
||||
defer wss.mu.Unlock()
|
||||
|
||||
// 更新主窗口位置和尺寸缓存
|
||||
wss.lastMainWindowPos = models.WindowPosition{X: x, Y: y}
|
||||
wss.lastMainWindowPos = WindowPosition{X: x, Y: y}
|
||||
wss.lastMainWindowSize = [2]int{w, h}
|
||||
|
||||
// 只更新已吸附窗口的位置,无需重新检测所有窗口
|
||||
@@ -347,7 +382,7 @@ func (wss *WindowSnapService) onMainWindowMoved() {
|
||||
}
|
||||
|
||||
// onChildWindowMoved 子窗口移动事件处理
|
||||
func (wss *WindowSnapService) onChildWindowMoved(window *application.WebviewWindow, windowInfo *models.WindowInfo) {
|
||||
func (wss *WindowSnapService) onChildWindowMoved(window *application.WebviewWindow, windowInfo *WindowInfo) {
|
||||
if !wss.snapEnabled {
|
||||
return
|
||||
}
|
||||
@@ -361,7 +396,7 @@ func (wss *WindowSnapService) onChildWindowMoved(window *application.WebviewWind
|
||||
wss.mu.Unlock()
|
||||
|
||||
x, y := window.Position()
|
||||
currentPos := models.WindowPosition{X: x, Y: y}
|
||||
currentPos := WindowPosition{X: x, Y: y}
|
||||
|
||||
wss.mu.Lock()
|
||||
defer wss.mu.Unlock()
|
||||
@@ -392,7 +427,7 @@ func (wss *WindowSnapService) onChildWindowMoved(window *application.WebviewWind
|
||||
}
|
||||
|
||||
// updateSnappedWindowPosition 更新已吸附窗口的位置
|
||||
func (wss *WindowSnapService) updateSnappedWindowPosition(windowInfo *models.WindowInfo) {
|
||||
func (wss *WindowSnapService) updateSnappedWindowPosition(windowInfo *WindowInfo) {
|
||||
// 计算新的目标位置(基于主窗口新位置)
|
||||
expectedX := wss.lastMainWindowPos.X + windowInfo.SnapOffset.X
|
||||
expectedY := wss.lastMainWindowPos.Y + windowInfo.SnapOffset.Y
|
||||
@@ -413,11 +448,11 @@ func (wss *WindowSnapService) updateSnappedWindowPosition(windowInfo *models.Win
|
||||
// 清除更新标志
|
||||
wss.isUpdatingPosition[windowInfo.DocumentID] = false
|
||||
|
||||
windowInfo.LastPos = models.WindowPosition{X: expectedX, Y: expectedY}
|
||||
windowInfo.LastPos = WindowPosition{X: expectedX, Y: expectedY}
|
||||
}
|
||||
|
||||
// handleSnappedWindow 处理已吸附窗口的移动
|
||||
func (wss *WindowSnapService) handleSnappedWindow(window *application.WebviewWindow, windowInfo *models.WindowInfo, currentPos models.WindowPosition) {
|
||||
func (wss *WindowSnapService) handleSnappedWindow(window *application.WebviewWindow, windowInfo *WindowInfo, currentPos WindowPosition) {
|
||||
// 计算预期位置
|
||||
expectedX := wss.lastMainWindowPos.X + windowInfo.SnapOffset.X
|
||||
expectedY := wss.lastMainWindowPos.Y + windowInfo.SnapOffset.Y
|
||||
@@ -434,12 +469,12 @@ func (wss *WindowSnapService) handleSnappedWindow(window *application.WebviewWin
|
||||
if isUserDrag {
|
||||
// 用户主动拖拽,解除吸附
|
||||
windowInfo.IsSnapped = false
|
||||
windowInfo.SnapEdge = models.SnapEdgeNone
|
||||
windowInfo.SnapEdge = SnapEdgeNone
|
||||
}
|
||||
}
|
||||
|
||||
// handleUnsnappedWindow 处理未吸附窗口的移动,返回是否成功吸附
|
||||
func (wss *WindowSnapService) handleUnsnappedWindow(window *application.WebviewWindow, windowInfo *models.WindowInfo, currentPos models.WindowPosition, lastMoveTime time.Time) bool {
|
||||
func (wss *WindowSnapService) handleUnsnappedWindow(window *application.WebviewWindow, windowInfo *WindowInfo, currentPos WindowPosition, lastMoveTime time.Time) bool {
|
||||
// 检查是否应该吸附
|
||||
should, snapEdge := wss.shouldSnapToMainWindow(window, windowInfo, currentPos, lastMoveTime)
|
||||
if should {
|
||||
@@ -474,11 +509,11 @@ func (wss *WindowSnapService) handleUnsnappedWindow(window *application.WebviewW
|
||||
}
|
||||
|
||||
// shouldSnapToMainWindow 吸附检测
|
||||
func (wss *WindowSnapService) shouldSnapToMainWindow(window *application.WebviewWindow, windowInfo *models.WindowInfo, currentPos models.WindowPosition, lastMoveTime time.Time) (bool, models.SnapEdge) {
|
||||
func (wss *WindowSnapService) shouldSnapToMainWindow(window *application.WebviewWindow, windowInfo *WindowInfo, currentPos WindowPosition, lastMoveTime time.Time) (bool, SnapEdge) {
|
||||
// 防抖:移动太快时不检测(使用统一的防抖阈值)
|
||||
timeSinceLastMove := time.Since(lastMoveTime)
|
||||
if timeSinceLastMove < debounceThreshold {
|
||||
return false, models.SnapEdgeNone
|
||||
return false, SnapEdgeNone
|
||||
}
|
||||
|
||||
// 使用缓存的主窗口位置和尺寸
|
||||
@@ -507,7 +542,7 @@ func (wss *WindowSnapService) shouldSnapToMainWindow(window *application.Webview
|
||||
|
||||
// 简化的距离计算结构
|
||||
type snapCheck struct {
|
||||
edge models.SnapEdge
|
||||
edge SnapEdge
|
||||
distance float64
|
||||
priority int // 1=角落, 2=边缘
|
||||
}
|
||||
@@ -516,14 +551,14 @@ func (wss *WindowSnapService) shouldSnapToMainWindow(window *application.Webview
|
||||
|
||||
// 检查角落吸附(优先级1)
|
||||
cornerChecks := []struct {
|
||||
edge models.SnapEdge
|
||||
edge SnapEdge
|
||||
dx int
|
||||
dy int
|
||||
}{
|
||||
{models.SnapEdgeTopRight, mainRight - windowLeft, mainTop - windowBottom},
|
||||
{models.SnapEdgeBottomRight, mainRight - windowLeft, mainBottom - windowTop},
|
||||
{models.SnapEdgeBottomLeft, mainLeft - windowRight, mainBottom - windowTop},
|
||||
{models.SnapEdgeTopLeft, mainLeft - windowRight, mainTop - windowBottom},
|
||||
{SnapEdgeTopRight, mainRight - windowLeft, mainTop - windowBottom},
|
||||
{SnapEdgeBottomRight, mainRight - windowLeft, mainBottom - windowTop},
|
||||
{SnapEdgeBottomLeft, mainLeft - windowRight, mainBottom - windowTop},
|
||||
{SnapEdgeTopLeft, mainLeft - windowRight, mainTop - windowBottom},
|
||||
}
|
||||
|
||||
for _, check := range cornerChecks {
|
||||
@@ -538,13 +573,13 @@ func (wss *WindowSnapService) shouldSnapToMainWindow(window *application.Webview
|
||||
// 如果没有角落吸附,检查边缘吸附(优先级2)
|
||||
if bestSnap == nil {
|
||||
edgeChecks := []struct {
|
||||
edge models.SnapEdge
|
||||
edge SnapEdge
|
||||
distance float64
|
||||
}{
|
||||
{models.SnapEdgeRight, math.Abs(float64(mainRight - windowLeft))},
|
||||
{models.SnapEdgeLeft, math.Abs(float64(mainLeft - windowRight))},
|
||||
{models.SnapEdgeBottom, math.Abs(float64(mainBottom - windowTop))},
|
||||
{models.SnapEdgeTop, math.Abs(float64(mainTop - windowBottom))},
|
||||
{SnapEdgeRight, math.Abs(float64(mainRight - windowLeft))},
|
||||
{SnapEdgeLeft, math.Abs(float64(mainLeft - windowRight))},
|
||||
{SnapEdgeBottom, math.Abs(float64(mainBottom - windowTop))},
|
||||
{SnapEdgeTop, math.Abs(float64(mainTop - windowBottom))},
|
||||
}
|
||||
|
||||
for _, check := range edgeChecks {
|
||||
@@ -557,14 +592,14 @@ func (wss *WindowSnapService) shouldSnapToMainWindow(window *application.Webview
|
||||
}
|
||||
|
||||
if bestSnap == nil {
|
||||
return false, models.SnapEdgeNone
|
||||
return false, SnapEdgeNone
|
||||
}
|
||||
|
||||
return true, bestSnap.edge
|
||||
}
|
||||
|
||||
// calculateSnapPosition 计算吸附目标位置
|
||||
func (wss *WindowSnapService) calculateSnapPosition(snapEdge models.SnapEdge, currentPos models.WindowPosition, documentID int64, window *application.WebviewWindow) models.WindowPosition {
|
||||
func (wss *WindowSnapService) calculateSnapPosition(snapEdge SnapEdge, currentPos WindowPosition, documentID int64, window *application.WebviewWindow) WindowPosition {
|
||||
// 使用缓存的主窗口信息
|
||||
mainPos := wss.lastMainWindowPos
|
||||
mainWidth := wss.lastMainWindowSize[0]
|
||||
@@ -574,43 +609,43 @@ func (wss *WindowSnapService) calculateSnapPosition(snapEdge models.SnapEdge, cu
|
||||
windowWidth, windowHeight := wss.getWindowSizeCached(documentID, window)
|
||||
|
||||
switch snapEdge {
|
||||
case models.SnapEdgeRight:
|
||||
return models.WindowPosition{
|
||||
case SnapEdgeRight:
|
||||
return WindowPosition{
|
||||
X: mainPos.X + mainWidth,
|
||||
Y: currentPos.Y, // 保持当前Y位置
|
||||
}
|
||||
case models.SnapEdgeLeft:
|
||||
return models.WindowPosition{
|
||||
case SnapEdgeLeft:
|
||||
return WindowPosition{
|
||||
X: mainPos.X - windowWidth,
|
||||
Y: currentPos.Y,
|
||||
}
|
||||
case models.SnapEdgeBottom:
|
||||
return models.WindowPosition{
|
||||
case SnapEdgeBottom:
|
||||
return WindowPosition{
|
||||
X: currentPos.X,
|
||||
Y: mainPos.Y + mainHeight,
|
||||
}
|
||||
case models.SnapEdgeTop:
|
||||
return models.WindowPosition{
|
||||
case SnapEdgeTop:
|
||||
return WindowPosition{
|
||||
X: currentPos.X,
|
||||
Y: mainPos.Y - windowHeight,
|
||||
}
|
||||
case models.SnapEdgeTopRight:
|
||||
return models.WindowPosition{
|
||||
case SnapEdgeTopRight:
|
||||
return WindowPosition{
|
||||
X: mainPos.X + mainWidth,
|
||||
Y: mainPos.Y - windowHeight,
|
||||
}
|
||||
case models.SnapEdgeBottomRight:
|
||||
return models.WindowPosition{
|
||||
case SnapEdgeBottomRight:
|
||||
return WindowPosition{
|
||||
X: mainPos.X + mainWidth,
|
||||
Y: mainPos.Y + mainHeight,
|
||||
}
|
||||
case models.SnapEdgeBottomLeft:
|
||||
return models.WindowPosition{
|
||||
case SnapEdgeBottomLeft:
|
||||
return WindowPosition{
|
||||
X: mainPos.X - windowWidth,
|
||||
Y: mainPos.Y + mainHeight,
|
||||
}
|
||||
case models.SnapEdgeTopLeft:
|
||||
return models.WindowPosition{
|
||||
case SnapEdgeTopLeft:
|
||||
return WindowPosition{
|
||||
X: mainPos.X - windowWidth,
|
||||
Y: mainPos.Y - windowHeight,
|
||||
}
|
||||
@@ -636,7 +671,7 @@ func (wss *WindowSnapService) Cleanup() {
|
||||
}
|
||||
|
||||
// 清空管理的窗口
|
||||
wss.managedWindows = make(map[int64]*models.WindowInfo)
|
||||
wss.managedWindows = make(map[int64]*WindowInfo)
|
||||
wss.windowRefs = make(map[int64]*application.WebviewWindow)
|
||||
wss.windowSizeCache = make(map[int64][2]int)
|
||||
wss.isUpdatingPosition = make(map[int64]bool)
|
||||
|
||||
Reference in New Issue
Block a user