79 lines
2.1 KiB
Python
79 lines
2.1 KiB
Python
"""
|
||
JWT 工具函数
|
||
"""
|
||
import jwt
|
||
from datetime import datetime, timedelta
|
||
from typing import Optional
|
||
from fastapi import HTTPException, Depends, Header
|
||
from app.config import get_settings
|
||
|
||
|
||
def create_token(user_id: int, openid: str) -> str:
|
||
"""
|
||
创建 JWT Token
|
||
"""
|
||
settings = get_settings()
|
||
expire = datetime.utcnow() + timedelta(hours=settings.jwt_expire_hours)
|
||
|
||
payload = {
|
||
"user_id": user_id,
|
||
"openid": openid,
|
||
"exp": expire,
|
||
"iat": datetime.utcnow()
|
||
}
|
||
|
||
token = jwt.encode(payload, settings.jwt_secret_key, algorithm="HS256")
|
||
return token
|
||
|
||
|
||
def verify_token(token: str) -> dict:
|
||
"""
|
||
验证 JWT Token
|
||
返回 payload 或抛出异常
|
||
"""
|
||
settings = get_settings()
|
||
|
||
try:
|
||
payload = jwt.decode(token, settings.jwt_secret_key, algorithms=["HS256"])
|
||
return payload
|
||
except jwt.ExpiredSignatureError:
|
||
raise HTTPException(status_code=401, detail="Token已过期,请重新登录")
|
||
except jwt.InvalidTokenError:
|
||
raise HTTPException(status_code=401, detail="无效的Token")
|
||
|
||
|
||
def get_current_user_id(authorization: Optional[str] = Header(None, alias="Authorization")) -> int:
|
||
"""
|
||
从 Header 中获取并验证 Token,返回 user_id
|
||
用作 FastAPI 依赖注入
|
||
"""
|
||
if not authorization:
|
||
raise HTTPException(status_code=401, detail="未提供身份令牌")
|
||
|
||
# 支持 "Bearer xxx" 格式
|
||
token = authorization
|
||
if authorization.startswith("Bearer "):
|
||
token = authorization[7:]
|
||
|
||
payload = verify_token(token)
|
||
return payload.get("user_id")
|
||
|
||
|
||
def get_optional_user_id(authorization: Optional[str] = Header(None, alias="Authorization")) -> Optional[int]:
|
||
"""
|
||
可选的用户验证,未提供 Token 时返回 None
|
||
用于不强制要求登录的接口
|
||
"""
|
||
if not authorization:
|
||
return None
|
||
|
||
try:
|
||
token = authorization
|
||
if authorization.startswith("Bearer "):
|
||
token = authorization[7:]
|
||
|
||
payload = verify_token(token)
|
||
return payload.get("user_id")
|
||
except HTTPException:
|
||
return None
|