package config import ( "fmt" "log" "os" "path/filepath" "github.com/glebarez/sqlite" "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/logger" "github.com/spf13/viper" ) // 获取项目根目录 func getProjectRoot() string { // 从当前工作目录向上查找 .env 文件 dir, _ := os.Getwd() for i := 0; i < 5; i++ { if _, err := os.Stat(filepath.Join(dir, ".env")); err == nil { return dir } dir = filepath.Dir(dir) } // 默认返回当前目录 return "." } type Config struct { Port string JWTSecret string DatabaseType string // 数据库类型: mysql 或 sqlite DatabaseHost string DatabasePort string DatabaseUser string DatabasePassword string DatabaseName string DatabaseURL string // 拼接后的完整连接字符串 SQLitePath string // SQLite 数据库文件路径 PythonServiceURL string AICoreServiceAddr string // AI-Core gRPC 服务地址,如 "localhost:50051" // 文件上传配置 UploadMode string // "local" 或 "minio" UploadLocalPath string // 本地存储路径,如 "resource/files" ServerBaseURL string // 服务器基础 URL,用于生成本地文件 URL MarkdownLocalPath string // Markdown 文件存储路径,如 "resource/markdown" // MinIO 配置 MinIOEndpoint string MinIOAccessKey string MinIOSecretKey string MinIOBucket string MinIOUseSSL bool } func Load() *Config { // 重新初始化 viper,避免之前的状态影响 viper.Reset() // 第一步:设置默认值 viper.SetDefault("port", "8080") viper.SetDefault("jwt_secret", "your-secret-key-change-in-production") viper.SetDefault("python_service_url", "http://localhost:8081") viper.SetDefault("ai_core_service_addr", "localhost:50051") // 数据库默认配置 viper.SetDefault("database_type", "mysql") viper.SetDefault("database_host", "localhost") viper.SetDefault("database_port", "3306") viper.SetDefault("database_user", "root") viper.SetDefault("database_password", "root") viper.SetDefault("database_name", "x_agents") viper.SetDefault("sqlite_path", "./data/x_agents.db") // 文件上传默认配置 viper.SetDefault("upload_mode", "local") viper.SetDefault("upload_local_path", "resource/files") viper.SetDefault("server_base_url", "http://localhost:8080") viper.SetDefault("markdown_local_path", "resource/markdown") viper.SetDefault("minio_endpoint", "localhost:9000") viper.SetDefault("minio_access_key", "") viper.SetDefault("minio_secret_key", "") viper.SetDefault("minio_bucket", "x-agents") viper.SetDefault("minio_use_ssl", false) // 第二步:读取 config.yaml(优先级低) viper.SetConfigName("config") viper.SetConfigType("yaml") viper.AddConfigPath("./config") viper.AddConfigPath("../config") viper.AddConfigPath("../../config") _ = viper.MergeInConfig() // 忽略错误,可能没有 config.yaml // 第三步:读取 .env 文件(优先级最高) projectRoot := getProjectRoot() log.Printf("Project root: %s", projectRoot) viper.SetConfigName(".env") viper.SetConfigType("env") viper.AddConfigPath(projectRoot) // 项目根目录 (X-Agents) viper.AddConfigPath(".") // 当前目录 viper.AddConfigPath("..") // 父目录 viper.AddConfigPath("../..") // 上两级目录 viper.SetEnvPrefix("GO") // 环境变量前缀 GO_xxx (仅对环境变量生效) viper.AutomaticEnv() _ = viper.MergeInConfig() // 忽略错误,可能没有 .env // 处理 .env 文件中的键名(去掉 GO_ 前缀映射) envToConfig := map[string]string{ "GO_PORT": "port", "GO_DATABASE_TYPE": "database_type", "GO_DATABASE_HOST": "database_host", "GO_DATABASE_PORT": "database_port", "GO_DATABASE_NAME": "database_name", "GO_DATABASE_USER": "database_user", "GO_DATABASE_PASSWORD": "database_password", "GO_SQLITE_PATH": "sqlite_path", } for envKey, configKey := range envToConfig { if val := viper.GetString(envKey); val != "" { viper.Set(configKey, val) } } log.Printf("Loaded config: database_type=%s, port=%s", viper.GetString("database_type"), viper.GetString("port")) // 获取数据库类型 dbType := viper.GetString("database_type") var databaseURL string if dbType == "sqlite" { sqlitePath := viper.GetString("sqlite_path") // 确保 SQLite 数据目录存在 (跨平台处理) dir := filepath.Dir(sqlitePath) if dir != "." && dir != "" { os.MkdirAll(dir, 0755) } databaseURL = sqlitePath } else { // MySQL 连接字符串 dbHost := viper.GetString("database_host") dbPort := viper.GetString("database_port") dbUser := viper.GetString("database_user") dbPassword := viper.GetString("database_password") dbName := viper.GetString("database_name") databaseURL = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", dbUser, dbPassword, dbHost, dbPort, dbName) } return &Config{ Port: viper.GetString("port"), JWTSecret: viper.GetString("jwt_secret"), DatabaseType: dbType, DatabaseURL: databaseURL, DatabaseHost: viper.GetString("database_host"), DatabasePort: viper.GetString("database_port"), DatabaseUser: viper.GetString("database_user"), DatabasePassword: viper.GetString("database_password"), DatabaseName: viper.GetString("database_name"), SQLitePath: viper.GetString("sqlite_path"), PythonServiceURL: viper.GetString("python_service_url"), AICoreServiceAddr: viper.GetString("ai_core_service_addr"), // 文件上传配置 UploadMode: viper.GetString("upload_mode"), UploadLocalPath: viper.GetString("upload_local_path"), ServerBaseURL: viper.GetString("server_base_url"), MarkdownLocalPath: viper.GetString("markdown_local_path"), // MinIO 配置 MinIOEndpoint: viper.GetString("minio_endpoint"), MinIOAccessKey: viper.GetString("minio_access_key"), MinIOSecretKey: viper.GetString("minio_secret_key"), MinIOBucket: viper.GetString("minio_bucket"), MinIOUseSSL: viper.GetBool("minio_use_ssl"), } } func InitDB(cfg *Config) (*gorm.DB, error) { dsn := cfg.DatabaseURL if dsn == "" { return nil, fmt.Errorf("database URL is empty") } var db *gorm.DB var err error if cfg.DatabaseType == "sqlite" { // SQLite 不需要创建目录逻辑,因为 Load 函数已经处理了 db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), }) log.Printf("Using SQLite database: %s", dsn) } else { db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), }) log.Printf("Using MySQL database: %s:%s/%s", cfg.DatabaseHost, cfg.DatabasePort, cfg.DatabaseName) } if err != nil { return nil, fmt.Errorf("failed to connect database: %w", err) } log.Println("Database connected successfully") return db, nil }