Files
yixiaogao/database/user_database.py

557 lines
20 KiB
Python
Raw Normal View History

2025-12-02 14:58:52 +08:00
import sqlite3
import datetime
import os
import bcrypt
import re
import logging
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
filename=os.path.join(os.path.dirname(__file__), 'user_db.log')
)
logger = logging.getLogger('UserDatabase')
class UserDatabase:
def __init__(self, db_path='users.db'):
"""
初始化数据库配置
:param db_path: 数据库文件路径
"""
self.db_path = os.path.join(os.path.dirname(__file__), db_path)
# 初始化时只创建表,不保持连接
self._setup_adapters() # 设置全局适配器
self._create_tables_at_init() # 初始化时创建表
def _setup_adapters(self):
"""设置SQLite适配器和转换器全局设置"""
# 添加自定义datetime适配器以解决Python 3.12弃用警告
def adapt_datetime(dt):
return dt.isoformat()
# 注册适配器
sqlite3.register_adapter(datetime.datetime, adapt_datetime)
# 添加转换器从字符串转换回datetime
def convert_datetime(val):
try:
return datetime.datetime.fromisoformat(val.decode())
except (ValueError, AttributeError):
return val
# 注册转换器
sqlite3.register_converter('TIMESTAMP', convert_datetime)
def _get_connection(self):
"""
获取新的数据库连接线程安全
:return: 数据库连接对象
"""
try:
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row # 允许通过列名访问结果
return conn
except sqlite3.Error as e:
print(f"数据库连接错误: {e}")
raise
def _create_tables_at_init(self):
"""
初始化时创建用户表
"""
conn = self._get_connection()
try:
cursor = conn.cursor()
# 创建用户表,添加适当的主键和索引
cursor.execute('''
CREATE TABLE IF NOT EXISTS users (
user_id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
email TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_login_at TIMESTAMP,
status INTEGER DEFAULT 1,
profile_image TEXT,
bio TEXT,
UNIQUE(username),
UNIQUE(email)
)
''')
# 创建索引以提高查询效率
cursor.execute('CREATE INDEX IF NOT EXISTS idx_username ON users(username)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_email ON users(email)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON users(status)')
conn.commit()
except sqlite3.Error as e:
conn.rollback()
print(f"创建表错误: {e}")
raise
finally:
conn.close()
def _hash_password(self, password):
"""
对密码进行哈希加密
:param password: 原始密码
:return: 加密后的密码哈希
"""
# 使用bcrypt进行哈希更安全的密码存储方式
salt = bcrypt.gensalt()
return bcrypt.hashpw(password.encode(), salt).decode()
# 定义自定义异常类
class DatabaseError(Exception):
"""数据库操作异常基类"""
pass
class UserExistsError(DatabaseError):
"""用户已存在异常"""
pass
class UserNotFoundError(DatabaseError):
"""用户不存在异常"""
pass
class ValidationError(DatabaseError):
"""数据验证异常"""
pass
def _validate_input(self, username, email, password=None):
"""
验证输入数据的有效性
:param username: 用户名
:param email: 邮箱
:param password: 密码可选
:return: 验证通过返回True否则抛出异常
"""
# 验证用户名
if not username or not isinstance(username, str):
raise self.ValidationError("用户名不能为空")
if len(username) < 3:
raise self.ValidationError("用户名必须至少包含3个字符")
if len(username) > 50:
raise self.ValidationError("用户名不能超过50个字符")
if not re.match(r'^[a-zA-Z0-9_]+$', username):
raise self.ValidationError("用户名只能包含字母、数字和下划线")
# 验证邮箱 - 使用更严格的正则表达式
if not email or not isinstance(email, str):
raise self.ValidationError("邮箱不能为空")
email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
if not re.match(email_pattern, email):
raise self.ValidationError("邮箱格式不正确")
if len(email) > 255:
raise self.ValidationError("邮箱不能超过255个字符")
# 验证密码
if password is not None:
if not password:
raise self.ValidationError("密码不能为空")
if len(password) < 6:
raise self.ValidationError("密码必须至少包含6个字符")
if len(password) > 128:
raise self.ValidationError("密码不能超过128个字符")
return True
def create_user(self, username, password, email, nickname=''):
"""
创建新用户
:param username: 用户名
:param password: 密码
:param email: 邮箱
:param nickname: 昵称可选
:return: 创建成功返回用户ID否则抛出异常
"""
conn = self._get_connection()
try:
cursor = conn.cursor()
# 验证输入
self._validate_input(username, email, password)
# 检查用户名是否已存在
cursor.execute("SELECT user_id FROM users WHERE username = ?", (username,))
if cursor.fetchone():
raise self.UserExistsError("用户名已存在")
# 检查邮箱是否已存在
cursor.execute("SELECT user_id FROM users WHERE email = ?", (email,))
if cursor.fetchone():
raise self.UserExistsError("邮箱已被注册")
# 哈希密码
password_hash = self._hash_password(password)
# 插入用户记录
cursor.execute('''
INSERT INTO users (username, email, password_hash, created_at)
VALUES (?, ?, ?, ?)
''', (username, email, password_hash, datetime.datetime.now()))
conn.commit()
user_id = cursor.lastrowid
logger.info(f"创建用户成功ID: {user_id}, 用户名: {username}")
return user_id
except self.DatabaseError:
raise
except sqlite3.Error as e:
conn.rollback()
error_msg = f"创建用户失败: {str(e)}"
logger.error(f"创建用户错误: {e}, 用户名: {username}, 邮箱: {email}")
raise self.DatabaseError(error_msg)
def get_user_by_id(self, user_id):
"""
根据用户ID获取用户信息
:param user_id: 用户ID
:return: 用户信息字典如果不存在返回None
"""
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE user_id = ?", (user_id,))
row = cursor.fetchone()
if row:
return dict(row)
return None
except sqlite3.Error as e:
print(f"查询用户错误: {e}")
return None
finally:
conn.close()
def get_user_by_username(self, username):
"""
根据用户名获取用户信息
:param username: 用户名
:return: 用户信息字典如果不存在返回None
"""
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
row = cursor.fetchone()
if row:
return dict(row)
return None
except sqlite3.Error as e:
print(f"查询用户错误: {e}")
return None
finally:
conn.close()
def get_user_by_email(self, email):
"""
根据邮箱获取用户信息
:param email: 邮箱
:return: 用户信息字典如果不存在返回None
"""
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE email = ?", (email,))
row = cursor.fetchone()
if row:
return dict(row)
return None
except sqlite3.Error as e:
print(f"查询用户错误: {e}")
return None
finally:
conn.close()
def verify_password(self, username, password):
"""
验证用户密码
:param username: 用户名
:param password: 密码
:return: 验证成功返回用户信息字典失败返回None
"""
try:
# 获取用户信息get_user_by_username已经使用独立连接
user = self.get_user_by_username(username)
if not user:
logger.warning(f"密码验证失败: 用户不存在 - {username}")
return None
# 检查账号状态
if user['status'] != 1:
logger.warning(f"密码验证失败: 账号状态异常 - 用户名: {username}, 状态: {user['status']}")
return None
# 验证密码 - 使用bcrypt的checkpw方法
is_valid = bcrypt.checkpw(password.encode('utf-8'), user['password_hash'].encode('utf-8'))
if is_valid:
logger.info(f"密码验证成功: {username}")
return user # 验证成功返回用户信息
else:
logger.warning(f"密码验证失败: 密码错误 - {username}")
return None
except Exception as e:
logger.error(f"密码验证错误: {e}, 用户名: {username}")
return None
def update_login_time(self, user_id):
"""
更新用户最后登录时间
:param user_id: 用户ID
:return: 更新成功返回True否则返回False
"""
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"UPDATE users SET last_login_at = ? WHERE user_id = ?",
(datetime.datetime.now(), user_id)
)
conn.commit()
return cursor.rowcount > 0
except sqlite3.Error as e:
conn.rollback()
print(f"更新登录时间错误: {e}")
return False
finally:
conn.close()
def update_user(self, user_id, **kwargs):
"""
更新用户信息
:param user_id: 用户ID
:param kwargs: 要更新的字段
:return: 更新成功返回True否则返回False
"""
conn = self._get_connection()
try:
cursor = conn.cursor()
# 检查用户是否存在
if not self.get_user_by_id(user_id):
raise self.UserNotFoundError("用户不存在")
# 准备更新字段
update_fields = []
update_values = []
# 记录要更新的字段用于日志
updated_fields_log = []
if 'username' in kwargs:
update_fields.append("username = ?")
update_values.append(kwargs['username'])
updated_fields_log.append('username')
# 验证用户名
self._validate_input(kwargs['username'], "dummy@example.com")
# 检查用户名是否被其他用户使用
cursor.execute(
"SELECT user_id FROM users WHERE username = ? AND user_id != ?",
(kwargs['username'], user_id)
)
if cursor.fetchone():
raise self.UserExistsError("用户名已存在")
if 'email' in kwargs:
update_fields.append("email = ?")
update_values.append(kwargs['email'])
updated_fields_log.append('email')
# 验证邮箱
self._validate_input("dummy", kwargs['email'])
# 检查邮箱是否被其他用户使用
cursor.execute(
"SELECT user_id FROM users WHERE email = ? AND user_id != ?",
(kwargs['email'], user_id)
)
if cursor.fetchone():
raise self.UserExistsError("邮箱已被注册")
if 'password' in kwargs:
# 验证密码
self._validate_input("dummy", "dummy@example.com", kwargs['password'])
# 哈希密码
password_hash = self._hash_password(kwargs['password'])
update_fields.append("password_hash = ?")
update_values.append(password_hash)
updated_fields_log.append('password')
if 'status' in kwargs:
# 验证状态值
if not isinstance(kwargs['status'], int) or kwargs['status'] not in (0, 1, 2):
raise self.ValidationError("无效的用户状态值")
update_fields.append("status = ?")
update_values.append(kwargs['status'])
updated_fields_log.append('status')
if 'profile_image' in kwargs:
# 验证头像路径
if kwargs['profile_image'] and len(kwargs['profile_image']) > 255:
raise self.ValidationError("头像路径过长")
update_fields.append("profile_image = ?")
update_values.append(kwargs['profile_image'])
updated_fields_log.append('profile_image')
if 'bio' in kwargs:
# 验证个人简介
if kwargs['bio'] and len(kwargs['bio']) > 500:
raise self.ValidationError("个人简介不能超过500个字符")
update_fields.append("bio = ?")
update_values.append(kwargs['bio'])
updated_fields_log.append('bio')
# 执行更新
if update_fields:
update_values.append(user_id)
update_sql = f"UPDATE users SET {', '.join(update_fields)} WHERE user_id = ?"
cursor.execute(update_sql, update_values)
conn.commit()
logger.info(f"更新用户信息成功ID: {user_id}, 更新字段: {', '.join(updated_fields_log)}")
return True
else:
# 没有要更新的字段
logger.warning(f"没有更新字段提供用户ID: {user_id}")
return True
except self.DatabaseError:
raise
except sqlite3.Error as e:
conn.rollback()
logger.error(f"更新用户信息错误: {e}, 用户ID: {user_id}")
raise self.DatabaseError(f"更新用户信息失败: {str(e)}")
finally:
conn.close()
def delete_user(self, user_id):
"""
删除用户
:param user_id: 用户ID
:return: 删除成功返回True否则返回False
"""
try:
# 检查用户是否存在
user = self.get_user_by_id(user_id)
if not user:
raise self.UserNotFoundError("用户不存在")
# 开始事务
self.cursor.execute("DELETE FROM users WHERE user_id = ?", (user_id,))
self.conn.commit()
if self.cursor.rowcount > 0:
logger.info(f"删除用户成功ID: {user_id}, 用户名: {user['username']}")
return True
return False
except self.DatabaseError:
raise
except sqlite3.Error as e:
self.conn.rollback()
error_msg = f"删除用户失败: {str(e)}"
logger.error(f"删除用户错误: {e}, 用户ID: {user_id}")
raise self.DatabaseError(error_msg)
def list_users(self, limit=100, offset=0):
"""
列出用户支持分页
:param limit: 每页数量
:param offset: 偏移量
:return: 用户列表
"""
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM users ORDER BY created_at DESC LIMIT ? OFFSET ?",
(limit, offset)
)
users = []
for row in cursor.fetchall():
users.append(dict(row))
return users
except sqlite3.Error as e:
print(f"列出用户错误: {e}")
return []
finally:
conn.close()
def close(self):
"""
关闭数据库连接保持兼容性
"""
# 现在每个操作都使用独立连接,不需要关闭持久连接
print("数据库连接管理已更新,每个操作使用独立连接")
# 示例用法
if __name__ == "__main__":
print("=== 用户数据库示例程序 ===")
# 创建数据库实例
db = UserDatabase()
try:
print("\n1. 创建测试用户")
# 尝试创建已存在的用户
try:
user_id = db.create_user(
username="demo_user",
email="demo@example.com",
password="SecurePass123"
)
print(f"✅ 创建用户成功ID: {user_id}")
except db.UserExistsError as e:
print(f" {e},继续使用现有用户")
user = db.get_user_by_username("demo_user")
if user:
user_id = user['user_id']
print("\n2. 查询用户信息")
user = db.get_user_by_id(user_id)
if user:
# 不打印密码哈希
safe_user = {k: v for k, v in user.items() if k != 'password_hash'}
print(f"✅ 查询用户成功: {safe_user}")
print("\n3. 验证密码")
is_valid = db.verify_password("demo_user", "SecurePass123")
print(f"✅ 密码验证: {'通过' if is_valid else '失败'}")
# 测试密码验证失败的情况
is_valid_wrong = db.verify_password("demo_user", "WrongPassword")
print(f"✅ 错误密码验证: {'通过' if is_valid_wrong else '失败'}")
print("\n4. 更新用户信息")
success = db.update_user(user_id,
bio="这是一个测试用户账号",
status=1)
if success:
updated_user = db.get_user_by_id(user_id)
safe_updated = {k: v for k, v in updated_user.items() if k != 'password_hash'}
print(f"✅ 更新用户成功: {safe_updated}")
print("\n5. 列出用户")
users = db.list_users(limit=10)
safe_users = [{k: v for k, v in u.items() if k != 'password_hash'} for u in users]
print(f"✅ 共找到 {len(safe_users)} 个用户")
for u in safe_users:
print(f" - {u['username']} ({u['email']})")
print("\n6. 测试数据验证")
try:
db.create_user(username="ab", email="invalid-email", password="123")
except db.ValidationError as e:
print(f"✅ 验证错误捕获成功: {e}")
print("\n7. 更新登录时间")
if db.update_login_time(user_id):
print("✅ 登录时间更新成功")
print("\n=== 测试完成 ===")
except Exception as e:
print(f"❌ 发生错误: {e}")
finally:
# 关闭数据库连接
print("\n关闭数据库连接...")
db.close()
print("数据库连接已关闭")