🚧 Refactor basic services

This commit is contained in:
2025-12-14 02:19:50 +08:00
parent d16905c0a3
commit cc4c2189dc
126 changed files with 18164 additions and 4247 deletions

View File

@@ -20,8 +20,6 @@ import (
"github.com/wailsapp/wails/v3/pkg/services/log"
"voidraft/internal/models"
_ "modernc.org/sqlite"
)
const (
@@ -223,20 +221,20 @@ func (s *BackupService) getAuthMethod(config *models.GitBackupConfig) (transport
// serializeDatabase 序列化数据库到文件
func (s *BackupService) serializeDatabase(repoPath string) error {
if s.dbService == nil || s.dbService.db == nil {
return errors.New("database service not available")
}
binFilePath := filepath.Join(repoPath, dbSerializeFile)
// 使用 VACUUM INTO 创建数据库副本,不影响现有连接
s.dbService.mu.RLock()
_, err := s.dbService.db.Exec(fmt.Sprintf("VACUUM INTO '%s'", binFilePath))
s.dbService.mu.RUnlock()
if err != nil {
return fmt.Errorf("creating database backup: %w", err)
}
//if s.dbService == nil || s.dbService.Engine == nil {
// return errors.New("database service not available")
//}
//
//binFilePath := filepath.Join(repoPath, dbSerializeFile)
//
//// 使用 VACUUM INTO 创建数据库副本,不影响现有连接
//s.dbService.mu.RLock()
//_, err := s.dbService.db.Exec(fmt.Sprintf("VACUUM INTO '%s'", binFilePath))
//s.dbService.mu.RUnlock()
//
//if err != nil {
// return fmt.Errorf("creating database backup: %w", err)
//}
return nil
}

View File

@@ -12,32 +12,23 @@ import (
"sync/atomic"
"time"
"github.com/wailsapp/wails/v3/pkg/application"
"github.com/wailsapp/wails/v3/pkg/services/log"
)
// MigrationStatus 迁移状态
type MigrationStatus string
const (
MigrationStatusMigrating MigrationStatus = "migrating"
MigrationStatusCompleted MigrationStatus = "completed"
MigrationStatusFailed MigrationStatus = "failed"
)
// MigrationProgress 迁移进度信息
type MigrationProgress struct {
Status MigrationStatus `json:"status"`
Progress float64 `json:"progress"`
Error string `json:"error,omitempty"`
Progress float64 `json:"progress"` // 0-100
Error string `json:"error,omitempty"`
}
// MigrationService 迁移服务
type MigrationService struct {
logger *log.LogService
dbService *DatabaseService
mu sync.RWMutex
progress atomic.Value // stores MigrationProgress
mu sync.Mutex
ctx context.Context
cancel context.CancelFunc
}
@@ -47,18 +38,11 @@ func NewMigrationService(dbService *DatabaseService, logger *log.LogService) *Mi
if logger == nil {
logger = log.New()
}
ms := &MigrationService{
logger: logger,
dbService: dbService,
}
// 初始化进度
ms.progress.Store(MigrationProgress{
Status: MigrationStatusCompleted,
Progress: 0,
})
ms.progress.Store(MigrationProgress{})
return ms
}
@@ -67,9 +51,15 @@ func (ms *MigrationService) GetProgress() MigrationProgress {
return ms.progress.Load().(MigrationProgress)
}
// updateProgress 更新进度
func (ms *MigrationService) updateProgress(progress MigrationProgress) {
ms.progress.Store(progress)
// setProgress 设置进度
func (ms *MigrationService) setProgress(progress float64) {
ms.progress.Store(MigrationProgress{Progress: progress})
}
// fail 标记失败并返回错误
func (ms *MigrationService) fail(err error) error {
ms.progress.Store(MigrationProgress{Error: err.Error()})
return err
}
// MigrateDirectory 迁移目录
@@ -77,118 +67,84 @@ func (ms *MigrationService) MigrateDirectory(srcPath, dstPath string) error {
// 创建可取消的上下文
ctx, cancel := context.WithCancel(context.Background())
ms.mu.Lock()
ms.ctx = ctx
ms.cancel = cancel
ms.ctx, ms.cancel = ctx, cancel
ms.mu.Unlock()
defer func() {
ms.mu.Lock()
ms.cancel = nil
ms.ctx = nil
ms.ctx, ms.cancel = nil, nil
ms.mu.Unlock()
}()
// 初始化进度
ms.updateProgress(MigrationProgress{
Status: MigrationStatusMigrating,
Progress: 0,
})
ms.setProgress(0)
// 预检查
if err := ms.preCheck(srcPath, dstPath); err != nil {
if err == errNoMigrationNeeded {
ms.updateProgress(MigrationProgress{
Status: MigrationStatusCompleted,
Progress: 100,
})
return nil
}
return ms.failWithError(err)
needMigrate, err := ms.preCheck(srcPath, dstPath)
if err != nil {
return ms.fail(err)
}
if !needMigrate {
ms.setProgress(100)
return nil
}
// 迁移前断开数据库连接
ms.updateProgress(MigrationProgress{
Status: MigrationStatusMigrating,
Progress: 10,
})
ms.setProgress(10)
if ms.dbService != nil {
if err := ms.dbService.ServiceShutdown(); err != nil {
ms.logger.Error("Failed to close database connection", "error", err)
}
}
// 确保失败时恢复数据库连接
defer func() {
if ms.dbService != nil {
if err := ms.dbService.ServiceStartup(ctx, application.ServiceOptions{}); err != nil {
ms.logger.Error("Failed to reconnect database", "error", err)
}
}
}()
// 执行原子迁移
if err := ms.atomicMove(ctx, srcPath, dstPath); err != nil {
return ms.failWithError(err)
return ms.fail(err)
}
// 迁移完成后重新连接数据库
ms.updateProgress(MigrationProgress{
Status: MigrationStatusMigrating,
Progress: 95,
})
if ms.dbService != nil {
if err := ms.dbService.initDatabase(); err != nil {
return ms.failWithError(fmt.Errorf("failed to reconnect database: %v", err))
}
}
// 迁移完成
ms.updateProgress(MigrationProgress{
Status: MigrationStatusCompleted,
Progress: 100,
})
ms.setProgress(100)
return nil
}
var errNoMigrationNeeded = fmt.Errorf("no migration needed")
// preCheck 预检查
func (ms *MigrationService) preCheck(srcPath, dstPath string) error {
// 检查源目录是否存在
// preCheck 预检查,返回是否需要迁移
func (ms *MigrationService) preCheck(srcPath, dstPath string) (bool, error) {
// 源目录不存在,无需迁移
if _, err := os.Stat(srcPath); os.IsNotExist(err) {
return errNoMigrationNeeded
return false, nil
}
// 如果路径相同,不需要迁移
// 路径相同,无需迁移
srcAbs, _ := filepath.Abs(srcPath)
dstAbs, _ := filepath.Abs(dstPath)
if srcAbs == dstAbs {
return errNoMigrationNeeded
return false, nil
}
// 检查目标路径是否是源路径的子目录
if ms.isSubDirectory(srcAbs, dstAbs) {
return fmt.Errorf("target path cannot be a subdirectory of source path")
// 目标不能是源的子目录
if isSubDir(srcAbs, dstAbs) {
return false, fmt.Errorf("target path cannot be a subdirectory of source path")
}
return nil
}
// failWithError 失败并记录错误
func (ms *MigrationService) failWithError(err error) error {
ms.updateProgress(MigrationProgress{
Status: MigrationStatusFailed,
Error: err.Error(),
})
return err
return true, nil
}
// atomicMove 原子移动目录
func (ms *MigrationService) atomicMove(ctx context.Context, srcPath, dstPath string) error {
// 检查是否取消
select {
case <-ctx.Done():
return ctx.Err()
default:
if err := ctx.Err(); err != nil {
return err
}
// 确保目标目录的父目录存在
// 确保目标父目录存在
if err := os.MkdirAll(filepath.Dir(dstPath), 0755); err != nil {
return fmt.Errorf("failed to create target parent directory: %v", err)
return fmt.Errorf("failed to create target parent directory: %w", err)
}
// 检查目标路径
@@ -196,132 +152,100 @@ func (ms *MigrationService) atomicMove(ctx context.Context, srcPath, dstPath str
return err
}
// 尝试直接重命名
ms.updateProgress(MigrationProgress{
Status: MigrationStatusMigrating,
Progress: 20,
})
ms.setProgress(20)
// 尝试直接重命名(同一文件系统时最快)
if err := os.Rename(srcPath, dstPath); err == nil {
// 重命名成功更新进度到90%
ms.updateProgress(MigrationProgress{
Status: MigrationStatusMigrating,
Progress: 90,
})
ms.setProgress(90)
ms.logger.Info("Directory migration completed using direct rename", "src", srcPath, "dst", dstPath)
return nil
}
// 重命名失败,使用压缩迁移
// 重命名失败(跨文件系统),使用压缩迁移
ms.logger.Info("Direct rename failed, using compress migration", "src", srcPath, "dst", dstPath)
ms.updateProgress(MigrationProgress{
Status: MigrationStatusMigrating,
Progress: 30,
})
ms.setProgress(30)
return ms.compressMove(ctx, srcPath, dstPath)
}
// checkTargetPath 检查目标路径
// checkTargetPath 检查目标路径是否可用
func (ms *MigrationService) checkTargetPath(dstPath string) error {
stat, err := os.Stat(dstPath)
if os.IsNotExist(err) {
return nil
}
if err != nil {
return fmt.Errorf("failed to check target path: %v", err)
return fmt.Errorf("failed to check target path: %w", err)
}
if !stat.IsDir() {
return fmt.Errorf("target path exists but is not a directory")
}
isEmpty, err := ms.isDirectoryEmpty(dstPath)
isEmpty, err := isDirEmpty(dstPath)
if err != nil {
return fmt.Errorf("failed to check target directory: %v", err)
return fmt.Errorf("failed to check target directory: %w", err)
}
if !isEmpty {
return fmt.Errorf("target directory is not empty")
}
return nil
}
// compressMove 压缩迁移
func (ms *MigrationService) compressMove(ctx context.Context, srcPath, dstPath string) error {
tempZipFile := filepath.Join(os.TempDir(),
fmt.Sprintf("voidraft_migration_%d.zip", time.Now().UnixNano()))
defer os.Remove(tempZipFile)
tempZip := filepath.Join(os.TempDir(), fmt.Sprintf("voidraft_migration_%d.zip", time.Now().UnixNano()))
defer os.Remove(tempZip)
// 压缩源目录
ms.updateProgress(MigrationProgress{
Status: MigrationStatusMigrating,
Progress: 40,
})
if err := ms.compressDirectory(ctx, srcPath, tempZipFile); err != nil {
return fmt.Errorf("failed to compress source directory: %v", err)
ms.setProgress(40)
if err := ms.compressDir(ctx, srcPath, tempZip); err != nil {
return fmt.Errorf("failed to compress source directory: %w", err)
}
// 解压到目标位置
ms.updateProgress(MigrationProgress{
Status: MigrationStatusMigrating,
Progress: 70,
})
if err := ms.extractToDirectory(ctx, tempZipFile, dstPath); err != nil {
return fmt.Errorf("failed to extract to target location: %v", err)
ms.setProgress(70)
if err := ms.extractZip(ctx, tempZip, dstPath); err != nil {
return fmt.Errorf("failed to extract to target location: %w", err)
}
// 检查是否取消
select {
case <-ctx.Done():
// 检查取消
if err := ctx.Err(); err != nil {
os.RemoveAll(dstPath)
return ctx.Err()
default:
return err
}
// 验证迁移是否成功
// 验证迁移结果
if err := ms.verifyMigration(dstPath); err != nil {
// 迁移验证失败,清理目标目录
os.RemoveAll(dstPath)
return fmt.Errorf("migration verification failed: %v", err)
return fmt.Errorf("migration verification failed: %w", err)
}
// 删除源目录
ms.updateProgress(MigrationProgress{
Status: MigrationStatusMigrating,
Progress: 90,
})
ms.setProgress(90)
os.RemoveAll(srcPath)
return nil
}
// compressDirectory 压缩目录到zip文件
func (ms *MigrationService) compressDirectory(ctx context.Context, srcDir, zipFile string) error {
zipWriter, err := os.Create(zipFile)
// compressDir 压缩目录到zip文件
func (ms *MigrationService) compressDir(ctx context.Context, srcDir, zipPath string) error {
zipFile, err := os.Create(zipPath)
if err != nil {
return err
}
defer zipWriter.Close()
defer zipFile.Close()
zw := zip.NewWriter(zipWriter)
zw := zip.NewWriter(zipFile)
defer zw.Close()
return filepath.Walk(srcDir, func(filePath string, info os.FileInfo, err error) error {
return filepath.Walk(srcDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// 检查是否取消
select {
case <-ctx.Done():
return ctx.Err()
default:
if err := ctx.Err(); err != nil {
return err
}
relPath, err := filepath.Rel(srcDir, filePath)
relPath, err := filepath.Rel(srcDir, path)
if err != nil || relPath == "." {
return err
}
@@ -345,27 +269,21 @@ func (ms *MigrationService) compressDirectory(ctx context.Context, srcDir, zipFi
}
if !info.IsDir() {
return ms.copyFileToZip(filePath, writer)
file, err := os.Open(path)
if err != nil {
return err
}
defer file.Close()
_, err = io.Copy(writer, file)
return err
}
return nil
})
}
// copyFileToZip 复制文件到zip
func (ms *MigrationService) copyFileToZip(filePath string, writer io.Writer) error {
file, err := os.Open(filePath)
if err != nil {
return err
}
defer file.Close()
_, err = io.Copy(writer, file)
return err
}
// extractToDirectory 从zip文件解压到目录
func (ms *MigrationService) extractToDirectory(ctx context.Context, zipFile, dstDir string) error {
reader, err := zip.OpenReader(zipFile)
// extractZip 解压zip文件到目录
func (ms *MigrationService) extractZip(ctx context.Context, zipPath, dstDir string) error {
reader, err := zip.OpenReader(zipPath)
if err != nil {
return err
}
@@ -376,27 +294,22 @@ func (ms *MigrationService) extractToDirectory(ctx context.Context, zipFile, dst
}
for _, file := range reader.File {
// 检查是否取消
select {
case <-ctx.Done():
return ctx.Err()
default:
if err := ctx.Err(); err != nil {
return err
}
if err := ms.extractSingleFile(file, dstDir); err != nil {
if err := extractFile(file, dstDir); err != nil {
return err
}
}
return nil
}
// extractSingleFile 解压单个文件
func (ms *MigrationService) extractSingleFile(file *zip.File, dstDir string) error {
// extractFile 解压单个文件
func extractFile(file *zip.File, dstDir string) error {
dstPath := filepath.Join(dstDir, file.Name)
// 安全检查防止zip slip攻击
if !strings.HasPrefix(dstPath, filepath.Clean(dstDir)+string(os.PathSeparator)) {
if !strings.HasPrefix(filepath.Clean(dstPath), filepath.Clean(dstDir)+string(os.PathSeparator)) {
return fmt.Errorf("invalid file path in archive: %s", file.Name)
}
@@ -424,45 +337,23 @@ func (ms *MigrationService) extractSingleFile(file *zip.File, dstDir string) err
return err
}
// isDirectoryEmpty 检查目录是否为空
func (ms *MigrationService) isDirectoryEmpty(dirPath string) (bool, error) {
f, err := os.Open(dirPath)
if err != nil {
return false, err
}
defer f.Close()
_, err = f.Readdir(1)
return err == io.EOF, nil
}
// isSubDirectory 检查target是否是parent的子目录
func (ms *MigrationService) isSubDirectory(parent, target string) bool {
parent = filepath.Clean(parent) + string(filepath.Separator)
target = filepath.Clean(target) + string(filepath.Separator)
return len(target) > len(parent) && strings.HasPrefix(target, parent)
}
// verifyMigration 验证迁移是否成功
// verifyMigration 验证迁移结果
func (ms *MigrationService) verifyMigration(dstPath string) error {
// 检查目标目录是否存在
dstStat, err := os.Stat(dstPath)
stat, err := os.Stat(dstPath)
if err != nil {
return fmt.Errorf("target directory does not exist: %v", err)
return fmt.Errorf("target directory does not exist: %w", err)
}
if !dstStat.IsDir() {
if !stat.IsDir() {
return fmt.Errorf("target path is not a directory")
}
// 简单验证:检查目标目录是否非空
isEmpty, err := ms.isDirectoryEmpty(dstPath)
isEmpty, err := isDirEmpty(dstPath)
if err != nil {
return fmt.Errorf("failed to check target directory: %v", err)
return fmt.Errorf("failed to check target directory: %w", err)
}
if isEmpty {
return fmt.Errorf("target directory is empty after migration")
}
return nil
}
@@ -475,12 +366,30 @@ func (ms *MigrationService) CancelMigration() error {
ms.cancel()
return nil
}
return fmt.Errorf("no active migration to cancel")
}
// ServiceShutdown 服务关闭
func (ms *MigrationService) ServiceShutdown() error {
ms.CancelMigration()
_ = ms.CancelMigration()
return nil
}
// isDirEmpty 检查目录是否为空
func isDirEmpty(path string) (bool, error) {
f, err := os.Open(path)
if err != nil {
return false, err
}
defer f.Close()
_, err = f.Readdir(1)
return err == io.EOF, nil
}
// isSubDir 检查target是否是parent的子目录
func isSubDir(parent, target string) bool {
parent = filepath.Clean(parent) + string(filepath.Separator)
target = filepath.Clean(target) + string(filepath.Separator)
return len(target) > len(parent) && strings.HasPrefix(target, parent)
}

View File

@@ -6,376 +6,100 @@ import (
"fmt"
"os"
"path/filepath"
"reflect"
"sync"
"voidraft/internal/models"
"time"
"voidraft/internal/models/ent"
"voidraft/internal/models/ent/migrate"
_ "voidraft/internal/models/ent/runtime"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "github.com/mattn/go-sqlite3"
"github.com/wailsapp/wails/v3/pkg/application"
"github.com/wailsapp/wails/v3/pkg/services/log"
_ "modernc.org/sqlite"
)
const (
dbName = "voidraft.db"
// SQLite performance optimization settings
sqlOptimizationSettings = `
PRAGMA journal_mode = WAL;
PRAGMA synchronous = NORMAL;
PRAGMA cache_size = -64000;
PRAGMA temp_store = MEMORY;
PRAGMA foreign_keys = ON;`
// Documents table
sqlCreateDocumentsTable = `
CREATE TABLE IF NOT EXISTS documents (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT NOT NULL,
content TEXT DEFAULT '∞∞∞text-a',
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
is_deleted INTEGER DEFAULT 0,
is_locked INTEGER DEFAULT 0
)`
// Extensions table
sqlCreateExtensionsTable = `
CREATE TABLE IF NOT EXISTS extensions (
id TEXT PRIMARY KEY,
enabled INTEGER NOT NULL DEFAULT 1,
is_default INTEGER NOT NULL DEFAULT 0,
config TEXT DEFAULT '{}',
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
)`
// Key bindings table
sqlCreateKeyBindingsTable = `
CREATE TABLE IF NOT EXISTS key_bindings (
id INTEGER PRIMARY KEY AUTOINCREMENT,
command TEXT NOT NULL,
extension TEXT NOT NULL,
key TEXT NOT NULL,
enabled INTEGER NOT NULL DEFAULT 1,
is_default INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
)`
// Themes table
sqlCreateThemesTable = `
CREATE TABLE IF NOT EXISTS themes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
type TEXT NOT NULL,
colors TEXT NOT NULL,
is_default INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
)`
dbName = "voidraft.db"
maxIdleConns = 10
maxOpenConns = 100
connMaxLife = time.Hour
)
// ColumnInfo 存储列的信息
type ColumnInfo struct {
SQLType string
DefaultValue string
}
// TableModel 表示表与模型之间的映射关系
type TableModel struct {
TableName string
Model interface{}
}
// DatabaseService provides shared database functionality
// DatabaseService 数据库服务
type DatabaseService struct {
configService *ConfigService
logger *log.LogService
Client *ent.Client
db *sql.DB
mu sync.RWMutex
ctx context.Context
tableModels []TableModel // 注册的表模型
// 配置观察者取消函数
cancelObserver CancelFunc
}
// NewDatabaseService creates a new database service
// NewDatabaseService 创建数据库服务
func NewDatabaseService(configService *ConfigService, logger *log.LogService) *DatabaseService {
if logger == nil {
logger = log.New()
}
ds := &DatabaseService{
return &DatabaseService{
configService: configService,
logger: logger,
}
// 注册所有模型
ds.registerAllModels()
return ds
}
// registerAllModels 注册所有数据模型,集中管理表-模型映射
func (ds *DatabaseService) registerAllModels() {
// 文档表
ds.RegisterModel("documents", &models.Document{})
// 扩展表
ds.RegisterModel("extensions", &models.Extension{})
// 快捷键表
ds.RegisterModel("key_bindings", &models.KeyBinding{})
// 主题表
ds.RegisterModel("themes", &models.Theme{})
}
// ServiceStartup initializes the service when the application starts
func (ds *DatabaseService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error {
ds.ctx = ctx
return ds.initDatabase()
}
// initDatabase initializes the SQLite database
func (ds *DatabaseService) initDatabase() error {
dbPath, err := ds.getDatabasePath()
// ServiceStartup 服务启动
func (s *DatabaseService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error {
dbPath, err := s.getBasePath()
if err != nil {
return fmt.Errorf("failed to get database path: %w", err)
return fmt.Errorf("get database path error: %w", err)
}
// 确保数据库目录存在
dbDir := filepath.Dir(dbPath)
if err := os.MkdirAll(dbDir, 0755); err != nil {
return fmt.Errorf("failed to create database directory: %w", err)
if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
return fmt.Errorf("create database directory error: %w", err)
}
// 打开数据库连接
ds.db, err = sql.Open("sqlite", dbPath)
// _fk=1 启用外键_journal_mode=WAL 启用 WAL 模式
dsn := fmt.Sprintf("file:%s?_fk=1&_journal_mode=WAL&_synchronous=NORMAL&_busy_timeout=5000", dbPath)
s.db, err = sql.Open("sqlite3", dsn)
if err != nil {
return fmt.Errorf("failed to open database: %w", err)
return fmt.Errorf("open database error: %w", err)
}
// 测试连接
if err := ds.db.Ping(); err != nil {
return fmt.Errorf("failed to ping database: %w", err)
}
// 连接池配置
s.db.SetMaxIdleConns(maxIdleConns)
s.db.SetMaxOpenConns(maxOpenConns)
s.db.SetConnMaxLifetime(connMaxLife)
// 应用性能优化设置
if _, err := ds.db.Exec(sqlOptimizationSettings); err != nil {
return fmt.Errorf("failed to apply optimization settings: %w", err)
}
// 创建 ent 客户端
drv := entsql.OpenDB(dialect.SQLite, s.db)
s.Client = ent.NewClient(ent.Driver(drv))
// 创建表和索引
if err := ds.createTables(); err != nil {
return fmt.Errorf("failed to create tables: %w", err)
}
if err := ds.createIndexes(); err != nil {
return fmt.Errorf("failed to create indexes: %w", err)
}
// 执行模型与表结构同步
if err := ds.syncAllModelTables(); err != nil {
return fmt.Errorf("failed to sync model tables: %w", err)
// 自动迁移
if err := s.Client.Schema.Create(ctx,
migrate.WithDropColumn(true),
migrate.WithDropIndex(true),
); err != nil {
return fmt.Errorf("schema migration error: %w", err)
}
return nil
}
// getDatabasePath gets the database file path
func (ds *DatabaseService) getDatabasePath() (string, error) {
config, err := ds.configService.GetConfig()
func (s *DatabaseService) getBasePath() (string, error) {
config, err := s.configService.GetConfig()
if err != nil {
return "", err
}
return filepath.Join(config.General.DataPath, dbName), nil
}
// createTables creates all database tables
func (ds *DatabaseService) createTables() error {
tables := []string{
sqlCreateDocumentsTable,
sqlCreateExtensionsTable,
sqlCreateKeyBindingsTable,
sqlCreateThemesTable,
}
for _, table := range tables {
if _, err := ds.db.Exec(table); err != nil {
// ServiceShutdown 服务关闭
func (s *DatabaseService) ServiceShutdown() error {
if s.Client != nil {
if err := s.Client.Close(); err != nil {
return err
}
}
return nil
}
// createIndexes creates database indexes
func (ds *DatabaseService) createIndexes() error {
indexes := []string{
// Documents indexes
`CREATE INDEX IF NOT EXISTS idx_documents_updated_at ON documents(updated_at DESC)`,
`CREATE INDEX IF NOT EXISTS idx_documents_title ON documents(title)`,
`CREATE INDEX IF NOT EXISTS idx_documents_is_deleted ON documents(is_deleted)`,
// Extensions indexes
`CREATE INDEX IF NOT EXISTS idx_extensions_enabled ON extensions(enabled)`,
// Key bindings indexes
`CREATE INDEX IF NOT EXISTS idx_key_bindings_command ON key_bindings(command)`,
`CREATE INDEX IF NOT EXISTS idx_key_bindings_extension ON key_bindings(extension)`,
`CREATE INDEX IF NOT EXISTS idx_key_bindings_enabled ON key_bindings(enabled)`,
// Themes indexes
`CREATE INDEX IF NOT EXISTS idx_themes_type ON themes(type)`,
`CREATE INDEX IF NOT EXISTS idx_themes_is_default ON themes(is_default)`,
}
for _, index := range indexes {
if _, err := ds.db.Exec(index); err != nil {
return err
}
}
return nil
}
// RegisterModel 注册模型与表的映射关系
func (ds *DatabaseService) RegisterModel(tableName string, model interface{}) {
ds.mu.Lock()
defer ds.mu.Unlock()
ds.tableModels = append(ds.tableModels, TableModel{
TableName: tableName,
Model: model,
})
}
// syncAllModelTables 同步所有注册的模型与表结构
func (ds *DatabaseService) syncAllModelTables() error {
for _, tm := range ds.tableModels {
if err := ds.syncModelTable(tm.TableName, tm.Model); err != nil {
return fmt.Errorf("failed to sync table %s: %w", tm.TableName, err)
}
}
return nil
}
// syncModelTable 同步模型与表结构
func (ds *DatabaseService) syncModelTable(tableName string, model interface{}) error {
// 获取表结构元数据
columns, err := ds.getTableColumns(tableName)
if err != nil {
return fmt.Errorf("failed to get table columns: %w", err)
}
// 使用反射从模型中提取字段信息
expectedColumns, err := ds.getModelColumns(model)
if err != nil {
return fmt.Errorf("failed to get model columns: %w", err)
}
// 检查缺失的列并添加
for colName, colInfo := range expectedColumns {
if _, exists := columns[colName]; !exists {
// 执行添加列的SQL
alterSQL := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s DEFAULT %s",
tableName, colName, colInfo.SQLType, colInfo.DefaultValue)
if _, err := ds.db.Exec(alterSQL); err != nil {
return fmt.Errorf("failed to add column %s: %w", colName, err)
}
}
}
return nil
}
// getTableColumns 获取表的列信息
func (ds *DatabaseService) getTableColumns(table string) (map[string]string, error) {
query := fmt.Sprintf("PRAGMA table_info(%s)", table)
rows, err := ds.db.Query(query)
if err != nil {
return nil, err
}
defer rows.Close()
columns := make(map[string]string)
for rows.Next() {
var cid int
var name, typeName string
var notNull, pk int
var dflt_value interface{}
if err := rows.Scan(&cid, &name, &typeName, &notNull, &dflt_value, &pk); err != nil {
return nil, err
}
columns[name] = typeName
}
if err := rows.Err(); err != nil {
return nil, err
}
return columns, nil
}
// getModelColumns 从模型结构体中提取数据库列信息
func (ds *DatabaseService) getModelColumns(model interface{}) (map[string]ColumnInfo, error) {
columns := make(map[string]ColumnInfo)
// 使用反射获取结构体的类型信息
t := reflect.TypeOf(model)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return nil, fmt.Errorf("model must be a struct or a pointer to struct")
}
// 遍历所有字段
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
// 只处理有db标签的字段
dbTag := field.Tag.Get("db")
if dbTag == "" {
// 如果没有db标签跳过该字段
continue
}
// 获取字段类型对应的SQL类型和默认值
sqlType, defaultVal := getSQLTypeAndDefault(field.Type)
columns[dbTag] = ColumnInfo{
SQLType: sqlType,
DefaultValue: defaultVal,
}
}
return columns, nil
}
// getSQLTypeAndDefault 根据Go类型获取对应的SQL类型和默认值
func getSQLTypeAndDefault(t reflect.Type) (string, string) {
switch t.Kind() {
case reflect.Bool:
return "INTEGER", "0"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return "INTEGER", "0"
case reflect.Float32, reflect.Float64:
return "REAL", "0.0"
case reflect.String:
return "TEXT", "''"
default:
return "TEXT", "NULL"
}
}
// ServiceShutdown shuts down the service when the application closes
func (ds *DatabaseService) ServiceShutdown() error {
// 取消配置观察者
if ds.cancelObserver != nil {
ds.cancelObserver()
}
if ds.db != nil {
return ds.db.Close()
if s.db != nil {
return s.db.Close()
}
return nil
}

View File

@@ -2,428 +2,141 @@ package services
import (
"context"
"database/sql"
"errors"
"fmt"
"sync"
"time"
"voidraft/internal/models"
"voidraft/internal/models/ent/document"
"github.com/wailsapp/wails/v3/pkg/application"
"github.com/wailsapp/wails/v3/pkg/services/log"
"voidraft/internal/models/ent"
)
// SQL constants for document operations
const (
const defaultDocumentTitle = "default"
const defaultDocumentContent = "\n∞∞∞text-a\n"
// Document operations
sqlGetDocumentByID = `
SELECT id, title, content, created_at, updated_at, is_deleted, is_locked
FROM documents
WHERE id = ?`
sqlInsertDocument = `
INSERT INTO documents (title, content, created_at, updated_at, is_deleted, is_locked)
VALUES (?, ?, ?, ?, 0, 0)`
sqlUpdateDocumentContent = `
UPDATE documents
SET content = ?, updated_at = ?
WHERE id = ? AND is_deleted = 0`
sqlUpdateDocumentTitle = `
UPDATE documents
SET title = ?, updated_at = ?
WHERE id = ? AND is_deleted = 0`
sqlMarkDocumentAsDeleted = `
UPDATE documents
SET is_deleted = 1, updated_at = ?
WHERE id = ? AND is_locked = 0`
sqlRestoreDocument = `
UPDATE documents
SET is_deleted = 0, updated_at = ?
WHERE id = ?`
sqlListAllDocumentsMeta = `
SELECT id, title, created_at, updated_at, is_locked
FROM documents
WHERE is_deleted = 0
ORDER BY updated_at DESC`
sqlListDeletedDocumentsMeta = `
SELECT id, title, created_at, updated_at, is_locked
FROM documents
WHERE is_deleted = 1
ORDER BY updated_at DESC`
sqlGetFirstDocumentID = `
SELECT id FROM documents WHERE is_deleted = 0 ORDER BY id LIMIT 1`
sqlCountDocuments = `SELECT COUNT(*) FROM documents WHERE is_deleted = 0`
sqlSetDocumentLocked = `
UPDATE documents
SET is_locked = 1, updated_at = ?
WHERE id = ?`
sqlSetDocumentUnlocked = `
UPDATE documents
SET is_locked = 0, updated_at = ?
WHERE id = ?`
sqlDefaultDocumentID = 1 // 默认文档的ID
)
// DocumentService provides document management functionality
// DocumentService 文档服务
type DocumentService struct {
databaseService *DatabaseService
logger *log.LogService
mu sync.RWMutex
ctx context.Context
db *DatabaseService
logger *log.LogService
}
// NewDocumentService creates a new document service
func NewDocumentService(databaseService *DatabaseService, logger *log.LogService) *DocumentService {
// NewDocumentService 创建文档服务
func NewDocumentService(db *DatabaseService, logger *log.LogService) *DocumentService {
if logger == nil {
logger = log.New()
}
ds := &DocumentService{
databaseService: databaseService,
logger: logger,
}
return ds
return &DocumentService{db: db, logger: logger}
}
// ServiceStartup initializes the service when the application starts
func (ds *DocumentService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error {
ds.ctx = ctx
// 确保默认文档存在
if err := ds.ensureDefaultDocument(); err != nil {
return fmt.Errorf("failed to ensure default document: %w", err)
}
return nil
}
// ensureDefaultDocument ensures a default document exists
func (ds *DocumentService) ensureDefaultDocument() error {
if ds.databaseService == nil || ds.databaseService.db == nil {
return errors.New("database service not available")
}
// Check if any document exists
var count int64
err := ds.databaseService.db.QueryRow(sqlCountDocuments).Scan(&count)
// ServiceStartup 服务启动
func (s *DocumentService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error {
exists, err := s.db.Client.Document.Query().Exist(ctx)
if err != nil {
return fmt.Errorf("failed to query document count: %w", err)
return fmt.Errorf("check document exists error: %w", err)
}
// If no documents exist, create default document
if count == 0 {
defaultDoc := models.NewDefaultDocument()
_, err := ds.CreateDocument(defaultDoc.Title)
return err
if !exists {
_, err = s.CreateDocument(ctx, defaultDocumentTitle)
}
return nil
return err
}
// GetDocumentByID gets a document by ID
func (ds *DocumentService) GetDocumentByID(id int64) (*models.Document, error) {
ds.mu.RLock()
defer ds.mu.RUnlock()
if ds.databaseService == nil || ds.databaseService.db == nil {
return nil, errors.New("database service not available")
}
doc := &models.Document{}
var isDeleted, isLocked int
err := ds.databaseService.db.QueryRow(sqlGetDocumentByID, id).Scan(
&doc.ID,
&doc.Title,
&doc.Content,
&doc.CreatedAt,
&doc.UpdatedAt,
&isDeleted,
&isLocked,
)
// GetDocumentByID 根据ID获取文档
func (s *DocumentService) GetDocumentByID(ctx context.Context, id int) (*ent.Document, error) {
doc, err := s.db.Client.Document.Get(ctx, id)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
if ent.IsNotFound(err) {
return nil, nil
}
return nil, fmt.Errorf("failed to get document by ID: %w", err)
return nil, fmt.Errorf("get document by id error: %w", err)
}
// 转换布尔字段
doc.IsDeleted = isDeleted == 1
doc.IsLocked = isLocked == 1
return doc, nil
}
// CreateDocument creates a new document and returns the created document with ID
func (ds *DocumentService) CreateDocument(title string) (*models.Document, error) {
ds.mu.Lock()
defer ds.mu.Unlock()
if ds.databaseService == nil || ds.databaseService.db == nil {
return nil, errors.New("database service not available")
}
doc := models.NewDocument(title, "\n∞∞∞text-a\n")
// 执行插入操作
result, err := ds.databaseService.db.Exec(sqlInsertDocument,
doc.Title, doc.Content, doc.CreatedAt, doc.UpdatedAt)
// CreateDocument 创建文档
func (s *DocumentService) CreateDocument(ctx context.Context, title string) (*ent.Document, error) {
doc, err := s.db.Client.Document.Create().
SetTitle(title).
SetContent(defaultDocumentContent).
Save(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create document: %w", err)
return nil, fmt.Errorf("create document error: %w", err)
}
// 获取自增ID
lastID, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("failed to get last insert ID: %w", err)
}
// 返回带ID的文档
doc.ID = lastID
return doc, nil
}
// LockDocument 锁定文档,防止删除
func (ds *DocumentService) LockDocument(id int64) error {
if ds.databaseService == nil || ds.databaseService.db == nil {
return errors.New("database service not available")
}
// UpdateDocumentContent 更新文档内容
func (s *DocumentService) UpdateDocumentContent(ctx context.Context, id int, content string) error {
return s.db.Client.Document.UpdateOneID(id).
SetContent(content).
Exec(ctx)
}
// 先检查文档是否存在且未删除
doc, err := ds.GetDocumentByID(id)
// UpdateDocumentTitle 更新文档标题
func (s *DocumentService) UpdateDocumentTitle(ctx context.Context, id int, title string) error {
return s.db.Client.Document.UpdateOneID(id).
SetTitle(title).
Exec(ctx)
}
// LockDocument 锁定文档
func (s *DocumentService) LockDocument(ctx context.Context, id int) error {
doc, err := s.GetDocumentByID(ctx, id)
if err != nil {
return fmt.Errorf("failed to get document: %w", err)
return err
}
if doc == nil {
return fmt.Errorf("document not found: %d", id)
}
if doc.IsDeleted {
return fmt.Errorf("cannot lock deleted document: %d", id)
}
// 如果已经锁定,无需操作
if doc.IsLocked {
if doc.Locked {
return nil
}
// 现在加锁执行锁定操作
ds.mu.Lock()
defer ds.mu.Unlock()
_, err = ds.databaseService.db.Exec(sqlSetDocumentLocked, time.Now().Format("2006-01-02 15:04:05"), id)
if err != nil {
return fmt.Errorf("failed to lock document: %w", err)
}
return nil
return s.db.Client.Document.UpdateOneID(id).
SetLocked(true).
Exec(ctx)
}
// UnlockDocument 解锁文档
func (ds *DocumentService) UnlockDocument(id int64) error {
if ds.databaseService == nil || ds.databaseService.db == nil {
return errors.New("database service not available")
}
// 先检查文档是否存在
doc, err := ds.GetDocumentByID(id)
func (s *DocumentService) UnlockDocument(ctx context.Context, id int) error {
doc, err := s.GetDocumentByID(ctx, id)
if err != nil {
return fmt.Errorf("failed to get document: %w", err)
return err
}
if doc == nil {
return fmt.Errorf("document not found: %d", id)
}
// 如果未锁定,无需操作
if !doc.IsLocked {
if !doc.Locked {
return nil
}
// 现在加锁执行解锁操作
ds.mu.Lock()
defer ds.mu.Unlock()
_, err = ds.databaseService.db.Exec(sqlSetDocumentUnlocked, time.Now().Format("2006-01-02 15:04:05"), id)
if err != nil {
return fmt.Errorf("failed to unlock document: %w", err)
}
return nil
return s.db.Client.Document.UpdateOneID(id).
SetLocked(false).
Exec(ctx)
}
// UpdateDocumentContent updates the content of a document
func (ds *DocumentService) UpdateDocumentContent(id int64, content string) error {
ds.mu.Lock()
defer ds.mu.Unlock()
if ds.databaseService == nil || ds.databaseService.db == nil {
return errors.New("database service not available")
}
_, err := ds.databaseService.db.Exec(sqlUpdateDocumentContent, content, time.Now().Format("2006-01-02 15:04:05"), id)
// DeleteDocument 删除文档
func (s *DocumentService) DeleteDocument(ctx context.Context, id int) error {
doc, err := s.GetDocumentByID(ctx, id)
if err != nil {
return fmt.Errorf("failed to update document content: %w", err)
}
return nil
}
// UpdateDocumentTitle updates the title of a document
func (ds *DocumentService) UpdateDocumentTitle(id int64, title string) error {
ds.mu.Lock()
defer ds.mu.Unlock()
if ds.databaseService == nil || ds.databaseService.db == nil {
return errors.New("database service not available")
}
_, err := ds.databaseService.db.Exec(sqlUpdateDocumentTitle, title, time.Now().Format("2006-01-02 15:04:05"), id)
if err != nil {
return fmt.Errorf("failed to update document title: %w", err)
}
return nil
}
// DeleteDocument marks a document as deleted (default document with ID=1 cannot be deleted)
func (ds *DocumentService) DeleteDocument(id int64) error {
if ds.databaseService == nil || ds.databaseService.db == nil {
ds.logger.Error("database service not available")
return errors.New("database service not available")
}
// 不允许删除默认文档
if id == sqlDefaultDocumentID {
return fmt.Errorf("cannot delete the default document")
}
// 先检查文档是否存在和锁定状态(不加锁避免死锁)
doc, err := ds.GetDocumentByID(id)
if err != nil {
return fmt.Errorf("failed to get document: %w", err)
return err
}
if doc == nil {
return fmt.Errorf("document not found: %d", id)
}
if doc.IsLocked {
if doc.Locked {
return fmt.Errorf("cannot delete locked document: %d", id)
}
// 现在加锁执行删除操作
ds.mu.Lock()
defer ds.mu.Unlock()
_, err = ds.databaseService.db.Exec(sqlMarkDocumentAsDeleted, time.Now().Format("2006-01-02 15:04:05"), id)
count, err := s.db.Client.Document.Query().Count(ctx)
if err != nil {
return fmt.Errorf("failed to mark document as deleted: %w", err)
return err
}
return nil
if count <= 1 {
return errors.New("cannot delete the last document")
}
return s.db.Client.Document.DeleteOneID(id).Exec(ctx)
}
// RestoreDocument restores a deleted document
func (ds *DocumentService) RestoreDocument(id int64) error {
ds.mu.Lock()
defer ds.mu.Unlock()
if ds.databaseService == nil || ds.databaseService.db == nil {
return errors.New("database service not available")
}
_, err := ds.databaseService.db.Exec(sqlRestoreDocument, time.Now().Format("2006-01-02 15:04:05"), id)
if err != nil {
return fmt.Errorf("failed to restore document: %w", err)
}
return nil
}
// ListAllDocumentsMeta lists all active (non-deleted) document metadata
func (ds *DocumentService) ListAllDocumentsMeta() ([]*models.Document, error) {
ds.mu.RLock()
defer ds.mu.RUnlock()
if ds.databaseService == nil || ds.databaseService.db == nil {
return nil, errors.New("database service not available")
}
rows, err := ds.databaseService.db.Query(sqlListAllDocumentsMeta)
if err != nil {
return nil, fmt.Errorf("failed to list document meta: %w", err)
}
defer rows.Close()
var documents []*models.Document
for rows.Next() {
doc := &models.Document{IsDeleted: false}
var isLocked int
err := rows.Scan(
&doc.ID,
&doc.Title,
&doc.CreatedAt,
&doc.UpdatedAt,
&isLocked,
)
if err != nil {
return nil, fmt.Errorf("failed to scan document row: %w", err)
}
doc.IsLocked = isLocked == 1
documents = append(documents, doc)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating document rows: %w", err)
}
return documents, nil
}
// ListDeletedDocumentsMeta lists all deleted document metadata
func (ds *DocumentService) ListDeletedDocumentsMeta() ([]*models.Document, error) {
ds.mu.RLock()
defer ds.mu.RUnlock()
if ds.databaseService == nil || ds.databaseService.db == nil {
return nil, errors.New("database service not available")
}
rows, err := ds.databaseService.db.Query(sqlListDeletedDocumentsMeta)
if err != nil {
return nil, fmt.Errorf("failed to list deleted document meta: %w", err)
}
defer rows.Close()
var documents []*models.Document
for rows.Next() {
doc := &models.Document{IsDeleted: true}
var isLocked int
err := rows.Scan(
&doc.ID,
&doc.Title,
&doc.CreatedAt,
&doc.UpdatedAt,
&isLocked,
)
if err != nil {
return nil, fmt.Errorf("failed to scan document row: %w", err)
}
doc.IsLocked = isLocked == 1
documents = append(documents, doc)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating deleted document rows: %w", err)
}
return documents, nil
// ListAllDocumentsMeta 列出所有文档
func (s *DocumentService) ListAllDocumentsMeta(ctx context.Context) ([]*ent.Document, error) {
return s.db.Client.Document.Query().Select(document.FieldID, document.FieldTitle, document.FieldUpdatedAt, document.FieldLocked, document.FieldCreatedAt).
Order(ent.Desc(document.FieldUpdatedAt)).
All(ctx)
}

View File

@@ -2,380 +2,165 @@ package services
import (
"context"
"encoding/json"
"errors"
"fmt"
"sync"
"time"
"voidraft/internal/models"
"voidraft/internal/models/ent"
"voidraft/internal/models/ent/extension"
"voidraft/internal/models/schema/mixin"
"github.com/wailsapp/wails/v3/pkg/application"
"github.com/wailsapp/wails/v3/pkg/services/log"
)
// SQL constants for extension operations
const (
// Extension operations
sqlGetAllExtensions = `
SELECT id, enabled, is_default, config
FROM extensions
ORDER BY id`
sqlGetExtensionByID = `
SELECT id, enabled, is_default, config
FROM extensions
WHERE id = ?`
sqlInsertExtension = `
INSERT INTO extensions (id, enabled, is_default, config, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?)`
sqlUpdateExtension = `
UPDATE extensions
SET enabled = ?, config = ?, updated_at = ?
WHERE id = ?`
sqlDeleteAllExtensions = `DELETE FROM extensions`
)
// ExtensionService 扩展管理服务
// ExtensionService 扩展服务
type ExtensionService struct {
databaseService *DatabaseService
logger *log.LogService
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
initOnce sync.Once
db *DatabaseService
logger *log.LogService
}
// ExtensionError 扩展错误
type ExtensionError struct {
Operation string
Extension string
Err error
}
func (e *ExtensionError) Error() string {
if e.Extension != "" {
return fmt.Sprintf("extension %s for %s: %v", e.Operation, e.Extension, e.Err)
}
return fmt.Sprintf("extension %s: %v", e.Operation, e.Err)
}
func (e *ExtensionError) Unwrap() error {
return e.Err
}
func (e *ExtensionError) Is(target error) bool {
var extensionError *ExtensionError
return errors.As(target, &extensionError)
}
// NewExtensionService 创建扩展服务实例
func NewExtensionService(databaseService *DatabaseService, logger *log.LogService) *ExtensionService {
// NewExtensionService 创建扩展服务
func NewExtensionService(db *DatabaseService, logger *log.LogService) *ExtensionService {
if logger == nil {
logger = log.New()
}
ctx, cancel := context.WithCancel(context.Background())
service := &ExtensionService{
databaseService: databaseService,
logger: logger,
ctx: ctx,
cancel: cancel,
}
return service
return &ExtensionService{db: db, logger: logger}
}
// initialize 初始化配置
func (es *ExtensionService) initialize() {
es.initOnce.Do(func() {
if err := es.initDatabase(); err != nil {
es.logger.Error("failed to initialize extension database", "error", err)
}
})
// ServiceStartup 服务启动
func (s *ExtensionService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error {
return s.SyncExtensions(ctx)
}
// initDatabase 初始化数据库数据
func (es *ExtensionService) initDatabase() error {
es.mu.Lock()
defer es.mu.Unlock()
if es.databaseService == nil || es.databaseService.db == nil {
return &ExtensionError{"check_db", "", errors.New("database service not available")}
// SyncExtensions 同步扩展配置
func (s *ExtensionService) SyncExtensions(ctx context.Context) error {
defaults := models.NewDefaultExtensions()
definedKeys := make(map[models.ExtensionKey]models.Extension)
for _, ext := range defaults {
definedKeys[ext.Key] = ext
}
// 检查是否已有扩展数据
var count int64
err := es.databaseService.db.QueryRow("SELECT COUNT(*) FROM extensions").Scan(&count)
// 获取数据库中已有扩展
existing, err := s.db.Client.Extension.Query().All(ctx)
if err != nil {
return &ExtensionError{"check_extensions_count", "", err}
return fmt.Errorf("find extensions error: %w", err)
}
// 如果没有数据,插入默认配置
if count == 0 {
if err := es.insertDefaultExtensions(); err != nil {
es.logger.Error("Failed to insert default extensions", "error", err)
return err
existingKeys := make(map[string]bool)
for _, ext := range existing {
existingKeys[ext.Key] = true
}
// 批量添加缺失的扩展
var builders []*ent.ExtensionCreate
for key, ext := range definedKeys {
if !existingKeys[string(key)] {
builders = append(builders, s.db.Client.Extension.Create().
SetKey(string(ext.Key)).
SetEnabled(ext.Enabled).
SetConfig(ext.Config))
}
} else {
// 检查并补充缺失的扩展
if err := es.syncExtensions(); err != nil {
es.logger.Error("Failed to ensure all extensions exist", "error", err)
return err
}
if len(builders) > 0 {
if _, err := s.db.Client.Extension.CreateBulk(builders...).Save(ctx); err != nil {
return fmt.Errorf("bulk insert extensions error: %w", err)
}
}
// 批量删除废弃的扩展
var deleteIDs []int
for _, ext := range existing {
if _, ok := definedKeys[models.ExtensionKey(ext.Key)]; !ok {
deleteIDs = append(deleteIDs, ext.ID)
}
}
if len(deleteIDs) > 0 {
if _, err := s.db.Client.Extension.Delete().
Where(extension.IDIn(deleteIDs...)).
Exec(mixin.SkipSoftDelete(ctx)); err != nil {
return fmt.Errorf("bulk delete extensions error: %w", err)
}
}
return nil
}
// insertDefaultExtensions 插入默认扩展配置
func (es *ExtensionService) insertDefaultExtensions() error {
defaultSettings := models.NewDefaultExtensionSettings()
now := time.Now().Format("2006-01-02 15:04:05")
for _, ext := range defaultSettings.Extensions {
configJSON, err := json.Marshal(ext.Config)
if err != nil {
return &ExtensionError{"marshal_config", string(ext.ID), err}
}
_, err = es.databaseService.db.Exec(sqlInsertExtension,
string(ext.ID),
ext.Enabled,
ext.IsDefault,
string(configJSON),
now,
now,
)
if err != nil {
return &ExtensionError{"insert_extension", string(ext.ID), err}
}
}
return nil
// GetAllExtensions 获取所有扩展
func (s *ExtensionService) GetAllExtensions(ctx context.Context) ([]*ent.Extension, error) {
return s.db.Client.Extension.Query().All(ctx)
}
// syncExtensions 确保数据库中的扩展与代码定义同步
func (es *ExtensionService) syncExtensions() error {
defaultSettings := models.NewDefaultExtensionSettings()
now := time.Now().Format("2006-01-02 15:04:05")
// 构建代码中定义的扩展ID集合
definedExtensions := make(map[string]bool)
for _, ext := range defaultSettings.Extensions {
definedExtensions[string(ext.ID)] = true
}
// 1. 添加缺失的扩展
for _, ext := range defaultSettings.Extensions {
var exists int
err := es.databaseService.db.QueryRow("SELECT COUNT(*) FROM extensions WHERE id = ?", string(ext.ID)).Scan(&exists)
if err != nil {
return &ExtensionError{"check_extension_exists", string(ext.ID), err}
}
if exists == 0 {
configJSON, err := json.Marshal(ext.Config)
if err != nil {
return &ExtensionError{"marshal_config", string(ext.ID), err}
}
_, err = es.databaseService.db.Exec(sqlInsertExtension,
string(ext.ID),
ext.Enabled,
ext.IsDefault,
string(configJSON),
now,
now,
)
if err != nil {
return &ExtensionError{"insert_missing_extension", string(ext.ID), err}
}
es.logger.Info("Added missing extension to database", "id", ext.ID)
}
}
// 2. 删除数据库中已不存在于代码定义的扩展
rows, err := es.databaseService.db.Query("SELECT id FROM extensions")
// GetExtensionByKey 根据Key获取扩展
func (s *ExtensionService) GetExtensionByKey(ctx context.Context, key string) (*ent.Extension, error) {
ext, err := s.db.Client.Extension.Query().
Where(extension.Key(key)).
Only(ctx)
if err != nil {
return &ExtensionError{"query_all_extension_ids", "", err}
}
defer rows.Close()
var toDelete []string
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return &ExtensionError{"scan_extension_id", "", err}
}
if !definedExtensions[id] {
toDelete = append(toDelete, id)
if ent.IsNotFound(err) {
return nil, nil
}
return nil, fmt.Errorf("get extension error: %w", err)
}
if err = rows.Err(); err != nil {
return &ExtensionError{"iterate_extension_ids", "", err}
}
// 删除不再定义的扩展
for _, id := range toDelete {
_, err := es.databaseService.db.Exec("DELETE FROM extensions WHERE id = ?", id)
if err != nil {
return &ExtensionError{"delete_obsolete_extension", id, err}
}
es.logger.Info("Removed obsolete extension from database", "id", id)
}
return nil
}
// ServiceStartup 启动时调用
func (es *ExtensionService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error {
es.ctx = ctx
// 初始化数据库
var initErr error
es.initOnce.Do(func() {
if err := es.initDatabase(); err != nil {
es.logger.Error("failed to initialize extension database", "error", err)
initErr = err
}
})
return initErr
}
// GetAllExtensions 获取所有扩展配置
func (es *ExtensionService) GetAllExtensions() ([]models.Extension, error) {
es.mu.RLock()
defer es.mu.RUnlock()
if es.databaseService == nil || es.databaseService.db == nil {
return nil, &ExtensionError{"query_db", "", errors.New("database service not available")}
}
rows, err := es.databaseService.db.Query(sqlGetAllExtensions)
if err != nil {
return nil, &ExtensionError{"query_extensions", "", err}
}
defer rows.Close()
var extensions []models.Extension
for rows.Next() {
var ext models.Extension
var id string
var configJSON string
var enabled, isDefault int
err := rows.Scan(
&id,
&enabled,
&isDefault,
&configJSON,
)
if err != nil {
return nil, &ExtensionError{"scan_extension", "", err}
}
ext.ID = models.ExtensionID(id)
ext.Enabled = enabled == 1
ext.IsDefault = isDefault == 1
var config models.ExtensionConfig
if err := json.Unmarshal([]byte(configJSON), &config); err != nil {
return nil, &ExtensionError{"unmarshal_config", id, err}
}
ext.Config = config
extensions = append(extensions, ext)
}
if err = rows.Err(); err != nil {
return nil, &ExtensionError{"iterate_extensions", "", err}
}
return extensions, nil
return ext, nil
}
// UpdateExtensionEnabled 更新扩展启用状态
func (es *ExtensionService) UpdateExtensionEnabled(id models.ExtensionID, enabled bool) error {
return es.UpdateExtensionState(id, enabled, nil)
}
// UpdateExtensionState 更新扩展状态
func (es *ExtensionService) UpdateExtensionState(id models.ExtensionID, enabled bool, config models.ExtensionConfig) error {
es.mu.Lock()
defer es.mu.Unlock()
if es.databaseService == nil || es.databaseService.db == nil {
return &ExtensionError{"check_db", string(id), errors.New("database service not available")}
}
var configJSON []byte
var err error
if config != nil {
configJSON, err = json.Marshal(config)
if err != nil {
return &ExtensionError{"marshal_config", string(id), err}
}
} else {
// 如果没有提供配置,保持原有配置
var currentConfigJSON string
err := es.databaseService.db.QueryRow("SELECT config FROM extensions WHERE id = ?", string(id)).Scan(&currentConfigJSON)
if err != nil {
return &ExtensionError{"query_current_config", string(id), err}
}
configJSON = []byte(currentConfigJSON)
}
_, err = es.databaseService.db.Exec(sqlUpdateExtension,
enabled,
string(configJSON),
time.Now().Format("2006-01-02 15:04:05"),
string(id))
func (s *ExtensionService) UpdateExtensionEnabled(ctx context.Context, key string, enabled bool) error {
ext, err := s.GetExtensionByKey(ctx, key)
if err != nil {
return &ExtensionError{"update_extension", string(id), err}
}
return nil
}
// ResetExtensionToDefault 重置扩展到默认状态
func (es *ExtensionService) ResetExtensionToDefault(id models.ExtensionID) error {
// 获取默认配置
defaultSettings := models.NewDefaultExtensionSettings()
defaultExtension := defaultSettings.GetExtensionByID(id)
if defaultExtension == nil {
return &ExtensionError{"default_extension_not_found", string(id), nil}
}
return es.UpdateExtensionState(id, defaultExtension.Enabled, defaultExtension.Config)
}
// ResetAllExtensionsToDefault 重置所有扩展到默认状态
func (es *ExtensionService) ResetAllExtensionsToDefault() error {
es.mu.Lock()
defer es.mu.Unlock()
if es.databaseService == nil || es.databaseService.db == nil {
return &ExtensionError{"check_db", "", errors.New("database service not available")}
}
// 删除所有现有扩展
_, err := es.databaseService.db.Exec(sqlDeleteAllExtensions)
if err != nil {
return &ExtensionError{"delete_all_extensions", "", err}
}
// 插入默认扩展配置
if err := es.insertDefaultExtensions(); err != nil {
return err
}
return nil
if ext == nil {
return fmt.Errorf("extension not found: %s", key)
}
return s.db.Client.Extension.UpdateOneID(ext.ID).
SetEnabled(enabled).
Exec(ctx)
}
// UpdateExtensionConfig 更新扩展配置
func (s *ExtensionService) UpdateExtensionConfig(ctx context.Context, key string, config map[string]interface{}) error {
ext, err := s.GetExtensionByKey(ctx, key)
if err != nil {
return err
}
if ext == nil {
return fmt.Errorf("extension not found: %s", key)
}
return s.db.Client.Extension.UpdateOneID(ext.ID).
SetConfig(config).
Exec(ctx)
}
// ResetExtensionConfig 重置单个扩展到默认状态
func (s *ExtensionService) ResetExtensionConfig(ctx context.Context, key string) error {
defaults := models.NewDefaultExtensions()
var defaultExt *models.Extension
for _, ext := range defaults {
if string(ext.Key) == key {
defaultExt = &ext
break
}
}
if defaultExt == nil {
return fmt.Errorf("default extension not found: %s", key)
}
ext, err := s.GetExtensionByKey(ctx, key)
if err != nil {
return err
}
if ext == nil {
return fmt.Errorf("extension not found: %s", key)
}
return s.db.Client.Extension.UpdateOneID(ext.ID).
SetEnabled(defaultExt.Enabled).
SetConfig(defaultExt.Config).
Exec(ctx)
}
// GetDefaultExtensions 获取默认扩展配置(用于前端绑定生成)
func (s *ExtensionService) GetDefaultExtensions() []models.Extension {
return models.NewDefaultExtensions()
}

View File

@@ -2,212 +2,141 @@ package services
import (
"context"
"errors"
"fmt"
"sync"
"time"
"voidraft/internal/models"
"voidraft/internal/models/ent"
"voidraft/internal/models/ent/keybinding"
"voidraft/internal/models/schema/mixin"
"github.com/wailsapp/wails/v3/pkg/application"
"github.com/wailsapp/wails/v3/pkg/services/log"
)
// SQL 查询语句
const (
// 快捷键操作
sqlGetAllKeyBindings = `
SELECT command, extension, key, enabled, is_default
FROM key_bindings
ORDER BY command
`
sqlGetKeyBindingByCommand = `
SELECT command, extension, key, enabled, is_default
FROM key_bindings
WHERE command = ?
`
sqlInsertKeyBinding = `
INSERT INTO key_bindings (command, extension, key, enabled, is_default, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
`
sqlUpdateKeyBinding = `
UPDATE key_bindings
SET extension = ?, key = ?, enabled = ?, updated_at = ?
WHERE command = ?
`
sqlDeleteKeyBinding = `
DELETE FROM key_bindings
WHERE command = ?
`
sqlDeleteAllKeyBindings = `
DELETE FROM key_bindings
`
)
// KeyBindingService 快捷键管理服务
// KeyBindingService 快捷键服务
type KeyBindingService struct {
databaseService *DatabaseService
logger *log.LogService
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
initOnce sync.Once
db *DatabaseService
logger *log.LogService
}
// KeyBindingError 快捷键错误
type KeyBindingError struct {
Operation string
Command string
Err error
}
func (e *KeyBindingError) Error() string {
if e.Command != "" {
return fmt.Sprintf("keybinding %s for %s: %v", e.Operation, e.Command, e.Err)
}
return fmt.Sprintf("keybinding %s: %v", e.Operation, e.Err)
}
func (e *KeyBindingError) Unwrap() error {
return e.Err
}
func (e *KeyBindingError) Is(target error) bool {
var keyBindingError *KeyBindingError
return errors.As(target, &keyBindingError)
}
// NewKeyBindingService 创建快捷键服务实例
func NewKeyBindingService(databaseService *DatabaseService, logger *log.LogService) *KeyBindingService {
// NewKeyBindingService 创建快捷键服务
func NewKeyBindingService(db *DatabaseService, logger *log.LogService) *KeyBindingService {
if logger == nil {
logger = log.New()
}
ctx, cancel := context.WithCancel(context.Background())
service := &KeyBindingService{
databaseService: databaseService,
logger: logger,
ctx: ctx,
cancel: cancel,
}
return service
return &KeyBindingService{db: db, logger: logger}
}
// initDatabase 初始化数据库数据
func (kbs *KeyBindingService) initDatabase() error {
kbs.mu.Lock()
defer kbs.mu.Unlock()
// ServiceStartup 服务启动
func (s *KeyBindingService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error {
return s.SyncKeyBindings(ctx)
}
if kbs.databaseService == nil || kbs.databaseService.db == nil {
return &KeyBindingError{"check_db", "", errors.New("database service not available")}
// SyncKeyBindings 同步快捷键配置
func (s *KeyBindingService) SyncKeyBindings(ctx context.Context) error {
defaults := models.NewDefaultKeyBindings()
definedKeys := make(map[models.KeyBindingKey]models.KeyBinding)
for _, kb := range defaults {
definedKeys[kb.Key] = kb
}
// 检查是否已有快捷键数据
var count int64
err := kbs.databaseService.db.QueryRow("SELECT COUNT(*) FROM key_bindings").Scan(&count)
// 获取数据库中已有快捷键
existing, err := s.db.Client.KeyBinding.Query().All(ctx)
if err != nil {
return &KeyBindingError{"check_keybindings_count", "", err}
return fmt.Errorf("find key bindings error: %w", err)
}
// 如果没有数据,插入默认配置
if count == 0 {
if err := kbs.insertDefaultKeyBindings(); err != nil {
kbs.logger.Error("Failed to insert default key bindings", "error", err)
return err
existingKeys := make(map[string]bool)
for _, kb := range existing {
existingKeys[kb.Key] = true
}
// 批量添加缺失的快捷键
var builders []*ent.KeyBindingCreate
for key, kb := range definedKeys {
if !existingKeys[string(key)] {
create := s.db.Client.KeyBinding.Create().
SetKey(string(kb.Key)).
SetCommand(kb.Command).
SetEnabled(kb.Enabled)
if kb.Extension != "" {
create.SetExtension(string(kb.Extension))
}
builders = append(builders, create)
}
}
if len(builders) > 0 {
if _, err := s.db.Client.KeyBinding.CreateBulk(builders...).Save(ctx); err != nil {
return fmt.Errorf("bulk insert key bindings error: %w", err)
}
}
// 批量删除废弃的快捷键(硬删除)
var deleteIDs []int
for _, kb := range existing {
if _, ok := definedKeys[models.KeyBindingKey(kb.Key)]; !ok {
deleteIDs = append(deleteIDs, kb.ID)
}
}
if len(deleteIDs) > 0 {
if _, err := s.db.Client.KeyBinding.Delete().
Where(keybinding.IDIn(deleteIDs...)).
Exec(mixin.SkipSoftDelete(ctx)); err != nil {
return fmt.Errorf("bulk delete key bindings error: %w", err)
}
}
return nil
}
// insertDefaultKeyBindings 插入默认快捷键配置
func (kbs *KeyBindingService) insertDefaultKeyBindings() error {
defaultConfig := models.NewDefaultKeyBindingConfig()
now := time.Now().Format("2006-01-02 15:04:05")
for _, kb := range defaultConfig.KeyBindings {
_, err := kbs.databaseService.db.Exec(sqlInsertKeyBinding,
string(kb.Command), // 转换为字符串存储
string(kb.Extension), // 转换为字符串存储
kb.Key,
kb.Enabled,
kb.IsDefault,
now,
now,
)
if err != nil {
return &KeyBindingError{"insert_keybinding", string(kb.Command), err}
}
}
return nil
// GetAllKeyBindings 获取所有快捷键
func (s *KeyBindingService) GetAllKeyBindings(ctx context.Context) ([]*ent.KeyBinding, error) {
return s.db.Client.KeyBinding.Query().All(ctx)
}
// GetAllKeyBindings 获取所有快捷键配置
func (kbs *KeyBindingService) GetAllKeyBindings() ([]models.KeyBinding, error) {
kbs.mu.RLock()
defer kbs.mu.RUnlock()
if kbs.databaseService == nil || kbs.databaseService.db == nil {
return nil, &KeyBindingError{"query_db", "", errors.New("database service not available")}
}
rows, err := kbs.databaseService.db.Query(sqlGetAllKeyBindings)
// GetKeyBindingByKey 根据Key获取快捷键
func (s *KeyBindingService) GetKeyBindingByKey(ctx context.Context, key string) (*ent.KeyBinding, error) {
kb, err := s.db.Client.KeyBinding.Query().
Where(keybinding.Key(key)).
Only(ctx)
if err != nil {
return nil, &KeyBindingError{"query_keybindings", "", err}
}
defer rows.Close()
var keyBindings []models.KeyBinding
for rows.Next() {
var kb models.KeyBinding
var command, extension string
var enabled, isDefault int
err := rows.Scan(
&command,
&extension,
&kb.Key,
&enabled,
&isDefault,
)
if err != nil {
return nil, &KeyBindingError{"scan_keybinding", "", err}
if ent.IsNotFound(err) {
return nil, nil
}
kb.Command = models.KeyBindingCommand(command)
kb.Extension = models.ExtensionID(extension)
kb.Enabled = enabled == 1
kb.IsDefault = isDefault == 1
keyBindings = append(keyBindings, kb)
return nil, fmt.Errorf("get key binding error: %w", err)
}
if err = rows.Err(); err != nil {
return nil, &KeyBindingError{"iterate_keybindings", "", err}
}
return keyBindings, nil
return kb, nil
}
// ServiceStartup 启动时调用
func (kbs *KeyBindingService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error {
kbs.ctx = ctx
// 初始化数据库
var initErr error
kbs.initOnce.Do(func() {
if err := kbs.initDatabase(); err != nil {
kbs.logger.Error("failed to initialize keybinding database", "error", err)
initErr = err
}
})
return initErr
// UpdateKeyBindingCommand 更新快捷键命令
func (s *KeyBindingService) UpdateKeyBindingCommand(ctx context.Context, key string, command string) error {
kb, err := s.GetKeyBindingByKey(ctx, key)
if err != nil {
return err
}
if kb == nil {
return fmt.Errorf("key binding not found: %s", key)
}
return s.db.Client.KeyBinding.UpdateOneID(kb.ID).
SetCommand(command).
Exec(ctx)
}
// UpdateKeyBindingEnabled 更新快捷键启用状态
func (s *KeyBindingService) UpdateKeyBindingEnabled(ctx context.Context, key string, enabled bool) error {
kb, err := s.GetKeyBindingByKey(ctx, key)
if err != nil {
return err
}
if kb == nil {
return fmt.Errorf("key binding not found: %s", key)
}
return s.db.Client.KeyBinding.UpdateOneID(kb.ID).
SetEnabled(enabled).
Exec(ctx)
}
// GetDefaultKeyBindings 获取默认快捷键配置
func (s *KeyBindingService) GetDefaultKeyBindings() []models.KeyBinding {
return models.NewDefaultKeyBindings()
}

View File

@@ -140,7 +140,6 @@ func (sm *ServiceManager) GetServices() []application.Service {
application.NewService(sm.systemService),
application.NewService(sm.hotkeyService),
application.NewService(sm.dialogService),
application.NewService(sm.trayService),
application.NewService(sm.startupService),
application.NewService(sm.selfUpdateService),
application.NewService(sm.translationService),

View File

@@ -2,12 +2,12 @@ package services
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
"voidraft/internal/models"
"voidraft/internal/models/ent"
"voidraft/internal/models/ent/theme"
"voidraft/internal/models/schema/mixin"
"github.com/wailsapp/wails/v3/pkg/application"
"github.com/wailsapp/wails/v3/pkg/services/log"
@@ -15,152 +15,90 @@ import (
// ThemeService 主题服务
type ThemeService struct {
databaseService *DatabaseService
logger *log.LogService
ctx context.Context
db *DatabaseService
logger *log.LogService
}
// NewThemeService 创建新的主题服务
func NewThemeService(databaseService *DatabaseService, logger *log.LogService) *ThemeService {
// NewThemeService 创建主题服务
func NewThemeService(db *DatabaseService, logger *log.LogService) *ThemeService {
if logger == nil {
logger = log.New()
}
return &ThemeService{
databaseService: databaseService,
logger: logger,
}
return &ThemeService{db: db, logger: logger}
}
// ServiceStartup 服务启动
func (ts *ThemeService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error {
ts.ctx = ctx
func (s *ThemeService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error {
return nil
}
// getDB 获取数据库连接
func (ts *ThemeService) getDB() *sql.DB {
return ts.databaseService.db
}
// GetThemeByName 通过名称获取主题覆盖,若不存在则返回 nil
func (ts *ThemeService) GetThemeByName(name string) (*models.Theme, error) {
db := ts.getDB()
if db == nil {
return nil, fmt.Errorf("database not available")
}
trimmed := strings.TrimSpace(name)
// GetThemeByKey 根据Key获取主题
func (s *ThemeService) GetThemeByKey(ctx context.Context, key string) (*ent.Theme, error) {
trimmed := strings.TrimSpace(key)
if trimmed == "" {
return nil, fmt.Errorf("theme name cannot be empty")
return nil, fmt.Errorf("theme key cannot be empty")
}
query := `
SELECT id, name, type, colors, is_default, created_at, updated_at
FROM themes
WHERE name = ?
LIMIT 1
`
theme := &models.Theme{}
err := db.QueryRow(query, trimmed).Scan(
&theme.ID,
&theme.Name,
&theme.Type,
&theme.Colors,
&theme.IsDefault,
&theme.CreatedAt,
&theme.UpdatedAt,
)
t, err := s.db.Client.Theme.Query().
Where(theme.Key(trimmed)).
Only(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
if ent.IsNotFound(err) {
return nil, nil
}
return nil, fmt.Errorf("failed to query theme: %w", err)
return nil, fmt.Errorf("get theme error: %w", err)
}
return theme, nil
return t, nil
}
// UpdateTheme 保存或更新主题覆盖
func (ts *ThemeService) UpdateTheme(name string, colors models.ThemeColorConfig) error {
db := ts.getDB()
if db == nil {
return fmt.Errorf("database not available")
}
trimmed := strings.TrimSpace(name)
// UpdateTheme 保存或更新主题
func (s *ThemeService) UpdateTheme(ctx context.Context, key string, colors map[string]interface{}) error {
trimmed := strings.TrimSpace(key)
if trimmed == "" {
return fmt.Errorf("theme name cannot be empty")
return fmt.Errorf("theme key cannot be empty")
}
if colors == nil {
colors = models.ThemeColorConfig{}
colors = map[string]interface{}{}
}
colors["themeName"] = trimmed
themeType := models.ThemeTypeDark
// 判断主题类型
themeType := theme.TypeDark
if raw, ok := colors["dark"].(bool); ok && !raw {
themeType = models.ThemeTypeLight
themeType = theme.TypeLight
}
now := time.Now().Format("2006-01-02 15:04:05")
existing, err := ts.GetThemeByName(trimmed)
existing, err := s.GetThemeByKey(ctx, trimmed)
if err != nil {
return err
}
if existing == nil {
_, err = db.Exec(
`INSERT INTO themes (name, type, colors, is_default, created_at, updated_at) VALUES (?, ?, ?, 0, ?, ?)`,
trimmed,
themeType,
colors,
now,
now,
)
if err != nil {
return fmt.Errorf("failed to insert theme: %w", err)
}
return nil
// 插入新主题
_, err = s.db.Client.Theme.Create().
SetKey(trimmed).
SetType(themeType).
SetColors(colors).
Save(ctx)
return err
}
_, err = db.Exec(
`UPDATE themes SET type = ?, colors = ?, updated_at = ? WHERE name = ?`,
themeType,
colors,
now,
trimmed,
)
if err != nil {
return fmt.Errorf("failed to update theme: %w", err)
}
return nil
// 更新现有主题
return s.db.Client.Theme.UpdateOneID(existing.ID).
SetType(themeType).
SetColors(colors).
Exec(ctx)
}
// ResetTheme 删除指定主题的覆盖配置
func (ts *ThemeService) ResetTheme(name string) error {
db := ts.getDB()
if db == nil {
return fmt.Errorf("database not available")
}
trimmed := strings.TrimSpace(name)
// ResetTheme 删除主题
func (s *ThemeService) ResetTheme(ctx context.Context, key string) error {
trimmed := strings.TrimSpace(key)
if trimmed == "" {
return fmt.Errorf("theme name cannot be empty")
return fmt.Errorf("theme key cannot be empty")
}
if _, err := db.Exec(`DELETE FROM themes WHERE name = ?`, trimmed); err != nil {
return fmt.Errorf("failed to reset theme: %w", err)
}
return nil
}
// ServiceShutdown 服务关闭
func (ts *ThemeService) ServiceShutdown() error {
return nil
_, err := s.db.Client.Theme.Delete().
Where(theme.Key(trimmed)).
Exec(mixin.SkipSoftDelete(ctx))
return err
}

View File

@@ -52,7 +52,7 @@ func (ws *WindowService) OpenDocumentWindow(documentID int64) error {
}
// 获取文档信息
doc, err := ws.documentService.GetDocumentByID(documentID)
doc, err := ws.documentService.GetDocumentByID(context.Background(), int(documentID))
if err != nil {
return fmt.Errorf("failed to get document: %w", err)
}

View File

@@ -1,15 +1,13 @@
package services
import (
"github.com/wailsapp/wails/v3/pkg/application"
"github.com/wailsapp/wails/v3/pkg/events"
"github.com/wailsapp/wails/v3/pkg/services/log"
"math"
"sync"
"time"
"voidraft/internal/common/helper"
"voidraft/internal/models"
"github.com/wailsapp/wails/v3/pkg/application"
"github.com/wailsapp/wails/v3/pkg/events"
"github.com/wailsapp/wails/v3/pkg/services/log"
)
// 防抖和检测常量
@@ -22,6 +20,43 @@ const (
dragDetectionThreshold = 40 * time.Millisecond
)
// SnapEdge 表示吸附的边缘类型
type SnapEdge int
const (
SnapEdgeNone SnapEdge = iota // 未吸附
SnapEdgeTop // 吸附到上边缘
SnapEdgeRight // 吸附到右边缘
SnapEdgeBottom // 吸附到下边缘
SnapEdgeLeft // 吸附到左边缘
SnapEdgeTopRight // 吸附到右上角
SnapEdgeBottomRight // 吸附到右下角
SnapEdgeBottomLeft // 吸附到左下角
SnapEdgeTopLeft // 吸附到左上角
)
// WindowPosition 窗口位置
type WindowPosition struct {
X int `json:"x"` // X坐标
Y int `json:"y"` // Y坐标
}
// SnapPosition 表示吸附的相对位置
type SnapPosition struct {
X int `json:"x"` // X轴相对偏移
Y int `json:"y"` // Y轴相对偏移
}
// WindowInfo 窗口信息
type WindowInfo struct {
DocumentID int64 `json:"documentID"` // 文档ID
IsSnapped bool `json:"isSnapped"` // 是否处于吸附状态
SnapOffset SnapPosition `json:"snapOffset"` // 与主窗口的相对位置偏移
SnapEdge SnapEdge `json:"snapEdge"` // 吸附的边缘类型
LastPos WindowPosition `json:"lastPos"` // 上一次记录的窗口位置
MoveTime time.Time `json:"moveTime"` // 上次移动时间,用于判断移动速度
}
// WindowSnapService 窗口吸附服务
type WindowSnapService struct {
logger *log.LogService
@@ -38,11 +73,11 @@ type WindowSnapService struct {
maxThreshold int // 最大阈值(像素)
// 位置缓存
lastMainWindowPos models.WindowPosition // 缓存主窗口位置
lastMainWindowSize [2]int // 缓存主窗口尺寸 [width, height]
lastMainWindowPos WindowPosition // 缓存主窗口位置
lastMainWindowSize [2]int // 缓存主窗口尺寸 [width, height]
// 管理的窗口
managedWindows map[int64]*models.WindowInfo // documentID -> WindowInfo
managedWindows map[int64]*WindowInfo // documentID -> WindowInfo
windowRefs map[int64]*application.WebviewWindow // documentID -> Window引用
// 窗口尺寸缓存
@@ -81,7 +116,7 @@ func NewWindowSnapService(logger *log.LogService, configService *ConfigService)
baseThresholdRatio: 0.025, // 2.5%的主窗口宽度作为基础阈值
minThreshold: 8, // 最小8像素小屏幕保底
maxThreshold: 40, // 最大40像素大屏幕上限
managedWindows: make(map[int64]*models.WindowInfo),
managedWindows: make(map[int64]*WindowInfo),
windowRefs: make(map[int64]*application.WebviewWindow),
windowSizeCache: make(map[int64][2]int),
isUpdatingPosition: make(map[int64]bool),
@@ -114,12 +149,12 @@ func (wss *WindowSnapService) RegisterWindow(documentID int64, window *applicati
// 获取初始位置
x, y := window.Position()
windowInfo := &models.WindowInfo{
windowInfo := &WindowInfo{
DocumentID: documentID,
IsSnapped: false,
SnapOffset: models.SnapPosition{X: 0, Y: 0},
SnapEdge: models.SnapEdgeNone,
LastPos: models.WindowPosition{X: x, Y: y},
SnapOffset: SnapPosition{X: 0, Y: 0},
SnapEdge: SnapEdgeNone,
LastPos: WindowPosition{X: x, Y: y},
MoveTime: time.Now(),
}
@@ -176,7 +211,7 @@ func (wss *WindowSnapService) SetSnapEnabled(enabled bool) {
for _, windowInfo := range wss.managedWindows {
if windowInfo.IsSnapped {
windowInfo.IsSnapped = false
windowInfo.SnapEdge = models.SnapEdgeNone
windowInfo.SnapEdge = SnapEdgeNone
}
}
}
@@ -251,7 +286,7 @@ func (wss *WindowSnapService) cleanupMainWindowEvents() {
}
// setupWindowEvents 为子窗口设置事件监听
func (wss *WindowSnapService) setupWindowEvents(window *application.WebviewWindow, windowInfo *models.WindowInfo) {
func (wss *WindowSnapService) setupWindowEvents(window *application.WebviewWindow, windowInfo *WindowInfo) {
// 监听子窗口移动事件,保存清理函数
unhook := window.RegisterHook(events.Common.WindowDidMove, func(event *application.WindowEvent) {
wss.onChildWindowMoved(window, windowInfo)
@@ -274,7 +309,7 @@ func (wss *WindowSnapService) updateMainWindowCacheLocked() {
w, h := mainWindow.Size()
wss.mu.Lock()
wss.lastMainWindowPos = models.WindowPosition{X: x, Y: y}
wss.lastMainWindowPos = WindowPosition{X: x, Y: y}
wss.lastMainWindowSize = [2]int{w, h}
}
@@ -335,7 +370,7 @@ func (wss *WindowSnapService) onMainWindowMoved() {
defer wss.mu.Unlock()
// 更新主窗口位置和尺寸缓存
wss.lastMainWindowPos = models.WindowPosition{X: x, Y: y}
wss.lastMainWindowPos = WindowPosition{X: x, Y: y}
wss.lastMainWindowSize = [2]int{w, h}
// 只更新已吸附窗口的位置,无需重新检测所有窗口
@@ -347,7 +382,7 @@ func (wss *WindowSnapService) onMainWindowMoved() {
}
// onChildWindowMoved 子窗口移动事件处理
func (wss *WindowSnapService) onChildWindowMoved(window *application.WebviewWindow, windowInfo *models.WindowInfo) {
func (wss *WindowSnapService) onChildWindowMoved(window *application.WebviewWindow, windowInfo *WindowInfo) {
if !wss.snapEnabled {
return
}
@@ -361,7 +396,7 @@ func (wss *WindowSnapService) onChildWindowMoved(window *application.WebviewWind
wss.mu.Unlock()
x, y := window.Position()
currentPos := models.WindowPosition{X: x, Y: y}
currentPos := WindowPosition{X: x, Y: y}
wss.mu.Lock()
defer wss.mu.Unlock()
@@ -392,7 +427,7 @@ func (wss *WindowSnapService) onChildWindowMoved(window *application.WebviewWind
}
// updateSnappedWindowPosition 更新已吸附窗口的位置
func (wss *WindowSnapService) updateSnappedWindowPosition(windowInfo *models.WindowInfo) {
func (wss *WindowSnapService) updateSnappedWindowPosition(windowInfo *WindowInfo) {
// 计算新的目标位置(基于主窗口新位置)
expectedX := wss.lastMainWindowPos.X + windowInfo.SnapOffset.X
expectedY := wss.lastMainWindowPos.Y + windowInfo.SnapOffset.Y
@@ -413,11 +448,11 @@ func (wss *WindowSnapService) updateSnappedWindowPosition(windowInfo *models.Win
// 清除更新标志
wss.isUpdatingPosition[windowInfo.DocumentID] = false
windowInfo.LastPos = models.WindowPosition{X: expectedX, Y: expectedY}
windowInfo.LastPos = WindowPosition{X: expectedX, Y: expectedY}
}
// handleSnappedWindow 处理已吸附窗口的移动
func (wss *WindowSnapService) handleSnappedWindow(window *application.WebviewWindow, windowInfo *models.WindowInfo, currentPos models.WindowPosition) {
func (wss *WindowSnapService) handleSnappedWindow(window *application.WebviewWindow, windowInfo *WindowInfo, currentPos WindowPosition) {
// 计算预期位置
expectedX := wss.lastMainWindowPos.X + windowInfo.SnapOffset.X
expectedY := wss.lastMainWindowPos.Y + windowInfo.SnapOffset.Y
@@ -434,12 +469,12 @@ func (wss *WindowSnapService) handleSnappedWindow(window *application.WebviewWin
if isUserDrag {
// 用户主动拖拽,解除吸附
windowInfo.IsSnapped = false
windowInfo.SnapEdge = models.SnapEdgeNone
windowInfo.SnapEdge = SnapEdgeNone
}
}
// handleUnsnappedWindow 处理未吸附窗口的移动,返回是否成功吸附
func (wss *WindowSnapService) handleUnsnappedWindow(window *application.WebviewWindow, windowInfo *models.WindowInfo, currentPos models.WindowPosition, lastMoveTime time.Time) bool {
func (wss *WindowSnapService) handleUnsnappedWindow(window *application.WebviewWindow, windowInfo *WindowInfo, currentPos WindowPosition, lastMoveTime time.Time) bool {
// 检查是否应该吸附
should, snapEdge := wss.shouldSnapToMainWindow(window, windowInfo, currentPos, lastMoveTime)
if should {
@@ -474,11 +509,11 @@ func (wss *WindowSnapService) handleUnsnappedWindow(window *application.WebviewW
}
// shouldSnapToMainWindow 吸附检测
func (wss *WindowSnapService) shouldSnapToMainWindow(window *application.WebviewWindow, windowInfo *models.WindowInfo, currentPos models.WindowPosition, lastMoveTime time.Time) (bool, models.SnapEdge) {
func (wss *WindowSnapService) shouldSnapToMainWindow(window *application.WebviewWindow, windowInfo *WindowInfo, currentPos WindowPosition, lastMoveTime time.Time) (bool, SnapEdge) {
// 防抖:移动太快时不检测(使用统一的防抖阈值)
timeSinceLastMove := time.Since(lastMoveTime)
if timeSinceLastMove < debounceThreshold {
return false, models.SnapEdgeNone
return false, SnapEdgeNone
}
// 使用缓存的主窗口位置和尺寸
@@ -507,7 +542,7 @@ func (wss *WindowSnapService) shouldSnapToMainWindow(window *application.Webview
// 简化的距离计算结构
type snapCheck struct {
edge models.SnapEdge
edge SnapEdge
distance float64
priority int // 1=角落, 2=边缘
}
@@ -516,14 +551,14 @@ func (wss *WindowSnapService) shouldSnapToMainWindow(window *application.Webview
// 检查角落吸附优先级1
cornerChecks := []struct {
edge models.SnapEdge
edge SnapEdge
dx int
dy int
}{
{models.SnapEdgeTopRight, mainRight - windowLeft, mainTop - windowBottom},
{models.SnapEdgeBottomRight, mainRight - windowLeft, mainBottom - windowTop},
{models.SnapEdgeBottomLeft, mainLeft - windowRight, mainBottom - windowTop},
{models.SnapEdgeTopLeft, mainLeft - windowRight, mainTop - windowBottom},
{SnapEdgeTopRight, mainRight - windowLeft, mainTop - windowBottom},
{SnapEdgeBottomRight, mainRight - windowLeft, mainBottom - windowTop},
{SnapEdgeBottomLeft, mainLeft - windowRight, mainBottom - windowTop},
{SnapEdgeTopLeft, mainLeft - windowRight, mainTop - windowBottom},
}
for _, check := range cornerChecks {
@@ -538,13 +573,13 @@ func (wss *WindowSnapService) shouldSnapToMainWindow(window *application.Webview
// 如果没有角落吸附检查边缘吸附优先级2
if bestSnap == nil {
edgeChecks := []struct {
edge models.SnapEdge
edge SnapEdge
distance float64
}{
{models.SnapEdgeRight, math.Abs(float64(mainRight - windowLeft))},
{models.SnapEdgeLeft, math.Abs(float64(mainLeft - windowRight))},
{models.SnapEdgeBottom, math.Abs(float64(mainBottom - windowTop))},
{models.SnapEdgeTop, math.Abs(float64(mainTop - windowBottom))},
{SnapEdgeRight, math.Abs(float64(mainRight - windowLeft))},
{SnapEdgeLeft, math.Abs(float64(mainLeft - windowRight))},
{SnapEdgeBottom, math.Abs(float64(mainBottom - windowTop))},
{SnapEdgeTop, math.Abs(float64(mainTop - windowBottom))},
}
for _, check := range edgeChecks {
@@ -557,14 +592,14 @@ func (wss *WindowSnapService) shouldSnapToMainWindow(window *application.Webview
}
if bestSnap == nil {
return false, models.SnapEdgeNone
return false, SnapEdgeNone
}
return true, bestSnap.edge
}
// calculateSnapPosition 计算吸附目标位置
func (wss *WindowSnapService) calculateSnapPosition(snapEdge models.SnapEdge, currentPos models.WindowPosition, documentID int64, window *application.WebviewWindow) models.WindowPosition {
func (wss *WindowSnapService) calculateSnapPosition(snapEdge SnapEdge, currentPos WindowPosition, documentID int64, window *application.WebviewWindow) WindowPosition {
// 使用缓存的主窗口信息
mainPos := wss.lastMainWindowPos
mainWidth := wss.lastMainWindowSize[0]
@@ -574,43 +609,43 @@ func (wss *WindowSnapService) calculateSnapPosition(snapEdge models.SnapEdge, cu
windowWidth, windowHeight := wss.getWindowSizeCached(documentID, window)
switch snapEdge {
case models.SnapEdgeRight:
return models.WindowPosition{
case SnapEdgeRight:
return WindowPosition{
X: mainPos.X + mainWidth,
Y: currentPos.Y, // 保持当前Y位置
}
case models.SnapEdgeLeft:
return models.WindowPosition{
case SnapEdgeLeft:
return WindowPosition{
X: mainPos.X - windowWidth,
Y: currentPos.Y,
}
case models.SnapEdgeBottom:
return models.WindowPosition{
case SnapEdgeBottom:
return WindowPosition{
X: currentPos.X,
Y: mainPos.Y + mainHeight,
}
case models.SnapEdgeTop:
return models.WindowPosition{
case SnapEdgeTop:
return WindowPosition{
X: currentPos.X,
Y: mainPos.Y - windowHeight,
}
case models.SnapEdgeTopRight:
return models.WindowPosition{
case SnapEdgeTopRight:
return WindowPosition{
X: mainPos.X + mainWidth,
Y: mainPos.Y - windowHeight,
}
case models.SnapEdgeBottomRight:
return models.WindowPosition{
case SnapEdgeBottomRight:
return WindowPosition{
X: mainPos.X + mainWidth,
Y: mainPos.Y + mainHeight,
}
case models.SnapEdgeBottomLeft:
return models.WindowPosition{
case SnapEdgeBottomLeft:
return WindowPosition{
X: mainPos.X - windowWidth,
Y: mainPos.Y + mainHeight,
}
case models.SnapEdgeTopLeft:
return models.WindowPosition{
case SnapEdgeTopLeft:
return WindowPosition{
X: mainPos.X - windowWidth,
Y: mainPos.Y - windowHeight,
}
@@ -636,7 +671,7 @@ func (wss *WindowSnapService) Cleanup() {
}
// 清空管理的窗口
wss.managedWindows = make(map[int64]*models.WindowInfo)
wss.managedWindows = make(map[int64]*WindowInfo)
wss.windowRefs = make(map[int64]*application.WebviewWindow)
wss.windowSizeCache = make(map[int64][2]int)
wss.isUpdatingPosition = make(map[int64]bool)