base_crud.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559
  1. # -*- coding: utf-8 -*-
  2. from pydantic import BaseModel
  3. from typing import TypeVar, Sequence, Generic, Dict, Any, List, Optional, Type, Union
  4. from sqlalchemy.sql.elements import ColumnElement
  5. from sqlalchemy.orm import selectinload
  6. from sqlalchemy.engine import Result
  7. from sqlalchemy import asc, func, select, delete, Select, desc, update, text
  8. from sqlalchemy import inspect as sa_inspect
  9. from app.core.base_model import MappedBase
  10. from app.core.exceptions import CustomException
  11. from app.core.permission import Permission
  12. from app.api.v1.module_system.auth.schema import AuthSchema
  13. ModelType = TypeVar("ModelType", bound=MappedBase)
  14. CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
  15. UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
  16. OutSchemaType = TypeVar("OutSchemaType", bound=BaseModel)
  17. class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
  18. """基础数据层"""
  19. def __init__(self, model: Type[ModelType], auth: AuthSchema) -> None:
  20. """
  21. 初始化CRUDBase类
  22. 参数:
  23. - model (Type[ModelType]): 数据模型类。
  24. - auth (AuthSchema): 认证信息。
  25. 返回:
  26. - None
  27. """
  28. self.model = model
  29. self.auth = auth
  30. async def get(self, preload: Optional[List[Union[str, Any]]] = None, **kwargs) -> Optional[ModelType]:
  31. """
  32. 根据条件获取单个对象
  33. 参数:
  34. - preload (Optional[List[Union[str, Any]]]): 预加载关系,支持关系名字符串或SQLAlchemy loader option
  35. - **kwargs: 查询条件
  36. 返回:
  37. - Optional[ModelType]: 对象实例
  38. 异常:
  39. - CustomException: 查询失败时抛出异常
  40. """
  41. try:
  42. conditions = await self.__build_conditions(**kwargs)
  43. sql = select(self.model).where(*conditions)
  44. # 应用可配置的预加载选项
  45. for opt in self.__loader_options(preload):
  46. sql = sql.options(opt)
  47. sql = await self.__filter_permissions(sql)
  48. result: Result = await self.auth.db.execute(sql)
  49. obj = result.scalars().first()
  50. return obj
  51. except Exception as e:
  52. raise CustomException(msg=f"获取查询失败: {str(e)}")
  53. async def list(self, search: Optional[Dict] = None, order_by: Optional[List[Dict[str, str]]] = None, preload: Optional[List[Union[str, Any]]] = None) -> Sequence[ModelType]:
  54. """
  55. 根据条件获取对象列表
  56. 参数:
  57. - search (Optional[Dict]): 查询条件,格式为 {'id': value, 'name': value}
  58. - order_by (Optional[List[Dict[str, str]]]): 排序字段,格式为 [{'id': 'asc'}, {'name': 'desc'}]
  59. - preload (Optional[List[Union[str, Any]]]): 预加载关系,支持关系名字符串或SQLAlchemy loader option
  60. 返回:
  61. - Sequence[ModelType]: 对象列表
  62. 异常:
  63. - CustomException: 查询失败时抛出异常
  64. """
  65. try:
  66. conditions = await self.__build_conditions(**search) if search else []
  67. order = order_by or [{'id': 'asc'}]
  68. sql = select(self.model).where(*conditions).order_by(*self.__order_by(order))
  69. # 应用可配置的预加载选项
  70. for opt in self.__loader_options(preload):
  71. sql = sql.options(opt)
  72. sql = await self.__filter_permissions(sql)
  73. result: Result = await self.auth.db.execute(sql)
  74. return result.scalars().all()
  75. except Exception as e:
  76. raise CustomException(msg=f"列表查询失败: {str(e)}")
  77. async def list_sql(self, sql:str,params: Optional[Dict[str, Any]] = None) -> List[Dict]:
  78. """
  79. 根据条件获取对象列表
  80. 参数:
  81. - search (Optional[Dict]): 查询条件,格式为 {'id': value, 'name': value}
  82. - order_by (Optional[List[Dict[str, str]]]): 排序字段,格式为 [{'id': 'asc'}, {'name': 'desc'}]
  83. - preload (Optional[List[Union[str, Any]]]): 预加载关系,支持关系名字符串或SQLAlchemy loader option
  84. 返回:
  85. - Sequence[ModelType]: 对象列表
  86. 异常:
  87. - CustomException: 查询失败时抛出异常
  88. """
  89. try:
  90. business_params = params.copy() if params else {}
  91. result = await self.execute_raw_sql(
  92. sql=sql,
  93. params=business_params,
  94. fetch_one=False,
  95. scalar=False
  96. )
  97. return result
  98. except Exception as e:
  99. raise CustomException(msg=f"列表查询失败: {str(e)}")
  100. async def tree_list(self, search: Optional[Dict] = None, order_by: Optional[List[Dict[str, str]]] = None, children_attr: str = 'children', preload: Optional[List[Union[str, Any]]] = None) -> Sequence[ModelType]:
  101. """
  102. 获取树形结构数据列表
  103. 参数:
  104. - search (Optional[Dict]): 查询条件
  105. - order_by (Optional[List[Dict[str, str]]]): 排序字段
  106. - children_attr (str): 子节点属性名
  107. - preload (Optional[List[Union[str, Any]]]): 额外预加载关系,若为None则默认包含children_attr
  108. 返回:
  109. - Sequence[ModelType]: 树形结构数据列表
  110. 异常:
  111. - CustomException: 查询失败时抛出异常
  112. """
  113. try:
  114. conditions = await self.__build_conditions(**search) if search else []
  115. order = order_by or [{'id': 'asc'}]
  116. sql = select(self.model).where(*conditions).order_by(*self.__order_by(order))
  117. # 处理预加载选项
  118. final_preload = preload
  119. # 如果没有提供preload且children_attr存在,则添加到预加载选项中
  120. if preload is None and children_attr and hasattr(self.model, children_attr):
  121. # 获取模型默认预加载选项
  122. model_defaults = getattr(self.model, "__loader_options__", [])
  123. # 将children_attr添加到默认预加载选项中
  124. final_preload = list(model_defaults) + [children_attr]
  125. # 应用预加载选项
  126. for opt in self.__loader_options(final_preload):
  127. sql = sql.options(opt)
  128. sql = await self.__filter_permissions(sql)
  129. result: Result = await self.auth.db.execute(sql)
  130. return result.scalars().all()
  131. except Exception as e:
  132. raise CustomException(msg=f"树形列表查询失败: {str(e)}")
  133. async def page(self, offset: int, limit: int, order_by: List[Dict[str, str]], search: Dict, out_schema: Type[OutSchemaType], preload: Optional[List[Union[str, Any]]] = None) -> Dict:
  134. """
  135. 获取分页数据
  136. 参数:
  137. - offset (int): 偏移量
  138. - limit (int): 每页数量
  139. - order_by (List[Dict[str, str]]): 排序字段
  140. - search (Dict): 查询条件
  141. - out_schema (Type[OutSchemaType]): 输出数据模型
  142. - preload (Optional[List[Union[str, Any]]]): 预加载关系
  143. 返回:
  144. - Dict: 分页数据
  145. 异常:
  146. - CustomException: 查询失败时抛出异常
  147. """
  148. try:
  149. conditions = await self.__build_conditions(**search) if search else []
  150. order = order_by or [{'id': 'asc'}]
  151. sql = select(self.model).where(*conditions).order_by(*self.__order_by(order))
  152. # 应用预加载选项
  153. for opt in self.__loader_options(preload):
  154. sql = sql.options(opt)
  155. sql = await self.__filter_permissions(sql)
  156. # 优化count查询:使用主键计数而非全表扫描
  157. mapper = sa_inspect(self.model)
  158. pk_cols = list(getattr(mapper, "primary_key", []))
  159. if pk_cols:
  160. # 使用主键的第一列进行计数(主键必定非NULL,性能更好)
  161. count_sql = select(func.count(pk_cols[0])).select_from(self.model)
  162. else:
  163. # 降级方案:使用count(*)
  164. count_sql = select(func.count()).select_from(self.model)
  165. if conditions:
  166. count_sql = count_sql.where(*conditions)
  167. count_sql = await self.__filter_permissions(count_sql)
  168. total_result = await self.auth.db.execute(count_sql)
  169. total = total_result.scalar() or 0
  170. result: Result = await self.auth.db.execute(sql.offset(offset).limit(limit))
  171. objs = result.scalars().all()
  172. return {
  173. "page_no": offset // limit + 1 if limit else 1,
  174. "page_size": limit if limit else 10,
  175. "total": total,
  176. "has_next": offset + limit < total,
  177. "items": [out_schema.model_validate(obj).model_dump() for obj in objs]
  178. }
  179. except Exception as e:
  180. raise CustomException(msg=f"分页查询失败: {str(e)}")
  181. async def create(self, data: Union[CreateSchemaType, Dict]) -> ModelType:
  182. """
  183. 创建新对象
  184. 参数:
  185. - data (Union[CreateSchemaType, Dict]): 对象属性
  186. 返回:
  187. - ModelType: 新创建的对象实例
  188. 异常:
  189. - CustomException: 创建失败时抛出异常
  190. """
  191. try:
  192. obj_dict = data if isinstance(data, dict) else data.model_dump()
  193. obj = self.model(**obj_dict)
  194. # 设置字段值(只检查一次current_user)
  195. if self.auth.user:
  196. if hasattr(obj, "created_id"):
  197. setattr(obj, "created_id", self.auth.user.id)
  198. if hasattr(obj, "updated_id"):
  199. setattr(obj, "updated_id", self.auth.user.id)
  200. self.auth.db.add(obj)
  201. await self.auth.db.flush()
  202. await self.auth.db.refresh(obj)
  203. return obj
  204. except Exception as e:
  205. raise CustomException(msg=f"创建失败: {str(e)}")
  206. async def update(self, id: int, data: Union[UpdateSchemaType, Dict]) -> ModelType:
  207. """
  208. 更新对象
  209. 参数:
  210. - id (int): 对象ID
  211. - data (Union[UpdateSchemaType, Dict]): 更新的属性及值
  212. 返回:
  213. - ModelType: 更新后的对象实例
  214. 异常:
  215. - CustomException: 更新失败时抛出异常
  216. """
  217. try:
  218. obj_dict = data if isinstance(data, dict) else data.model_dump(exclude_unset=True, exclude={"id"})
  219. obj = await self.get(id=id)
  220. if not obj:
  221. raise CustomException(msg="更新对象不存在")
  222. # 设置字段值(只检查一次current_user)
  223. if self.auth.user:
  224. if hasattr(obj, "updated_id"):
  225. setattr(obj, "updated_id", self.auth.user.id)
  226. for key, value in obj_dict.items():
  227. if hasattr(obj, key):
  228. setattr(obj, key, value)
  229. await self.auth.db.flush()
  230. await self.auth.db.refresh(obj)
  231. # 权限二次确认:flush后再次验证对象仍在权限范围内
  232. # 防止并发修改导致的权限逃逸(如其他事务修改了created_id)
  233. verify_obj = await self.get(id=id)
  234. if not verify_obj:
  235. # 对象已被删除或权限已失效
  236. raise CustomException(msg="更新失败,对象不存在或无权限访问")
  237. return obj
  238. except Exception as e:
  239. raise CustomException(msg=f"更新失败: {str(e)}")
  240. async def delete(self, ids: List[int]) -> None:
  241. """
  242. 删除对象
  243. 参数:
  244. - ids (List[int]): 对象ID列表
  245. 异常:
  246. - CustomException: 删除失败时抛出异常
  247. """
  248. try:
  249. # 先查询确认权限,避免删除无权限的数据
  250. objs = await self.list(search={"id": ("in", ids)})
  251. accessible_ids = [obj.id for obj in objs]
  252. # 检查是否所有ID都有权限访问
  253. inaccessible_count = len(ids) - len(accessible_ids)
  254. if inaccessible_count > 0:
  255. raise CustomException(msg=f"无权限删除{inaccessible_count}条数据")
  256. if not accessible_ids:
  257. return # 没有可删除的数据
  258. mapper = sa_inspect(self.model)
  259. pk_cols = list(getattr(mapper, "primary_key", []))
  260. if not pk_cols:
  261. raise CustomException(msg="模型缺少主键,无法删除")
  262. if len(pk_cols) > 1:
  263. raise CustomException(msg="暂不支持复合主键的批量删除")
  264. # 只删除有权限的数据
  265. sql = delete(self.model).where(pk_cols[0].in_(accessible_ids))
  266. await self.auth.db.execute(sql)
  267. await self.auth.db.flush()
  268. except Exception as e:
  269. raise CustomException(msg=f"删除失败: {str(e)}")
  270. async def clear(self) -> None:
  271. """
  272. 清空对象表
  273. 异常:
  274. - CustomException: 清空失败时抛出异常
  275. """
  276. try:
  277. sql = delete(self.model)
  278. await self.auth.db.execute(sql)
  279. await self.auth.db.flush()
  280. except Exception as e:
  281. raise CustomException(msg=f"清空失败: {str(e)}")
  282. async def execute_raw_sql(
  283. self,
  284. sql: str,
  285. params: Optional[Dict[str, Any]] = None,
  286. fetch_one: bool = False,
  287. scalar: bool = False
  288. ) -> Optional[Union[Dict, List[Dict], Any, Sequence[Any]]]:
  289. try:
  290. # ---------------------- 1. 严格校验输入:避免空SQL、非字符串等错误 ----------------------
  291. # 校验SQL是否为非空字符串
  292. if not isinstance(sql, str) or len(sql.strip()) == 0:
  293. raise CustomException(msg="传入的原始SQL不能为空且必须为字符串类型")
  294. # 初始化最终SQL(仅对字符串操作,规避TextClause错误)
  295. final_sql = sql.strip()
  296. # 初始化最终参数:兜底空字典,避免None报错
  297. final_params = params.copy() if (params and isinstance(params, Dict)) else {}
  298. # ---------------------- 3. 核心:执行SQL(仅包装一次TextClause,不进行后续字符串操作) ----------------------
  299. # 包装为TextClause(SQLAlchemy执行原始SQL必需,仅执行此步骤,无额外操作)
  300. raw_sql_clause = text(final_sql)
  301. # 执行SQL:传入上层已合并(业务+权限)的参数,确保所有占位符绑定完成
  302. db_result: Result = await self.auth.db.execute(raw_sql_clause, final_params)
  303. # ---------------------- 4. 处理结果:格式统一,兼容上层list_sql的返回需求 ----------------------
  304. # 处理SELECT查询(返回字典/字典列表,兼容Sequence[ModelType])
  305. if final_sql.upper().startswith("SELECT"):
  306. # 标量结果(单个值,如COUNT(*))
  307. if scalar:
  308. return db_result.scalar() if fetch_one else db_result.scalars().all()
  309. # 完整行结果:转换为字典列表,避免字段名丢失,兜底空列表
  310. row_mappings = db_result.mappings().all()
  311. result_list = [dict(row) for row in row_mappings] if row_mappings else []
  312. # 单条/多条结果返回,确保返回序列类型
  313. if fetch_one:
  314. return result_list[0] if len(result_list) > 0 else None
  315. else:
  316. return result_list
  317. # 处理DML操作(INSERT/UPDATE/DELETE):刷新会话,返回None
  318. await self.auth.db.flush()
  319. return None
  320. # ---------------------- 5. 异常捕获:补充上下文,方便排查上层拼接后的问题 ----------------------
  321. except CustomException as ce:
  322. raise ce
  323. except Exception as e:
  324. error_msg = (
  325. f"执行原始SQL失败:{str(e)}"
  326. f"\n 执行的SQL:{final_sql[:500]}..."
  327. f"\n 传入的参数:{final_params}"
  328. )
  329. raise CustomException(msg=error_msg)
  330. async def set(self, ids: List[int], **kwargs) -> None:
  331. """
  332. 批量更新对象
  333. 参数:
  334. - ids (List[int]): 对象ID列表
  335. - **kwargs: 更新的属性及值
  336. 异常:
  337. - CustomException: 更新失败时抛出异常
  338. """
  339. try:
  340. # 先查询确认权限,避免更新无权限的数据
  341. objs = await self.list(search={"id": ("in", ids)})
  342. accessible_ids = [obj.id for obj in objs]
  343. # 检查是否所有ID都有权限访问
  344. inaccessible_count = len(ids) - len(accessible_ids)
  345. if inaccessible_count > 0:
  346. raise CustomException(msg=f"无权限更新{inaccessible_count}条数据")
  347. if not accessible_ids:
  348. return # 没有可更新的数据
  349. mapper = sa_inspect(self.model)
  350. pk_cols = list(getattr(mapper, "primary_key", []))
  351. if not pk_cols:
  352. raise CustomException(msg="模型缺少主键,无法更新")
  353. if len(pk_cols) > 1:
  354. raise CustomException(msg="暂不支持复合主键的批量更新")
  355. # 只更新有权限的数据
  356. sql = update(self.model).where(pk_cols[0].in_(accessible_ids)).values(**kwargs)
  357. await self.auth.db.execute(sql)
  358. await self.auth.db.flush()
  359. except CustomException:
  360. raise
  361. except Exception as e:
  362. raise CustomException(msg=f"批量更新失败: {str(e)}")
  363. async def __filter_permissions(self, sql: Select) -> Select:
  364. """
  365. 过滤数据权限(仅用于Select)。
  366. """
  367. filter = Permission(
  368. model=self.model,
  369. auth=self.auth
  370. )
  371. return await filter.filter_query(sql)
  372. async def __build_conditions(self, **kwargs) -> List[ColumnElement]:
  373. """
  374. 构建查询条件
  375. 参数:
  376. - **kwargs: 查询参数
  377. 返回:
  378. - List[ColumnElement]: SQL条件表达式列表
  379. 异常:
  380. - CustomException: 查询参数不存在时抛出异常
  381. """
  382. conditions = []
  383. for key, value in kwargs.items():
  384. if value is None or value == "":
  385. continue
  386. attr = getattr(self.model, key)
  387. if isinstance(value, tuple):
  388. seq, val = value
  389. if seq == "None":
  390. conditions.append(attr.is_(None))
  391. elif seq == "not None":
  392. conditions.append(attr.isnot(None))
  393. elif seq == "date" and val:
  394. conditions.append(func.date_format(attr, "%Y-%m-%d") == val)
  395. elif seq == "month" and val:
  396. conditions.append(func.date_format(attr, "%Y-%m") == val)
  397. elif seq == "like" and val:
  398. conditions.append(attr.like(f"%{val}%"))
  399. elif seq == "in" and val:
  400. conditions.append(attr.in_(val))
  401. elif seq == "between" and isinstance(val, (list, tuple)) and len(val) == 2:
  402. conditions.append(attr.between(val[0], val[1]))
  403. elif seq == "!=" and val:
  404. conditions.append(attr != val)
  405. elif seq == ">" and val:
  406. conditions.append(attr > val)
  407. elif seq == ">=" and val:
  408. conditions.append(attr >= val)
  409. elif seq == "<" and val:
  410. conditions.append(attr < val)
  411. elif seq == "<=" and val:
  412. conditions.append(attr <= val)
  413. elif seq == "==" and val:
  414. conditions.append(attr == val)
  415. else:
  416. conditions.append(attr == value)
  417. return conditions
  418. def __order_by(self, order_by: List[Dict[str, str]]) -> List[ColumnElement]:
  419. """
  420. 获取排序字段
  421. 参数:
  422. - order_by (List[Dict[str, str]]): 排序字段列表,格式为 [{'id': 'asc'}, {'name': 'desc'}]
  423. 返回:
  424. - List[ColumnElement]: 排序字段列表
  425. 异常:
  426. - CustomException: 排序字段不存在时抛出异常
  427. """
  428. columns = []
  429. for order in order_by:
  430. for field, direction in order.items():
  431. column = getattr(self.model, field)
  432. columns.append(desc(column) if direction.lower() == 'desc' else asc(column))
  433. return columns
  434. def __loader_options(self, preload: Optional[List[Union[str, Any]]] = None) -> List[Any]:
  435. """
  436. 构建预加载选项
  437. 参数:
  438. - preload (Optional[List[Union[str, Any]]]): 预加载关系,支持关系名字符串或SQLAlchemy loader option
  439. 返回:
  440. - List[Any]: 预加载选项列表
  441. """
  442. options = []
  443. # 获取模型定义的默认加载选项
  444. model_loader_options = getattr(self.model, '__loader_options__', [])
  445. # 合并所有需要预加载的选项
  446. all_preloads = set(model_loader_options)
  447. if preload:
  448. for opt in preload:
  449. if isinstance(opt, str):
  450. all_preloads.add(opt)
  451. elif preload == []:
  452. # 如果明确指定空列表,则不使用任何预加载
  453. all_preloads = set()
  454. # 处理所有预加载选项
  455. for opt in all_preloads:
  456. if isinstance(opt, str):
  457. # 使用selectinload来避免在异步环境中的MissingGreenlet错误
  458. if hasattr(self.model, opt):
  459. options.append(selectinload(getattr(self.model, opt)))
  460. else:
  461. # 直接使用非字符串的加载选项
  462. options.append(opt)
  463. return options