import logging import time from typing import Any, Callable, Optional from .database import Database logger = logging.getLogger(__name__) class TimeoutError(Exception): pass def wait_for_condition( condition: Callable[[], bool], timeout: float = 30, poll_interval: float = 1, message: str = "等待条件满足", ) -> bool: start = time.time() while time.time() - start < timeout: try: if condition(): logger.info(f"{message}: 成功 (耗时 {time.time() - start:.1f}s)") return True except Exception as e: logger.debug(f"{message}: 检查失败 - {e}") time.sleep(poll_interval) raise TimeoutError(f"{message}: 超时 ({timeout}s)") def wait_for_task( task_type: str, task_id: int, timeout: float = 60, poll_interval: float = 2, db: Optional[Database] = None, ) -> dict: db = db or Database() start = time.time() while time.time() - start < timeout: row = db.query_one( "SELECT status, result, error FROM tb_async_task WHERE task_type = %s AND id = %s", (task_type, task_id) ) if not row: raise ValueError(f"任务不存在: {task_type}#{task_id}") if row["status"] in ("completed", "failed"): logger.info(f"任务完成: {task_type}#{task_id}, status={row['status']}, 耗时 {time.time() - start:.1f}s") return dict(row) logger.debug(f"等待任务: {task_type}#{task_id}, 当前状态={row['status']}") time.sleep(poll_interval) raise TimeoutError(f"任务超时: {task_type}#{task_id} ({timeout}s)") def wait_for_db_condition( sql: str, params: tuple = (), expected: Any = True, timeout: float = 30, poll_interval: float = 1, db: Optional[Database] = None, ) -> Any: db = db or Database() def check(): result = db.scalar(sql, params) if callable(expected): return expected(result) return result == expected wait_for_condition(check, timeout, poll_interval, f"等待 SQL 条件: {sql[:50]}...") return db.scalar(sql, params) def wait_for_record_count( table: str, where_clause: str, expected_count: int, timeout: float = 30, db: Optional[Database] = None, ) -> int: db = db or Database() sql = f"SELECT COUNT(*) FROM {table} WHERE {where_clause}" def check(): count = db.scalar(sql) return count >= expected_count wait_for_condition(check, timeout, 1, f"等待 {table} 记录数 >= {expected_count}") return db.scalar(sql)