import_util.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. # -*- coding: utf-8 -*-
  2. import importlib
  3. import inspect
  4. import os
  5. from pathlib import Path
  6. from functools import lru_cache
  7. from sqlalchemy import inspect as sa_inspect
  8. from typing import Any, Type
  9. from app.config.path_conf import BASE_DIR
  10. class ImportUtil:
  11. @classmethod
  12. def find_project_root(cls) -> Path:
  13. """
  14. 查找项目根目录
  15. :return: 项目根目录路径
  16. """
  17. return BASE_DIR
  18. @classmethod
  19. def is_valid_model(cls, obj: Any, base_class: Type) -> bool:
  20. """
  21. 验证是否为有效的SQLAlchemy模型类
  22. :param obj: 待验证的对象
  23. :param base_class: SQLAlchemy的基类
  24. :return: 验证结果
  25. """
  26. # 必须继承自base_class且不是base_class本身
  27. if not (inspect.isclass(obj) and issubclass(obj, base_class) and obj is not base_class):
  28. return False
  29. # 必须有表名定义(排除抽象基类)
  30. if not hasattr(obj, '__tablename__') or obj.__tablename__ is None:
  31. return False
  32. # 必须有至少一个列定义
  33. try:
  34. return len(sa_inspect(obj).columns) > 0
  35. except Exception:
  36. return False
  37. @classmethod
  38. @lru_cache(maxsize=256)
  39. def find_models(cls, base_class: Type) -> list[Any]:
  40. """
  41. 查找并过滤有效的模型类,避免重复和无效定义
  42. :param base_class: SQLAlchemy的Base类,用于验证模型类
  43. :return: 有效模型类列表
  44. """
  45. models = []
  46. # 按类对象去重
  47. seen_models = set()
  48. # 按表名去重(防止同表名冲突)
  49. seen_tables = set()
  50. # 记录已经处理过的model.py文件路径
  51. processed_model_files = set()
  52. project_root = cls.find_project_root()
  53. print(f"⏰️ 开始在项目根目录 {project_root} 中查找模型...")
  54. # 排除目录扩展
  55. exclude_dirs = {
  56. 'venv',
  57. '.env',
  58. '.git',
  59. '__pycache__',
  60. 'migrations',
  61. 'alembic',
  62. 'tests',
  63. 'test',
  64. 'docs',
  65. 'examples',
  66. 'scripts',
  67. '.venv',
  68. '__pycache__',
  69. 'static',
  70. 'templates',
  71. 'sql',
  72. 'env'
  73. }
  74. # 定义要搜索的模型目录模式
  75. model_dir_patterns = [
  76. 'model.py',
  77. 'models.py'
  78. ]
  79. # 使用一个更高效的方法来查找所有model.py文件
  80. model_files = []
  81. for root, dirs, files in os.walk(project_root):
  82. # 过滤排除目录
  83. dirs[:] = [d for d in dirs if d not in exclude_dirs]
  84. for file in files:
  85. if file in model_dir_patterns:
  86. file_path = Path(root) / file
  87. # 构建相对于项目根的模块路径
  88. relative_path = file_path.relative_to(project_root)
  89. model_files.append((file_path, relative_path))
  90. print(f"🔍 找到 {len(model_files)} 个模型文件")
  91. # 按模块路径排序,确保先导入基础模块
  92. model_files.sort(key=lambda x: str(x[1]))
  93. for file_path, relative_path in model_files:
  94. # 确保文件路径没有被处理过
  95. if str(file_path) in processed_model_files:
  96. continue
  97. processed_model_files.add(str(file_path))
  98. # 构建模块名(将路径分隔符转换为点)
  99. module_parts = relative_path.parts[:-1] + (relative_path.stem,)
  100. module_name = '.'.join(module_parts)
  101. try:
  102. # 导入模块
  103. module = importlib.import_module(module_name)
  104. # 获取模块中的所有类
  105. for name, obj in inspect.getmembers(module, inspect.isclass):
  106. # 验证模型有效性
  107. if not cls.is_valid_model(obj, base_class):
  108. continue
  109. # 检查类对象重复
  110. if obj in seen_models:
  111. continue
  112. # 检查表名重复
  113. table_name = obj.__tablename__
  114. if table_name in seen_tables:
  115. continue
  116. # 添加到已处理集合
  117. seen_models.add(obj)
  118. seen_tables.add(table_name)
  119. models.append(obj)
  120. print(f'✅️ 找到有效模型: {obj.__module__}.{obj.__name__} (表: {table_name})')
  121. except ImportError as e:
  122. if 'cannot import name' not in str(e):
  123. print(f'❗️ 警告: 无法导入模块 {module_name}: {e}')
  124. except Exception as e:
  125. print(f'❌️ 处理模块 {module_name} 时出错: {e}')
  126. # 查找apscheduler_jobs表的模型(如果存在)
  127. cls._find_apscheduler_model(base_class, models, seen_models, seen_tables)
  128. return models
  129. @classmethod
  130. def _find_apscheduler_model(cls, base_class: Type, models: list[Any], seen_models: set[Any], seen_tables: set[str]):
  131. """
  132. 专门查找APScheduler相关的模型
  133. :param base_class: SQLAlchemy的Base类
  134. :param models: 模型列表
  135. :param seen_models: 已处理的模型集合
  136. :param seen_tables: 已处理的表名集合
  137. """
  138. # 尝试从apscheduler相关模块导入
  139. try:
  140. # 检查是否有自定义的apscheduler模型
  141. for module_name in ['app.core.ap_scheduler', 'app.module_task.scheduler_test']:
  142. try:
  143. module = importlib.import_module(module_name)
  144. for name, obj in inspect.getmembers(module, inspect.isclass):
  145. if cls.is_valid_model(obj, base_class) and hasattr(obj, '__tablename__') and obj.__tablename__ == 'apscheduler_jobs':
  146. if obj not in seen_models and 'apscheduler_jobs' not in seen_tables:
  147. seen_models.add(obj)
  148. seen_tables.add('apscheduler_jobs')
  149. models.append(obj)
  150. print(f'✅️ 找到有效模型: {obj.__module__}.{obj.__name__} (表: apscheduler_jobs)')
  151. except ImportError:
  152. pass
  153. except Exception as e:
  154. print(f'❗️ 查找APScheduler模型时出错: {e}')