middlewares.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. # -*- coding: utf-8 -*-
  2. import json
  3. import time
  4. from starlette.middleware.cors import CORSMiddleware
  5. from starlette.types import ASGIApp
  6. from starlette.requests import Request
  7. from starlette.middleware.gzip import GZipMiddleware
  8. from starlette.responses import Response
  9. from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
  10. from app.common.response import ErrorResponse
  11. from app.config.setting import settings
  12. from app.core.logger import log
  13. from app.core.exceptions import CustomException
  14. from app.core.security import decode_access_token
  15. from app.api.v1.module_system.params.service import ParamsService
  16. class CustomCORSMiddleware(CORSMiddleware):
  17. """CORS跨域中间件"""
  18. def __init__(self, app: ASGIApp) -> None:
  19. super().__init__(
  20. app,
  21. allow_origins=settings.ALLOW_ORIGINS,
  22. allow_methods=settings.ALLOW_METHODS,
  23. allow_headers=settings.ALLOW_HEADERS,
  24. allow_credentials=settings.ALLOW_CREDENTIALS,
  25. expose_headers=settings.CORS_EXPOSE_HEADERS,
  26. )
  27. class RequestLogMiddleware(BaseHTTPMiddleware):
  28. """
  29. 记录请求日志中间件: 提供一个基础的中间件类,允许你自定义请求和响应处理逻辑。
  30. """
  31. def __init__(self, app: ASGIApp) -> None:
  32. super().__init__(app)
  33. @staticmethod
  34. def _extract_session_id_from_request(request: Request) -> str | None:
  35. """
  36. 从请求中提取session_id(支持从Token或已设置的scope中获取)
  37. 参数:
  38. - request (Request): 请求对象
  39. 返回:
  40. - str | None: 会话ID,如果无法提取则返回None
  41. """
  42. # 1. 先检查 scope 中是否已经有 session_id(登录接口会设置)
  43. session_id = request.scope.get('session_id')
  44. if session_id:
  45. return session_id
  46. # 2. 尝试从 Authorization Header 中提取
  47. try:
  48. authorization = request.headers.get("Authorization")
  49. if not authorization:
  50. return None
  51. # 处理Bearer token
  52. token = authorization.replace('Bearer ', '').strip()
  53. # 解码token
  54. payload = decode_access_token(token)
  55. if not payload or not hasattr(payload, 'sub'):
  56. return None
  57. # 从payload中提取session_id
  58. user_info = json.loads(payload.sub)
  59. session_id = user_info.get("session_id")
  60. # 同时设置到request.scope中,避免后续重复解析
  61. if session_id:
  62. request.scope["session_id"] = session_id
  63. return session_id
  64. except Exception:
  65. # 解析失败静默处理,返回None(可能是未认证请求)
  66. return None
  67. async def dispatch(
  68. self, request: Request, call_next: RequestResponseEndpoint
  69. ) -> Response:
  70. start_time = time.time()
  71. # 尝试提取session_id
  72. session_id = self._extract_session_id_from_request(request)
  73. # 组装请求日志字段
  74. log_fields = [
  75. f"请求来源: {request.client.host if request.client else '未知'}",
  76. f"请求方法: {request.method}",
  77. f"请求路径: {request.url.path}",
  78. ]
  79. log.info(log_fields)
  80. try:
  81. # 初始化响应变量
  82. response = None
  83. # 获取请求路径
  84. path = request.scope.get("path")
  85. # 尝试获取客户端真实IP
  86. request_ip = None
  87. x_forwarded_for = request.headers.get('X-Forwarded-For')
  88. if x_forwarded_for:
  89. # 取第一个 IP 地址,通常为客户端真实 IP
  90. request_ip = x_forwarded_for.split(',')[0].strip()
  91. else:
  92. # 若没有 X-Forwarded-For 头,则使用 request.client.host
  93. request_ip = request.client.host if request.client else None
  94. # 检查是否启用演示模式
  95. demo_enable = False
  96. ip_white_list = []
  97. white_api_list_path = []
  98. ip_black_list = []
  99. try:
  100. # 从应用实例获取Redis连接
  101. redis = request.app.state.redis
  102. if not redis:
  103. raise Exception("无法获取Redis连接")
  104. # 使用ParamsService获取系统配置
  105. system_config = await ParamsService.get_system_config_for_middleware(redis)
  106. # 提取配置值
  107. demo_enable = system_config["demo_enable"]
  108. ip_white_list = system_config["ip_white_list"]
  109. white_api_list_path = system_config["white_api_list_path"]
  110. ip_black_list = system_config["ip_black_list"]
  111. except Exception as e:
  112. log.error(f"获取系统配置失败: {e}")
  113. # 检查是否需要拦截请求
  114. should_block = False
  115. block_reason = ""
  116. # 1. 首先检查IP是否在黑名单中
  117. if request_ip and request_ip in ip_black_list:
  118. should_block = True
  119. block_reason = f"IP地址 {request_ip} 在黑名单中"
  120. # 2. 如果不在黑名单中,检查是否在演示模式下需要拦截
  121. elif demo_enable in ["true", "True"] and request.method != "GET":
  122. # 在演示模式下,非GET请求需要检查白名单
  123. is_ip_whitelisted = request_ip in ip_white_list
  124. is_path_whitelisted = path in white_api_list_path
  125. if not is_ip_whitelisted and not is_path_whitelisted:
  126. should_block = True
  127. block_reason = f"演示模式下拦截非GET请求,IP: {request_ip}, 路径: {path}"
  128. if should_block:
  129. # 增强安全审计:记录详细的拦截日志
  130. log.warning([
  131. f"会话ID: {session_id or '未认证'}",
  132. f"请求被拦截: {block_reason}",
  133. f"请求来源: {request_ip}",
  134. f"请求方法: {request.method}",
  135. f"请求路径: {path}",
  136. f"用户代理: {request.headers.get('user-agent', '未知')}",
  137. f"演示模式: {demo_enable}"
  138. ])
  139. # 拦截请求
  140. return ErrorResponse(msg="演示环境,禁止操作")
  141. else:
  142. # 正常处理请求
  143. response = await call_next(request)
  144. # 计算处理时间并添加到响应头
  145. process_time = round(time.time() - start_time, 5)
  146. response.headers["X-Process-Time"] = str(process_time)
  147. # 构建响应日志信息
  148. content_length = response.headers.get('content-length', '0')
  149. response_info = (
  150. f"响应状态: {response.status_code}, "
  151. f"响应内容长度: {content_length}, "
  152. f"处理时间: {round(process_time * 1000, 3)}ms"
  153. )
  154. log.info(response_info)
  155. return response
  156. except CustomException as e:
  157. log.error(f"中间件处理异常: {str(e)}")
  158. return ErrorResponse(msg=f"系统异常,请联系管理员", data=str(e))
  159. class CustomGZipMiddleware(GZipMiddleware):
  160. """GZip压缩中间件"""
  161. def __init__(self, app: ASGIApp) -> None:
  162. super().__init__(
  163. app,
  164. minimum_size=settings.GZIP_MIN_SIZE,
  165. compresslevel=settings.GZIP_COMPRESS_LEVEL
  166. )