| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204 |
- # -*- coding: utf-8 -*-
- import json
- import time
- from starlette.middleware.cors import CORSMiddleware
- from starlette.types import ASGIApp
- from starlette.requests import Request
- from starlette.middleware.gzip import GZipMiddleware
- from starlette.responses import Response
- from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
- from app.common.response import ErrorResponse
- from app.config.setting import settings
- from app.core.logger import log
- from app.core.exceptions import CustomException
- from app.core.security import decode_access_token
- from app.api.v1.module_system.params.service import ParamsService
- class CustomCORSMiddleware(CORSMiddleware):
- """CORS跨域中间件"""
- def __init__(self, app: ASGIApp) -> None:
- super().__init__(
- app,
- allow_origins=settings.ALLOW_ORIGINS,
- allow_methods=settings.ALLOW_METHODS,
- allow_headers=settings.ALLOW_HEADERS,
- allow_credentials=settings.ALLOW_CREDENTIALS,
- expose_headers=settings.CORS_EXPOSE_HEADERS,
- )
- class RequestLogMiddleware(BaseHTTPMiddleware):
- """
- 记录请求日志中间件: 提供一个基础的中间件类,允许你自定义请求和响应处理逻辑。
- """
- def __init__(self, app: ASGIApp) -> None:
- super().__init__(app)
- @staticmethod
- def _extract_session_id_from_request(request: Request) -> str | None:
- """
- 从请求中提取session_id(支持从Token或已设置的scope中获取)
-
- 参数:
- - request (Request): 请求对象
-
- 返回:
- - str | None: 会话ID,如果无法提取则返回None
- """
- # 1. 先检查 scope 中是否已经有 session_id(登录接口会设置)
- session_id = request.scope.get('session_id')
- if session_id:
- return session_id
-
- # 2. 尝试从 Authorization Header 中提取
- try:
- authorization = request.headers.get("Authorization")
- if not authorization:
- return None
-
- # 处理Bearer token
- token = authorization.replace('Bearer ', '').strip()
-
- # 解码token
- payload = decode_access_token(token)
- if not payload or not hasattr(payload, 'sub'):
- return None
-
- # 从payload中提取session_id
- user_info = json.loads(payload.sub)
- session_id = user_info.get("session_id")
-
- # 同时设置到request.scope中,避免后续重复解析
- if session_id:
- request.scope["session_id"] = session_id
-
- return session_id
- except Exception:
- # 解析失败静默处理,返回None(可能是未认证请求)
- return None
- async def dispatch(
- self, request: Request, call_next: RequestResponseEndpoint
- ) -> Response:
- start_time = time.time()
-
- # 尝试提取session_id
- session_id = self._extract_session_id_from_request(request)
-
- # 组装请求日志字段
- log_fields = [
-
- f"请求来源: {request.client.host if request.client else '未知'}",
- f"请求方法: {request.method}",
- f"请求路径: {request.url.path}",
- ]
- log.info(log_fields)
- try:
- # 初始化响应变量
- response = None
-
- # 获取请求路径
- path = request.scope.get("path")
-
- # 尝试获取客户端真实IP
- request_ip = None
- x_forwarded_for = request.headers.get('X-Forwarded-For')
- if x_forwarded_for:
- # 取第一个 IP 地址,通常为客户端真实 IP
- request_ip = x_forwarded_for.split(',')[0].strip()
- else:
- # 若没有 X-Forwarded-For 头,则使用 request.client.host
- request_ip = request.client.host if request.client else None
-
- # 检查是否启用演示模式
- demo_enable = False
- ip_white_list = []
- white_api_list_path = []
- ip_black_list = []
-
- try:
- # 从应用实例获取Redis连接
- redis = request.app.state.redis
- if not redis:
- raise Exception("无法获取Redis连接")
-
- # 使用ParamsService获取系统配置
- system_config = await ParamsService.get_system_config_for_middleware(redis)
- # 提取配置值
- demo_enable = system_config["demo_enable"]
- ip_white_list = system_config["ip_white_list"]
- white_api_list_path = system_config["white_api_list_path"]
- ip_black_list = system_config["ip_black_list"]
-
- except Exception as e:
- log.error(f"获取系统配置失败: {e}")
-
- # 检查是否需要拦截请求
- should_block = False
- block_reason = ""
-
- # 1. 首先检查IP是否在黑名单中
- if request_ip and request_ip in ip_black_list:
- should_block = True
- block_reason = f"IP地址 {request_ip} 在黑名单中"
-
- # 2. 如果不在黑名单中,检查是否在演示模式下需要拦截
- elif demo_enable in ["true", "True"] and request.method != "GET":
- # 在演示模式下,非GET请求需要检查白名单
- is_ip_whitelisted = request_ip in ip_white_list
- is_path_whitelisted = path in white_api_list_path
-
- if not is_ip_whitelisted and not is_path_whitelisted:
- should_block = True
- block_reason = f"演示模式下拦截非GET请求,IP: {request_ip}, 路径: {path}"
-
- if should_block:
- # 增强安全审计:记录详细的拦截日志
- log.warning([
- f"会话ID: {session_id or '未认证'}",
- f"请求被拦截: {block_reason}",
- f"请求来源: {request_ip}",
- f"请求方法: {request.method}",
- f"请求路径: {path}",
- f"用户代理: {request.headers.get('user-agent', '未知')}",
- f"演示模式: {demo_enable}"
- ])
- # 拦截请求
- return ErrorResponse(msg="演示环境,禁止操作")
- else:
- # 正常处理请求
- response = await call_next(request)
-
- # 计算处理时间并添加到响应头
- process_time = round(time.time() - start_time, 5)
- response.headers["X-Process-Time"] = str(process_time)
-
- # 构建响应日志信息
-
- content_length = response.headers.get('content-length', '0')
- response_info = (
- f"响应状态: {response.status_code}, "
- f"响应内容长度: {content_length}, "
- f"处理时间: {round(process_time * 1000, 3)}ms"
- )
- log.info(response_info)
-
- return response
-
- except CustomException as e:
- log.error(f"中间件处理异常: {str(e)}")
- return ErrorResponse(msg=f"系统异常,请联系管理员", data=str(e))
- class CustomGZipMiddleware(GZipMiddleware):
- """GZip压缩中间件"""
- def __init__(self, app: ASGIApp) -> None:
- super().__init__(
- app,
- minimum_size=settings.GZIP_MIN_SIZE,
- compresslevel=settings.GZIP_COMPRESS_LEVEL
- )
|