| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573 |
- # -*- coding: utf-8 -*-
- import io
- import os
- from pathlib import Path
- import zipfile
- from typing import Any
- from sqlglot.expressions import Add, Alter, Create, Delete, Drop, Expression, Insert, Table, TruncateTable, Update
- from sqlglot import parse as sqlglot_parse
- from app.config.path_conf import BASE_DIR
- from app.config.setting import settings
- from app.core.logger import log
- from app.core.exceptions import CustomException
- from app.api.v1.module_system.auth.schema import AuthSchema
- from .tools.jinja2_template_util import Jinja2TemplateUtil
- from .tools.gen_util import GenUtils
- from .schema import GenTableSchema, GenTableOutSchema, GenTableColumnSchema, GenTableColumnOutSchema, GenTableQueryParam
- from .crud import GenTableColumnCRUD, GenTableCRUD
- def handle_service_exception(func):
- async def wrapper(*args, **kwargs):
- try:
- return await func(*args, **kwargs)
- except CustomException:
- raise
- except Exception as e:
- raise CustomException(msg=f'{func.__name__}执行失败: {str(e)}')
- return wrapper
- class GenTableService:
- """代码生成业务表服务层"""
- @classmethod
- @handle_service_exception
- async def get_gen_table_detail_service(cls, auth: AuthSchema, table_id: int) -> dict:
- """获取业务表详细信息(含字段与其他表列表)。
- - 备注:优先解析`options`为`GenTableOptionSchema`,设置`parent_menu_id`等选项;保证`columns`与`tables`结构完整。
- """
- gen_table = await cls.get_gen_table_by_id_service(auth, table_id)
- return GenTableOutSchema.model_validate(gen_table).model_dump()
- @classmethod
- @handle_service_exception
- async def get_gen_table_list_service(cls, auth: AuthSchema, search: GenTableQueryParam) -> list[dict]:
- """
- 获取代码生成业务表列表信息。
- 参数:
- - auth (AuthSchema): 认证信息。
- - search (GenTableQueryParam): 查询参数模型。
- 返回:
- - list[dict]: 包含业务表列表信息的字典列表。
- """
- gen_table_list_result = await GenTableCRUD(auth=auth).get_gen_table_list(search)
- return [GenTableOutSchema.model_validate(obj).model_dump() for obj in gen_table_list_result]
- @classmethod
- @handle_service_exception
- async def get_gen_db_table_list_service(cls, auth: AuthSchema, search: GenTableQueryParam) -> list[Any]:
- """获取数据库表列表(跨方言)。
- - 备注:返回已转换为字典的结构,适用于前端直接展示;排序参数保留扩展位但当前未使用。
- """
- gen_db_table_list_result = await GenTableCRUD(auth=auth).get_db_table_list(search)
- return gen_db_table_list_result
- @classmethod
- @handle_service_exception
- async def get_gen_db_table_list_by_name_service(cls, auth: AuthSchema, table_names: list[str]) -> list[GenTableOutSchema]:
- """根据表名称组获取数据库表信息。
- - 校验:如有不存在的表名,抛出明确异常;返回统一的`GenTableOutSchema`列表。
- """
- # 验证输入参数
- if not table_names:
- raise CustomException(msg="表名列表不能为空")
-
- gen_db_table_list_result = await GenTableCRUD(auth).get_db_table_list_by_names(table_names)
- # 修复:将GenDBTableSchema对象转换为字典后再传递给GenTableOutSchema
- result = []
- for gen_table in gen_db_table_list_result:
- # 确保table_name不为None
- if gen_table.table_name is not None:
- result.append(GenTableOutSchema(**gen_table.model_dump()))
-
- return result
- @classmethod
- @handle_service_exception
- async def import_gen_table_service(cls, auth: AuthSchema, gen_table_list: list[GenTableOutSchema]) -> bool | None:
- """导入表结构到生成器(持久化并初始化列)。
- - 备注:避免重复导入;为每列调用`GenUtils.init_column_field`填充默认属性,保留语义一致性。
- """
- # 检查是否有表需要导入
- if not gen_table_list:
- raise CustomException(msg="导入的表结构不能为空")
- try:
- for table in gen_table_list:
- table_name = table.table_name
- # 检查表是否已存在
- existing_table = await GenTableCRUD(auth).get_gen_table_by_name(table_name)
- if existing_table:
- raise CustomException(msg=f"以下表已存在,不能重复导入: {table_name}")
- GenUtils.init_table(table)
- if not table.columns:
- table.columns = []
- add_gen_table = await GenTableCRUD(auth).add_gen_table(GenTableSchema.model_validate(table.model_dump()))
- gen_table_columns = await GenTableColumnCRUD(auth).get_gen_db_table_columns_by_name(table_name)
- if len(gen_table_columns) > 0:
- table.id = add_gen_table.id
- for column in gen_table_columns:
- column_schema = GenTableColumnSchema(
- table_id=table.id,
- column_name=column.column_name,
- column_comment=column.column_comment,
- column_type=column.column_type,
- column_length=column.column_length,
- column_default=column.column_default,
- is_pk=column.is_pk,
- is_increment=column.is_increment,
- is_nullable=column.is_nullable,
- is_unique=column.is_unique,
- sort=column.sort,
- python_type=column.python_type,
- python_field=column.python_field,
- )
- GenUtils.init_column_field(column_schema, table)
- await GenTableColumnCRUD(auth).create_gen_table_column_crud(column_schema)
- return True
- except Exception as e:
- raise CustomException(msg=f'导入失败, {str(e)}')
- @classmethod
- @handle_service_exception
- async def create_table_service(cls, auth: AuthSchema, sql: str) -> bool | None:
- """创建表结构并导入至代码生成模块。
- - 校验:使用`sqlglot`确保仅包含`CREATE TABLE`语句;失败抛出明确异常。
- - 唯一性检查:在创建前检查该表是否已存在于数据库中。
- """
- # 验证SQL非空
- if not sql or not sql.strip():
- raise CustomException(msg='SQL语句不能为空')
-
- try:
- # 解析SQL语句
- sql_statements = sqlglot_parse(sql, dialect=settings.DATABASE_TYPE)
- if not sql_statements:
- raise CustomException(msg='无法解析SQL语句,请检查SQL语法')
-
- # 校验sql语句是否为合法的建表语句
- if not cls.__is_valid_create_table(sql_statements):
- raise CustomException(msg='sql语句不是合法的建表语句')
-
- # 获取要创建的表名
- table_names = cls.__get_table_names(sql_statements)
- # 创建CRUD实例
- gen_table_crud = GenTableCRUD(auth=auth)
-
- # 检查每个表是否已存在
- for table_name in table_names:
- # 检查数据库中是否已存在该表
- if await gen_table_crud.check_table_exists(table_name):
- raise CustomException(msg=f'表 {table_name} 已存在,请检查并修改表名后重试')
-
- # 检查代码生成模块中是否已导入该表
- existing_table = await gen_table_crud.get_gen_table_by_name(table_name)
- if existing_table:
- raise CustomException(msg=f'表 {table_name} 已在代码生成模块中存在,请检查并修改表名后重试')
-
- # 表不存在,执行SQL语句创建表
- result = await gen_table_crud.create_table_by_sql(sql_statements)
- if not result:
- raise CustomException(msg=f'创建表 {table_names} 失败,请检查SQL语句')
-
- # 导入表结构到代码生成模块 - 简化逻辑,移除多余的None检查
- gen_table_list = await cls.get_gen_db_table_list_by_name_service(auth, table_names)
-
- import_result = await cls.import_gen_table_service(auth, gen_table_list)
- return import_result
-
- except Exception as e:
- raise CustomException(msg=f'创建表结构失败: {str(e)}')
-
- @classmethod
- @handle_service_exception
- async def execute_sql_service(cls, auth: AuthSchema, gen_table: GenTableOutSchema) -> bool:
- """
- 执行菜单 SQL(INSERT / DO 块)并写入 sys_menu。
- - 仅处理菜单 SQL,不再混杂建表逻辑;
- - 文件不存在时给出友好提示;
- - 统一异常信息,日志与业务提示分离。
- """
- sql_path = f'{BASE_DIR}/sql/menu/{gen_table.module_name}/{gen_table.business_name}.sql'
- # 文件存在性前置检查,避免多余解析开销
- if not os.path.isfile(sql_path):
- raise CustomException(msg=f'菜单 SQL 文件不存在: {sql_path}')
- sql = Path(sql_path).read_text(encoding='utf-8').strip()
- if not sql:
- raise CustomException(msg='菜单 SQL 文件内容为空')
- # 仅做语法校验,不限制关键字;真正的语义安全由数据库权限控制
- try:
- statements = sqlglot_parse(sql, dialect=settings.DATABASE_TYPE)
- if not statements:
- raise CustomException(msg='菜单 SQL 语法解析失败,请检查文件内容')
- except Exception as e:
- log.error(f'菜单 SQL 解析异常: {e}')
- raise CustomException(msg='菜单 SQL 语法错误,请检查文件内容')
- # 执行 SQL
- try:
- await GenTableCRUD(auth).execute_sql(sql)
- log.info(f'成功执行菜单 SQL: {sql_path}')
- return True
- except Exception as e:
- log.error(f'菜单 SQL 执行失败: {e}')
- raise CustomException(msg='菜单 SQL 执行失败,请确认语句及数据库状态')
-
- @classmethod
- def __is_valid_create_table(cls, sql_statements: list[Expression | None]) -> bool:
- """
- 校验SQL语句是否为合法的建表语句。
-
- 参数:
- - sql_statements (list[Expression | None]): SQL的AST列表。
-
- 返回:
- - bool: 校验结果。
- """
- validate_create = [isinstance(sql_statement, Create) for sql_statement in sql_statements]
- validate_forbidden_keywords = [
- isinstance(
- sql_statement,
- (Add, Alter, Delete, Drop, Insert, TruncateTable, Update),
- )
- for sql_statement in sql_statements
- ]
- if not any(validate_create) or any(validate_forbidden_keywords):
- return False
- return True
-
- @classmethod
- def __get_table_names(cls, sql_statements: list[Expression | None]) -> list[str]:
- """
- 获取SQL语句中所有的建表表名。
-
- 参数:
- - sql_statements (list[Expression | None]): SQL的AST列表。
-
- 返回:
- - list[str]: 建表表名列表。
- """
- table_names = []
- for sql_statement in sql_statements:
- if isinstance(sql_statement, Create):
- table = sql_statement.find(Table)
- if table and table.name:
- table_names.append(table.name)
- return list(set(table_names))
- @classmethod
- @handle_service_exception
- async def update_gen_table_service(cls, auth: AuthSchema, data: GenTableSchema, table_id: int) -> dict[str, Any]:
- """编辑业务表信息(含选项与字段)。
- - 备注:将`params`序列化写入`options`以持久化;仅更新存在`id`的列,避免误创建。
- """
- # 处理params为None的情况
- gen_table_info = await cls.get_gen_table_by_id_service(auth, table_id)
- if gen_table_info.id:
- try:
- # 直接调用edit_gen_table方法,它会在内部处理排除嵌套字段的逻辑
- result = await GenTableCRUD(auth).edit_gen_table(table_id, data)
-
- # 处理data.columns为None的情况
- if data.columns:
- for gen_table_column in data.columns:
- # 确保column有id字段
- if hasattr(gen_table_column, 'id') and gen_table_column.id:
- column_schema = GenTableColumnSchema(**gen_table_column.model_dump())
- await GenTableColumnCRUD(auth).update_gen_table_column_crud(gen_table_column.id, column_schema)
- return GenTableOutSchema.model_validate(result).model_dump()
- except Exception as e:
- raise CustomException(msg=str(e))
- else:
- raise CustomException(msg='业务表不存在')
- @classmethod
- @handle_service_exception
- async def delete_gen_table_service(cls, auth: AuthSchema, ids: list[int]) -> None:
- """删除业务表信息(先删字段,再删表)。"""
- # 验证ID列表非空
- if not ids:
- raise CustomException(msg="ID列表不能为空")
-
- try:
- # 先删除相关的字段信息
- await GenTableColumnCRUD(auth=auth).delete_gen_table_column_by_table_id_crud(ids)
- # 再删除表信息
- await GenTableCRUD(auth=auth).delete_gen_table(ids)
- except Exception as e:
- raise CustomException(msg=str(e))
- @classmethod
- @handle_service_exception
- async def get_gen_table_by_id_service(cls, auth: AuthSchema, table_id: int) -> GenTableOutSchema:
- """获取需要生成代码的业务表详细信息。
- - 备注:去除SQLAlchemy内部状态;将`None`值转为适配前端的默认值;解析`options`补充选项。
- """
- gen_table = await GenTableCRUD(auth=auth).get_gen_table_by_id(table_id)
- if not gen_table:
- raise CustomException(msg='业务表不存在')
-
- result = GenTableOutSchema.model_validate(gen_table)
- return result
- @classmethod
- @handle_service_exception
- async def get_gen_table_all_service(cls, auth: AuthSchema) -> list[GenTableOutSchema]:
- """获取所有业务表信息(列表)。"""
- gen_table_all = await GenTableCRUD(auth=auth).get_gen_table_all() or []
- result = []
- for gen_table in gen_table_all:
- try:
- table_out = GenTableOutSchema.model_validate(gen_table)
- result.append(table_out)
- except Exception as e:
- log.error(f"转换业务表时出错: {str(e)}")
- continue
- return result
- @classmethod
- @handle_service_exception
- async def preview_code_service(cls, auth: AuthSchema, table_id: int) -> dict[str, Any]:
- """
- 预览代码(根据模板渲染内存结果)。
- - 备注:构建Jinja2上下文;根据模板类型与前端类型选择模板清单;返回文件名到内容映射。
- """
- gen_table = GenTableOutSchema.model_validate(
- await GenTableCRUD(auth).get_gen_table_by_id(table_id)
- )
- await cls.set_pk_column(gen_table)
- env = Jinja2TemplateUtil.get_env()
- context = Jinja2TemplateUtil.prepare_context(gen_table)
- template_list = Jinja2TemplateUtil.get_template_list()
- preview_code_result = {}
- for template in template_list:
- try:
- render_content = await env.get_template(template).render_async(**context)
- preview_code_result[template] = render_content
- except Exception as e:
- log.error(f"渲染模板 {template} 时出错: {str(e)}")
- # 即使某个模板渲染失败,也继续处理其他模板
- preview_code_result[template] = f"渲染错误: {str(e)}"
- return preview_code_result
- @classmethod
- @handle_service_exception
- async def generate_code_service(cls, auth: AuthSchema, table_name: str) -> bool:
- """生成代码至指定路径(安全写入+可跳过覆盖)。
- - 安全:限制写入在项目根目录内;越界路径自动回退到项目根目录。
- """
- # 验证表名非空
- if not table_name or not table_name.strip():
- raise CustomException(msg='表名不能为空')
-
- env = Jinja2TemplateUtil.get_env()
- render_info = await cls.__get_gen_render_info(auth, table_name)
- gen_table_schema = render_info[3]
- for template in render_info[0]:
- try:
- render_content = await env.get_template(template).render_async(**render_info[2])
- gen_path = cls.__get_gen_path(gen_table_schema, template)
- if not gen_path:
- raise CustomException(msg='【代码生成】生成路径为空')
- # 确保目录存在
- os.makedirs(os.path.dirname(gen_path), exist_ok=True)
- with open(gen_path, 'w', encoding='utf-8') as f:
- f.write(render_content)
- except Exception as e:
- raise CustomException(msg=f'渲染模板失败,表名:{gen_table_schema.table_name},详细错误信息:{str(e)}')
-
- await cls.execute_sql_service(auth, gen_table_schema)
- return True
- @classmethod
- @handle_service_exception
- async def batch_gen_code_service(cls, auth: AuthSchema, table_names: list[str]) -> bytes:
- """
- 批量生成代码并打包为ZIP。
- - 备注:内存生成并压缩,兼容多模板类型;供下载使用。
- """
- # 验证表名列表非空
- if not table_names:
- raise CustomException(msg="表名列表不能为空")
-
- zip_buffer = io.BytesIO()
- with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
- for table_name in table_names:
- if not table_name.strip():
- continue
-
- try:
- env = Jinja2TemplateUtil.get_env()
- render_info = await cls.__get_gen_render_info(auth, table_name)
- for template_file, output_file in zip(render_info[0], render_info[1]):
- render_content = await env.get_template(template_file).render_async(**render_info[2])
- zip_file.writestr(output_file, render_content)
- except Exception as e:
- log.error(f"批量生成代码时处理表 {table_name} 出错: {str(e)}")
- # 继续处理其他表,不中断整个过程
- continue
- zip_data = zip_buffer.getvalue()
- zip_buffer.close()
- return zip_data
- @classmethod
- @handle_service_exception
- async def sync_db_service(cls, auth: AuthSchema, table_name: str) -> None:
- """同步数据库表结构至生成器(保留用户配置)。
- - 备注:按数据库实际字段重建或更新生成器字段;保留字典/查询/展示等用户自定义属性;清理已删除字段。
- """
- # 验证表名非空
- if not table_name or not table_name.strip():
- raise CustomException(msg='表名不能为空')
-
- gen_table = await GenTableCRUD(auth).get_gen_table_by_name(table_name)
- if not gen_table:
- raise CustomException(msg='业务表不存在')
- table = GenTableOutSchema.model_validate(gen_table)
- if not table.id:
- raise CustomException(msg='业务表ID不能为空')
- table_columns = table.columns or []
- table_column_map = {column.column_name: column for column in table_columns}
- # 确保db_table_columns始终是列表类型,避免None值
- db_table_columns = await GenTableColumnCRUD(auth).get_gen_db_table_columns_by_name(table_name) or []
- db_table_columns = [col for col in db_table_columns if col is not None]
- db_table_column_names = [column.column_name for column in db_table_columns]
- try:
- for column in db_table_columns:
- # 仅在缺省时初始化默认属性(包含 table_id 关联)
- GenUtils.init_column_field(column, table)
- # 利用schema层的默认值,移除多余的None检查
- if column.column_name in table_column_map:
- prev_column = table_column_map[column.column_name]
- # 复用旧记录ID,确保执行更新
- if hasattr(prev_column, 'id') and prev_column.id:
- column.id = prev_column.id
- # 保留用户配置的显示与查询属性 - 使用getattr确保安全访问
- if hasattr(prev_column, 'dict_type') and prev_column.dict_type:
- column.dict_type = prev_column.dict_type
- if hasattr(prev_column, 'query_type') and prev_column.query_type:
- column.query_type = prev_column.query_type
- if hasattr(prev_column, 'html_type') and prev_column.html_type:
- column.html_type = prev_column.html_type
- # 保留关键用户自定义属性 - 安全处理is_pk
- is_pk_bool = False
- if hasattr(prev_column, 'is_pk'):
- # 处理不同类型的is_pk值
- if isinstance(prev_column.is_pk, bool):
- is_pk_bool = prev_column.is_pk
- else:
- is_pk_bool = str(prev_column.is_pk) == '1'
-
- # 安全处理nullable属性
- if hasattr(prev_column, 'is_nullable') and not is_pk_bool:
- column.is_nullable = prev_column.is_nullable
-
- # 保留其他重要用户设置
- if hasattr(prev_column, 'python_field'):
- column.python_field = prev_column.python_field or column.python_field
- if hasattr(column, 'id') and column.id:
- await GenTableColumnCRUD(auth).update_gen_table_column_crud(column.id, column)
- else:
- await GenTableColumnCRUD(auth).create_gen_table_column_crud(column)
- else:
- # 设置table_id以确保新字段能正确关联到表
- column.table_id = table.id
- await GenTableColumnCRUD(auth).create_gen_table_column_crud(column)
- del_columns = [column for column in table_columns if column.column_name not in db_table_column_names]
- if del_columns:
- for column in del_columns:
- if hasattr(column, 'id') and column.id:
- await GenTableColumnCRUD(auth).delete_gen_table_column_by_column_id_crud([column.id])
- except Exception as e:
- raise CustomException(msg=f'同步失败: {str(e)}')
- @classmethod
- async def set_pk_column(cls, gen_table: GenTableOutSchema) -> None:
- """设置主键列信息(主表/子表)。
- - 备注:同时兼容`pk`布尔与`is_pk == '1'`字符串两种标识。
- """
- if gen_table.columns:
- for column in gen_table.columns:
- # 修复:确保正确检查主键标识
- if getattr(column, 'pk', False) or getattr(column, 'is_pk', '') == '1':
- gen_table.pk_column = column
- break
- # 如果没有找到主键列且有列存在,使用第一个列作为主键
- if gen_table.pk_column is None and gen_table.columns:
- gen_table.pk_column = gen_table.columns[0]
- @classmethod
- async def __get_gen_render_info(cls, auth: AuthSchema, table_name: str) -> list[Any]:
- """
- 获取生成代码渲染模板相关信息。
-
- 参数:
- - auth (AuthSchema): 认证对象。
- - table_name (str): 业务表名称。
-
- 返回:
- - list[Any]: [模板列表, 输出文件名列表, 渲染上下文, 业务表对象]。
-
- 异常:
- - CustomException: 当业务表不存在或数据转换失败时抛出。
- """
- gen_table_model = await GenTableCRUD(auth=auth).get_gen_table_by_name(table_name)
- # 检查表是否存在
- if gen_table_model is None:
- raise CustomException(msg=f"业务表 {table_name} 不存在")
-
- gen_table = GenTableOutSchema.model_validate(gen_table_model)
- await cls.set_pk_column(gen_table)
- context = Jinja2TemplateUtil.prepare_context(gen_table)
- template_list = Jinja2TemplateUtil.get_template_list()
- output_files = [Jinja2TemplateUtil.get_file_name(template, gen_table) for template in template_list]
- return [template_list, output_files, context, gen_table]
- @classmethod
- def __get_gen_path(cls, gen_table: GenTableOutSchema, template: str) -> str | None:
- """根据GenTableOutSchema对象和模板名称生成路径。"""
- try:
- file_name = Jinja2TemplateUtil.get_file_name(template, gen_table)
- # 默认写入到项目根目录(backend的上一级)
- project_root = str(BASE_DIR.parent)
- full_path = os.path.join(project_root, file_name)
-
- # 确保路径在项目根目录内,防止路径遍历攻击
- if not os.path.abspath(full_path).startswith(os.path.abspath(project_root)):
- log.error(f"路径越界,回退到项目根目录: {file_name}")
- # 回退到项目根目录下的generated文件夹
- full_path = os.path.join(project_root, "generated", os.path.basename(file_name))
-
- return full_path
- except Exception as e:
- log.error(f"生成路径时出错: {str(e)}")
- return None
- class GenTableColumnService:
- """代码生成业务表字段服务层"""
- @classmethod
- @handle_service_exception
- async def get_gen_table_column_list_by_table_id_service(cls, auth: AuthSchema, table_id: int) -> list[dict[str, Any]]:
- """获取业务表字段列表信息(输出模型)。"""
- gen_table_column_list_result = await GenTableColumnCRUD(auth).list_gen_table_column_crud({"table_id": table_id})
- result = [GenTableColumnOutSchema.model_validate(gen_table_column).model_dump() for gen_table_column in gen_table_column_list_result]
- return result
|