package services import ( "context" "database/sql" "fmt" "os" "path/filepath" "reflect" "sync" "time" "voidraft/internal/models" "github.com/wailsapp/wails/v3/pkg/application" "github.com/wailsapp/wails/v3/pkg/services/log" _ "modernc.org/sqlite" ) const ( dbName = "voidraft.db" // SQLite performance optimization settings sqlOptimizationSettings = ` PRAGMA journal_mode = WAL; PRAGMA synchronous = NORMAL; PRAGMA cache_size = -64000; PRAGMA temp_store = MEMORY; PRAGMA foreign_keys = ON;` // Documents table sqlCreateDocumentsTable = ` CREATE TABLE IF NOT EXISTS documents ( id INTEGER PRIMARY KEY AUTOINCREMENT, title TEXT NOT NULL, content TEXT DEFAULT '∞∞∞text-a', created_at DATETIME DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, is_deleted INTEGER DEFAULT 0, is_locked INTEGER DEFAULT 0 )` // Extensions table sqlCreateExtensionsTable = ` CREATE TABLE IF NOT EXISTS extensions ( id TEXT PRIMARY KEY, enabled INTEGER NOT NULL DEFAULT 1, is_default INTEGER NOT NULL DEFAULT 0, config TEXT DEFAULT '{}', created_at DATETIME DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME DEFAULT CURRENT_TIMESTAMP )` // Key bindings table sqlCreateKeyBindingsTable = ` CREATE TABLE IF NOT EXISTS key_bindings ( id INTEGER PRIMARY KEY AUTOINCREMENT, command TEXT NOT NULL, extension TEXT NOT NULL, key TEXT NOT NULL, enabled INTEGER NOT NULL DEFAULT 1, is_default INTEGER NOT NULL DEFAULT 0, created_at DATETIME DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, UNIQUE(command, extension) )` ) // ColumnInfo 存储列的信息 type ColumnInfo struct { SQLType string DefaultValue string } // TableModel 表示表与模型之间的映射关系 type TableModel struct { TableName string Model interface{} } // DatabaseService provides shared database functionality type DatabaseService struct { configService *ConfigService logger *log.Service db *sql.DB mu sync.RWMutex ctx context.Context tableModels []TableModel // 注册的表模型 } // NewDatabaseService creates a new database service func NewDatabaseService(configService *ConfigService, logger *log.Service) *DatabaseService { if logger == nil { logger = log.New() } ds := &DatabaseService{ configService: configService, logger: logger, } // 注册所有模型 ds.registerAllModels() return ds } // registerAllModels 注册所有数据模型,集中管理表-模型映射 func (ds *DatabaseService) registerAllModels() { // 文档表 ds.RegisterModel("documents", &models.Document{}) // 扩展表 ds.RegisterModel("extensions", &models.Extension{}) // 快捷键表 ds.RegisterModel("key_bindings", &models.KeyBinding{}) } // ServiceStartup initializes the service when the application starts func (ds *DatabaseService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error { ds.ctx = ctx return ds.initDatabase() } // initDatabase initializes the SQLite database func (ds *DatabaseService) initDatabase() error { dbPath, err := ds.getDatabasePath() if err != nil { return fmt.Errorf("failed to get database path: %w", err) } // 确保数据库目录存在 dbDir := filepath.Dir(dbPath) if err := os.MkdirAll(dbDir, 0755); err != nil { return fmt.Errorf("failed to create database directory: %w", err) } // 打开数据库连接 ds.db, err = sql.Open("sqlite", dbPath) if err != nil { return fmt.Errorf("failed to open database: %w", err) } // 测试连接 if err := ds.db.Ping(); err != nil { return fmt.Errorf("failed to ping database: %w", err) } // 应用性能优化设置 if _, err := ds.db.Exec(sqlOptimizationSettings); err != nil { return fmt.Errorf("failed to apply optimization settings: %w", err) } // 创建表和索引 if err := ds.createTables(); err != nil { return fmt.Errorf("failed to create tables: %w", err) } if err := ds.createIndexes(); err != nil { return fmt.Errorf("failed to create indexes: %w", err) } // 执行模型与表结构同步 if err := ds.syncAllModelTables(); err != nil { return fmt.Errorf("failed to sync model tables: %w", err) } return nil } // getDatabasePath gets the database file path func (ds *DatabaseService) getDatabasePath() (string, error) { config, err := ds.configService.GetConfig() if err != nil { return "", err } return filepath.Join(config.General.DataPath, dbName), nil } // createTables creates all database tables func (ds *DatabaseService) createTables() error { tables := []string{ sqlCreateDocumentsTable, sqlCreateExtensionsTable, sqlCreateKeyBindingsTable, } for _, table := range tables { if _, err := ds.db.Exec(table); err != nil { return err } } return nil } // createIndexes creates database indexes func (ds *DatabaseService) createIndexes() error { indexes := []string{ // Documents indexes `CREATE INDEX IF NOT EXISTS idx_documents_updated_at ON documents(updated_at DESC)`, `CREATE INDEX IF NOT EXISTS idx_documents_title ON documents(title)`, `CREATE INDEX IF NOT EXISTS idx_documents_is_deleted ON documents(is_deleted)`, // Extensions indexes `CREATE INDEX IF NOT EXISTS idx_extensions_enabled ON extensions(enabled)`, // Key bindings indexes `CREATE INDEX IF NOT EXISTS idx_key_bindings_command ON key_bindings(command)`, `CREATE INDEX IF NOT EXISTS idx_key_bindings_extension ON key_bindings(extension)`, `CREATE INDEX IF NOT EXISTS idx_key_bindings_enabled ON key_bindings(enabled)`, } for _, index := range indexes { if _, err := ds.db.Exec(index); err != nil { return err } } return nil } // RegisterModel 注册模型与表的映射关系 func (ds *DatabaseService) RegisterModel(tableName string, model interface{}) { ds.mu.Lock() defer ds.mu.Unlock() ds.tableModels = append(ds.tableModels, TableModel{ TableName: tableName, Model: model, }) } // syncAllModelTables 同步所有注册的模型与表结构 func (ds *DatabaseService) syncAllModelTables() error { for _, tm := range ds.tableModels { if err := ds.syncModelTable(tm.TableName, tm.Model); err != nil { return fmt.Errorf("failed to sync table %s: %w", tm.TableName, err) } } return nil } // syncModelTable 同步模型与表结构 func (ds *DatabaseService) syncModelTable(tableName string, model interface{}) error { // 获取表结构元数据 columns, err := ds.getTableColumns(tableName) if err != nil { return fmt.Errorf("failed to get table columns: %w", err) } // 使用反射从模型中提取字段信息 expectedColumns, err := ds.getModelColumns(model) if err != nil { return fmt.Errorf("failed to get model columns: %w", err) } // 检查缺失的列并添加 for colName, colInfo := range expectedColumns { if _, exists := columns[colName]; !exists { // 执行添加列的SQL alterSQL := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s DEFAULT %s", tableName, colName, colInfo.SQLType, colInfo.DefaultValue) if _, err := ds.db.Exec(alterSQL); err != nil { return fmt.Errorf("failed to add column %s: %w", colName, err) } } } return nil } // getTableColumns 获取表的列信息 func (ds *DatabaseService) getTableColumns(table string) (map[string]string, error) { query := fmt.Sprintf("PRAGMA table_info(%s)", table) rows, err := ds.db.Query(query) if err != nil { return nil, err } defer rows.Close() columns := make(map[string]string) for rows.Next() { var cid int var name, typeName string var notNull, pk int var dflt_value interface{} if err := rows.Scan(&cid, &name, &typeName, ¬Null, &dflt_value, &pk); err != nil { return nil, err } columns[name] = typeName } if err := rows.Err(); err != nil { return nil, err } return columns, nil } // getModelColumns 从模型结构体中提取数据库列信息 func (ds *DatabaseService) getModelColumns(model interface{}) (map[string]ColumnInfo, error) { columns := make(map[string]ColumnInfo) // 使用反射获取结构体的类型信息 t := reflect.TypeOf(model) if t.Kind() == reflect.Ptr { t = t.Elem() } if t.Kind() != reflect.Struct { return nil, fmt.Errorf("model must be a struct or a pointer to struct") } // 遍历所有字段 for i := 0; i < t.NumField(); i++ { field := t.Field(i) // 只处理有db标签的字段 dbTag := field.Tag.Get("db") if dbTag == "" { // 如果没有db标签,跳过该字段 continue } // 获取字段类型对应的SQL类型和默认值 sqlType, defaultVal := getSQLTypeAndDefault(field.Type) columns[dbTag] = ColumnInfo{ SQLType: sqlType, DefaultValue: defaultVal, } } return columns, nil } // getSQLTypeAndDefault 根据Go类型获取对应的SQL类型和默认值 func getSQLTypeAndDefault(t reflect.Type) (string, string) { switch t.Kind() { case reflect.Bool: return "INTEGER", "0" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return "INTEGER", "0" case reflect.Float32, reflect.Float64: return "REAL", "0.0" case reflect.String: return "TEXT", "''" default: // 处理特殊类型 if t == reflect.TypeOf(time.Time{}) { return "DATETIME", "CURRENT_TIMESTAMP" } return "TEXT", "NULL" } } // ServiceShutdown shuts down the service when the application closes func (ds *DatabaseService) ServiceShutdown() error { if ds.db != nil { return ds.db.Close() } return nil } // OnDataPathChanged handles data path changes func (ds *DatabaseService) OnDataPathChanged() error { // 关闭当前连接 if ds.db != nil { if err := ds.db.Close(); err != nil { return err } } // 用新路径重新初始化 return ds.initDatabase() }