package services import ( "context" "fmt" "voidraft/internal/models" "voidraft/internal/models/ent" "voidraft/internal/models/ent/extension" "voidraft/internal/models/ent/keybinding" "voidraft/internal/models/schema/mixin" "github.com/wailsapp/wails/v3/pkg/application" "github.com/wailsapp/wails/v3/pkg/services/log" ) // ExtensionService 扩展服务 type ExtensionService struct { db *DatabaseService logger *log.LogService } // NewExtensionService 创建扩展服务 func NewExtensionService(db *DatabaseService, logger *log.LogService) *ExtensionService { if logger == nil { logger = log.New() } return &ExtensionService{db: db, logger: logger} } // ServiceStartup 服务启动 func (s *ExtensionService) ServiceStartup(ctx context.Context, options application.ServiceOptions) error { return s.SyncExtensions(ctx) } // SyncExtensions 同步扩展配置 func (s *ExtensionService) SyncExtensions(ctx context.Context) error { defaults := models.NewDefaultExtensions() definedKeys := make(map[models.ExtensionName]models.Extension) for _, ext := range defaults { definedKeys[ext.Name] = ext } // 获取数据库中已有的扩展 existing, err := s.db.Client.Extension.Query().All(ctx) if err != nil { return fmt.Errorf("find extensions error: %w", err) } existingKeys := make(map[string]bool) for _, ext := range existing { existingKeys[ext.Name] = true } // 批量添加缺失的扩展 var builders []*ent.ExtensionCreate for key, ext := range definedKeys { if !existingKeys[string(key)] { builders = append(builders, s.db.Client.Extension.Create(). SetName(string(ext.Name)). SetEnabled(ext.Enabled). SetConfig(ext.Config)) } } if len(builders) > 0 { if _, err := s.db.Client.Extension.CreateBulk(builders...).Save(ctx); err != nil { return fmt.Errorf("bulk insert extensions error: %w", err) } } // 批量删除废弃的扩展 var deleteIDs []int for _, ext := range existing { if _, ok := definedKeys[models.ExtensionName(ext.Name)]; !ok { deleteIDs = append(deleteIDs, ext.ID) } } if len(deleteIDs) > 0 { if _, err := s.db.Client.Extension.Delete(). Where(extension.IDIn(deleteIDs...)). Exec(mixin.SkipSoftDelete(ctx)); err != nil { return fmt.Errorf("bulk delete extensions error: %w", err) } } return nil } // GetExtensions 获取所有扩展 func (s *ExtensionService) GetExtensions(ctx context.Context) ([]*ent.Extension, error) { return s.db.Client.Extension.Query().All(ctx) } // GetExtensionByID 根据ID获取扩展 func (s *ExtensionService) GetExtensionByID(ctx context.Context, id int) (*ent.Extension, error) { ext, err := s.db.Client.Extension.Query(). Where(extension.ID(id)). Only(ctx) if err != nil { if ent.IsNotFound(err) { return nil, nil } return nil, fmt.Errorf("get extension error: %w", err) } return ext, nil } // UpdateExtensionEnabled 更新扩展启用状态 func (s *ExtensionService) UpdateExtensionEnabled(ctx context.Context, id int, enabled bool) error { ext, err := s.GetExtensionByID(ctx, id) if err != nil { return err } if ext == nil { return fmt.Errorf("extension not found: %d", id) } // 更新扩展状态 if err := s.db.Client.Extension.UpdateOneID(ext.ID). SetEnabled(enabled). Exec(ctx); err != nil { return err } // 同步更新该扩展关联的快捷键启用状态 if _, err := s.db.Client.KeyBinding.Update(). Where(keybinding.Extension(ext.Name)). SetEnabled(enabled). Save(ctx); err != nil { return fmt.Errorf("update keybindings for extension %s error: %w", ext.Name, err) } return nil } // UpdateExtensionConfig 更新扩展配置 func (s *ExtensionService) UpdateExtensionConfig(ctx context.Context, id int, config map[string]interface{}) error { ext, err := s.GetExtensionByID(ctx, id) if err != nil { return err } if ext == nil { return fmt.Errorf("extension not found: %d", id) } return s.db.Client.Extension.UpdateOneID(ext.ID). SetConfig(config). Exec(ctx) } // ResetExtensionConfig 重置单个扩展到默认状态 func (s *ExtensionService) ResetExtensionConfig(ctx context.Context, id int) error { ext, err := s.GetExtensionByID(ctx, id) if err != nil { return err } if ext == nil { return fmt.Errorf("extension not found: %d", id) } defaults := models.NewDefaultExtensions() var defaultExt *models.Extension for _, defExt := range defaults { if string(defExt.Name) == ext.Name { defaultExt = &defExt break } } if defaultExt == nil { return fmt.Errorf("default extension not found: %s", ext.Name) } return s.db.Client.Extension.UpdateOneID(ext.ID). SetEnabled(defaultExt.Enabled). SetConfig(defaultExt.Config). Exec(ctx) } // GetExtensionConfig 获取扩展配置 func (s *ExtensionService) GetExtensionConfig(ctx context.Context, id int) (map[string]interface{}, error) { ext, err := s.GetExtensionByID(ctx, id) if err != nil { return nil, err } if ext == nil { return nil, fmt.Errorf("extension not found: %d", id) } if ext.Config == nil { return make(map[string]interface{}), nil } return ext.Config, nil } // GetDefaultExtensions 获取默认扩展配置(用于前端绑定生成) func (s *ExtensionService) GetDefaultExtensions() []models.Extension { return models.NewDefaultExtensions() }