Enhance interview prompt handling and update README.md
- Introduced a prefix to optimize interview prompts, ensuring agents respond directly with text without invoking tools. - Updated the simulation API to utilize the optimized prompts for individual and batch interviews. - Modified the `get_interview_history` function to allow for flexible platform querying, returning results from both Reddit and Twitter when no platform is specified. - Enhanced README.md to include new prompt optimization details and updated API usage examples for clarity.
This commit is contained in:
@@ -1279,30 +1279,16 @@ class SimulationRunner:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_interview_history(
|
||||
def _get_interview_history_from_db(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
platform: str = "reddit",
|
||||
db_path: str,
|
||||
platform_name: str,
|
||||
agent_id: Optional[int] = None,
|
||||
limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取Interview历史记录(从数据库读取)
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
platform: 平台类型(reddit/twitter)
|
||||
agent_id: 过滤Agent ID(可选)
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
Interview历史记录列表
|
||||
"""
|
||||
"""从单个数据库获取Interview历史"""
|
||||
import sqlite3
|
||||
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
db_path = os.path.join(sim_dir, f"{platform}_simulation.db")
|
||||
|
||||
if not os.path.exists(db_path):
|
||||
return []
|
||||
|
||||
@@ -1312,8 +1298,6 @@ class SimulationRunner:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 构建查询
|
||||
# 注意:ActionType.INTERVIEW.value 应该是字符串形式
|
||||
if agent_id is not None:
|
||||
cursor.execute("""
|
||||
SELECT user_id, info, created_at
|
||||
@@ -1342,13 +1326,66 @@ class SimulationRunner:
|
||||
"response": info.get("response", info),
|
||||
"prompt": info.get("prompt", ""),
|
||||
"timestamp": created_at,
|
||||
"platform": platform
|
||||
"platform": platform_name
|
||||
})
|
||||
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"读取Interview历史失败: {e}")
|
||||
logger.error(f"读取Interview历史失败 ({platform_name}): {e}")
|
||||
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
def get_interview_history(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
platform: str = None,
|
||||
agent_id: Optional[int] = None,
|
||||
limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取Interview历史记录(从数据库读取)
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
platform: 平台类型(reddit/twitter/None)
|
||||
- "reddit": 只获取Reddit平台的历史
|
||||
- "twitter": 只获取Twitter平台的历史
|
||||
- None: 获取两个平台的所有历史
|
||||
agent_id: 指定Agent ID(可选,只获取该Agent的历史)
|
||||
limit: 每个平台返回数量限制
|
||||
|
||||
Returns:
|
||||
Interview历史记录列表
|
||||
"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
|
||||
results = []
|
||||
|
||||
# 确定要查询的平台
|
||||
if platform in ("reddit", "twitter"):
|
||||
platforms = [platform]
|
||||
else:
|
||||
# 不指定platform时,查询两个平台
|
||||
platforms = ["twitter", "reddit"]
|
||||
|
||||
for p in platforms:
|
||||
db_path = os.path.join(sim_dir, f"{p}_simulation.db")
|
||||
platform_results = cls._get_interview_history_from_db(
|
||||
db_path=db_path,
|
||||
platform_name=p,
|
||||
agent_id=agent_id,
|
||||
limit=limit
|
||||
)
|
||||
results.extend(platform_results)
|
||||
|
||||
# 按时间降序排序
|
||||
results.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
|
||||
|
||||
# 如果查询了多个平台,限制总数
|
||||
if len(platforms) > 1 and len(results) > limit:
|
||||
results = results[:limit]
|
||||
|
||||
return results
|
||||
|
||||
|
||||
Reference in New Issue
Block a user