Files
ai_wht_wechat/backend/config.py
2026-01-23 16:27:47 +08:00

170 lines
6.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
配置管理模块
支持从YAML文件加载配置支持环境变量覆盖
"""
import os
import yaml
from typing import Dict, Any
class Config:
"""配置类"""
def __init__(self, config_dict: Dict[str, Any]):
self._config = config_dict
def get(self, key: str, default=None):
"""获取配置值,支持点号分隔的嵌套键"""
keys = key.split('.')
value = self._config
for k in keys:
if isinstance(value, dict):
value = value.get(k)
if value is None:
return default
else:
return default
return value
def get_dict(self, key: str) -> Dict[str, Any]:
"""获取配置字典"""
value = self.get(key)
return value if isinstance(value, dict) else {}
def get_int(self, key: str, default: int = 0) -> int:
"""获取整数配置"""
value = self.get(key, default)
try:
return int(value)
except (ValueError, TypeError):
return default
def get_bool(self, key: str, default: bool = False) -> bool:
"""获取布尔配置"""
value = self.get(key, default)
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.lower() in ('true', 'yes', '1', 'on')
return bool(value)
def get_str(self, key: str, default: str = '') -> str:
"""获取字符串配置"""
value = self.get(key, default)
return str(value) if value is not None else default
def load_config(env: str = None) -> Config:
"""
加载配置文件
Args:
env: 环境名称,可选值: dev, prod
如果不指定,从环境变量 ENV 读取,默认为 dev
Returns:
Config对象
"""
# 确定环境
if env is None:
env = os.getenv('ENV', 'dev')
# 配置文件路径
config_file = f'config.{env}.yaml'
config_path = os.path.join(os.path.dirname(__file__), config_file)
if not os.path.exists(config_path):
raise FileNotFoundError(f"配置文件不存在: {config_path}")
# 加载YAML配置
with open(config_path, 'r', encoding='utf-8') as f:
config_dict = yaml.safe_load(f)
# 环境变量覆盖(支持常用配置)
# 数据库配置
if os.getenv('DB_HOST'):
config_dict.setdefault('database', {})['host'] = os.getenv('DB_HOST')
if os.getenv('DB_PORT'):
config_dict.setdefault('database', {})['port'] = int(os.getenv('DB_PORT'))
if os.getenv('DB_USER'):
config_dict.setdefault('database', {})['username'] = os.getenv('DB_USER')
if os.getenv('DB_PASSWORD'):
config_dict.setdefault('database', {})['password'] = os.getenv('DB_PASSWORD')
if os.getenv('DB_NAME'):
config_dict.setdefault('database', {})['dbname'] = os.getenv('DB_NAME')
# Redis配置
if os.getenv('REDIS_HOST'):
config_dict.setdefault('redis', {})['host'] = os.getenv('REDIS_HOST')
if os.getenv('REDIS_PORT'):
config_dict.setdefault('redis', {})['port'] = int(os.getenv('REDIS_PORT'))
if os.getenv('REDIS_PASSWORD'):
config_dict.setdefault('redis', {})['password'] = os.getenv('REDIS_PASSWORD')
if os.getenv('REDIS_DB'):
config_dict.setdefault('redis', {})['db'] = int(os.getenv('REDIS_DB'))
# 调度器配置
if os.getenv('SCHEDULER_ENABLED'):
config_dict.setdefault('scheduler', {})['enabled'] = os.getenv('SCHEDULER_ENABLED').lower() == 'true'
if os.getenv('SCHEDULER_CRON'):
config_dict.setdefault('scheduler', {})['cron'] = os.getenv('SCHEDULER_CRON')
if os.getenv('SCHEDULER_MAX_CONCURRENT'):
config_dict.setdefault('scheduler', {})['max_concurrent'] = int(os.getenv('SCHEDULER_MAX_CONCURRENT'))
if os.getenv('SCHEDULER_PUBLISH_TIMEOUT'):
config_dict.setdefault('scheduler', {})['publish_timeout'] = int(os.getenv('SCHEDULER_PUBLISH_TIMEOUT'))
if os.getenv('SCHEDULER_MAX_ARTICLES_PER_USER_PER_RUN'):
config_dict.setdefault('scheduler', {})['max_articles_per_user_per_run'] = int(os.getenv('SCHEDULER_MAX_ARTICLES_PER_USER_PER_RUN'))
if os.getenv('SCHEDULER_MAX_FAILURES_PER_USER_PER_RUN'):
config_dict.setdefault('scheduler', {})['max_failures_per_user_per_run'] = int(os.getenv('SCHEDULER_MAX_FAILURES_PER_USER_PER_RUN'))
if os.getenv('SCHEDULER_MAX_DAILY_ARTICLES_PER_USER'):
config_dict.setdefault('scheduler', {})['max_daily_articles_per_user'] = int(os.getenv('SCHEDULER_MAX_DAILY_ARTICLES_PER_USER'))
if os.getenv('SCHEDULER_MAX_HOURLY_ARTICLES_PER_USER'):
config_dict.setdefault('scheduler', {})['max_hourly_articles_per_user'] = int(os.getenv('SCHEDULER_MAX_HOURLY_ARTICLES_PER_USER'))
# 代理池配置
if os.getenv('PROXY_POOL_ENABLED'):
config_dict.setdefault('proxy_pool', {})['enabled'] = os.getenv('PROXY_POOL_ENABLED').lower() == 'true'
if os.getenv('PROXY_POOL_API_URL'):
config_dict.setdefault('proxy_pool', {})['api_url'] = os.getenv('PROXY_POOL_API_URL')
# AdsPower指纹浏览器配置
if os.getenv('ADSPOWER_ENABLED'):
config_dict.setdefault('adspower', {})['enabled'] = os.getenv('ADSPOWER_ENABLED').lower() == 'true'
if os.getenv('ADSPOWER_API_BASE'):
config_dict.setdefault('adspower', {})['api_base'] = os.getenv('ADSPOWER_API_BASE')
if os.getenv('ADSPOWER_API_KEY'):
config_dict.setdefault('adspower', {})['api_key'] = os.getenv('ADSPOWER_API_KEY')
if os.getenv('ADSPOWER_USER_ID'):
config_dict.setdefault('adspower', {})['user_id'] = os.getenv('ADSPOWER_USER_ID')
if os.getenv('ADSPOWER_DEFAULT_GROUP_ID'):
config_dict.setdefault('adspower', {})['default_group_id'] = os.getenv('ADSPOWER_DEFAULT_GROUP_ID')
print(f"[配置] 已加载配置文件: {config_file}")
print(f"[配置] 环境: {env}")
print(f"[配置] 数据库: {config_dict.get('database', {}).get('host')}:{config_dict.get('database', {}).get('port')}")
print(f"[配置] Redis: {config_dict.get('redis', {}).get('host')}:{config_dict.get('redis', {}).get('port')}")
print(f"[配置] 调度器: {'启用' if config_dict.get('scheduler', {}).get('enabled') else '禁用'}")
return Config(config_dict)
# 全局配置对象
app_config: Config = None
def init_config(env: str = None):
"""初始化全局配置"""
global app_config
app_config = load_config(env)
return app_config
def get_config() -> Config:
"""获取全局配置对象"""
global app_config
if app_config is None:
app_config = load_config()
return app_config