2025-12-2genxin
This commit is contained in:
556
database/user_database.py
Normal file
556
database/user_database.py
Normal file
@@ -0,0 +1,556 @@
|
||||
import sqlite3
|
||||
import datetime
|
||||
import os
|
||||
import bcrypt
|
||||
import re
|
||||
import logging
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
filename=os.path.join(os.path.dirname(__file__), 'user_db.log')
|
||||
)
|
||||
logger = logging.getLogger('UserDatabase')
|
||||
|
||||
class UserDatabase:
|
||||
def __init__(self, db_path='users.db'):
|
||||
"""
|
||||
初始化数据库配置
|
||||
:param db_path: 数据库文件路径
|
||||
"""
|
||||
self.db_path = os.path.join(os.path.dirname(__file__), db_path)
|
||||
# 初始化时只创建表,不保持连接
|
||||
self._setup_adapters() # 设置全局适配器
|
||||
self._create_tables_at_init() # 初始化时创建表
|
||||
|
||||
def _setup_adapters(self):
|
||||
"""设置SQLite适配器和转换器(全局设置)"""
|
||||
# 添加自定义datetime适配器以解决Python 3.12弃用警告
|
||||
def adapt_datetime(dt):
|
||||
return dt.isoformat()
|
||||
|
||||
# 注册适配器
|
||||
sqlite3.register_adapter(datetime.datetime, adapt_datetime)
|
||||
|
||||
# 添加转换器(从字符串转换回datetime)
|
||||
def convert_datetime(val):
|
||||
try:
|
||||
return datetime.datetime.fromisoformat(val.decode())
|
||||
except (ValueError, AttributeError):
|
||||
return val
|
||||
|
||||
# 注册转换器
|
||||
sqlite3.register_converter('TIMESTAMP', convert_datetime)
|
||||
|
||||
def _get_connection(self):
|
||||
"""
|
||||
获取新的数据库连接(线程安全)
|
||||
:return: 数据库连接对象
|
||||
"""
|
||||
try:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row # 允许通过列名访问结果
|
||||
return conn
|
||||
except sqlite3.Error as e:
|
||||
print(f"数据库连接错误: {e}")
|
||||
raise
|
||||
|
||||
def _create_tables_at_init(self):
|
||||
"""
|
||||
初始化时创建用户表
|
||||
"""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
# 创建用户表,添加适当的主键和索引
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
user_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
email TEXT NOT NULL UNIQUE,
|
||||
password_hash TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_login_at TIMESTAMP,
|
||||
status INTEGER DEFAULT 1,
|
||||
profile_image TEXT,
|
||||
bio TEXT,
|
||||
UNIQUE(username),
|
||||
UNIQUE(email)
|
||||
)
|
||||
''')
|
||||
|
||||
# 创建索引以提高查询效率
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_username ON users(username)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_email ON users(email)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON users(status)')
|
||||
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
conn.rollback()
|
||||
print(f"创建表错误: {e}")
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _hash_password(self, password):
|
||||
"""
|
||||
对密码进行哈希加密
|
||||
:param password: 原始密码
|
||||
:return: 加密后的密码哈希
|
||||
"""
|
||||
# 使用bcrypt进行哈希,更安全的密码存储方式
|
||||
salt = bcrypt.gensalt()
|
||||
return bcrypt.hashpw(password.encode(), salt).decode()
|
||||
|
||||
# 定义自定义异常类
|
||||
class DatabaseError(Exception):
|
||||
"""数据库操作异常基类"""
|
||||
pass
|
||||
|
||||
class UserExistsError(DatabaseError):
|
||||
"""用户已存在异常"""
|
||||
pass
|
||||
|
||||
class UserNotFoundError(DatabaseError):
|
||||
"""用户不存在异常"""
|
||||
pass
|
||||
|
||||
class ValidationError(DatabaseError):
|
||||
"""数据验证异常"""
|
||||
pass
|
||||
|
||||
def _validate_input(self, username, email, password=None):
|
||||
"""
|
||||
验证输入数据的有效性
|
||||
:param username: 用户名
|
||||
:param email: 邮箱
|
||||
:param password: 密码(可选)
|
||||
:return: 验证通过返回True,否则抛出异常
|
||||
"""
|
||||
# 验证用户名
|
||||
if not username or not isinstance(username, str):
|
||||
raise self.ValidationError("用户名不能为空")
|
||||
if len(username) < 3:
|
||||
raise self.ValidationError("用户名必须至少包含3个字符")
|
||||
if len(username) > 50:
|
||||
raise self.ValidationError("用户名不能超过50个字符")
|
||||
if not re.match(r'^[a-zA-Z0-9_]+$', username):
|
||||
raise self.ValidationError("用户名只能包含字母、数字和下划线")
|
||||
|
||||
# 验证邮箱 - 使用更严格的正则表达式
|
||||
if not email or not isinstance(email, str):
|
||||
raise self.ValidationError("邮箱不能为空")
|
||||
email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
||||
if not re.match(email_pattern, email):
|
||||
raise self.ValidationError("邮箱格式不正确")
|
||||
if len(email) > 255:
|
||||
raise self.ValidationError("邮箱不能超过255个字符")
|
||||
|
||||
# 验证密码
|
||||
if password is not None:
|
||||
if not password:
|
||||
raise self.ValidationError("密码不能为空")
|
||||
if len(password) < 6:
|
||||
raise self.ValidationError("密码必须至少包含6个字符")
|
||||
if len(password) > 128:
|
||||
raise self.ValidationError("密码不能超过128个字符")
|
||||
|
||||
return True
|
||||
|
||||
def create_user(self, username, password, email, nickname=''):
|
||||
"""
|
||||
创建新用户
|
||||
:param username: 用户名
|
||||
:param password: 密码
|
||||
:param email: 邮箱
|
||||
:param nickname: 昵称(可选)
|
||||
:return: 创建成功返回用户ID,否则抛出异常
|
||||
"""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
# 验证输入
|
||||
self._validate_input(username, email, password)
|
||||
|
||||
# 检查用户名是否已存在
|
||||
cursor.execute("SELECT user_id FROM users WHERE username = ?", (username,))
|
||||
if cursor.fetchone():
|
||||
raise self.UserExistsError("用户名已存在")
|
||||
|
||||
# 检查邮箱是否已存在
|
||||
cursor.execute("SELECT user_id FROM users WHERE email = ?", (email,))
|
||||
if cursor.fetchone():
|
||||
raise self.UserExistsError("邮箱已被注册")
|
||||
|
||||
# 哈希密码
|
||||
password_hash = self._hash_password(password)
|
||||
|
||||
# 插入用户记录
|
||||
cursor.execute('''
|
||||
INSERT INTO users (username, email, password_hash, created_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
''', (username, email, password_hash, datetime.datetime.now()))
|
||||
|
||||
conn.commit()
|
||||
user_id = cursor.lastrowid
|
||||
logger.info(f"创建用户成功,ID: {user_id}, 用户名: {username}")
|
||||
return user_id
|
||||
except self.DatabaseError:
|
||||
raise
|
||||
except sqlite3.Error as e:
|
||||
conn.rollback()
|
||||
error_msg = f"创建用户失败: {str(e)}"
|
||||
logger.error(f"创建用户错误: {e}, 用户名: {username}, 邮箱: {email}")
|
||||
raise self.DatabaseError(error_msg)
|
||||
|
||||
def get_user_by_id(self, user_id):
|
||||
"""
|
||||
根据用户ID获取用户信息
|
||||
:param user_id: 用户ID
|
||||
:return: 用户信息字典,如果不存在返回None
|
||||
"""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM users WHERE user_id = ?", (user_id,))
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return dict(row)
|
||||
return None
|
||||
except sqlite3.Error as e:
|
||||
print(f"查询用户错误: {e}")
|
||||
return None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_user_by_username(self, username):
|
||||
"""
|
||||
根据用户名获取用户信息
|
||||
:param username: 用户名
|
||||
:return: 用户信息字典,如果不存在返回None
|
||||
"""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return dict(row)
|
||||
return None
|
||||
except sqlite3.Error as e:
|
||||
print(f"查询用户错误: {e}")
|
||||
return None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_user_by_email(self, email):
|
||||
"""
|
||||
根据邮箱获取用户信息
|
||||
:param email: 邮箱
|
||||
:return: 用户信息字典,如果不存在返回None
|
||||
"""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM users WHERE email = ?", (email,))
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return dict(row)
|
||||
return None
|
||||
except sqlite3.Error as e:
|
||||
print(f"查询用户错误: {e}")
|
||||
return None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def verify_password(self, username, password):
|
||||
"""
|
||||
验证用户密码
|
||||
:param username: 用户名
|
||||
:param password: 密码
|
||||
:return: 验证成功返回用户信息字典,失败返回None
|
||||
"""
|
||||
try:
|
||||
# 获取用户信息(get_user_by_username已经使用独立连接)
|
||||
user = self.get_user_by_username(username)
|
||||
if not user:
|
||||
logger.warning(f"密码验证失败: 用户不存在 - {username}")
|
||||
return None
|
||||
|
||||
# 检查账号状态
|
||||
if user['status'] != 1:
|
||||
logger.warning(f"密码验证失败: 账号状态异常 - 用户名: {username}, 状态: {user['status']}")
|
||||
return None
|
||||
|
||||
# 验证密码 - 使用bcrypt的checkpw方法
|
||||
is_valid = bcrypt.checkpw(password.encode('utf-8'), user['password_hash'].encode('utf-8'))
|
||||
|
||||
if is_valid:
|
||||
logger.info(f"密码验证成功: {username}")
|
||||
return user # 验证成功返回用户信息
|
||||
else:
|
||||
logger.warning(f"密码验证失败: 密码错误 - {username}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"密码验证错误: {e}, 用户名: {username}")
|
||||
return None
|
||||
|
||||
def update_login_time(self, user_id):
|
||||
"""
|
||||
更新用户最后登录时间
|
||||
:param user_id: 用户ID
|
||||
:return: 更新成功返回True,否则返回False
|
||||
"""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"UPDATE users SET last_login_at = ? WHERE user_id = ?",
|
||||
(datetime.datetime.now(), user_id)
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
except sqlite3.Error as e:
|
||||
conn.rollback()
|
||||
print(f"更新登录时间错误: {e}")
|
||||
return False
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def update_user(self, user_id, **kwargs):
|
||||
"""
|
||||
更新用户信息
|
||||
:param user_id: 用户ID
|
||||
:param kwargs: 要更新的字段
|
||||
:return: 更新成功返回True,否则返回False
|
||||
"""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
# 检查用户是否存在
|
||||
if not self.get_user_by_id(user_id):
|
||||
raise self.UserNotFoundError("用户不存在")
|
||||
|
||||
# 准备更新字段
|
||||
update_fields = []
|
||||
update_values = []
|
||||
|
||||
# 记录要更新的字段用于日志
|
||||
updated_fields_log = []
|
||||
|
||||
if 'username' in kwargs:
|
||||
update_fields.append("username = ?")
|
||||
update_values.append(kwargs['username'])
|
||||
updated_fields_log.append('username')
|
||||
# 验证用户名
|
||||
self._validate_input(kwargs['username'], "dummy@example.com")
|
||||
# 检查用户名是否被其他用户使用
|
||||
cursor.execute(
|
||||
"SELECT user_id FROM users WHERE username = ? AND user_id != ?",
|
||||
(kwargs['username'], user_id)
|
||||
)
|
||||
if cursor.fetchone():
|
||||
raise self.UserExistsError("用户名已存在")
|
||||
|
||||
if 'email' in kwargs:
|
||||
update_fields.append("email = ?")
|
||||
update_values.append(kwargs['email'])
|
||||
updated_fields_log.append('email')
|
||||
# 验证邮箱
|
||||
self._validate_input("dummy", kwargs['email'])
|
||||
# 检查邮箱是否被其他用户使用
|
||||
cursor.execute(
|
||||
"SELECT user_id FROM users WHERE email = ? AND user_id != ?",
|
||||
(kwargs['email'], user_id)
|
||||
)
|
||||
if cursor.fetchone():
|
||||
raise self.UserExistsError("邮箱已被注册")
|
||||
|
||||
if 'password' in kwargs:
|
||||
# 验证密码
|
||||
self._validate_input("dummy", "dummy@example.com", kwargs['password'])
|
||||
# 哈希密码
|
||||
password_hash = self._hash_password(kwargs['password'])
|
||||
update_fields.append("password_hash = ?")
|
||||
update_values.append(password_hash)
|
||||
updated_fields_log.append('password')
|
||||
|
||||
if 'status' in kwargs:
|
||||
# 验证状态值
|
||||
if not isinstance(kwargs['status'], int) or kwargs['status'] not in (0, 1, 2):
|
||||
raise self.ValidationError("无效的用户状态值")
|
||||
update_fields.append("status = ?")
|
||||
update_values.append(kwargs['status'])
|
||||
updated_fields_log.append('status')
|
||||
|
||||
if 'profile_image' in kwargs:
|
||||
# 验证头像路径
|
||||
if kwargs['profile_image'] and len(kwargs['profile_image']) > 255:
|
||||
raise self.ValidationError("头像路径过长")
|
||||
update_fields.append("profile_image = ?")
|
||||
update_values.append(kwargs['profile_image'])
|
||||
updated_fields_log.append('profile_image')
|
||||
|
||||
if 'bio' in kwargs:
|
||||
# 验证个人简介
|
||||
if kwargs['bio'] and len(kwargs['bio']) > 500:
|
||||
raise self.ValidationError("个人简介不能超过500个字符")
|
||||
update_fields.append("bio = ?")
|
||||
update_values.append(kwargs['bio'])
|
||||
updated_fields_log.append('bio')
|
||||
|
||||
# 执行更新
|
||||
if update_fields:
|
||||
update_values.append(user_id)
|
||||
update_sql = f"UPDATE users SET {', '.join(update_fields)} WHERE user_id = ?"
|
||||
cursor.execute(update_sql, update_values)
|
||||
conn.commit()
|
||||
logger.info(f"更新用户信息成功,ID: {user_id}, 更新字段: {', '.join(updated_fields_log)}")
|
||||
return True
|
||||
else:
|
||||
# 没有要更新的字段
|
||||
logger.warning(f"没有更新字段提供,用户ID: {user_id}")
|
||||
return True
|
||||
except self.DatabaseError:
|
||||
raise
|
||||
except sqlite3.Error as e:
|
||||
conn.rollback()
|
||||
logger.error(f"更新用户信息错误: {e}, 用户ID: {user_id}")
|
||||
raise self.DatabaseError(f"更新用户信息失败: {str(e)}")
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
|
||||
def delete_user(self, user_id):
|
||||
"""
|
||||
删除用户
|
||||
:param user_id: 用户ID
|
||||
:return: 删除成功返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
# 检查用户是否存在
|
||||
user = self.get_user_by_id(user_id)
|
||||
if not user:
|
||||
raise self.UserNotFoundError("用户不存在")
|
||||
|
||||
# 开始事务
|
||||
self.cursor.execute("DELETE FROM users WHERE user_id = ?", (user_id,))
|
||||
self.conn.commit()
|
||||
|
||||
if self.cursor.rowcount > 0:
|
||||
logger.info(f"删除用户成功,ID: {user_id}, 用户名: {user['username']}")
|
||||
return True
|
||||
return False
|
||||
except self.DatabaseError:
|
||||
raise
|
||||
except sqlite3.Error as e:
|
||||
self.conn.rollback()
|
||||
error_msg = f"删除用户失败: {str(e)}"
|
||||
logger.error(f"删除用户错误: {e}, 用户ID: {user_id}")
|
||||
raise self.DatabaseError(error_msg)
|
||||
|
||||
def list_users(self, limit=100, offset=0):
|
||||
"""
|
||||
列出用户(支持分页)
|
||||
:param limit: 每页数量
|
||||
:param offset: 偏移量
|
||||
:return: 用户列表
|
||||
"""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT * FROM users ORDER BY created_at DESC LIMIT ? OFFSET ?",
|
||||
(limit, offset)
|
||||
)
|
||||
users = []
|
||||
for row in cursor.fetchall():
|
||||
users.append(dict(row))
|
||||
return users
|
||||
except sqlite3.Error as e:
|
||||
print(f"列出用户错误: {e}")
|
||||
return []
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
关闭数据库连接(保持兼容性)
|
||||
"""
|
||||
# 现在每个操作都使用独立连接,不需要关闭持久连接
|
||||
print("数据库连接管理已更新,每个操作使用独立连接")
|
||||
|
||||
# 示例用法
|
||||
if __name__ == "__main__":
|
||||
print("=== 用户数据库示例程序 ===")
|
||||
# 创建数据库实例
|
||||
db = UserDatabase()
|
||||
|
||||
try:
|
||||
print("\n1. 创建测试用户")
|
||||
# 尝试创建已存在的用户
|
||||
try:
|
||||
user_id = db.create_user(
|
||||
username="demo_user",
|
||||
email="demo@example.com",
|
||||
password="SecurePass123"
|
||||
)
|
||||
print(f"✅ 创建用户成功,ID: {user_id}")
|
||||
except db.UserExistsError as e:
|
||||
print(f"ℹ️ {e},继续使用现有用户")
|
||||
user = db.get_user_by_username("demo_user")
|
||||
if user:
|
||||
user_id = user['user_id']
|
||||
|
||||
print("\n2. 查询用户信息")
|
||||
user = db.get_user_by_id(user_id)
|
||||
if user:
|
||||
# 不打印密码哈希
|
||||
safe_user = {k: v for k, v in user.items() if k != 'password_hash'}
|
||||
print(f"✅ 查询用户成功: {safe_user}")
|
||||
|
||||
print("\n3. 验证密码")
|
||||
is_valid = db.verify_password("demo_user", "SecurePass123")
|
||||
print(f"✅ 密码验证: {'通过' if is_valid else '失败'}")
|
||||
|
||||
# 测试密码验证失败的情况
|
||||
is_valid_wrong = db.verify_password("demo_user", "WrongPassword")
|
||||
print(f"✅ 错误密码验证: {'通过' if is_valid_wrong else '失败'}")
|
||||
|
||||
print("\n4. 更新用户信息")
|
||||
success = db.update_user(user_id,
|
||||
bio="这是一个测试用户账号",
|
||||
status=1)
|
||||
if success:
|
||||
updated_user = db.get_user_by_id(user_id)
|
||||
safe_updated = {k: v for k, v in updated_user.items() if k != 'password_hash'}
|
||||
print(f"✅ 更新用户成功: {safe_updated}")
|
||||
|
||||
print("\n5. 列出用户")
|
||||
users = db.list_users(limit=10)
|
||||
safe_users = [{k: v for k, v in u.items() if k != 'password_hash'} for u in users]
|
||||
print(f"✅ 共找到 {len(safe_users)} 个用户")
|
||||
for u in safe_users:
|
||||
print(f" - {u['username']} ({u['email']})")
|
||||
|
||||
print("\n6. 测试数据验证")
|
||||
try:
|
||||
db.create_user(username="ab", email="invalid-email", password="123")
|
||||
except db.ValidationError as e:
|
||||
print(f"✅ 验证错误捕获成功: {e}")
|
||||
|
||||
print("\n7. 更新登录时间")
|
||||
if db.update_login_time(user_id):
|
||||
print("✅ 登录时间更新成功")
|
||||
|
||||
print("\n=== 测试完成 ===")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 发生错误: {e}")
|
||||
finally:
|
||||
# 关闭数据库连接
|
||||
print("\n关闭数据库连接...")
|
||||
db.close()
|
||||
print("数据库连接已关闭")
|
||||
Reference in New Issue
Block a user