base_crud.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. # -*- coding: utf-8 -*-
  2. from pydantic import BaseModel
  3. from typing import TypeVar, Sequence, Generic, Dict, Any, List, Optional, Type, Union
  4. from sqlalchemy.sql.elements import ColumnElement
  5. from sqlalchemy.orm import selectinload
  6. from sqlalchemy.engine import Result
  7. from sqlalchemy import asc, func, select, delete, Select, desc, update
  8. from sqlalchemy import inspect as sa_inspect
  9. from app.core.base_model import MappedBase
  10. from app.core.exceptions import CustomException
  11. from app.core.permission import Permission
  12. from app.api.v1.module_system.auth.schema import AuthSchema
  13. ModelType = TypeVar("ModelType", bound=MappedBase)
  14. CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
  15. UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
  16. OutSchemaType = TypeVar("OutSchemaType", bound=BaseModel)
  17. class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
  18. """基础数据层"""
  19. def __init__(self, model: Type[ModelType], auth: AuthSchema) -> None:
  20. """
  21. 初始化CRUDBase类
  22. 参数:
  23. - model (Type[ModelType]): 数据模型类。
  24. - auth (AuthSchema): 认证信息。
  25. 返回:
  26. - None
  27. """
  28. self.model = model
  29. self.auth = auth
  30. async def get(self, preload: Optional[List[Union[str, Any]]] = None, **kwargs) -> Optional[ModelType]:
  31. """
  32. 根据条件获取单个对象
  33. 参数:
  34. - preload (Optional[List[Union[str, Any]]]): 预加载关系,支持关系名字符串或SQLAlchemy loader option
  35. - **kwargs: 查询条件
  36. 返回:
  37. - Optional[ModelType]: 对象实例
  38. 异常:
  39. - CustomException: 查询失败时抛出异常
  40. """
  41. try:
  42. conditions = await self.__build_conditions(**kwargs)
  43. sql = select(self.model).where(*conditions)
  44. # 应用可配置的预加载选项
  45. for opt in self.__loader_options(preload):
  46. sql = sql.options(opt)
  47. sql = await self.__filter_permissions(sql)
  48. result: Result = await self.auth.db.execute(sql)
  49. obj = result.scalars().first()
  50. return obj
  51. except Exception as e:
  52. raise CustomException(msg=f"获取查询失败: {str(e)}")
  53. async def list(self, search: Optional[Dict] = None, order_by: Optional[List[Dict[str, str]]] = None, preload: Optional[List[Union[str, Any]]] = None) -> Sequence[ModelType]:
  54. """
  55. 根据条件获取对象列表
  56. 参数:
  57. - search (Optional[Dict]): 查询条件,格式为 {'id': value, 'name': value}
  58. - order_by (Optional[List[Dict[str, str]]]): 排序字段,格式为 [{'id': 'asc'}, {'name': 'desc'}]
  59. - preload (Optional[List[Union[str, Any]]]): 预加载关系,支持关系名字符串或SQLAlchemy loader option
  60. 返回:
  61. - Sequence[ModelType]: 对象列表
  62. 异常:
  63. - CustomException: 查询失败时抛出异常
  64. """
  65. try:
  66. conditions = await self.__build_conditions(**search) if search else []
  67. order = order_by or [{'id': 'asc'}]
  68. sql = select(self.model).where(*conditions).order_by(*self.__order_by(order))
  69. # 应用可配置的预加载选项
  70. for opt in self.__loader_options(preload):
  71. sql = sql.options(opt)
  72. sql = await self.__filter_permissions(sql)
  73. result: Result = await self.auth.db.execute(sql)
  74. return result.scalars().all()
  75. except Exception as e:
  76. raise CustomException(msg=f"列表查询失败: {str(e)}")
  77. async def tree_list(self, search: Optional[Dict] = None, order_by: Optional[List[Dict[str, str]]] = None, children_attr: str = 'children', preload: Optional[List[Union[str, Any]]] = None) -> Sequence[ModelType]:
  78. """
  79. 获取树形结构数据列表
  80. 参数:
  81. - search (Optional[Dict]): 查询条件
  82. - order_by (Optional[List[Dict[str, str]]]): 排序字段
  83. - children_attr (str): 子节点属性名
  84. - preload (Optional[List[Union[str, Any]]]): 额外预加载关系,若为None则默认包含children_attr
  85. 返回:
  86. - Sequence[ModelType]: 树形结构数据列表
  87. 异常:
  88. - CustomException: 查询失败时抛出异常
  89. """
  90. try:
  91. conditions = await self.__build_conditions(**search) if search else []
  92. order = order_by or [{'id': 'asc'}]
  93. sql = select(self.model).where(*conditions).order_by(*self.__order_by(order))
  94. # 处理预加载选项
  95. final_preload = preload
  96. # 如果没有提供preload且children_attr存在,则添加到预加载选项中
  97. if preload is None and children_attr and hasattr(self.model, children_attr):
  98. # 获取模型默认预加载选项
  99. model_defaults = getattr(self.model, "__loader_options__", [])
  100. # 将children_attr添加到默认预加载选项中
  101. final_preload = list(model_defaults) + [children_attr]
  102. # 应用预加载选项
  103. for opt in self.__loader_options(final_preload):
  104. sql = sql.options(opt)
  105. sql = await self.__filter_permissions(sql)
  106. result: Result = await self.auth.db.execute(sql)
  107. return result.scalars().all()
  108. except Exception as e:
  109. raise CustomException(msg=f"树形列表查询失败: {str(e)}")
  110. async def page(self, offset: int, limit: int, order_by: List[Dict[str, str]], search: Dict, out_schema: Type[OutSchemaType], preload: Optional[List[Union[str, Any]]] = None) -> Dict:
  111. """
  112. 获取分页数据
  113. 参数:
  114. - offset (int): 偏移量
  115. - limit (int): 每页数量
  116. - order_by (List[Dict[str, str]]): 排序字段
  117. - search (Dict): 查询条件
  118. - out_schema (Type[OutSchemaType]): 输出数据模型
  119. - preload (Optional[List[Union[str, Any]]]): 预加载关系
  120. 返回:
  121. - Dict: 分页数据
  122. 异常:
  123. - CustomException: 查询失败时抛出异常
  124. """
  125. try:
  126. conditions = await self.__build_conditions(**search) if search else []
  127. order = order_by or [{'id': 'asc'}]
  128. sql = select(self.model).where(*conditions).order_by(*self.__order_by(order))
  129. # 应用预加载选项
  130. for opt in self.__loader_options(preload):
  131. sql = sql.options(opt)
  132. sql = await self.__filter_permissions(sql)
  133. # 优化count查询:使用主键计数而非全表扫描
  134. mapper = sa_inspect(self.model)
  135. pk_cols = list(getattr(mapper, "primary_key", []))
  136. if pk_cols:
  137. # 使用主键的第一列进行计数(主键必定非NULL,性能更好)
  138. count_sql = select(func.count(pk_cols[0])).select_from(self.model)
  139. else:
  140. # 降级方案:使用count(*)
  141. count_sql = select(func.count()).select_from(self.model)
  142. if conditions:
  143. count_sql = count_sql.where(*conditions)
  144. count_sql = await self.__filter_permissions(count_sql)
  145. total_result = await self.auth.db.execute(count_sql)
  146. total = total_result.scalar() or 0
  147. result: Result = await self.auth.db.execute(sql.offset(offset).limit(limit))
  148. objs = result.scalars().all()
  149. return {
  150. "page_no": offset // limit + 1 if limit else 1,
  151. "page_size": limit if limit else 10,
  152. "total": total,
  153. "has_next": offset + limit < total,
  154. "items": [out_schema.model_validate(obj).model_dump() for obj in objs]
  155. }
  156. except Exception as e:
  157. raise CustomException(msg=f"分页查询失败: {str(e)}")
  158. async def create(self, data: Union[CreateSchemaType, Dict]) -> ModelType:
  159. """
  160. 创建新对象
  161. 参数:
  162. - data (Union[CreateSchemaType, Dict]): 对象属性
  163. 返回:
  164. - ModelType: 新创建的对象实例
  165. 异常:
  166. - CustomException: 创建失败时抛出异常
  167. """
  168. try:
  169. obj_dict = data if isinstance(data, dict) else data.model_dump()
  170. obj = self.model(**obj_dict)
  171. # 设置字段值(只检查一次current_user)
  172. if self.auth.user:
  173. if hasattr(obj, "created_id"):
  174. setattr(obj, "created_id", self.auth.user.id)
  175. if hasattr(obj, "updated_id"):
  176. setattr(obj, "updated_id", self.auth.user.id)
  177. self.auth.db.add(obj)
  178. await self.auth.db.flush()
  179. await self.auth.db.refresh(obj)
  180. return obj
  181. except Exception as e:
  182. raise CustomException(msg=f"创建失败: {str(e)}")
  183. async def update(self, id: int, data: Union[UpdateSchemaType, Dict]) -> ModelType:
  184. """
  185. 更新对象
  186. 参数:
  187. - id (int): 对象ID
  188. - data (Union[UpdateSchemaType, Dict]): 更新的属性及值
  189. 返回:
  190. - ModelType: 更新后的对象实例
  191. 异常:
  192. - CustomException: 更新失败时抛出异常
  193. """
  194. try:
  195. obj_dict = data if isinstance(data, dict) else data.model_dump(exclude_unset=True, exclude={"id"})
  196. obj = await self.get(id=id)
  197. if not obj:
  198. raise CustomException(msg="更新对象不存在")
  199. # 设置字段值(只检查一次current_user)
  200. if self.auth.user:
  201. if hasattr(obj, "updated_id"):
  202. setattr(obj, "updated_id", self.auth.user.id)
  203. for key, value in obj_dict.items():
  204. if hasattr(obj, key):
  205. setattr(obj, key, value)
  206. await self.auth.db.flush()
  207. await self.auth.db.refresh(obj)
  208. # 权限二次确认:flush后再次验证对象仍在权限范围内
  209. # 防止并发修改导致的权限逃逸(如其他事务修改了created_id)
  210. verify_obj = await self.get(id=id)
  211. if not verify_obj:
  212. # 对象已被删除或权限已失效
  213. raise CustomException(msg="更新失败,对象不存在或无权限访问")
  214. return obj
  215. except Exception as e:
  216. raise CustomException(msg=f"更新失败: {str(e)}")
  217. async def delete(self, ids: List[int]) -> None:
  218. """
  219. 删除对象
  220. 参数:
  221. - ids (List[int]): 对象ID列表
  222. 异常:
  223. - CustomException: 删除失败时抛出异常
  224. """
  225. try:
  226. # 先查询确认权限,避免删除无权限的数据
  227. objs = await self.list(search={"id": ("in", ids)})
  228. accessible_ids = [obj.id for obj in objs]
  229. # 检查是否所有ID都有权限访问
  230. inaccessible_count = len(ids) - len(accessible_ids)
  231. if inaccessible_count > 0:
  232. raise CustomException(msg=f"无权限删除{inaccessible_count}条数据")
  233. if not accessible_ids:
  234. return # 没有可删除的数据
  235. mapper = sa_inspect(self.model)
  236. pk_cols = list(getattr(mapper, "primary_key", []))
  237. if not pk_cols:
  238. raise CustomException(msg="模型缺少主键,无法删除")
  239. if len(pk_cols) > 1:
  240. raise CustomException(msg="暂不支持复合主键的批量删除")
  241. # 只删除有权限的数据
  242. sql = delete(self.model).where(pk_cols[0].in_(accessible_ids))
  243. await self.auth.db.execute(sql)
  244. await self.auth.db.flush()
  245. except Exception as e:
  246. raise CustomException(msg=f"删除失败: {str(e)}")
  247. async def clear(self) -> None:
  248. """
  249. 清空对象表
  250. 异常:
  251. - CustomException: 清空失败时抛出异常
  252. """
  253. try:
  254. sql = delete(self.model)
  255. await self.auth.db.execute(sql)
  256. await self.auth.db.flush()
  257. except Exception as e:
  258. raise CustomException(msg=f"清空失败: {str(e)}")
  259. async def set(self, ids: List[int], **kwargs) -> None:
  260. """
  261. 批量更新对象
  262. 参数:
  263. - ids (List[int]): 对象ID列表
  264. - **kwargs: 更新的属性及值
  265. 异常:
  266. - CustomException: 更新失败时抛出异常
  267. """
  268. try:
  269. # 先查询确认权限,避免更新无权限的数据
  270. objs = await self.list(search={"id": ("in", ids)})
  271. accessible_ids = [obj.id for obj in objs]
  272. # 检查是否所有ID都有权限访问
  273. inaccessible_count = len(ids) - len(accessible_ids)
  274. if inaccessible_count > 0:
  275. raise CustomException(msg=f"无权限更新{inaccessible_count}条数据")
  276. if not accessible_ids:
  277. return # 没有可更新的数据
  278. mapper = sa_inspect(self.model)
  279. pk_cols = list(getattr(mapper, "primary_key", []))
  280. if not pk_cols:
  281. raise CustomException(msg="模型缺少主键,无法更新")
  282. if len(pk_cols) > 1:
  283. raise CustomException(msg="暂不支持复合主键的批量更新")
  284. # 只更新有权限的数据
  285. sql = update(self.model).where(pk_cols[0].in_(accessible_ids)).values(**kwargs)
  286. await self.auth.db.execute(sql)
  287. await self.auth.db.flush()
  288. except CustomException:
  289. raise
  290. except Exception as e:
  291. raise CustomException(msg=f"批量更新失败: {str(e)}")
  292. async def __filter_permissions(self, sql: Select) -> Select:
  293. """
  294. 过滤数据权限(仅用于Select)。
  295. """
  296. filter = Permission(
  297. model=self.model,
  298. auth=self.auth
  299. )
  300. return await filter.filter_query(sql)
  301. async def __build_conditions(self, **kwargs) -> List[ColumnElement]:
  302. """
  303. 构建查询条件
  304. 参数:
  305. - **kwargs: 查询参数
  306. 返回:
  307. - List[ColumnElement]: SQL条件表达式列表
  308. 异常:
  309. - CustomException: 查询参数不存在时抛出异常
  310. """
  311. conditions = []
  312. for key, value in kwargs.items():
  313. if value is None or value == "":
  314. continue
  315. attr = getattr(self.model, key)
  316. if isinstance(value, tuple):
  317. seq, val = value
  318. if seq == "None":
  319. conditions.append(attr.is_(None))
  320. elif seq == "not None":
  321. conditions.append(attr.isnot(None))
  322. elif seq == "date" and val:
  323. conditions.append(func.date_format(attr, "%Y-%m-%d") == val)
  324. elif seq == "month" and val:
  325. conditions.append(func.date_format(attr, "%Y-%m") == val)
  326. elif seq == "like" and val:
  327. conditions.append(attr.like(f"%{val}%"))
  328. elif seq == "in" and val:
  329. conditions.append(attr.in_(val))
  330. elif seq == "between" and isinstance(val, (list, tuple)) and len(val) == 2:
  331. conditions.append(attr.between(val[0], val[1]))
  332. elif seq == "!=" and val:
  333. conditions.append(attr != val)
  334. elif seq == ">" and val:
  335. conditions.append(attr > val)
  336. elif seq == ">=" and val:
  337. conditions.append(attr >= val)
  338. elif seq == "<" and val:
  339. conditions.append(attr < val)
  340. elif seq == "<=" and val:
  341. conditions.append(attr <= val)
  342. elif seq == "==" and val:
  343. conditions.append(attr == val)
  344. else:
  345. conditions.append(attr == value)
  346. return conditions
  347. def __order_by(self, order_by: List[Dict[str, str]]) -> List[ColumnElement]:
  348. """
  349. 获取排序字段
  350. 参数:
  351. - order_by (List[Dict[str, str]]): 排序字段列表,格式为 [{'id': 'asc'}, {'name': 'desc'}]
  352. 返回:
  353. - List[ColumnElement]: 排序字段列表
  354. 异常:
  355. - CustomException: 排序字段不存在时抛出异常
  356. """
  357. columns = []
  358. for order in order_by:
  359. for field, direction in order.items():
  360. column = getattr(self.model, field)
  361. columns.append(desc(column) if direction.lower() == 'desc' else asc(column))
  362. return columns
  363. def __loader_options(self, preload: Optional[List[Union[str, Any]]] = None) -> List[Any]:
  364. """
  365. 构建预加载选项
  366. 参数:
  367. - preload (Optional[List[Union[str, Any]]]): 预加载关系,支持关系名字符串或SQLAlchemy loader option
  368. 返回:
  369. - List[Any]: 预加载选项列表
  370. """
  371. options = []
  372. # 获取模型定义的默认加载选项
  373. model_loader_options = getattr(self.model, '__loader_options__', [])
  374. # 合并所有需要预加载的选项
  375. all_preloads = set(model_loader_options)
  376. if preload:
  377. for opt in preload:
  378. if isinstance(opt, str):
  379. all_preloads.add(opt)
  380. elif preload == []:
  381. # 如果明确指定空列表,则不使用任何预加载
  382. all_preloads = set()
  383. # 处理所有预加载选项
  384. for opt in all_preloads:
  385. if isinstance(opt, str):
  386. # 使用selectinload来避免在异步环境中的MissingGreenlet错误
  387. if hasattr(self.model, opt):
  388. options.append(selectinload(getattr(self.model, opt)))
  389. else:
  390. # 直接使用非字符串的加载选项
  391. options.append(opt)
  392. return options