Enhance interview prompt handling and update README.md

- Introduced a prefix to optimize interview prompts, ensuring agents respond directly with text without invoking tools.
- Updated the simulation API to utilize the optimized prompts for individual and batch interviews.
- Modified the `get_interview_history` function to allow for flexible platform querying, returning results from both Reddit and Twitter when no platform is specified.
- Enhanced README.md to include new prompt optimization details and updated API usage examples for clarity.
This commit is contained in:
666ghj
2025-12-08 16:08:33 +08:00
parent 1042d50306
commit 1f191cb21e
3 changed files with 120 additions and 33 deletions

View File

@@ -1279,30 +1279,16 @@ class SimulationRunner:
}
@classmethod
def get_interview_history(
def _get_interview_history_from_db(
cls,
simulation_id: str,
platform: str = "reddit",
db_path: str,
platform_name: str,
agent_id: Optional[int] = None,
limit: int = 100
) -> List[Dict[str, Any]]:
"""
获取Interview历史记录从数据库读取
Args:
simulation_id: 模拟ID
platform: 平台类型reddit/twitter
agent_id: 过滤Agent ID可选
limit: 返回数量限制
Returns:
Interview历史记录列表
"""
"""从单个数据库获取Interview历史"""
import sqlite3
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
db_path = os.path.join(sim_dir, f"{platform}_simulation.db")
if not os.path.exists(db_path):
return []
@@ -1312,8 +1298,6 @@ class SimulationRunner:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# 构建查询
# 注意ActionType.INTERVIEW.value 应该是字符串形式
if agent_id is not None:
cursor.execute("""
SELECT user_id, info, created_at
@@ -1342,13 +1326,66 @@ class SimulationRunner:
"response": info.get("response", info),
"prompt": info.get("prompt", ""),
"timestamp": created_at,
"platform": platform
"platform": platform_name
})
conn.close()
except Exception as e:
logger.error(f"读取Interview历史失败: {e}")
logger.error(f"读取Interview历史失败 ({platform_name}): {e}")
return results
@classmethod
def get_interview_history(
cls,
simulation_id: str,
platform: str = None,
agent_id: Optional[int] = None,
limit: int = 100
) -> List[Dict[str, Any]]:
"""
获取Interview历史记录从数据库读取
Args:
simulation_id: 模拟ID
platform: 平台类型reddit/twitter/None
- "reddit": 只获取Reddit平台的历史
- "twitter": 只获取Twitter平台的历史
- None: 获取两个平台的所有历史
agent_id: 指定Agent ID可选只获取该Agent的历史
limit: 每个平台返回数量限制
Returns:
Interview历史记录列表
"""
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
results = []
# 确定要查询的平台
if platform in ("reddit", "twitter"):
platforms = [platform]
else:
# 不指定platform时查询两个平台
platforms = ["twitter", "reddit"]
for p in platforms:
db_path = os.path.join(sim_dir, f"{p}_simulation.db")
platform_results = cls._get_interview_history_from_db(
db_path=db_path,
platform_name=p,
agent_id=agent_id,
limit=limit
)
results.extend(platform_results)
# 按时间降序排序
results.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
# 如果查询了多个平台,限制总数
if len(platforms) > 1 and len(results) > limit:
results = results[:limit]
return results