import logging from collections import defaultdict from typing import Dict, List, Optional, Tuple from .database import Database logger = logging.getLogger(__name__) TABLE_DEPENDENCIES: Dict[str, List[str]] = { "tb_account": ["tb_shop", "tb_enterprise"], "tb_account_role": ["tb_account", "tb_role"], "tb_shop": [], "tb_enterprise": ["tb_shop"], "tb_role": [], "tb_permission": [], "tb_role_permission": ["tb_role", "tb_permission"], "tb_device": ["tb_shop"], "tb_iot_card": ["tb_shop", "tb_device"], "tb_package": [], "tb_package_series": [], "tb_order": ["tb_iot_card", "tb_package"], "tb_shop_package": ["tb_shop", "tb_package"], "tb_shop_package_series": ["tb_shop", "tb_package_series"], } SOFT_DELETE_TABLES = { "tb_account", "tb_shop", "tb_enterprise", "tb_role", "tb_device", "tb_iot_card", "tb_package", "tb_order", } class CleanupTracker: def __init__(self, db: Optional[Database] = None): self.db = db or Database() self._tracked: Dict[str, List[int]] = defaultdict(list) self._tracked_queries: List[Tuple[str, str]] = [] def track(self, table: str, record_id: int): self._tracked[table].append(record_id) logger.debug(f"追踪: {table}#{record_id}") def track_many(self, table: str, record_ids: List[int]): self._tracked[table].extend(record_ids) logger.debug(f"追踪: {table}#{record_ids}") def track_by_query(self, table: str, where_clause: str): self._tracked_queries.append((table, where_clause)) logger.debug(f"追踪查询: {table} WHERE {where_clause}") def track_with_relations(self, table: str, record_id: int, relations: List[Tuple[str, str]]): self.track(table, record_id) for rel_table, fk_column in relations: ids = self.db.query( f"SELECT id FROM {rel_table} WHERE {fk_column} = %s", (record_id,) ) for row in ids: self.track(rel_table, row["id"]) def cleanup(self): logger.info("开始清理测试数据...") for table, where_clause in reversed(self._tracked_queries): self._delete_by_query(table, where_clause) sorted_tables = self._sort_by_dependency() for table in sorted_tables: ids = self._tracked.get(table, []) if ids: self._delete_records(table, ids) logger.info("测试数据清理完成") def _sort_by_dependency(self) -> List[str]: tables = list(self._tracked.keys()) def get_order(t: str) -> int: deps = TABLE_DEPENDENCIES.get(t, []) if not deps: return 0 return max(get_order(d) for d in deps if d in tables) + 1 if any(d in tables for d in deps) else 0 return sorted(tables, key=get_order, reverse=True) def _delete_records(self, table: str, ids: List[int]): if not ids: return placeholders = ",".join(["%s"] * len(ids)) if table in SOFT_DELETE_TABLES: sql = f"UPDATE {table} SET deleted_at = NOW() WHERE id IN ({placeholders}) AND deleted_at IS NULL" else: sql = f"DELETE FROM {table} WHERE id IN ({placeholders})" try: count = self.db.execute(sql, tuple(ids)) logger.info(f"清理 {table}: {count} 条记录") except Exception as e: logger.error(f"清理 {table} 失败: {e}") def _delete_by_query(self, table: str, where_clause: str): if table in SOFT_DELETE_TABLES: sql = f"UPDATE {table} SET deleted_at = NOW() WHERE {where_clause} AND deleted_at IS NULL" else: sql = f"DELETE FROM {table} WHERE {where_clause}" try: count = self.db.execute(sql) logger.info(f"清理 {table} (查询): {count} 条记录") except Exception as e: logger.error(f"清理 {table} (查询) 失败: {e}")