initialize.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. # -*- coding: utf-8 -*-
  2. import asyncio
  3. import json
  4. from sqlalchemy import select, func
  5. from sqlalchemy.ext.asyncio import AsyncSession
  6. from app.config.path_conf import SCRIPT_DIR
  7. from app.core.logger import log
  8. from app.core.database import async_db_session, async_engine
  9. from app.core.base_model import MappedBase
  10. from app.api.v1.module_system.user.model import UserModel, UserRolesModel
  11. from app.api.v1.module_system.role.model import RoleModel
  12. from app.api.v1.module_system.dept.model import DeptModel
  13. from app.api.v1.module_system.menu.model import MenuModel
  14. from app.api.v1.module_system.params.model import ParamsModel
  15. from app.api.v1.module_system.dict.model import DictTypeModel, DictDataModel
  16. class InitializeData:
  17. """
  18. 初始化数据库和基础数据
  19. """
  20. def __init__(self) -> None:
  21. """
  22. 初始化数据库和基础数据
  23. """
  24. # 按照依赖关系排序:先创建基础表,再创建关联表
  25. self.prepare_init_models = [
  26. MenuModel,
  27. ParamsModel,
  28. DeptModel,
  29. RoleModel,
  30. DictTypeModel,
  31. DictDataModel,
  32. UserModel,
  33. UserRolesModel,
  34. ]
  35. async def __init_create_table(self) -> None:
  36. """
  37. 初始化表结构(第一阶段)
  38. """
  39. try:
  40. # 使用引擎创建所有表
  41. async with async_engine.begin() as conn:
  42. await conn.run_sync(MappedBase.metadata.create_all)
  43. log.info("✅️ 数据库表结构初始化完成")
  44. except asyncio.exceptions.TimeoutError:
  45. log.error("❌️ 数据库表结构初始化超时")
  46. raise
  47. except Exception as e:
  48. log.error(f"❌️ 数据库表结构初始化失败: {str(e)}")
  49. raise
  50. async def __init_data(self, db: AsyncSession) -> None:
  51. """
  52. 初始化基础数据
  53. 参数:
  54. - db (AsyncSession): 异步数据库会话。
  55. """
  56. # 存储字典类型数据的映射,用于后续字典数据的初始化
  57. dict_type_mapping = {}
  58. for model in self.prepare_init_models:
  59. table_name = model.__tablename__
  60. # 检查表中是否已经有数据
  61. count_result = await db.execute(select(func.count()).select_from(model))
  62. existing_count = count_result.scalar()
  63. if existing_count and existing_count > 0:
  64. log.warning(f"⚠️ 跳过 {table_name} 表数据初始化(表已存在 {existing_count} 条记录)")
  65. continue
  66. data = await self.__get_data(table_name)
  67. if not data:
  68. log.warning(f"⚠️ 跳过 {table_name} 表,无初始化数据")
  69. continue
  70. try:
  71. # 特殊处理具有嵌套 children 数据的表
  72. if table_name in ["sys_dept", "sys_menu"]:
  73. # 获取对应的模型类
  74. model_class = DeptModel if table_name == "sys_dept" else MenuModel
  75. objs = self.__create_objects_with_children(data, model_class)
  76. # 处理字典类型表,保存类型映射
  77. elif table_name == "sys_dict_type":
  78. objs = []
  79. for item in data:
  80. obj = model(**item)
  81. objs.append(obj)
  82. dict_type_mapping[item['dict_type']] = obj
  83. # 处理字典数据表,添加dict_type_id关联
  84. elif table_name == "sys_dict_data":
  85. objs = []
  86. for item in data:
  87. dict_type = item.get('dict_type')
  88. if dict_type in dict_type_mapping:
  89. # 添加dict_type_id关联
  90. item['dict_type_id'] = dict_type_mapping[dict_type].id
  91. else:
  92. log.warning(f"⚠️ 未找到字典类型 {dict_type},跳过该字典数据")
  93. continue
  94. objs.append(model(**item))
  95. else:
  96. # 表为空,直接插入全部数据
  97. objs = [model(**item) for item in data]
  98. db.add_all(objs)
  99. await db.flush()
  100. log.info(f"✅️ 已向 {table_name} 表写入初始化数据")
  101. except Exception as e:
  102. log.error(f"❌️ 初始化 {table_name} 表数据失败: {str(e)}")
  103. raise
  104. def __create_objects_with_children(self, data: list[dict], model_class) -> list:
  105. """
  106. 通用递归创建对象函数,处理嵌套的 children 数据
  107. 参数:
  108. - data (list[dict]): 包含嵌套 children 数据的列表。
  109. - model_class: 对应的 SQLAlchemy 模型类。
  110. 返回:
  111. - list: 包含创建的对象的列表。
  112. """
  113. objs = []
  114. def create_object(obj_data: dict):
  115. # 分离 children 数据
  116. children_data = obj_data.pop('children', [])
  117. # 创建当前对象
  118. obj = model_class(**obj_data)
  119. # 递归处理子对象
  120. if children_data:
  121. obj.children = [create_object(child) for child in children_data]
  122. return obj
  123. for item in data:
  124. objs.append(create_object(item))
  125. return objs
  126. async def __get_data(self, filename: str) -> list[dict]:
  127. """
  128. 读取初始化数据文件
  129. 参数:
  130. - filename (str): 文件名(不包含扩展名)。
  131. 返回:
  132. - list[dict]: 解析后的 JSON 数据列表。
  133. """
  134. json_path = SCRIPT_DIR / f'{filename}.json'
  135. if not json_path.exists():
  136. return []
  137. try:
  138. with open(json_path, 'r', encoding='utf-8') as f:
  139. return json.loads(f.read())
  140. except json.JSONDecodeError as e:
  141. log.error(f"❌️ 解析 {json_path} 失败: {str(e)}")
  142. raise
  143. except Exception as e:
  144. log.error(f"❌️ 读取 {json_path} 失败: {str(e)}")
  145. raise
  146. async def init_db(self) -> None:
  147. """
  148. 执行完整初始化流程
  149. """
  150. # 先创建表结构
  151. await self.__init_create_table()
  152. # 再初始化数据
  153. async with async_db_session() as session:
  154. async with session.begin():
  155. await self.__init_data(session)
  156. # session.add_all(objs)
  157. # 确保提交事务
  158. await session.commit()