feat(shop-role): 实现店铺角色继承功能和权限检查优化
All checks were successful
构建并部署到测试环境(无 SSH) / build-and-deploy (push) Successful in 6m39s
All checks were successful
构建并部署到测试环境(无 SSH) / build-and-deploy (push) Successful in 6m39s
- 新增店铺角色管理 API 和数据模型 - 实现角色继承和权限检查逻辑 - 添加流程测试框架和集成测试 - 更新权限服务和账号管理逻辑 - 添加数据库迁移脚本 - 归档 OpenSpec 变更文档 Ultraworked with Sisyphus
This commit is contained in:
17
flow_tests/core/__init__.py
Normal file
17
flow_tests/core/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from .client import APIClient, APIResponse
|
||||
from .auth import AuthManager
|
||||
from .database import Database
|
||||
from .cleanup import CleanupTracker
|
||||
from .mock import MockService
|
||||
from .wait import wait_for_task, wait_for_condition
|
||||
|
||||
__all__ = [
|
||||
"APIClient",
|
||||
"APIResponse",
|
||||
"AuthManager",
|
||||
"Database",
|
||||
"CleanupTracker",
|
||||
"MockService",
|
||||
"wait_for_task",
|
||||
"wait_for_condition",
|
||||
]
|
||||
71
flow_tests/core/auth.py
Normal file
71
flow_tests/core/auth.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from config.settings import settings
|
||||
from .client import APIClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthManager:
|
||||
def __init__(self, client: APIClient):
|
||||
self.client = client
|
||||
self._current_role: Optional[str] = None
|
||||
|
||||
@property
|
||||
def current_role(self) -> Optional[str]:
|
||||
return self._current_role
|
||||
|
||||
def login(self, username: str, password: str) -> bool:
|
||||
resp = self.client.login(username, password)
|
||||
if resp.ok():
|
||||
self._current_role = "custom"
|
||||
return True
|
||||
logger.error(f"登录失败: {resp.msg}")
|
||||
return False
|
||||
|
||||
def logout(self):
|
||||
self.client.clear_token()
|
||||
self._current_role = None
|
||||
|
||||
def _login_preset_account(self, role: str) -> bool:
|
||||
account = settings.get_account(role)
|
||||
if not account:
|
||||
raise ValueError(f"未配置 {role} 账号,请检查配置文件")
|
||||
|
||||
if self.login(account["username"], account["password"]):
|
||||
self._current_role = role
|
||||
return True
|
||||
return False
|
||||
|
||||
def as_super_admin(self) -> 'AuthManager':
|
||||
self._login_preset_account("super_admin")
|
||||
return self
|
||||
|
||||
def as_platform_admin(self) -> 'AuthManager':
|
||||
self._login_preset_account("platform_admin")
|
||||
return self
|
||||
|
||||
def as_agent(self, shop_id: int, username: Optional[str] = None, password: Optional[str] = None) -> 'AuthManager':
|
||||
if username and password:
|
||||
self.login(username, password)
|
||||
else:
|
||||
account = settings.get_account(f"agent_{shop_id}")
|
||||
if account:
|
||||
self.login(account["username"], account["password"])
|
||||
else:
|
||||
raise ValueError(f"未配置 agent_{shop_id} 账号,请提供用户名密码或在配置文件中添加")
|
||||
self._current_role = f"agent_{shop_id}"
|
||||
return self
|
||||
|
||||
def as_enterprise(self, enterprise_id: int, username: Optional[str] = None, password: Optional[str] = None) -> 'AuthManager':
|
||||
if username and password:
|
||||
self.login(username, password)
|
||||
else:
|
||||
account = settings.get_account(f"enterprise_{enterprise_id}")
|
||||
if account:
|
||||
self.login(account["username"], account["password"])
|
||||
else:
|
||||
raise ValueError(f"未配置 enterprise_{enterprise_id} 账号,请提供用户名密码或在配置文件中添加")
|
||||
self._current_role = f"enterprise_{enterprise_id}"
|
||||
return self
|
||||
113
flow_tests/core/cleanup.py
Normal file
113
flow_tests/core/cleanup.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from .database import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TABLE_DEPENDENCIES: Dict[str, List[str]] = {
|
||||
"tb_account": ["tb_shop", "tb_enterprise"],
|
||||
"tb_account_role": ["tb_account", "tb_role"],
|
||||
"tb_shop": [],
|
||||
"tb_enterprise": ["tb_shop"],
|
||||
"tb_role": [],
|
||||
"tb_permission": [],
|
||||
"tb_role_permission": ["tb_role", "tb_permission"],
|
||||
"tb_device": ["tb_shop"],
|
||||
"tb_iot_card": ["tb_shop", "tb_device"],
|
||||
"tb_package": [],
|
||||
"tb_package_series": [],
|
||||
"tb_order": ["tb_iot_card", "tb_package"],
|
||||
"tb_shop_package": ["tb_shop", "tb_package"],
|
||||
"tb_shop_package_series": ["tb_shop", "tb_package_series"],
|
||||
}
|
||||
|
||||
SOFT_DELETE_TABLES = {
|
||||
"tb_account", "tb_shop", "tb_enterprise", "tb_role",
|
||||
"tb_device", "tb_iot_card", "tb_package", "tb_order",
|
||||
}
|
||||
|
||||
|
||||
class CleanupTracker:
|
||||
def __init__(self, db: Optional[Database] = None):
|
||||
self.db = db or Database()
|
||||
self._tracked: Dict[str, List[int]] = defaultdict(list)
|
||||
self._tracked_queries: List[Tuple[str, str]] = []
|
||||
|
||||
def track(self, table: str, record_id: int):
|
||||
self._tracked[table].append(record_id)
|
||||
logger.debug(f"追踪: {table}#{record_id}")
|
||||
|
||||
def track_many(self, table: str, record_ids: List[int]):
|
||||
self._tracked[table].extend(record_ids)
|
||||
logger.debug(f"追踪: {table}#{record_ids}")
|
||||
|
||||
def track_by_query(self, table: str, where_clause: str):
|
||||
self._tracked_queries.append((table, where_clause))
|
||||
logger.debug(f"追踪查询: {table} WHERE {where_clause}")
|
||||
|
||||
def track_with_relations(self, table: str, record_id: int, relations: List[Tuple[str, str]]):
|
||||
self.track(table, record_id)
|
||||
for rel_table, fk_column in relations:
|
||||
ids = self.db.query(
|
||||
f"SELECT id FROM {rel_table} WHERE {fk_column} = %s",
|
||||
(record_id,)
|
||||
)
|
||||
for row in ids:
|
||||
self.track(rel_table, row["id"])
|
||||
|
||||
def cleanup(self):
|
||||
logger.info("开始清理测试数据...")
|
||||
|
||||
for table, where_clause in reversed(self._tracked_queries):
|
||||
self._delete_by_query(table, where_clause)
|
||||
|
||||
sorted_tables = self._sort_by_dependency()
|
||||
|
||||
for table in sorted_tables:
|
||||
ids = self._tracked.get(table, [])
|
||||
if ids:
|
||||
self._delete_records(table, ids)
|
||||
|
||||
logger.info("测试数据清理完成")
|
||||
|
||||
def _sort_by_dependency(self) -> List[str]:
|
||||
tables = list(self._tracked.keys())
|
||||
|
||||
def get_order(t: str) -> int:
|
||||
deps = TABLE_DEPENDENCIES.get(t, [])
|
||||
if not deps:
|
||||
return 0
|
||||
return max(get_order(d) for d in deps if d in tables) + 1 if any(d in tables for d in deps) else 0
|
||||
|
||||
return sorted(tables, key=get_order, reverse=True)
|
||||
|
||||
def _delete_records(self, table: str, ids: List[int]):
|
||||
if not ids:
|
||||
return
|
||||
|
||||
placeholders = ",".join(["%s"] * len(ids))
|
||||
|
||||
if table in SOFT_DELETE_TABLES:
|
||||
sql = f"UPDATE {table} SET deleted_at = NOW() WHERE id IN ({placeholders}) AND deleted_at IS NULL"
|
||||
else:
|
||||
sql = f"DELETE FROM {table} WHERE id IN ({placeholders})"
|
||||
|
||||
try:
|
||||
count = self.db.execute(sql, tuple(ids))
|
||||
logger.info(f"清理 {table}: {count} 条记录")
|
||||
except Exception as e:
|
||||
logger.error(f"清理 {table} 失败: {e}")
|
||||
|
||||
def _delete_by_query(self, table: str, where_clause: str):
|
||||
if table in SOFT_DELETE_TABLES:
|
||||
sql = f"UPDATE {table} SET deleted_at = NOW() WHERE {where_clause} AND deleted_at IS NULL"
|
||||
else:
|
||||
sql = f"DELETE FROM {table} WHERE {where_clause}"
|
||||
|
||||
try:
|
||||
count = self.db.execute(sql)
|
||||
logger.info(f"清理 {table} (查询): {count} 条记录")
|
||||
except Exception as e:
|
||||
logger.error(f"清理 {table} (查询) 失败: {e}")
|
||||
100
flow_tests/core/client.py
Normal file
100
flow_tests/core/client.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from config.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIResponse:
|
||||
status_code: int
|
||||
code: int
|
||||
msg: str
|
||||
data: Any
|
||||
raw: dict
|
||||
|
||||
def ok(self) -> bool:
|
||||
return self.code == 0
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return self.ok()
|
||||
|
||||
|
||||
class APIClient:
|
||||
def __init__(self, base_url: Optional[str] = None):
|
||||
self.base_url = base_url or settings.api_base_url
|
||||
self.timeout = settings.api_timeout
|
||||
self.token: Optional[str] = None
|
||||
self.session = requests.Session()
|
||||
|
||||
def set_token(self, token: str):
|
||||
self.token = token
|
||||
self.session.headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
def clear_token(self):
|
||||
self.token = None
|
||||
self.session.headers.pop("Authorization", None)
|
||||
|
||||
def _request(self, method: str, path: str, **kwargs) -> APIResponse:
|
||||
url = f"{self.base_url}{path}"
|
||||
kwargs.setdefault("timeout", self.timeout)
|
||||
logger.info(f"{method} {path}")
|
||||
|
||||
try:
|
||||
resp = self.session.request(method, url, **kwargs)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"请求失败: {e}")
|
||||
return APIResponse(status_code=0, code=-1, msg=str(e), data=None, raw={})
|
||||
|
||||
try:
|
||||
raw = resp.json()
|
||||
except ValueError:
|
||||
return APIResponse(
|
||||
status_code=resp.status_code, code=-1,
|
||||
msg="响应不是有效的 JSON", data=None, raw={}
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
status_code=resp.status_code,
|
||||
code=raw.get("code", -1),
|
||||
msg=raw.get("msg", ""),
|
||||
data=raw.get("data"),
|
||||
raw=raw,
|
||||
)
|
||||
|
||||
def get(self, path: str, params: Optional[dict] = None, **kwargs) -> APIResponse:
|
||||
return self._request("GET", path, params=params, **kwargs)
|
||||
|
||||
def post(self, path: str, json: Optional[dict] = None, **kwargs) -> APIResponse:
|
||||
return self._request("POST", path, json=json, **kwargs)
|
||||
|
||||
def put(self, path: str, json: Optional[dict] = None, **kwargs) -> APIResponse:
|
||||
return self._request("PUT", path, json=json, **kwargs)
|
||||
|
||||
def delete(self, path: str, **kwargs) -> APIResponse:
|
||||
return self._request("DELETE", path, **kwargs)
|
||||
|
||||
def patch(self, path: str, json: Optional[dict] = None, **kwargs) -> APIResponse:
|
||||
return self._request("PATCH", path, json=json, **kwargs)
|
||||
|
||||
def upload(self, path: str, file, field_name: str = "file", **kwargs) -> APIResponse:
|
||||
files = {field_name: file}
|
||||
return self._request("POST", path, files=files, **kwargs)
|
||||
|
||||
def login(self, username: str, password: str, login_path: str = "/api/admin/auth/login") -> APIResponse:
|
||||
resp = self.post(login_path, json={
|
||||
"username": username,
|
||||
"password": password,
|
||||
})
|
||||
|
||||
if resp.ok() and resp.data:
|
||||
token = resp.data.get("token") or resp.data.get("access_token")
|
||||
if token:
|
||||
self.set_token(token)
|
||||
logger.info(f"登录成功: {username}")
|
||||
|
||||
return resp
|
||||
69
flow_tests/core/database.py
Normal file
69
flow_tests/core/database.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import logging
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import RealDictCursor
|
||||
|
||||
from config.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Database:
|
||||
_instance: Optional['Database'] = None
|
||||
_conn = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._connect()
|
||||
return cls._instance
|
||||
|
||||
def _connect(self):
|
||||
config = settings.db_config
|
||||
self._conn = psycopg2.connect(
|
||||
host=config["host"],
|
||||
port=config["port"],
|
||||
database=config["database"],
|
||||
user=config["user"],
|
||||
password=config["password"],
|
||||
cursor_factory=RealDictCursor,
|
||||
)
|
||||
self._conn.autocommit = True
|
||||
logger.info(f"数据库连接成功: {config['host']}:{config['port']}/{config['database']}")
|
||||
|
||||
def query(self, sql: str, params: tuple = ()) -> List[dict]:
|
||||
with self._conn.cursor() as cur:
|
||||
cur.execute(sql, params)
|
||||
return cur.fetchall()
|
||||
|
||||
def query_one(self, sql: str, params: tuple = ()) -> Optional[dict]:
|
||||
rows = self.query(sql, params)
|
||||
return rows[0] if rows else None
|
||||
|
||||
def scalar(self, sql: str, params: tuple = ()) -> Any:
|
||||
with self._conn.cursor() as cur:
|
||||
cur.execute(sql, params)
|
||||
row = cur.fetchone()
|
||||
if row:
|
||||
return list(row.values())[0]
|
||||
return None
|
||||
|
||||
def execute(self, sql: str, params: tuple = ()) -> int:
|
||||
with self._conn.cursor() as cur:
|
||||
cur.execute(sql, params)
|
||||
return cur.rowcount
|
||||
|
||||
def execute_many(self, sql: str, params_list: List[tuple]) -> int:
|
||||
with self._conn.cursor() as cur:
|
||||
cur.executemany(sql, params_list)
|
||||
return cur.rowcount
|
||||
|
||||
def close(self):
|
||||
if self._conn:
|
||||
self._conn.close()
|
||||
logger.info("数据库连接已关闭")
|
||||
|
||||
|
||||
def get_db() -> Database:
|
||||
return Database()
|
||||
74
flow_tests/core/mock.py
Normal file
74
flow_tests/core/mock.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import redis
|
||||
|
||||
from config.settings import settings
|
||||
from .database import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MockService:
|
||||
def __init__(self, db: Optional[Database] = None):
|
||||
self.db = db or Database()
|
||||
self._init_redis()
|
||||
|
||||
def _init_redis(self):
|
||||
config = settings.redis_config
|
||||
self.redis = redis.Redis(
|
||||
host=config["host"],
|
||||
port=config["port"],
|
||||
db=config["db"],
|
||||
decode_responses=True,
|
||||
)
|
||||
|
||||
def payment_success(self, order_id: int, amount: float, delay: float = 0):
|
||||
if delay > 0:
|
||||
time.sleep(delay)
|
||||
|
||||
self.db.execute(
|
||||
"UPDATE tb_order SET status = %s, paid_at = NOW(), paid_amount = %s WHERE id = %s",
|
||||
("paid", int(amount * 100), order_id)
|
||||
)
|
||||
logger.info(f"模拟支付成功: order_id={order_id}, amount={amount}")
|
||||
|
||||
def payment_failed(self, order_id: int, reason: str = "支付失败"):
|
||||
self.db.execute(
|
||||
"UPDATE tb_order SET status = %s, fail_reason = %s WHERE id = %s",
|
||||
("failed", reason, order_id)
|
||||
)
|
||||
logger.info(f"模拟支付失败: order_id={order_id}, reason={reason}")
|
||||
|
||||
def sms_code(self, phone: str, code: str, expire_seconds: int = 300):
|
||||
key = f"sms:code:{phone}"
|
||||
self.redis.setex(key, expire_seconds, code)
|
||||
logger.info(f"模拟短信验证码: phone={phone}, code={code}")
|
||||
|
||||
def task_complete(self, task_type: str, task_id: int, result: Any = None):
|
||||
self.db.execute(
|
||||
"UPDATE tb_async_task SET status = %s, result = %s, completed_at = NOW() WHERE task_type = %s AND id = %s",
|
||||
("completed", str(result) if result else None, task_type, task_id)
|
||||
)
|
||||
logger.info(f"模拟任务完成: {task_type}#{task_id}")
|
||||
|
||||
def task_failed(self, task_type: str, task_id: int, error: str):
|
||||
self.db.execute(
|
||||
"UPDATE tb_async_task SET status = %s, error = %s, completed_at = NOW() WHERE task_type = %s AND id = %s",
|
||||
("failed", error, task_type, task_id)
|
||||
)
|
||||
logger.info(f"模拟任务失败: {task_type}#{task_id}, error={error}")
|
||||
|
||||
def card_data_balance(self, card_id: int, balance_mb: int):
|
||||
self.db.execute(
|
||||
"UPDATE tb_iot_card SET data_balance = %s WHERE id = %s",
|
||||
(balance_mb, card_id)
|
||||
)
|
||||
logger.info(f"模拟卡片流量: card_id={card_id}, balance={balance_mb}MB")
|
||||
|
||||
def external_api_response(self, api_name: str, response: dict):
|
||||
key = f"mock:api:{api_name}"
|
||||
import json
|
||||
self.redis.setex(key, 300, json.dumps(response))
|
||||
logger.info(f"模拟外部 API: {api_name}")
|
||||
99
flow_tests/core/wait.py
Normal file
99
flow_tests/core/wait.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from .database import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def wait_for_condition(
|
||||
condition: Callable[[], bool],
|
||||
timeout: float = 30,
|
||||
poll_interval: float = 1,
|
||||
message: str = "等待条件满足",
|
||||
) -> bool:
|
||||
start = time.time()
|
||||
|
||||
while time.time() - start < timeout:
|
||||
try:
|
||||
if condition():
|
||||
logger.info(f"{message}: 成功 (耗时 {time.time() - start:.1f}s)")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(f"{message}: 检查失败 - {e}")
|
||||
|
||||
time.sleep(poll_interval)
|
||||
|
||||
raise TimeoutError(f"{message}: 超时 ({timeout}s)")
|
||||
|
||||
|
||||
def wait_for_task(
|
||||
task_type: str,
|
||||
task_id: int,
|
||||
timeout: float = 60,
|
||||
poll_interval: float = 2,
|
||||
db: Optional[Database] = None,
|
||||
) -> dict:
|
||||
db = db or Database()
|
||||
start = time.time()
|
||||
|
||||
while time.time() - start < timeout:
|
||||
row = db.query_one(
|
||||
"SELECT status, result, error FROM tb_async_task WHERE task_type = %s AND id = %s",
|
||||
(task_type, task_id)
|
||||
)
|
||||
|
||||
if not row:
|
||||
raise ValueError(f"任务不存在: {task_type}#{task_id}")
|
||||
|
||||
if row["status"] in ("completed", "failed"):
|
||||
logger.info(f"任务完成: {task_type}#{task_id}, status={row['status']}, 耗时 {time.time() - start:.1f}s")
|
||||
return dict(row)
|
||||
|
||||
logger.debug(f"等待任务: {task_type}#{task_id}, 当前状态={row['status']}")
|
||||
time.sleep(poll_interval)
|
||||
|
||||
raise TimeoutError(f"任务超时: {task_type}#{task_id} ({timeout}s)")
|
||||
|
||||
|
||||
def wait_for_db_condition(
|
||||
sql: str,
|
||||
params: tuple = (),
|
||||
expected: Any = True,
|
||||
timeout: float = 30,
|
||||
poll_interval: float = 1,
|
||||
db: Optional[Database] = None,
|
||||
) -> Any:
|
||||
db = db or Database()
|
||||
|
||||
def check():
|
||||
result = db.scalar(sql, params)
|
||||
if callable(expected):
|
||||
return expected(result)
|
||||
return result == expected
|
||||
|
||||
wait_for_condition(check, timeout, poll_interval, f"等待 SQL 条件: {sql[:50]}...")
|
||||
return db.scalar(sql, params)
|
||||
|
||||
|
||||
def wait_for_record_count(
|
||||
table: str,
|
||||
where_clause: str,
|
||||
expected_count: int,
|
||||
timeout: float = 30,
|
||||
db: Optional[Database] = None,
|
||||
) -> int:
|
||||
db = db or Database()
|
||||
sql = f"SELECT COUNT(*) FROM {table} WHERE {where_clause}"
|
||||
|
||||
def check():
|
||||
count = db.scalar(sql)
|
||||
return count >= expected_count
|
||||
|
||||
wait_for_condition(check, timeout, 1, f"等待 {table} 记录数 >= {expected_count}")
|
||||
return db.scalar(sql)
|
||||
Reference in New Issue
Block a user