375 lines
9.6 KiB
Go
375 lines
9.6 KiB
Go
package services
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"fmt"
|
||
"os"
|
||
"path/filepath"
|
||
"reflect"
|
||
"sync"
|
||
"time"
|
||
"voidraft/internal/models"
|
||
|
||
"github.com/wailsapp/wails/v3/pkg/application"
|
||
"github.com/wailsapp/wails/v3/pkg/services/log"
|
||
_ "modernc.org/sqlite"
|
||
)
|
||
|
||
const (
|
||
dbName = "voidraft.db"
|
||
|
||
// SQLite performance optimization settings
|
||
sqlOptimizationSettings = `
|
||
PRAGMA journal_mode = WAL;
|
||
PRAGMA synchronous = NORMAL;
|
||
PRAGMA cache_size = -64000;
|
||
PRAGMA temp_store = MEMORY;
|
||
PRAGMA foreign_keys = ON;`
|
||
|
||
// Documents table
|
||
sqlCreateDocumentsTable = `
|
||
CREATE TABLE IF NOT EXISTS documents (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
title TEXT NOT NULL,
|
||
content TEXT DEFAULT '∞∞∞text-a',
|
||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||
is_deleted INTEGER DEFAULT 0,
|
||
is_locked INTEGER DEFAULT 0
|
||
)`
|
||
|
||
// Extensions table
|
||
sqlCreateExtensionsTable = `
|
||
CREATE TABLE IF NOT EXISTS extensions (
|
||
id TEXT PRIMARY KEY,
|
||
enabled INTEGER NOT NULL DEFAULT 1,
|
||
is_default INTEGER NOT NULL DEFAULT 0,
|
||
config TEXT DEFAULT '{}',
|
||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||
)`
|
||
|
||
// Key bindings table
|
||
sqlCreateKeyBindingsTable = `
|
||
CREATE TABLE IF NOT EXISTS key_bindings (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
command TEXT NOT NULL,
|
||
extension TEXT NOT NULL,
|
||
key TEXT NOT NULL,
|
||
enabled INTEGER NOT NULL DEFAULT 1,
|
||
is_default INTEGER NOT NULL DEFAULT 0,
|
||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||
UNIQUE(command, extension)
|
||
)`
|
||
)
|
||
|
||
// ColumnInfo 存储列的信息
|
||
type ColumnInfo struct {
|
||
SQLType string
|
||
DefaultValue string
|
||
}
|
||
|
||
// TableModel 表示表与模型之间的映射关系
|
||
type TableModel struct {
|
||
TableName string
|
||
Model interface{}
|
||
}
|
||
|
||
// DatabaseService provides shared database functionality
|
||
type DatabaseService struct {
|
||
configService *ConfigService
|
||
logger *log.Service
|
||
db *sql.DB
|
||
mu sync.RWMutex
|
||
ctx context.Context
|
||
tableModels []TableModel // 注册的表模型
|
||
}
|
||
|
||
// NewDatabaseService creates a new database service
|
||
func NewDatabaseService(configService *ConfigService, logger *log.Service) *DatabaseService {
|
||
if logger == nil {
|
||
logger = log.New()
|
||
}
|
||
|
||
ds := &DatabaseService{
|
||
configService: configService,
|
||
logger: logger,
|
||
}
|
||
|
||
// 注册所有模型
|
||
ds.registerAllModels()
|
||
|
||
return ds
|
||
}
|
||
|
||
// registerAllModels 注册所有数据模型,集中管理表-模型映射
|
||
func (ds *DatabaseService) registerAllModels() {
|
||
// 文档表
|
||
ds.RegisterModel("documents", &models.Document{})
|
||
// 扩展表
|
||
ds.RegisterModel("extensions", &models.Extension{})
|
||
// 快捷键表
|
||
ds.RegisterModel("key_bindings", &models.KeyBinding{})
|
||
}
|
||
|
||
// ServiceStartup initializes the service when the application starts
|
||
func (ds *DatabaseService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error {
|
||
ds.ctx = ctx
|
||
return ds.initDatabase()
|
||
}
|
||
|
||
// initDatabase initializes the SQLite database
|
||
func (ds *DatabaseService) initDatabase() error {
|
||
dbPath, err := ds.getDatabasePath()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to get database path: %w", err)
|
||
}
|
||
|
||
// 确保数据库目录存在
|
||
dbDir := filepath.Dir(dbPath)
|
||
if err := os.MkdirAll(dbDir, 0755); err != nil {
|
||
return fmt.Errorf("failed to create database directory: %w", err)
|
||
}
|
||
|
||
// 打开数据库连接
|
||
ds.db, err = sql.Open("sqlite", dbPath)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to open database: %w", err)
|
||
}
|
||
|
||
// 测试连接
|
||
if err := ds.db.Ping(); err != nil {
|
||
return fmt.Errorf("failed to ping database: %w", err)
|
||
}
|
||
|
||
// 应用性能优化设置
|
||
if _, err := ds.db.Exec(sqlOptimizationSettings); err != nil {
|
||
return fmt.Errorf("failed to apply optimization settings: %w", err)
|
||
}
|
||
|
||
// 创建表和索引
|
||
if err := ds.createTables(); err != nil {
|
||
return fmt.Errorf("failed to create tables: %w", err)
|
||
}
|
||
|
||
if err := ds.createIndexes(); err != nil {
|
||
return fmt.Errorf("failed to create indexes: %w", err)
|
||
}
|
||
|
||
// 执行模型与表结构同步
|
||
if err := ds.syncAllModelTables(); err != nil {
|
||
return fmt.Errorf("failed to sync model tables: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// getDatabasePath gets the database file path
|
||
func (ds *DatabaseService) getDatabasePath() (string, error) {
|
||
config, err := ds.configService.GetConfig()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return filepath.Join(config.General.DataPath, dbName), nil
|
||
}
|
||
|
||
// createTables creates all database tables
|
||
func (ds *DatabaseService) createTables() error {
|
||
tables := []string{
|
||
sqlCreateDocumentsTable,
|
||
sqlCreateExtensionsTable,
|
||
sqlCreateKeyBindingsTable,
|
||
}
|
||
|
||
for _, table := range tables {
|
||
if _, err := ds.db.Exec(table); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// createIndexes creates database indexes
|
||
func (ds *DatabaseService) createIndexes() error {
|
||
indexes := []string{
|
||
// Documents indexes
|
||
`CREATE INDEX IF NOT EXISTS idx_documents_updated_at ON documents(updated_at DESC)`,
|
||
`CREATE INDEX IF NOT EXISTS idx_documents_title ON documents(title)`,
|
||
`CREATE INDEX IF NOT EXISTS idx_documents_is_deleted ON documents(is_deleted)`,
|
||
// Extensions indexes
|
||
`CREATE INDEX IF NOT EXISTS idx_extensions_enabled ON extensions(enabled)`,
|
||
// Key bindings indexes
|
||
`CREATE INDEX IF NOT EXISTS idx_key_bindings_command ON key_bindings(command)`,
|
||
`CREATE INDEX IF NOT EXISTS idx_key_bindings_extension ON key_bindings(extension)`,
|
||
`CREATE INDEX IF NOT EXISTS idx_key_bindings_enabled ON key_bindings(enabled)`,
|
||
}
|
||
|
||
for _, index := range indexes {
|
||
if _, err := ds.db.Exec(index); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// RegisterModel 注册模型与表的映射关系
|
||
func (ds *DatabaseService) RegisterModel(tableName string, model interface{}) {
|
||
ds.mu.Lock()
|
||
defer ds.mu.Unlock()
|
||
|
||
ds.tableModels = append(ds.tableModels, TableModel{
|
||
TableName: tableName,
|
||
Model: model,
|
||
})
|
||
}
|
||
|
||
// syncAllModelTables 同步所有注册的模型与表结构
|
||
func (ds *DatabaseService) syncAllModelTables() error {
|
||
for _, tm := range ds.tableModels {
|
||
if err := ds.syncModelTable(tm.TableName, tm.Model); err != nil {
|
||
return fmt.Errorf("failed to sync table %s: %w", tm.TableName, err)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// syncModelTable 同步模型与表结构
|
||
func (ds *DatabaseService) syncModelTable(tableName string, model interface{}) error {
|
||
// 获取表结构元数据
|
||
columns, err := ds.getTableColumns(tableName)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to get table columns: %w", err)
|
||
}
|
||
|
||
// 使用反射从模型中提取字段信息
|
||
expectedColumns, err := ds.getModelColumns(model)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to get model columns: %w", err)
|
||
}
|
||
|
||
// 检查缺失的列并添加
|
||
for colName, colInfo := range expectedColumns {
|
||
if _, exists := columns[colName]; !exists {
|
||
// 执行添加列的SQL
|
||
alterSQL := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s DEFAULT %s",
|
||
tableName, colName, colInfo.SQLType, colInfo.DefaultValue)
|
||
if _, err := ds.db.Exec(alterSQL); err != nil {
|
||
return fmt.Errorf("failed to add column %s: %w", colName, err)
|
||
}
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// getTableColumns 获取表的列信息
|
||
func (ds *DatabaseService) getTableColumns(table string) (map[string]string, error) {
|
||
query := fmt.Sprintf("PRAGMA table_info(%s)", table)
|
||
rows, err := ds.db.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
columns := make(map[string]string)
|
||
for rows.Next() {
|
||
var cid int
|
||
var name, typeName string
|
||
var notNull, pk int
|
||
var dflt_value interface{}
|
||
|
||
if err := rows.Scan(&cid, &name, &typeName, ¬Null, &dflt_value, &pk); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
columns[name] = typeName
|
||
}
|
||
|
||
if err := rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return columns, nil
|
||
}
|
||
|
||
// getModelColumns 从模型结构体中提取数据库列信息
|
||
func (ds *DatabaseService) getModelColumns(model interface{}) (map[string]ColumnInfo, error) {
|
||
columns := make(map[string]ColumnInfo)
|
||
|
||
// 使用反射获取结构体的类型信息
|
||
t := reflect.TypeOf(model)
|
||
if t.Kind() == reflect.Ptr {
|
||
t = t.Elem()
|
||
}
|
||
|
||
if t.Kind() != reflect.Struct {
|
||
return nil, fmt.Errorf("model must be a struct or a pointer to struct")
|
||
}
|
||
|
||
// 遍历所有字段
|
||
for i := 0; i < t.NumField(); i++ {
|
||
field := t.Field(i)
|
||
|
||
// 只处理有db标签的字段
|
||
dbTag := field.Tag.Get("db")
|
||
if dbTag == "" {
|
||
// 如果没有db标签,跳过该字段
|
||
continue
|
||
}
|
||
|
||
// 获取字段类型对应的SQL类型和默认值
|
||
sqlType, defaultVal := getSQLTypeAndDefault(field.Type)
|
||
|
||
columns[dbTag] = ColumnInfo{
|
||
SQLType: sqlType,
|
||
DefaultValue: defaultVal,
|
||
}
|
||
}
|
||
|
||
return columns, nil
|
||
}
|
||
|
||
// getSQLTypeAndDefault 根据Go类型获取对应的SQL类型和默认值
|
||
func getSQLTypeAndDefault(t reflect.Type) (string, string) {
|
||
switch t.Kind() {
|
||
case reflect.Bool:
|
||
return "INTEGER", "0"
|
||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||
return "INTEGER", "0"
|
||
case reflect.Float32, reflect.Float64:
|
||
return "REAL", "0.0"
|
||
case reflect.String:
|
||
return "TEXT", "''"
|
||
default:
|
||
// 处理特殊类型
|
||
if t == reflect.TypeOf(time.Time{}) {
|
||
return "DATETIME", "CURRENT_TIMESTAMP"
|
||
}
|
||
return "TEXT", "NULL"
|
||
}
|
||
}
|
||
|
||
// ServiceShutdown shuts down the service when the application closes
|
||
func (ds *DatabaseService) ServiceShutdown() error {
|
||
if ds.db != nil {
|
||
return ds.db.Close()
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// OnDataPathChanged handles data path changes
|
||
func (ds *DatabaseService) OnDataPathChanged() error {
|
||
// 关闭当前连接
|
||
if ds.db != nil {
|
||
if err := ds.db.Close(); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
// 用新路径重新初始化
|
||
return ds.initDatabase()
|
||
}
|