diff --git a/.DS_Store b/.DS_Store index 0789565..c0395df 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/.env.example b/.env.example index dece3c2..a22e6c5 100644 --- a/.env.example +++ b/.env.example @@ -10,6 +10,11 @@ MINIMAX_API_KEY= MINIMAX_API_BASE=https://api.minimax.io/v1 +# =========================================== +# FAL AI (picture-it image generation) +# =========================================== +FAL_KEY= + # =========================================== # GITEA (Optional - Git sync) # =========================================== @@ -38,14 +43,6 @@ UMAMI_PASSWORD= ADMIN_PASSWORD= UMAMI_DOMAIN=analytics.example.com -# =========================================== -# SHODH MEMORY (Optional - Persistent context) -# =========================================== -SHODH_API_KEY= -SHODH_HOST=http://localhost -SHODH_PORT=3030 -SHODH_USER_ID=default - # =========================================== # GOOGLE ANALYTICS 4 (Optional) # =========================================== @@ -71,6 +68,11 @@ DATAFORSEO_BASE_URL=https://api.dataforseo.com # JINA API - Content extraction JINA_API_KEY= +# =========================================== +# DESIGN SKILLS (Logo, CIP, Icon generation) +# =========================================== +GEMINI_API_KEY= + # LLM Config (MiniMax default, OpenAI compatible) LLM_PROVIDER=minimax LLM_MODEL=MiniMax-Text-01 diff --git a/.opencode/package-lock.json b/.opencode/package-lock.json new file mode 100644 index 0000000..6aa38cb --- /dev/null +++ b/.opencode/package-lock.json @@ -0,0 +1,115 @@ +{ + "name": ".opencode", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "dependencies": { + "@opencode-ai/plugin": "1.3.15" + } + }, + "node_modules/@opencode-ai/plugin": { + "version": "1.3.15", + "resolved": "https://registry.npmjs.org/@opencode-ai/plugin/-/plugin-1.3.15.tgz", + "integrity": "sha512-jZJbuvUXc5Limz8pacQl+ffATjjKGlq+xaA4wTUeW+/spwOf7Yr5Ryyvan8eNlYM8wy6h5SLfznl1rlFpjYC8w==", + "license": "MIT", + "dependencies": { + "@opencode-ai/sdk": "1.3.15", + "zod": "4.1.8" + }, + "peerDependencies": { + "@opentui/core": ">=0.1.96", + "@opentui/solid": ">=0.1.96" + }, + "peerDependenciesMeta": { + "@opentui/core": { + "optional": true + }, + "@opentui/solid": { + "optional": true + } + } + }, + "node_modules/@opencode-ai/sdk": { + "version": "1.3.15", + "resolved": "https://registry.npmjs.org/@opencode-ai/sdk/-/sdk-1.3.15.tgz", + "integrity": "sha512-Uk59C7wsK20wpdr277yx7Xz7TqG5jGqlZUpSW3wDH/7a2K2iBg0lXc2wskHuCXLRXMhXpPZtb4a3SOpPENkkbg==", + "license": "MIT", + "dependencies": { + "cross-spawn": "7.0.6" + } + }, + "node_modules/cross-spawn": { + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", + "license": "MIT", + "dependencies": { + "path-key": "^3.1.0", + "shebang-command": "^2.0.0", + "which": "^2.0.1" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/isexe": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", + "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", + "license": "ISC" + }, + "node_modules/path-key": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", + "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/shebang-command": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", + "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", + "license": "MIT", + "dependencies": { + "shebang-regex": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/shebang-regex": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", + "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/which": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", + "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", + "license": "ISC", + "dependencies": { + "isexe": "^2.0.0" + }, + "bin": { + "node-which": "bin/node-which" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/zod": { + "version": "4.1.8", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/colinhacks" + } + } + } +} diff --git a/output/test/results.json b/output/test/results.json deleted file mode 100644 index 954a584..0000000 --- a/output/test/results.json +++ /dev/null @@ -1,90 +0,0 @@ -{ - "topic": "test", - "generated_at": "2026-03-10T10:41:26.339482", - "channels": { - "facebook": { - "channel": "facebook", - "language": "th", - "variations": [ - { - "id": "facebook_var_1", - "created_at": "2026-03-10T10:41:26.339500", - "primary_text": "[Facebook Post 1] test...", - "headline": "[Headline] test", - "cta": "เรียนรู้เพิ่มเติม", - "hashtags": [ - "#test" - ], - "image": { - "path": "output/test/facebook/images/generated_20260310_104126.png" - } - }, - { - "id": "facebook_var_2", - "created_at": "2026-03-10T10:41:26.339584", - "primary_text": "[Facebook Post 2] test...", - "headline": "[Headline] test", - "cta": "เรียนรู้เพิ่มเติม", - "hashtags": [ - "#test" - ], - "image": { - "path": "output/test/facebook/images/generated_20260310_104126.png" - } - }, - { - "id": "facebook_var_3", - "created_at": "2026-03-10T10:41:26.339605", - "primary_text": "[Facebook Post 3] test...", - "headline": "[Headline] test", - "cta": "เรียนรู้เพิ่มเติม", - "hashtags": [ - "#test" - ], - "image": { - "path": "output/test/facebook/images/generated_20260310_104126.png" - } - }, - { - "id": "facebook_var_4", - "created_at": "2026-03-10T10:41:26.339620", - "primary_text": "[Facebook Post 4] test...", - "headline": "[Headline] test", - "cta": "เรียนรู้เพิ่มเติม", - "hashtags": [ - "#test" - ], - "image": { - "path": "output/test/facebook/images/generated_20260310_104126.png" - } - }, - { - "id": "facebook_var_5", - "created_at": "2026-03-10T10:41:26.339633", - "primary_text": "[Facebook Post 5] test...", - "headline": "[Headline] test", - "cta": "เรียนรู้เพิ่มเติม", - "hashtags": [ - "#test" - ], - "image": { - "path": "output/test/facebook/images/generated_20260310_104126.png" - } - } - ], - "api_ready": { - "platform": "meta", - "api_version": "v18.0", - "endpoint": "/act_{ad_account_id}/adcreatives", - "method": "POST", - "field_mapping": { - "primary_text": "body", - "headline": "title", - "cta": "call_to_action.type", - "image": "story_id or link_data.picture" - } - } - } - }, - "summary": {} -} \ No newline at end of file diff --git a/output/บรการ-podcast-hosting/results.json b/output/บรการ-podcast-hosting/results.json deleted file mode 100644 index 8991f09..0000000 --- a/output/บรการ-podcast-hosting/results.json +++ /dev/null @@ -1,437 +0,0 @@ -{ - "topic": "บริการ podcast hosting", - "generated_at": "2026-03-08T22:51:11.780847", - "channels": { - "facebook": { - "channel": "facebook", - "language": "th", - "variations": [ - { - "id": "facebook_var_1", - "created_at": "2026-03-08T22:51:11.780865", - "primary_text": "[Facebook Post 1] บริการ podcast hosting...", - "headline": "[Headline] บริการ podcast hosting", - "cta": "เรียนรู้เพิ่มเติม", - "hashtags": [ - "#บริการpodcasthosting" - ], - "image": { - "path": "output/บรการ-podcast-hosting/facebook/images/generated_20260308_225111.png" - } - }, - { - "id": "facebook_var_2", - "created_at": "2026-03-08T22:51:11.781143", - "primary_text": "[Facebook Post 2] บริการ podcast hosting...", - "headline": "[Headline] บริการ podcast hosting", - "cta": "เรียนรู้เพิ่มเติม", - "hashtags": [ - "#บริการpodcasthosting" - ], - "image": { - "path": "output/บรการ-podcast-hosting/facebook/images/generated_20260308_225111.png" - } - }, - { - "id": "facebook_var_3", - "created_at": "2026-03-08T22:51:11.781169", - "primary_text": "[Facebook Post 3] บริการ podcast hosting...", - "headline": "[Headline] บริการ podcast hosting", - "cta": "เรียนรู้เพิ่มเติม", - "hashtags": [ - "#บริการpodcasthosting" - ], - "image": { - "path": "output/บรการ-podcast-hosting/facebook/images/generated_20260308_225111.png" - } - }, - { - "id": "facebook_var_4", - "created_at": "2026-03-08T22:51:11.781186", - "primary_text": "[Facebook Post 4] บริการ podcast hosting...", - "headline": "[Headline] บริการ podcast hosting", - "cta": "เรียนรู้เพิ่มเติม", - "hashtags": [ - "#บริการpodcasthosting" - ], - "image": { - "path": "output/บรการ-podcast-hosting/facebook/images/generated_20260308_225111.png" - } - }, - { - "id": "facebook_var_5", - "created_at": "2026-03-08T22:51:11.781204", - "primary_text": "[Facebook Post 5] บริการ podcast hosting...", - "headline": "[Headline] บริการ podcast hosting", - "cta": "เรียนรู้เพิ่มเติม", - "hashtags": [ - "#บริการpodcasthosting" - ], - "image": { - "path": "output/บรการ-podcast-hosting/facebook/images/generated_20260308_225111.png" - } - } - ], - "api_ready": { - "platform": "meta", - "api_version": "v18.0", - "endpoint": "/act_{ad_account_id}/adcreatives", - "method": "POST", - "field_mapping": { - "primary_text": "body", - "headline": "title", - "cta": "call_to_action.type", - "image": "story_id or link_data.picture" - } - } - }, - "google_ads": { - "channel": "google_ads", - "language": "th", - "variations": [ - { - "id": "google_ads_var_1", - "created_at": "2026-03-08T22:51:11.781221", - "headlines": [ - { - "text": "[Headline 1] บริการ podcast hosting" - }, - { - "text": "[Headline 2] บริการ podcast hosting" - }, - { - "text": "[Headline 3] บริการ podcast hosting" - }, - { - "text": "[Headline 4] บริการ podcast hosting" - }, - { - "text": "[Headline 5] บริการ podcast hosting" - }, - { - "text": "[Headline 6] บริการ podcast hosting" - }, - { - "text": "[Headline 7] บริการ podcast hosting" - }, - { - "text": "[Headline 8] บริการ podcast hosting" - }, - { - "text": "[Headline 9] บริการ podcast hosting" - }, - { - "text": "[Headline 10] บริการ podcast hosting" - }, - { - "text": "[Headline 11] บริการ podcast hosting" - }, - { - "text": "[Headline 12] บริการ podcast hosting" - }, - { - "text": "[Headline 13] บริการ podcast hosting" - }, - { - "text": "[Headline 14] บริการ podcast hosting" - }, - { - "text": "[Headline 15] บริการ podcast hosting" - } - ], - "descriptions": [ - { - "text": "[Description 1] Learn more about บริการ podcast hosting" - }, - { - "text": "[Description 2] Learn more about บริการ podcast hosting" - }, - { - "text": "[Description 3] Learn more about บริการ podcast hosting" - }, - { - "text": "[Description 4] Learn more about บริการ podcast hosting" - } - ], - "keywords": [ - "บริการ podcast hosting", - "บริการ บริการ podcast hosting" - ], - "api_ready": { - "platform": "google", - "api_version": "v15.0", - "endpoint": "/google.ads.googleads.v15.services/GoogleAdsService:Mutate" - } - }, - { - "id": "google_ads_var_2", - "created_at": "2026-03-08T22:51:11.781228", - "headlines": [ - { - "text": "[Headline 1] บริการ podcast hosting" - }, - { - "text": "[Headline 2] บริการ podcast hosting" - }, - { - "text": "[Headline 3] บริการ podcast hosting" - }, - { - "text": "[Headline 4] บริการ podcast hosting" - }, - { - "text": "[Headline 5] บริการ podcast hosting" - }, - { - "text": "[Headline 6] บริการ podcast hosting" - }, - { - "text": "[Headline 7] บริการ podcast hosting" - }, - { - "text": "[Headline 8] บริการ podcast hosting" - }, - { - "text": "[Headline 9] บริการ podcast hosting" - }, - { - "text": "[Headline 10] บริการ podcast hosting" - }, - { - "text": "[Headline 11] บริการ podcast hosting" - }, - { - "text": "[Headline 12] บริการ podcast hosting" - }, - { - "text": "[Headline 13] บริการ podcast hosting" - }, - { - "text": "[Headline 14] บริการ podcast hosting" - }, - { - "text": "[Headline 15] บริการ podcast hosting" - } - ], - "descriptions": [ - { - "text": "[Description 1] Learn more about บริการ podcast hosting" - }, - { - "text": "[Description 2] Learn more about บริการ podcast hosting" - }, - { - "text": "[Description 3] Learn more about บริการ podcast hosting" - }, - { - "text": "[Description 4] Learn more about บริการ podcast hosting" - } - ], - "keywords": [ - "บริการ podcast hosting", - "บริการ บริการ podcast hosting" - ], - "api_ready": { - "platform": "google", - "api_version": "v15.0", - "endpoint": "/google.ads.googleads.v15.services/GoogleAdsService:Mutate" - } - }, - { - "id": "google_ads_var_3", - "created_at": "2026-03-08T22:51:11.781232", - "headlines": [ - { - "text": "[Headline 1] บริการ podcast hosting" - }, - { - "text": "[Headline 2] บริการ podcast hosting" - }, - { - "text": "[Headline 3] บริการ podcast hosting" - }, - { - "text": "[Headline 4] บริการ podcast hosting" - }, - { - "text": "[Headline 5] บริการ podcast hosting" - }, - { - "text": "[Headline 6] บริการ podcast hosting" - }, - { - "text": "[Headline 7] บริการ podcast hosting" - }, - { - "text": "[Headline 8] บริการ podcast hosting" - }, - { - "text": "[Headline 9] บริการ podcast hosting" - }, - { - "text": "[Headline 10] บริการ podcast hosting" - }, - { - "text": "[Headline 11] บริการ podcast hosting" - }, - { - "text": "[Headline 12] บริการ podcast hosting" - }, - { - "text": "[Headline 13] บริการ podcast hosting" - }, - { - "text": "[Headline 14] บริการ podcast hosting" - }, - { - "text": "[Headline 15] บริการ podcast hosting" - } - ], - "descriptions": [ - { - "text": "[Description 1] Learn more about บริการ podcast hosting" - }, - { - "text": "[Description 2] Learn more about บริการ podcast hosting" - }, - { - "text": "[Description 3] Learn more about บริการ podcast hosting" - }, - { - "text": "[Description 4] Learn more about บริการ podcast hosting" - } - ], - "keywords": [ - "บริการ podcast hosting", - "บริการ บริการ podcast hosting" - ], - "api_ready": { - "platform": "google", - "api_version": "v15.0", - "endpoint": "/google.ads.googleads.v15.services/GoogleAdsService:Mutate" - } - } - ], - "api_ready": { - "platform": "google", - "api_version": "v15.0", - "service": "GoogleAdsService", - "endpoint": "/google.ads.googleads.v15.services/GoogleAdsService:Mutate", - "resource_hierarchy": [ - "customer", - "campaign", - "ad_group", - "ad_group_ad", - "ad (RESPONSIVE_SEARCH_AD)" - ], - "field_mapping": { - "headlines": "responsive_search_ad.headlines", - "descriptions": "responsive_search_ad.descriptions", - "final_url": "responsive_search_ad.final_urls", - "display_path": "responsive_search_ad.path1, path2", - "keywords": "ad_group_criterion", - "bid_modifier": "ad_group_criterion.cpc_bid_modifier" - }, - "future_integration_notes": [ - "Add conversion_tracking_setup", - "Add value_track_parameters", - "Add ad_schedule_bid_modifiers", - "Add device_bid_modifiers", - "Add location_bid_modifiers", - "Setup enhanced conversions" - ] - } - }, - "blog": { - "channel": "blog", - "language": "th", - "variations": [ - { - "id": "blog_var_1", - "created_at": "2026-03-08T22:51:11.781238", - "markdown": "---\ntitle: \"บริการ podcast hosting - Complete Guide\"\ndescription: \"Learn everything about บริการ podcast hosting in this comprehensive guide\"\nkeywords: [\"บริการ podcast hosting\", \"บริการ บริการ podcast hosting\", \"guide\"]\nslug: บรการ-podcast-hosting\nlang: th\ncategory: guides\ntags: [\"บริการ podcast hosting\", \"guide\"]\ncreated: 2026-03-08\n---\n\n# บริการ podcast hosting: Complete Guide\n\n## Introduction\n\n[Opening hook about บริการ podcast hosting...]\n\n## What is บริการ podcast hosting?\n\n[Definition and explanation...]\n\n## Why บริการ podcast hosting Matters\n\n[Importance and benefits...]\n\n## How to Get Started with บริการ podcast hosting\n\n[Step-by-step guide...]\n\n## Best Practices for บริการ podcast hosting\n\n[Tips and recommendations...]\n\n## Conclusion\n\n[Summary and call-to-action...]\n", - "frontmatter": { - "title": "บริการ podcast hosting - Complete Guide", - "description": "Learn about บริการ podcast hosting", - "slug": "บรการ-podcast-hosting", - "lang": "th" - }, - "word_count": 1500, - "publish_status": "draft" - }, - { - "id": "blog_var_2", - "created_at": "2026-03-08T22:51:11.781250", - "markdown": "---\ntitle: \"บริการ podcast hosting - Complete Guide\"\ndescription: \"Learn everything about บริการ podcast hosting in this comprehensive guide\"\nkeywords: [\"บริการ podcast hosting\", \"บริการ บริการ podcast hosting\", \"guide\"]\nslug: บรการ-podcast-hosting\nlang: th\ncategory: guides\ntags: [\"บริการ podcast hosting\", \"guide\"]\ncreated: 2026-03-08\n---\n\n# บริการ podcast hosting: Complete Guide\n\n## Introduction\n\n[Opening hook about บริการ podcast hosting...]\n\n## What is บริการ podcast hosting?\n\n[Definition and explanation...]\n\n## Why บริการ podcast hosting Matters\n\n[Importance and benefits...]\n\n## How to Get Started with บริการ podcast hosting\n\n[Step-by-step guide...]\n\n## Best Practices for บริการ podcast hosting\n\n[Tips and recommendations...]\n\n## Conclusion\n\n[Summary and call-to-action...]\n", - "frontmatter": { - "title": "บริการ podcast hosting - Complete Guide", - "description": "Learn about บริการ podcast hosting", - "slug": "บรการ-podcast-hosting", - "lang": "th" - }, - "word_count": 1500, - "publish_status": "draft" - }, - { - "id": "blog_var_3", - "created_at": "2026-03-08T22:51:11.781259", - "markdown": "---\ntitle: \"บริการ podcast hosting - Complete Guide\"\ndescription: \"Learn everything about บริการ podcast hosting in this comprehensive guide\"\nkeywords: [\"บริการ podcast hosting\", \"บริการ บริการ podcast hosting\", \"guide\"]\nslug: บรการ-podcast-hosting\nlang: th\ncategory: guides\ntags: [\"บริการ podcast hosting\", \"guide\"]\ncreated: 2026-03-08\n---\n\n# บริการ podcast hosting: Complete Guide\n\n## Introduction\n\n[Opening hook about บริการ podcast hosting...]\n\n## What is บริการ podcast hosting?\n\n[Definition and explanation...]\n\n## Why บริการ podcast hosting Matters\n\n[Importance and benefits...]\n\n## How to Get Started with บริการ podcast hosting\n\n[Step-by-step guide...]\n\n## Best Practices for บริการ podcast hosting\n\n[Tips and recommendations...]\n\n## Conclusion\n\n[Summary and call-to-action...]\n", - "frontmatter": { - "title": "บริการ podcast hosting - Complete Guide", - "description": "Learn about บริการ podcast hosting", - "slug": "บรการ-podcast-hosting", - "lang": "th" - }, - "word_count": 1500, - "publish_status": "draft" - }, - { - "id": "blog_var_4", - "created_at": "2026-03-08T22:51:11.781272", - "markdown": "---\ntitle: \"บริการ podcast hosting - Complete Guide\"\ndescription: \"Learn everything about บริการ podcast hosting in this comprehensive guide\"\nkeywords: [\"บริการ podcast hosting\", \"บริการ บริการ podcast hosting\", \"guide\"]\nslug: บรการ-podcast-hosting\nlang: th\ncategory: guides\ntags: [\"บริการ podcast hosting\", \"guide\"]\ncreated: 2026-03-08\n---\n\n# บริการ podcast hosting: Complete Guide\n\n## Introduction\n\n[Opening hook about บริการ podcast hosting...]\n\n## What is บริการ podcast hosting?\n\n[Definition and explanation...]\n\n## Why บริการ podcast hosting Matters\n\n[Importance and benefits...]\n\n## How to Get Started with บริการ podcast hosting\n\n[Step-by-step guide...]\n\n## Best Practices for บริการ podcast hosting\n\n[Tips and recommendations...]\n\n## Conclusion\n\n[Summary and call-to-action...]\n", - "frontmatter": { - "title": "บริการ podcast hosting - Complete Guide", - "description": "Learn about บริการ podcast hosting", - "slug": "บรการ-podcast-hosting", - "lang": "th" - }, - "word_count": 1500, - "publish_status": "draft" - }, - { - "id": "blog_var_5", - "created_at": "2026-03-08T22:51:11.781279", - "markdown": "---\ntitle: \"บริการ podcast hosting - Complete Guide\"\ndescription: \"Learn everything about บริการ podcast hosting in this comprehensive guide\"\nkeywords: [\"บริการ podcast hosting\", \"บริการ บริการ podcast hosting\", \"guide\"]\nslug: บรการ-podcast-hosting\nlang: th\ncategory: guides\ntags: [\"บริการ podcast hosting\", \"guide\"]\ncreated: 2026-03-08\n---\n\n# บริการ podcast hosting: Complete Guide\n\n## Introduction\n\n[Opening hook about บริการ podcast hosting...]\n\n## What is บริการ podcast hosting?\n\n[Definition and explanation...]\n\n## Why บริการ podcast hosting Matters\n\n[Importance and benefits...]\n\n## How to Get Started with บริการ podcast hosting\n\n[Step-by-step guide...]\n\n## Best Practices for บริการ podcast hosting\n\n[Tips and recommendations...]\n\n## Conclusion\n\n[Summary and call-to-action...]\n", - "frontmatter": { - "title": "บริการ podcast hosting - Complete Guide", - "description": "Learn about บริการ podcast hosting", - "slug": "บรการ-podcast-hosting", - "lang": "th" - }, - "word_count": 1500, - "publish_status": "draft" - } - ], - "api_ready": { - "cms_compatible": [ - "WordPress", - "Contentful", - "Sanity", - "Strapi" - ], - "schema_org": { - "type": "BlogPosting", - "required_fields": [ - "headline", - "description", - "image", - "datePublished", - "author", - "publisher" - ] - } - } - } - }, - "summary": {} -} \ No newline at end of file diff --git a/scripts/install-skills.sh b/scripts/install-skills.sh index bb87747..81e7778 100755 --- a/scripts/install-skills.sh +++ b/scripts/install-skills.sh @@ -51,6 +51,15 @@ setup_unified_env() { [ -f "$env_example" ] || return + # Check if .env already exists in repo - skip interactive setup if it does + if [ -f "$env_file" ]; then + line + print_success "Using existing .env file in project" + line + echo "" + return + fi + line print_info "Unified Configuration Setup" line diff --git a/skills/.DS_Store b/skills/.DS_Store index 172f1a4..9ffc065 100644 Binary files a/skills/.DS_Store and b/skills/.DS_Store differ diff --git a/skills/alphaear-deepear-lite/SKILL.md b/skills/alphaear-deepear-lite/SKILL.md deleted file mode 100644 index 791405c..0000000 --- a/skills/alphaear-deepear-lite/SKILL.md +++ /dev/null @@ -1,32 +0,0 @@ ---- -name: alphaear-deepear-lite -description: Fetch the latest financial signals and transmission-chain analyses from DeepEar Lite. Use when the user needs immediate insights into financial market trends, stock performance factors, and reasoning from the DeepEar Lite dashboard. ---- - -# DeepEar Lite Skill - -## Overview - -Fetch high-frequency financial signals, including titles, summaries, confidence scores, and reasoning directly from the DeepEar Lite platform's real-time data source. - -## Capabilities - -### 1. Fetch Latest Financial Signals - -Use `scripts/deepear_lite.py` via `DeepEarLiteTools`. - -- **Fetch Signals**: `fetch_latest_signals()` - - Retrieves all latest signals from `https://deepear.vercel.app/latest.json`. - - Returns a formatted report of signal titles, sentiment/confidence metrics, summaries, and source links. - -## Dependencies - -- `requests`, `loguru` -- No local database required for this skill. - -## Testing - -Run the test script to verify the connection and data fetching: -```bash -python scripts/deepear_lite.py -``` diff --git a/skills/alphaear-deepear-lite/scripts/deepear_lite.py b/skills/alphaear-deepear-lite/scripts/deepear_lite.py deleted file mode 100644 index 7476fe9..0000000 --- a/skills/alphaear-deepear-lite/scripts/deepear_lite.py +++ /dev/null @@ -1,112 +0,0 @@ -import requests -import json -from loguru import logger -from datetime import datetime - -class DeepEarLiteTools: - """ - Tools for fetching signals from DeepEar Lite (https://deepear.vercel.app/lite). - """ - - LATEST_JSON_URL = "https://deepear.vercel.app/latest.json" - POSTHOG_API_KEY = "phc_1sa745Su9oyVDyCwqHxvYUXyvvWpamcqofuibEOZpJp" - POSTHOG_URL = "https://us.i.posthog.com/capture/" - - def _record_telemetry(self, event_name="skill_deepear_lite_called"): - """ - Record the skill usage to multiple tracking layers. - 1. Hits the HTML page to trigger server-side logs. - 2. Sends a PostHog event to simulate client-side tracking. - """ - headers = { - "User-Agent": "DeepEar-Skill-Agent/1.0 (Awesome-Finance-Skills)", - "Referer": "https://deepear.vercel.app/lite" - } - - # Layer 1: Server-side hit (Vercel Logs, etc.) - try: - requests.get("https://deepear.vercel.app/lite", headers=headers, timeout=5) - logger.debug("Server-side hit recorded.") - except Exception as e: - logger.debug(f"Server-side hit failed: {e}") - - # Layer 2: PostHog (Analytics Dashboard) - try: - import uuid - payload = { - "api_key": self.POSTHOG_API_KEY, - "event": event_name, - "properties": { - "distinct_id": str(uuid.uuid4()), - "app": "awesome-finance-skills", - "skill": "alphaear-deepear-lite", - "timestamp": datetime.now().isoformat(), - "$current_url": "https://deepear.vercel.app/lite", - "lib": "python-requests" - } - } - requests.post(self.POSTHOG_URL, json=payload, timeout=5) - logger.debug(f"PostHog telemetry recorded: {event_name}") - except Exception as e: - logger.debug(f"PostHog telemetry failed: {e}") - - def fetch_latest_signals(self): - """ - Fetch the newest financial signals from DeepEar Lite. - Returns a formatted summary of the latest signals. - """ - # Record telemetry before fetching - self._record_telemetry() - - try: - logger.info(f"Fetching data from {self.LATEST_JSON_URL}") - headers = { - "User-Agent": "DeepEar-Skill-Agent/1.0 (Awesome-Finance-Skills)", - "Referer": "https://deepear.vercel.app/lite" - } - response = requests.get(self.LATEST_JSON_URL, headers=headers, timeout=10) - response.raise_for_status() - data = response.json() - - generated_at = data.get("generated_at", "Unknown") - signals = data.get("signals", []) - - if not signals: - return "No signals found in the latest data." - - report = [f"### DeepEar Lite Signal Report (Updated: {generated_at})\n"] - - for i, signal in enumerate(signals, 1): - title = signal.get("title", "No Title") - summary = signal.get("summary", "No Summary") - sentiment = signal.get("sentiment_score", 0) - confidence = signal.get("confidence", 0) - intensity = signal.get("intensity", 0) - reasoning = signal.get("reasoning", "No Reasoning") - - report.append(f"#### {i}. {title}") - report.append(f"**Sentiment**: {sentiment} | **Confidence**: {confidence} | **Intensity**: {intensity}") - report.append(f"\n**Summary**: {summary}") - report.append(f"\n**Reasoning**: {reasoning}") - - # Check for sources/links - sources = signal.get("sources", []) - if sources: - report.append("\n**Sources**:") - for src in sources: - name = src.get("name", "Link") - url = src.get("url", "#") - report.append(f"- [{name}]({url})") - - report.append("\n" + "-"*40 + "\n") - - return "\n".join(report) - - except Exception as e: - error_msg = f"Error fetching DeepEar Lite data: {str(e)}" - logger.error(error_msg) - return error_msg - -if __name__ == "__main__": - tools = DeepEarLiteTools() - print(tools.fetch_latest_signals()) diff --git a/skills/alphaear-logic-visualizer/SKILL.md b/skills/alphaear-logic-visualizer/SKILL.md deleted file mode 100644 index 82c177d..0000000 --- a/skills/alphaear-logic-visualizer/SKILL.md +++ /dev/null @@ -1,31 +0,0 @@ ---- -name: alphaear-logic-visualizer -description: Create visualize finance logic diagrams (e.g., Draw.io XML) to explain complex finance transmission chains or finance logic flows. ---- - -# AlphaEar Logic Visualizer Skill - -## Overview - -This skill specializes in creating visual representations of logic flows, specifically generating Draw.io XML compatible diagrams. It is useful for visualizing investment theses or signal transmission chains. - -## Capabilities - -### 1. Generate Draw.io Diagrams - -### 1. Generate Draw.io Diagrams (Agentic Workflow) - -**YOU (the Agent)** are the Visualizer. Use the prompts in `references/PROMPTS.md` to generate the XML. - -**Workflow:** -1. **Generate XML**: Use the **Draw.io XML Generation Prompt** from `references/PROMPTS.md` to convert your logical chain into XML. -2. **Save/Render**: Use `scripts/visualizer.py` method `render_drawio_to_html(xml_content, filename)` to save the XML into a viewable HTML file for the user. - -**Example Usage (Conceptual):** -- **Agent Action**: "I will now generate a Draw.io XML for the transmission chain..." -- **Tool Call**: `visualizer.render_drawio_to_html(xml_content="...", filename="chain_visual.html")` - - -## Dependencies - -- None (Standard Library for string manipulation). diff --git a/skills/alphaear-logic-visualizer/references/PROMPTS.md b/skills/alphaear-logic-visualizer/references/PROMPTS.md deleted file mode 100644 index 99d6882..0000000 --- a/skills/alphaear-logic-visualizer/references/PROMPTS.md +++ /dev/null @@ -1,52 +0,0 @@ -# AlphaEar Logic Visualizer Prompts - -## Draw.io XML Generation - -**Prompt:** - -```markdown -You are an expert at creating Draw.io (MxGraph) diagrams in XML format. -Your task is to generate a valid MXGraphModel XML based on the logic description. - -### Rules: -1. Output ONLY the XML code. Start with `` and end with ``. -2. Do not use compressed XML. Use plain XML. -3. Use standard shapes: `rounded=1;whiteSpace=wrap;html=1;` for boxes. -4. **Auto-layout Strategy**: - - Identify "layers" or "stages" in the logic. - - Assign X coordinates based on layers (e.g., 0, 200, 400). - - Assign Y coordinates to distribute nodes vertically (e.g., 0, 100, 200). - - Ensure nodes do not overlap. -5. **Edges**: Connect nodes logically using ``. - -### Template: - - - - - - - - - - - - - - - - -``` - -**Task Input:** -```markdown -Please generate a Draw.io XML diagram for the following logic flow: - -**Title**: {title} - -**Nodes and Logic**: -{nodes_json} - -Ensure the layout flows logically from Left to Right (or Top to Bottom for hierarchies). -Use different colors for 'Positive' (Green/fillColor=#d5e8d4), 'Negative' (Red/fillColor=#f8cecc), and 'Neutral' (Grey/fillColor=#f5f5f5) impacts. -``` diff --git a/skills/alphaear-logic-visualizer/scripts/__init__.py b/skills/alphaear-logic-visualizer/scripts/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/alphaear-logic-visualizer/scripts/visualizer.py b/skills/alphaear-logic-visualizer/scripts/visualizer.py deleted file mode 100644 index 85a38cd..0000000 --- a/skills/alphaear-logic-visualizer/scripts/visualizer.py +++ /dev/null @@ -1,472 +0,0 @@ -import os -from typing import Dict, List, Any, Optional -import pandas as pd -from loguru import logger -from pyecharts.charts import Kline, Line, Bar, Grid, Radar, Graph -from pyecharts import options as opts -from pyecharts.globals import ThemeType -from datetime import datetime, timedelta - -class VisualizerTools: - """可视化工具库 - 使用 Pyecharts 生成 HTML 图表""" - - @staticmethod - def generate_stock_chart( - df: pd.DataFrame, - ticker: str, - title: str = None, - prediction: Optional[List[float]] = None, - forecast: Optional[Any] = None, # ForecastResult instance - ground_truth: Optional[pd.DataFrame] = None # For training visualization - ) -> Grid: - """ - 生成股票 K 线图 + 成交量 + 预测趋势 (支持多状态 K 线) - """ - if df.empty: - return None - - # 数据预处理 - df = df.sort_values('date') - dates = [str(d)[:10] for d in df['date'].tolist()] - k_data = df[['open', 'close', 'low', 'high']].values.tolist() - volumes = df['volume'].tolist() - - if not title: - title = f"{ticker} 股价走势与预测" - - legend_items = ["日K"] - - # 1. 处理传统的简单预测线 (Line) - pred_line = None - if prediction and not forecast: - try: - last_date_str = dates[-1] - last_date = datetime.strptime(last_date_str, "%Y-%m-%d") - - pred_dates = [] - for i in range(1, len(prediction) + 1): - pred_dates.append((last_date + timedelta(days=i)).strftime("%Y-%m-%d")) - - ext_dates = dates + pred_dates - last_close = df.iloc[-1]['close'] - pred_values = [None] * (len(df) - 1) + [float(last_close)] + prediction - - pred_line = ( - Line() - .add_xaxis(ext_dates) - .add_yaxis( - "AI预测趋势", - pred_values, - is_connect_nones=True, - is_symbol_show=True, - linestyle_opts=opts.LineStyleOpts(width=2, type_="dashed", color="#FF8C00"), - label_opts=opts.LabelOpts(is_show=False) - ) - ) - dates = ext_dates - legend_items.append("AI预测趋势") - except Exception as e: - logger.error(f"Failed to process simple prediction: {e}") - - # 2. 处理复杂的 Kronos 预测 (Kline) - base_kline = None - adj_kline = None - - if forecast: - try: - # 获取预测数据点 - base_points = forecast.base_forecast # List[KLinePoint] - adj_points = forecast.adjusted_forecast # List[KLinePoint] - - # 提取日期 - pred_dates = [str(p.date)[:10] for p in (adj_points or base_points)] - - # 检查日期是否已经包含在主 dates 中,如果没有则扩展 - if pred_dates and pred_dates[0] not in dates: - dates = dates + pred_dates - - # 构建 Baseline 预测 K 线数据 - if base_points: - # 前面填充 None - base_k_data = [[None]*4] * len(df) + [[p.open, p.close, p.low, p.high] for p in base_points] - base_kline = ( - Kline() - .add_xaxis(dates) - .add_yaxis( - "模型原始预测", - base_k_data, - itemstyle_opts=opts.ItemStyleOpts( - color="transparent", - color0="transparent", - border_color="#FF8C00", # 橙色 - border_color0="#FF8C00", - opacity=0.6, - border_type="dashed" - ), - ) - ) - legend_items.append("模型原始预测") - - # 构建 Adjusted 调优 K 线数据 - if adj_points: - adj_k_data = [[None]*4] * len(df) + [[p.open, p.close, p.low, p.high] for p in adj_points] - adj_kline = ( - Kline() - .add_xaxis(dates) - .add_yaxis( - "LLM调优预测", - adj_k_data, - itemstyle_opts=opts.ItemStyleOpts( - color="#9333ea", # 紫色 - color0="#9333ea", - border_color="#9333ea", - border_color0="#9333ea", - opacity=0.8 - ), - ) - ) - legend_items.append("LLM调优预测") - - except Exception as e: - logger.error(f"Failed to process complex forecast: {e}") - - # 2.5 处理 Ground Truth (用于训练评估可视化) - gt_line = None - if ground_truth is not None and not ground_truth.empty: - try: - gt_dates = [str(d)[:10] for d in ground_truth['date'].tolist()] - # 确保日期包含在 dates 中 - for d in gt_dates: - if d not in dates: - dates.append(d) - dates = sorted(list(set(dates))) # Re-sort to maintain order - - gt_values = [None] * len(dates) - for _, row in ground_truth.iterrows(): - d_str = str(row['date'])[:10] - if d_str in dates: - idx = dates.index(d_str) - gt_values[idx] = float(row['close']) - - gt_line = ( - Line() - .add_xaxis(dates) - .add_yaxis( - "真实走势 (GT)", - gt_values, - is_connect_nones=True, - linestyle_opts=opts.LineStyleOpts(width=3, color="#2ecc71"), # 绿色粗线 - label_opts=opts.LabelOpts(is_show=False) - ) - ) - legend_items.append("真实走势 (GT)") - except Exception as e: - logger.error(f"Failed to process ground truth: {e}") - - # 3. 主 K 线图 - # 为了展示预测,也需要对主 K 线数据进行填充 - main_k_data = k_data + [[None]*4] * (len(dates) - len(df)) - - kline = ( - Kline() - .add_xaxis(dates) - .add_yaxis( - "日K", - main_k_data, - itemstyle_opts=opts.ItemStyleOpts( - color="#ef4444", # 跌 - color0="#22c55e", # 涨 - border_color="#ef4444", - border_color0="#22c55e", - ), - ) - .set_global_opts( - title_opts=opts.TitleOpts(title=title, pos_left="center"), - xaxis_opts=opts.AxisOpts(is_scale=True), - yaxis_opts=opts.AxisOpts( - is_scale=True, - splitarea_opts=opts.SplitAreaOpts( - is_show=True, areastyle_opts=opts.AreaStyleOpts(opacity=1) - ), - ), - legend_opts=opts.LegendOpts(is_show=True, pos_top="5%"), - datazoom_opts=[opts.DataZoomOpts(type_="inside", range_start=50)], - tooltip_opts=opts.TooltipOpts(trigger="axis", axis_pointer_type="cross"), - ) - ) - - # Overlap all series - if pred_line: kline.overlap(pred_line) - if base_kline: kline.overlap(base_kline) - if adj_kline: kline.overlap(adj_kline) - if gt_line: kline.overlap(gt_line) - - # 4. 成交量柱状图 - # 同理扩展成交量数据 - ext_volumes = volumes + [0] * (len(dates) - len(df)) - - bar = ( - Bar() - .add_xaxis(dates) - .add_yaxis( - "成交量", - ext_volumes, - xaxis_index=1, - yaxis_index=1, - label_opts=opts.LabelOpts(is_show=False), - itemstyle_opts=opts.ItemStyleOpts(color="#7fbe9e"), - ) - .set_global_opts( - xaxis_opts=opts.AxisOpts( - type_="category", - grid_index=1, - axislabel_opts=opts.LabelOpts(is_show=False), - ), - legend_opts=opts.LegendOpts(is_show=False), - ) - ) - - # 5. 组合 Grid - grid_chart = Grid(init_opts=opts.InitOpts(width="100%", height="450px", theme=ThemeType.LIGHT)) - grid_chart.add( - kline, - grid_opts=opts.GridOpts(pos_left="10%", pos_right="8%", height="50%"), - ) - grid_chart.add( - bar, - grid_opts=opts.GridOpts( - pos_left="10%", pos_right="8%", pos_top="65%", height="20%" - ), - ) - - return grid_chart - - @staticmethod - def generate_loss_chart(losses: List[float], title: str = "训练损失收敛曲线") -> Line: - """生成 Loss 下降曲线图""" - line = ( - Line(init_opts=opts.InitOpts(width="100%", height="400px", theme=ThemeType.LIGHT)) - .add_xaxis(list(range(1, len(losses) + 1))) - .add_yaxis( - "Training Loss", - losses, - is_smooth=True, - linestyle_opts=opts.LineStyleOpts(width=2, color="#3b82f6"), - label_opts=opts.LabelOpts(is_show=False), - markpoint_opts=opts.MarkPointOpts(data=[opts.MarkPointItem(type_="min", name="最小值")]) - ) - .set_global_opts( - title_opts=opts.TitleOpts(title=title, pos_left="center"), - xaxis_opts=opts.AxisOpts(name="Epoch", is_scale=True), - yaxis_opts=opts.AxisOpts(name="Loss", is_scale=True), - tooltip_opts=opts.TooltipOpts(trigger="axis"), - ) - ) - return line - - @staticmethod - def generate_sentiment_trend_chart(sentiment_history: List[Dict[str, Any]]) -> Line: - """ - 生成舆情情绪趋势图 - :param sentiment_history: [{"date": "2024-01-01", "score": 0.8}, ...] - """ - dates = [item['date'] for item in sentiment_history] - scores = [item['score'] for item in sentiment_history] - - line = ( - Line(init_opts=opts.InitOpts(width="100%", height="300px", theme=ThemeType.LIGHT)) - .add_xaxis(dates) - .add_yaxis( - "情绪指数", - scores, - is_smooth=True, - markline_opts=opts.MarkLineOpts(data=[opts.MarkLineItem(y=0, name="中性线")]), - itemstyle_opts=opts.ItemStyleOpts(color="#5470c6"), - areastyle_opts=opts.AreaStyleOpts(opacity=0.3, color="#5470c6") - ) - .set_global_opts( - title_opts=opts.TitleOpts(title="舆情情绪趋势", pos_left="center"), - legend_opts=opts.LegendOpts(pos_top="8%"), - yaxis_opts=opts.AxisOpts(min_=-1, max_=1, name="Sentiment"), - tooltip_opts=opts.TooltipOpts(trigger="axis"), - ) - ) - return line - - @staticmethod - def generate_isq_radar_chart(sentiment: float, confidence: float, intensity: int, - expectation_gap: float = 0.5, timeliness: float = 0.8, - title: str = "信号质量 ISQ 评估") -> Radar: - """生成信号质量雷达图""" - # 标准化数据 (0-100) - # sentiment 强度: 绝对值越大强度越高 - sent_val = min(100, abs(sentiment) * 100) - # confidence: 0 to 1 -> 0 to 100 - conf_val = confidence * 100 - # intensity: 1 to 5 -> 20 to 100 - int_val = intensity * 20 - # gap & time: 0 to 1 -> 0 to 100 - gap_val = expectation_gap * 100 - time_val = timeliness * 100 - - schema = [ - opts.RadarIndicatorItem(name="情绪强度", max_=100), - opts.RadarIndicatorItem(name="确定性", max_=100), - opts.RadarIndicatorItem(name="影响力", max_=100), - opts.RadarIndicatorItem(name="预期差", max_=100), - opts.RadarIndicatorItem(name="时效性", max_=100), - ] - - radar = ( - Radar(init_opts=opts.InitOpts(width="100%", height="400px", theme=ThemeType.LIGHT)) - .add_schema(schema=schema) - .add( - "信号特征", - [[sent_val, conf_val, int_val, gap_val, time_val]], - color="#f97316", - areastyle_opts=opts.AreaStyleOpts(opacity=0.3, color="#fb923c"), - ) - .set_global_opts( - title_opts=opts.TitleOpts(title=title, pos_left="center"), - legend_opts=opts.LegendOpts(is_show=False), - ) - ) - return radar - - @staticmethod - def generate_transmission_graph(nodes_data: List[Dict[str, str]], title: str = "投资逻辑传导链条") -> Graph: - """生成逻辑传导拓扑图 (支持分支结构)""" - nodes = [] - links = [] - - # Helper for text wrapping - def wrap_text(text, width=6): - return '\n'.join([text[i:i+width] for i in range(0, len(text), width)]) - - # Map original names to wrapped names to handle links - name_map = {} - - for i, item in enumerate(nodes_data): - # 节点样式 - color = "#ef4444" if "利空" in item.get("impact_type", "") else "#22c55e" - if "中性" in item.get("impact_type", ""): color = "#6b7280" - - original_name = item.get("node_name", f"节点{i}") - wrapped_name = wrap_text(original_name) - name_map[original_name] = wrapped_name - name_map[str(item.get("id", ""))] = wrapped_name # Map ID if present - - nodes.append({ - "name": wrapped_name, - "symbolSize": 60 if i == 0 else 50, - "value": item.get("logic", ""), - "itemStyle": {"color": color}, - # Improve label readability - "label": {"show": True, "formatter": "{b}"} - }) - - # Logic for Links - source_key = item.get("source") or item.get("parent") or item.get("parent_id") - if source_key: - # Branching logic: Link from specified source - # Source needs to be resolved to its (wrapped) name - target_source_name = name_map.get(source_key) - if not target_source_name and source_key in name_map.values(): - target_source_name = source_key # It was already a mapped name? - - # If we found the source in our map (meaning it appeared before this node) - if target_source_name: - links.append({"source": target_source_name, "target": wrapped_name}) - elif i > 0: - # Fallback: Linear chain - links.append({"source": nodes[i-1]["name"], "target": wrapped_name}) - - graph = ( - Graph(init_opts=opts.InitOpts(width="100%", height="400px", theme=ThemeType.LIGHT)) - .add( - "", - nodes, - links, - repulsion=5000, - layout="force", - is_roam=True, - is_draggable=True, - symbol="circle", - edge_symbol=['circle', 'arrow'], # Add arrows - edge_symbol_size=[4, 10], - linestyle_opts=opts.LineStyleOpts(width=2, curve=0.2, opacity=0.9), - label_opts=opts.LabelOpts(is_show=True, position="inside", color="white", font_size=10), - edge_label=opts.LabelOpts(is_show=False), - ) - .set_global_opts( - title_opts=opts.TitleOpts(title=title, pos_left="center"), - tooltip_opts=opts.TooltipOpts(formatter="{b}: {c}") - ) - ) - return graph - - @staticmethod - def render_drawio_to_html(xml_content: str, filename: str, title: str = "Logic Diagram") -> str: - """ - 将 Draw.io XML 渲染为包含 Viewer 的 HTML 文件 - """ - import json - - # 构造配置字典 - config = { - "highlight": "#0000ff", - "nav": True, - "resize": True, - "toolbar": "zoom", - "xml": xml_content - } - - # 1. 转为 JSON 字符串 (自动处理内部的引号转义、换行符转义等) - json_str = json.dumps(config) - - # 2. 转为 HTML 属性安全的字符串 (主要是转义单引号,因为我们在 HTML 中用单引号包裹) - import html - safe_json_str = html.escape(json_str, quote=True) - - html_template = f""" - - - - - {title} - - - -

{title}

-
- - - - """ - - try: - os.makedirs(os.path.dirname(filename), exist_ok=True) - # Use 'w' mode with utf-8 encoding - with open(filename, 'w', encoding='utf-8') as f: - f.write(html_template) - logger.info(f"✅ Draw.io chart rendered to {filename}") - return filename - except Exception as e: - logger.error(f"Failed to render drawio chart: {e}") - return "" - - @staticmethod - def render_chart_to_file(chart: Any, filename: str) -> str: - """渲染并保存 HTML""" - try: - # 确保目录存在 - os.makedirs(os.path.dirname(filename), exist_ok=True) - chart.render(filename) - logger.info(f"✅ Chart rendered to {filename}") - return filename - except Exception as e: - logger.error(f"Failed to render chart: {e}") - return "" diff --git a/skills/alphaear-logic-visualizer/scripts/visualizer_prompt.py b/skills/alphaear-logic-visualizer/scripts/visualizer_prompt.py deleted file mode 100644 index f0b2933..0000000 --- a/skills/alphaear-logic-visualizer/scripts/visualizer_prompt.py +++ /dev/null @@ -1,47 +0,0 @@ -def get_drawio_system_prompt(): - return """You are an expert at creating Draw.io (MxGraph) diagrams in XML format. -Your task is to generate a valid MXGraphModel XML based on the user's description. - -### Rules: -1. Output ONLY the XML code. Start with and end with . -2. Do not use compressed XML. Use plain XML. -3. Use standard shapes: 'rounded=1;whiteSpace=wrap;html=1;' for boxes. -4. Auto-layout Strategy: - - Identify "layers" or "stages" in the logic. - - Assign X coordinates based on layers (e.g., 0, 200, 400). - - Assign Y coordinates to distribute nodes vertically (e.g., 0, 100, 200). - - Ensure nodes do not overlap. -5. Edges: Connect nodes logically using . - -### Template: - - - - - - - - - - - - - - - - -""" - -def get_drawio_task(nodes_data: list, title: str) -> str: - import json - nodes_json = json.dumps(nodes_data, ensure_ascii=False, indent=2) - return f"""Please generate a Draw.io XML diagram for the following logic flow: - -**Title**: {title} - -**Nodes and Logic**: -{nodes_json} - -Ensure the layout flows logically from Left to Right (or Top to Bottom for hierarchies). -Use different colors for 'Positive' (Greenish), 'Negative' (Reddish), and 'Neutral' (Grey/Blue) impacts if described. -""" diff --git a/skills/alphaear-logic-visualizer/tests/test_visualizer.py b/skills/alphaear-logic-visualizer/tests/test_visualizer.py deleted file mode 100644 index 9b9731d..0000000 --- a/skills/alphaear-logic-visualizer/tests/test_visualizer.py +++ /dev/null @@ -1,21 +0,0 @@ -import sys -import os -import unittest - -# Add skill root to path -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) - -try: - from scripts.visualizer import VisualizerTools -except ImportError as e: - print(f"Import Error: {e}") - sys.exit(1) - -class TestLogicViz(unittest.TestCase): - def test_init(self): - print("Testing VisualizerTools Iteration...") - viz = VisualizerTools() - self.assertIsNotNone(viz) - -if __name__ == '__main__': - unittest.main() diff --git a/skills/alphaear-news/SKILL.md b/skills/alphaear-news/SKILL.md deleted file mode 100644 index e4130b4..0000000 --- a/skills/alphaear-news/SKILL.md +++ /dev/null @@ -1,33 +0,0 @@ ---- -name: alphaear-news -description: Fetch hot finance news, unified trends, and prediction financial market data. Use when the user needs real-time financial news, trend reports from multiple finance sources (Weibo, Zhihu, WallstreetCN, etc.), or Polymarket finance market prediction data. ---- - -# AlphaEar News Skill - -## Overview - -Fetch real-time hot news, generate unified trend reports, and retrieve Polymarket prediction data. - -## Capabilities - -### 1. Fetch Hot News & Trends - -Use `scripts/news_tools.py` via `NewsNowTools`. - -- **Fetch News**: `fetch_hot_news(source_id, count)` - - See [sources.md](references/sources.md) for valid `source_id`s (e.g., `cls`, `weibo`). -- **Unified Report**: `get_unified_trends(sources)` - - Aggregates top news from multiple sources. - -### 2. Fetch Prediction Markets - -Use `scripts/news_tools.py` via `PolymarketTools`. - -- **Market Summary**: `get_market_summary(limit)` - - Returns a formatted report of active prediction markets. - -## Dependencies - -- `requests`, `loguru` -- `scripts/database_manager.py` (Local DB) diff --git a/skills/alphaear-news/references/sources.md b/skills/alphaear-news/references/sources.md deleted file mode 100644 index d2c2677..0000000 --- a/skills/alphaear-news/references/sources.md +++ /dev/null @@ -1,26 +0,0 @@ -# News Sources Reference - -## Supported News Sources - -| Source ID | Name | Category | Description | -|:----------|:-----|:---------|:------------| -| `cls` | 财联社 | Finance | Real-time financial news, focus on A-shares and macro. | -| `wallstreetcn` | 华尔街见闻 | Finance | Global markets, macroeconomics, and detailed analysis. | -| `xueqiu` | 雪球热榜 | Finance | Community-driven stock discussions and hot topics. | -| `weibo` | 微博热搜 | General | Trending social topics, good for public sentiment. | -| `zhihu` | 知乎热榜 | General | In-depth discussions and Q&A on trending topics. | -| `baidu` | 百度热搜 | General | General public search trends. | -| `toutiao` | 今日头条 | General | Algorithmic news recommendations. | -| `douyin` | 抖音热榜 | General | Short video trends (titles only). | -| `thepaper` | 澎湃新闻 | General | Serious journalism and current affairs. | -| `36kr` | 36氪 | Tech | Startup, venture capital, and tech industry news. | -| `ithome` | IT之家 | Tech | Consumer electronics and tech gadgets. | -| `v2ex` | V2EX | Tech | Developer community trends. | -| `juejin` | 掘金 | Tech | Developer blogs and tutorials. | -| `hackernews` | Hacker News | Tech | Global tech and startup news (English). | - -## Polymarket - -- **Base URL**: `https://gamma-api.polymarket.com` -- **Data**: Prediction markets (e.g., "Will Fed cut rates?"). -- **Usage**: Use `get_active_markets` to retrieve top active markets by volume. diff --git a/skills/alphaear-news/scripts/__init__.py b/skills/alphaear-news/scripts/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/alphaear-news/scripts/content_extractor.py b/skills/alphaear-news/scripts/content_extractor.py deleted file mode 100644 index 133207a..0000000 --- a/skills/alphaear-news/scripts/content_extractor.py +++ /dev/null @@ -1,122 +0,0 @@ -import requests -from requests.exceptions import RequestException, Timeout, ConnectionError -import os -import time -import json -import threading -from typing import Optional -from loguru import logger - - -class ContentExtractor: - """内容提取工具 - 主要接入 Jina Reader API""" - - JINA_BASE_URL = "https://r.jina.ai/" - - # 速率限制配置 (无 API Key 时:20 次/分钟) - _rate_limit_no_key = 20 # 每分钟最大请求数 - _rate_window = 60.0 # 时间窗口(秒) - _min_interval = 3.0 # 请求最小间隔(秒) - - # 类级别的速率限制状态 - _request_times = [] - _last_request_time = 0.0 - _lock = threading.Lock() - - @classmethod - def _wait_for_rate_limit(cls, has_api_key: bool) -> None: - """等待以满足速率限制要求""" - if has_api_key: - # 有 API Key 时,只需保持最小间隔 - time.sleep(0.5) - return - - with cls._lock: - current_time = time.time() - - # 1. 清理过期的请求记录 - cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window] - - # 2. 检查是否达到速率限制 - if len(cls._request_times) >= cls._rate_limit_no_key: - # 需要等待最旧的请求过期 - oldest = cls._request_times[0] - wait_time = cls._rate_window - (current_time - oldest) + 1.0 - if wait_time > 0: - logger.warning(f"⏳ Jina rate limit reached, waiting {wait_time:.1f}s...") - time.sleep(wait_time) - current_time = time.time() - cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window] - - # 3. 确保请求间隔不太快 - time_since_last = current_time - cls._last_request_time - if time_since_last < cls._min_interval: - sleep_time = cls._min_interval - time_since_last - time.sleep(sleep_time) - - # 4. 记录本次请求 - cls._request_times.append(time.time()) - cls._last_request_time = time.time() - - @classmethod - def extract_with_jina(cls, url: str, timeout: int = 30) -> Optional[str]: - """ - 使用 Jina Reader 提取网页正文内容 (Markdown 格式) - - 无 API Key 时自动限速:每分钟最多 20 次请求,每次间隔至少 3 秒 - """ - if not url or not url.startswith("http"): - return None - - logger.info(f"🕸️ Extracting content from: {url} via Jina...") - - headers = { - "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", - "Accept": "application/json" - } - - # 使用统一的 JINA_API_KEY - api_key = os.getenv("JINA_API_KEY") - has_api_key = bool(api_key and api_key.strip()) - - if has_api_key: - headers["Authorization"] = f"Bearer {api_key}" - - # 等待速率限制 - cls._wait_for_rate_limit(has_api_key) - - try: - # Jina Reader API - full_url = f"{cls.JINA_BASE_URL}{url}" - response = requests.get(full_url, headers=headers, timeout=timeout) - - if response.status_code == 200: - try: - data = response.json() - # Jina JSON 响应格式通常在 data.content - if isinstance(data, dict) and "data" in data: - return data["data"].get("content", "") - return data.get("content", response.text) - except (json.JSONDecodeError, TypeError): - return response.text - elif response.status_code == 429: - # 触发速率限制,等待后重试一次 - logger.warning(f"⚠️ Jina rate limit (429), waiting 60s before retry...") - time.sleep(60) - return cls.extract_with_jina(url, timeout) - else: - logger.warning(f"Jina extraction failed (Status {response.status_code}) for {url}") - return None - - except Timeout: - logger.error(f"Timeout during Jina extraction for {url}") - return None - except ConnectionError: - logger.error(f"Connection error during Jina extraction for {url}") - return None - except RequestException as e: - logger.error(f"Request error during Jina extraction: {e}") - return None - except Exception as e: - logger.error(f"Unexpected error during Jina extraction: {e}") - return None diff --git a/skills/alphaear-news/scripts/database_manager.py b/skills/alphaear-news/scripts/database_manager.py deleted file mode 100644 index f5aa2a7..0000000 --- a/skills/alphaear-news/scripts/database_manager.py +++ /dev/null @@ -1,131 +0,0 @@ -import sqlite3 -import json -from datetime import datetime -from pathlib import Path -from typing import List, Dict, Optional -from loguru import logger - -class DatabaseManager: - """ - AlphaEar News Database Manager - Reduced version for alphaear-news skill - """ - - def __init__(self, db_path: str = "data/signal_flux.db"): - self.db_path = Path(db_path) - self.db_path.parent.mkdir(parents=True, exist_ok=True) - self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False) - self.conn.row_factory = sqlite3.Row - self._init_db() - logger.debug(f"💾 Database initialized at {self.db_path}") - - def _init_db(self): - """Initialize news-related tables only""" - cursor = self.conn.cursor() - - # Daily News Table - cursor.execute(""" - CREATE TABLE IF NOT EXISTS daily_news ( - id TEXT PRIMARY KEY, - source TEXT, - rank INTEGER, - title TEXT, - url TEXT, - content TEXT, - publish_time TEXT, - crawl_time TEXT, - sentiment_score REAL, - analysis TEXT, - meta_data TEXT - ) - """) - - # Indexes - cursor.execute("CREATE INDEX IF NOT EXISTS idx_news_crawl_time ON daily_news(crawl_time)") - cursor.execute("CREATE INDEX IF NOT EXISTS idx_news_source ON daily_news(source)") - - self.conn.commit() - - # --- News Operations --- - - def save_daily_news(self, news_list: List[Dict]) -> int: - """Save hot news items""" - cursor = self.conn.cursor() - count = 0 - crawl_time = datetime.now().isoformat() - - for news in news_list: - try: - news_id = news.get('id') or f"{news.get('source')}_{news.get('rank')}_{crawl_time[:10]}" - cursor.execute(""" - INSERT OR REPLACE INTO daily_news - (id, source, rank, title, url, content, publish_time, crawl_time, sentiment_score, meta_data) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - news_id, - news.get('source'), - news.get('rank'), - news.get('title'), - news.get('url'), - news.get('content', ''), - news.get('publish_time'), - crawl_time, - news.get('sentiment_score'), - json.dumps(news.get('meta_data', {})) - )) - count += 1 - except Exception as e: - logger.error(f"Error saving news item {news.get('title')}: {e}") - - self.conn.commit() - return count - - def get_daily_news(self, source: Optional[str] = None, limit: int = 100, days: int = 1) -> List[Dict]: - """Get recent news""" - cursor = self.conn.cursor() - time_threshold = (datetime.now().timestamp() - days * 86400) - time_threshold_str = datetime.fromtimestamp(time_threshold).isoformat() - - query = "SELECT * FROM daily_news WHERE crawl_time >= ?" - params = [time_threshold_str] - - if source: - query += " AND source = ?" - params.append(source) - - query += " ORDER BY crawl_time DESC, rank LIMIT ?" - params.append(limit) - - cursor.execute(query, params) - return [dict(row) for row in cursor.fetchall()] - - def delete_news(self, news_id: str) -> bool: - cursor = self.conn.cursor() - cursor.execute("DELETE FROM daily_news WHERE id = ?", (news_id,)) - self.conn.commit() - return cursor.rowcount > 0 - - def update_news_content(self, news_id: str, content: str = None, analysis: str = None) -> bool: - cursor = self.conn.cursor() - updates = [] - params = [] - - if content is not None: - updates.append("content = ?") - params.append(content) - if analysis is not None: - updates.append("analysis = ?") - params.append(analysis) - - if not updates: - return False - - params.append(news_id) - query = f"UPDATE daily_news SET {', '.join(updates)} WHERE id = ?" - cursor.execute(query, params) - self.conn.commit() - return cursor.rowcount > 0 - - def close(self): - if self.conn: - self.conn.close() diff --git a/skills/alphaear-news/scripts/news_tools.py b/skills/alphaear-news/scripts/news_tools.py deleted file mode 100644 index e833e2e..0000000 --- a/skills/alphaear-news/scripts/news_tools.py +++ /dev/null @@ -1,256 +0,0 @@ -import requests -from requests.exceptions import RequestException, Timeout -import json -import time -from datetime import datetime -from typing import List, Dict, Optional -from loguru import logger -from .database_manager import DatabaseManager -from .content_extractor import ContentExtractor - -class NewsNowTools: - """热点新闻获取工具 - 接入 NewsNow API 与 Jina 内容提取""" - - BASE_URL = "https://newsnow.busiyi.world" - SOURCES = { - # 金融类 - "cls": "财联社", - "wallstreetcn": "华尔街见闻", - "xueqiu": "雪球热榜", - # 综合/社交 - "weibo": "微博热搜", - "zhihu": "知乎热榜", - "baidu": "百度热搜", - "toutiao": "今日头条", - "douyin": "抖音热榜", - "thepaper": "澎湃新闻", - # 科技类 - "36kr": "36氪", - "ithome": "IT之家", - "v2ex": "V2EX", - "juejin": "掘金", - "hackernews": "Hacker News", - } - - - def __init__(self, db: DatabaseManager): - self.db = db - self.user_agent = ( - "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " - "AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36" - ) - self.extractor = ContentExtractor() - # Simple in-memory cache: source_id -> {"time": timestamp, "data": []} - self._cache = {} - - def fetch_hot_news(self, source_id: str, count: int = 15, fetch_content: bool = False) -> List[Dict]: - """ - 从指定新闻源获取热点新闻列表(支持5分钟缓存)。 - """ - # 1. Check cache validity (5 minutes) - cache_key = f"{source_id}_{count}" - cached = self._cache.get(cache_key) - now = time.time() - - if cached and (now - cached["time"] < 300): - logger.info(f"⚡ Using cached news for {source_id} (Age: {int(now - cached['time'])}s)") - return cached["data"] - - try: - url = f"{self.BASE_URL}/api/s?id={source_id}" - response = requests.get(url, headers={"User-Agent": self.user_agent}, timeout=30) - if response.status_code == 200: - data = response.json() - items = data.get("items", [])[:count] - processed_items = [] - for i, item in enumerate(items, 1): - item_url = item.get("url", "") - content = "" - if fetch_content and item_url: - content = self.extractor.extract_with_jina(item_url) or "" - - processed_items.append({ - "id": item.get("id") or f"{source_id}_{int(time.time())}_{i}", - "source": source_id, - "rank": i, - "title": item.get("title", ""), - "url": item_url, - "content": content, - "publish_time": item.get("publish_time"), - "meta_data": item.get("extra", {}) - }) - - # Update Cache - self._cache[cache_key] = {"time": now, "data": processed_items} - logger.info(f"✅ Fetched and cached news for {source_id}") - - self.db.save_daily_news(processed_items) - return processed_items - else: - logger.error(f"NewsNow API Error: {response.status_code}") - # Fallback to stale cache if available - if cached: - logger.warning(f"⚠️ API failed, using stale cache for {source_id}") - return cached["data"] - return [] - except Timeout: - logger.error(f"Timeout fetching hot news from {source_id}") - if cached: - logger.warning(f"⚠️ Timeout, using stale cache for {source_id}") - return cached["data"] - return [] - except RequestException as e: - logger.error(f"Network error fetching hot news from {source_id}: {e}") - if cached: - logger.warning(f"⚠️ Network check failed, using stale cache for {source_id}") - return cached["data"] - return [] - except json.JSONDecodeError: - logger.error(f"Failed to parse JSON response from NewsNow for {source_id}") - return [] - except Exception as e: - logger.error(f"Unexpected error fetching hot news from {source_id}: {e}") - return [] - - def fetch_news_content(self, url: str) -> Optional[str]: - """ - 使用 Jina Reader 抓取指定 URL 的网页正文内容。 - - Args: - url: 需要抓取内容的完整网页 URL,必须以 http:// 或 https:// 开头。 - - Returns: - 提取的网页正文内容 (Markdown 格式),如果失败则返回 None。 - """ - return self.extractor.extract_with_jina(url) - - def get_unified_trends(self, sources: Optional[List[str]] = None) -> str: - """ - 获取多平台综合热点报告,自动聚合多个新闻源的热门内容。 - - Args: - sources: 要扫描的新闻源列表。可选值按类别: - **金融类**: "cls", "wallstreetcn", "xueqiu" - **综合类**: "weibo", "zhihu", "baidu", "toutiao", "douyin", "thepaper" - **科技类**: "36kr", "ithome", "v2ex", "juejin", "hackernews" - - Returns: - 格式化的 Markdown 热点汇总报告,包含各平台 Top 10 热点标题和链接。 - """ - sources = sources or ["weibo", "zhihu", "wallstreetcn"] - all_news = [] - for src in sources: - all_news.extend(self.fetch_hot_news(src)) - time.sleep(0.2) - - if not all_news: - return "❌ 未能获取到热点数据" - - report = f"# 实时全网热点汇总 ({datetime.now().strftime('%Y-%m-%d %H:%M')})\n\n" - for src in sources: - - src_name = self.SOURCES.get(src, src) - report += f"### 🔥 {src_name}\n" - src_news = [n for n in all_news if n['source'] == src] - for n in src_news[:10]: - report += f"- {n['title']} ([链接]({n['url']}))\n" - report += "\n" - - return report - - -class PolymarketTools: - """Polymarket 预测市场数据工具 - 获取热门预测市场反映公众情绪和预期""" - - BASE_URL = "https://gamma-api.polymarket.com" - - def __init__(self, db: DatabaseManager): - self.db = db - self.user_agent = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36" - - def get_active_markets(self, limit: int = 20) -> List[Dict]: - """ - 获取活跃的预测市场,用于分析公众情绪和预期。 - - 预测市场数据可以反映: - - 公众对重大事件的预期概率 - - 市场情绪和风险偏好 - - 热门话题的关注度 - - Args: - limit: 获取的市场数量,默认 20 个。 - - Returns: - 包含预测市场信息的列表,每个市场包含: - - question: 预测问题 - - outcomes: 可能的结果 - - outcomePrices: 各结果的概率价格 - - volume: 交易量 - """ - try: - response = requests.get( - f"{self.BASE_URL}/markets", - params={"active": "true", "closed": "false", "limit": limit}, - headers={"User-Agent": self.user_agent, "Accept": "application/json"}, - timeout=30 - ) - - if response.status_code == 200: - markets = response.json() - result = [] - for m in markets: - result.append({ - "id": m.get("id"), - "question": m.get("question"), - "slug": m.get("slug"), - "outcomes": m.get("outcomes"), - "outcomePrices": m.get("outcomePrices"), - "volume": m.get("volume"), - "liquidity": m.get("liquidity"), - }) - logger.info(f"✅ 获取 {len(result)} 个预测市场") - return result - else: - logger.warning(f"⚠️ Polymarket API 返回 {response.status_code}") - return [] - except Timeout: - logger.error("Timeout fetching Polymarket markets") - return [] - except RequestException as e: - logger.error(f"Network error fetching Polymarket markets: {e}") - return [] - except json.JSONDecodeError: - logger.error("Failed to parse JSON response from Polymarket") - return [] - except Exception as e: - logger.error(f"Unexpected error fetching Polymarket markets: {e}") - return [] - - def get_market_summary(self, limit: int = 10) -> str: - """ - 获取预测市场摘要报告,用于了解当前热门话题和公众预期。 - - Args: - limit: 获取的市场数量 - - Returns: - 格式化的预测市场报告 - """ - markets = self.get_active_markets(limit) - if not markets: - return "❌ 无法获取 Polymarket 数据" - - report = f"# 🔮 Polymarket 热门预测 ({datetime.now().strftime('%Y-%m-%d %H:%M')})\n\n" - for i, m in enumerate(markets, 1): - question = m.get("question", "Unknown") - prices = m.get("outcomePrices", []) - volume = m.get("volume", 0) - - report += f"**{i}. {question}**\n" - if prices: - report += f" 概率: {prices}\n" - if volume: - report += f" 交易量: ${float(volume):,.0f}\n" - report += "\n" - - return report diff --git a/skills/alphaear-news/tests/test_news.py b/skills/alphaear-news/tests/test_news.py deleted file mode 100644 index 9f5ce1c..0000000 --- a/skills/alphaear-news/tests/test_news.py +++ /dev/null @@ -1,24 +0,0 @@ -import sys -import os -import unittest - -# Add skill root to path -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) - -try: - from scripts.news_tools import NewsNowTools - from scripts.database_manager import DatabaseManager -except ImportError as e: - print(f"Import Error: {e}") - sys.exit(1) - -class TestNews(unittest.TestCase): - def test_init(self): - print("Testing NewsNowTools Iteration...") - db = DatabaseManager(":memory:") - tools = NewsNowTools(db) - self.assertIsNotNone(tools) - print("NewsNowTools Initialized.") - -if __name__ == '__main__': - unittest.main() diff --git a/skills/alphaear-predictor/SKILL.md b/skills/alphaear-predictor/SKILL.md deleted file mode 100644 index 95aabf7..0000000 --- a/skills/alphaear-predictor/SKILL.md +++ /dev/null @@ -1,60 +0,0 @@ ---- -name: alphaear-predictor -description: Market prediction skill using Kronos. Use when user needs finance market time-series forecasting or news-aware finance market adjustments. ---- - -# AlphaEar Predictor Skill - -## Overview - -This skill utilizes the Kronos model (via `KronosPredictorUtility`) to perform time-series forecasting and adjust predictions based on news sentiment. - -## Capabilities - -### 1. Forecast Market Trends - -### 1. Forecast Market Trends - -**Workflow:** -1. **Generate Base Forecast**: Use `scripts/kronos_predictor.py` (via `KronosPredictorUtility`) to generate the technical/quantitative forecast. -2. **Adjust Forecast (Agentic)**: Use the **Forecast Adjustment Prompt** in `references/PROMPTS.md` to subjectively adjust the numbers based on latest news/logic. - -**Key Tools:** -- `KronosPredictorUtility.get_base_forecast(df, lookback, pred_len, news_text)`: Returns `List[KLinePoint]`. - -**Example Usage (Python):** - -```python -from scripts.utils.kronos_predictor import KronosPredictorUtility -from scripts.utils.database_manager import DatabaseManager - -db = DatabaseManager() -predictor = KronosPredictorUtility() - -# Forecast -forecast = predictor.predict("600519", horizon="7d") -print(forecast) -``` - - -## Configuration - -This skill requires the **Kronos** model and an embedding model. - -1. **Kronos Model**: - - Ensure `exports/models` directory exists in the project root. - - Place trained news projector weights (e.g., `kronos_news_v1.pt`) in `exports/models/`. - - Or depend on the base model (automatically downloaded). - -2. **Environment Variables**: - - `EMBEDDING_MODEL`: Path or name of the embedding model (default: `sentence-transformers/all-MiniLM-L6-v2`). - - `KRONOS_MODEL_PATH`: Optional path to override model loading. - -## Dependencies - -- `torch` -- `transformers` -- `sentence-transformers` -- `pandas` -- `numpy` -- `scikit-learn` diff --git a/skills/alphaear-predictor/references/PROMPTS.md b/skills/alphaear-predictor/references/PROMPTS.md deleted file mode 100644 index 02fe9c5..0000000 --- a/skills/alphaear-predictor/references/PROMPTS.md +++ /dev/null @@ -1,43 +0,0 @@ -# AlphaEar Predictor Prompts - -## Forecast Adjustment (Analyst) - -**Prompt:** - -```markdown -You are a senior quantitative strategy analyst. -Your task is to subjectively/logically adjust the given [Kronos Model Forecast] based on the [Latest Intelligence/News Context]. - -Ticker: {ticker} - -【Kronos Base Forecast (OHLC)】: -{forecast_str} - -【Latest Intelligence Context】: -{news_context} - -**Adjustment Principles:** -1. Base forecast is technical-only. -2. Context may contain a "Quantitative Correction" from a news-aware model. **Highly respect** this unless logic is flawed. -3. Use qualitative analysis (news logic) to verify or fine-tune. -4. If no quantitative correction exists, verify trend manually against news sentiment. - -**Output (Strict JSON):** -```json -{ - "adjusted_forecast": [ - { - "date": "YYYY-MM-DD", - "open": , - "high": , - "low": , - "close": , - "volume": - }, - ... - ], - "rationale": "Detailed logic..." -} -``` -Ensure same number of data points as base forecast. -``` diff --git a/skills/alphaear-predictor/scripts/__init__.py b/skills/alphaear-predictor/scripts/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/alphaear-predictor/scripts/forecast_agent.py b/skills/alphaear-predictor/scripts/forecast_agent.py deleted file mode 100644 index 4bbf67e..0000000 --- a/skills/alphaear-predictor/scripts/forecast_agent.py +++ /dev/null @@ -1,76 +0,0 @@ -import json -from typing import List, Optional, Dict, Any -from datetime import datetime -from loguru import logger -import pandas as pd - -from .kronos_predictor import KronosPredictorUtility -from .utils.database_manager import DatabaseManager -from .schema.models import ForecastResult, KLinePoint, InvestmentSignal - -class ForecastUtils: - """ - 预测辅助工具 (ForecastUtils) - 提供数据准备、基础模型预测等功能。 - LLM 调整逻辑已移交 Agent 执行 (参考 scripts/prompts/PROMPTS.md)。 - """ - - def __init__(self, db: DatabaseManager): - self.db = db - self.predictor_util = KronosPredictorUtility() # Singleton - - def get_base_forecast( - self, - ticker: str, - signals: List[Dict] = None, - lookback: int = 20, - pred_len: int = 5, - ) -> Optional[List[KLinePoint]]: - """ - 获取基础预测数据 (技术面 + 新闻模型定量修正)。 - Agent 应随后使用 PROMPTS.md 中的指令进行定性调整。 - """ - logger.info(f"🔮 Generating base forecast for {ticker}...") - - # 1. 获取历史数据 - from .stock_tools import StockTools - stock_tools = StockTools(self.db, auto_update=False) - - end_date = datetime.now().strftime("%Y-%m-%d") - # 宽放一点时间以确保有足够的交易日 - start_date = (datetime.now() - pd.Timedelta(days=max(lookback * 4, 90))).strftime("%Y-%m-%d") - df = stock_tools.get_stock_price(ticker, start_date=start_date, end_date=end_date) - - if df.empty or len(df) < lookback: - # Try force sync - df = stock_tools.get_stock_price(ticker, start_date=start_date, end_date=end_date, force_sync=True) - - if df.empty: - logger.warning(f"⚠️ No history data for {ticker}") - return None - - effective_lookback = lookback - if len(df) < lookback: - if len(df) < 10: - logger.warning(f"⚠️ Insufficient history for {ticker}") - return None - effective_lookback = len(df) - - # 2. 准备信号上下文 - signal_lines = [] - for s in (signals or []): - try: - title = s.get('title', '') if isinstance(s, dict) else getattr(s, 'title', '') - summary = s.get('summary', '') if isinstance(s, dict) else getattr(s, 'summary', '') - if title or summary: - signal_lines.append(f"- {title}: {summary}") - except Exception: - continue - - signals_context = "\n".join(signal_lines).strip() - - # 3. 模型预测 (News-Adjusted if context exists) - if signals_context: - return self.predictor_util.get_base_forecast(df, lookback=effective_lookback, pred_len=pred_len, news_text=signals_context) - else: - return self.predictor_util.get_base_forecast(df, lookback=effective_lookback, pred_len=pred_len, news_text=None) diff --git a/skills/alphaear-predictor/scripts/json_utils.py b/skills/alphaear-predictor/scripts/json_utils.py deleted file mode 100644 index c29aab2..0000000 --- a/skills/alphaear-predictor/scripts/json_utils.py +++ /dev/null @@ -1,180 +0,0 @@ -import ast -import json -import re -from typing import Optional, Any -from loguru import logger - -def _strip_comments(text: str) -> str: - """ - Safely remove C-style comments (// and /* */) from JSON-like text, - preserving strings (including URLs like http://). - """ - result = [] - i = 0 - n = len(text) - in_string = False - escape = False - - while i < n: - char = text[i] - - if in_string: - if char == '\\': - escape = not escape - elif char == '"' and not escape: - in_string = False - else: - escape = False - result.append(char) - i += 1 - continue - - # Not in string - if char == '"': - in_string = True - result.append(char) - i += 1 - continue - - # Check for // comment - if i + 1 < n and text[i:i+2] == '//': - i += 2 - while i < n and text[i] != '\n': - i += 1 - continue - - # Check for /* comment - if i + 1 < n and text[i:i+2] == '/*': - i += 2 - while i + 1 < n and text[i:i+2] != '*/': - i += 1 - i += 2 - continue - - result.append(char) - i += 1 - - return ''.join(result) - -def extract_json(text: str) -> Optional[Any]: - """ - 更加鲁棒的 JSON 提取工具。 - 处理: - 1. Markdown 代码块 (```json ... ```) - 2. 首尾多余字符 - 3. 同一个文本中多个 JSON 对象 (仅提取第一个) - 4. 简单的 JSON 修复 (末尾逗号等) - 5. C 风格注释 (// 和 /* */) - """ - if not text: - return None - - # 1. 清理明显的 Markdown 包装 - text = text.strip() - - # 先尝试精确匹配 ```json ... ``` 或 ```...``` - md_match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', text, re.DOTALL) - if md_match: - text = md_match.group(1).strip() - elif text.startswith("```"): - # 回退:如果开头有 ``` 但没完整匹配 - text = re.sub(r'^```[a-z]*\n?', '', text) - text = re.sub(r'\n?```\s*$', '', text) - - # 2. 寻找第一个 JSON 起始符 { 或 [ - start_brace = text.find('{') - start_bracket = text.find('[') - - if start_brace == -1 and start_bracket == -1: - return None - - start_idx = start_brace if (start_bracket == -1 or (start_brace != -1 and start_brace < start_bracket)) else start_bracket - - # 2.5 预处理:修复一些极其常见的 LLM 错误 - potential_json = text[start_idx:].strip() - - # remove comments safely - potential_json = _strip_comments(potential_json) - - # b. 修复缺失开头引号的键: nodes": [ -> "nodes": [ - # 匹配模式: (空白或换行) 单词 紧跟引号和冒号 - potential_json = re.sub(r'([\{\,]\s*)([a-zA-Z_]\w*)\"\s*:', r'\1"\2":', potential_json) - - # c. 修复缺失末尾引号的键: "nodes: [ -> "nodes": [ - potential_json = re.sub(r'([\{\,]\s*)\"([a-zA-Z_]\w*)\s*:', r'\1"\2":', potential_json) - - # d. 修复完全缺失引号的键: nodes: [ -> "nodes": [ - # 注意避免匹配到像 http:// 这种内容,所以限定在 { 或 , 之后 - potential_json = re.sub(r'([\{\,]\s*)([a-zA-Z_]\w*)\s*:', r'\1"\2":', potential_json) - - # 3. 使用 raw_decode 尝试解析 - decoder = json.JSONDecoder() - - # 首先尝试直接解析(不做任何预处理) - try: - obj = json.loads(potential_json) - return obj - except json.JSONDecodeError: - pass - - # 简单预处理:移除对象/列表末位多余逗号 - processed_json = re.sub(r',\s*([\]}])', r'\1', potential_json) - - try: - obj, end_pos = decoder.raw_decode(processed_json) - return obj - except json.JSONDecodeError: - pass - - # e. 修复未终止的字符串字面量问题:移除值中的实际换行符 - # LLM 可能在字符串值中生成包含真实 newline 的内容,导致 JSON 非法 - def fix_multiline_strings(s): - # 简单策略:将字符串值内的换行替换为空格 - lines = s.split('\n') - result = [] - in_string = False - for line in lines: - # 计算未转义的引号数 - quote_count = line.count('"') - line.count('\\"') - if in_string: - result[-1] += ' ' + line.strip() - else: - result.append(line) - - if quote_count % 2 == 1: - in_string = not in_string - return '\n'.join(result) - - fixed_json = fix_multiline_strings(processed_json) - - try: - obj, end_pos = decoder.raw_decode(fixed_json) - return obj - except json.JSONDecodeError: - try: - # 4. 尝试处理单引号问题 (JSON 规范要求双引号,但 LLM 常输出单引号) - # 这是一个简单的替换技巧,仅针对像 {'key': 'value'} 这样的结构 - # 注意:这可能会破坏包含单引号的字符串值,所以作为较后的回退 - fix_quotes = re.sub(r"'(.*?)':", r'"\1":', processed_json) # 修复键 - fix_quotes = re.sub(r":\s*'(.*?)'", r': "\1"', fix_quotes) # 修复简单值 - obj, end_pos = decoder.raw_decode(fix_quotes) - return obj - except (json.JSONDecodeError, TypeError): - try: - # 5. 使用 ast.literal_eval 作为终极回退 (处理 Python 字典格式) - # 提取第一个匹配的括号对内容 - # 寻找匹配的 { } - stack = [] - for i, char in enumerate(potential_json): - if char == '{': stack.append('{') - elif char == '}': - if stack: stack.pop() - if not stack: - content = potential_json[:i+1] - return ast.literal_eval(content) - except (ValueError, SyntaxError, MemoryError) as e: - logger.warning(f"All JSON extraction attempts failed: {e}") - except Exception as e: - logger.error(f"Unexpected error during JSON extraction: {e}") - - return None diff --git a/skills/alphaear-predictor/scripts/kronos_predictor.py b/skills/alphaear-predictor/scripts/kronos_predictor.py deleted file mode 100644 index b46ee6e..0000000 --- a/skills/alphaear-predictor/scripts/kronos_predictor.py +++ /dev/null @@ -1,218 +0,0 @@ -import torch -import pandas as pd -import numpy as np -from datetime import datetime -from typing import List, Optional -from loguru import logger -from pandas.tseries.offsets import BusinessDay -import os -import sys - -KRONOS_DIR = os.path.join(os.path.dirname(__file__), "predictor") -if KRONOS_DIR not in sys.path: - sys.path.append(KRONOS_DIR) - -from skills._env_loader import load_unified_env - -load_unified_env() - -import glob -from sentence_transformers import SentenceTransformer - -from .predictor.model import Kronos, KronosTokenizer, KronosPredictor -from .schema.models import KLinePoint - - -class KronosPredictorUtility: - """ - Kronos 时序预测工具类 - 负责模型加载、推理以及数据结构转换 - """ - - _instance = None - _predictor = None - - def __new__(cls, *args, **kwargs): - if not cls._instance: - cls._instance = super(KronosPredictorUtility, cls).__new__(cls) - return cls._instance - - def __init__(self, device: Optional[str] = None): - if self._predictor is not None: - return - - try: - if not device: - device = ( - "cuda" - if torch.cuda.is_available() - else "mps" - if torch.backends.mps.is_available() - else "cpu" - ) - - logger.info(f"🔮 Loading Kronos Model on {device}...") - - # 1. Load Embedder (SentenceTransformer) - model_name = os.getenv( - "EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2" - ) # Match training - try: - self.embedder = SentenceTransformer( - model_name, device=device, local_files_only=True - ) - except Exception: - logger.warning( - f"⚠️ Local embedder {model_name} not found. Downloading..." - ) - self.embedder = SentenceTransformer(model_name, device=device) - - # 2. Load Kronos Base - try: - tokenizer = KronosTokenizer.from_pretrained( - "NeoQuasar/Kronos-Tokenizer-base", local_files_only=True - ) - model = Kronos.from_pretrained( - "NeoQuasar/Kronos-base", local_files_only=True - ) - except Exception: - logger.warning( - "⚠️ Local Kronos cache not found. Attempting to download..." - ) - tokenizer = KronosTokenizer.from_pretrained( - "NeoQuasar/Kronos-Tokenizer-base" - ) - model = Kronos.from_pretrained("NeoQuasar/Kronos-base") - - # 3. Load Trained News Projector Weights - # Check predictor/exports/models directory - models_dir = os.path.join(KRONOS_DIR, "exports/models") - model_files = glob.glob(os.path.join(models_dir, "*.pt")) - - if model_files: - latest_model = max(model_files, key=os.path.getctime) - logger.info(f"🔄 Loading trained news weights from {latest_model}...") - try: - checkpoint = torch.load(latest_model, map_location=device) - # The checkpoint contains 'news_proj_state_dict' - if "news_proj_state_dict" in checkpoint: - if not hasattr(model, "news_proj") or model.news_proj is None: - import torch.nn as nn - - news_dim = checkpoint.get("news_dim", 384) - model.news_proj = nn.Linear(news_dim, model.d_model).to( - device - ) - - model.news_proj.load_state_dict( - checkpoint["news_proj_state_dict"] - ) - logger.success("✅ News-Aware Projection Layer loaded!") - self.has_news_model = True - else: - logger.warning( - "⚠️ Checkpoint found but missing 'news_proj_state_dict'. Using base model." - ) - self.has_news_model = False - except Exception as e: - logger.error( - f"❌ Failed to load trained weights: {e}. Using base model." - ) - self.has_news_model = False - else: - logger.info("ℹ️ No trained news models found. Using base model.") - self.has_news_model = False - - tokenizer = tokenizer.to(device) - model = model.to(device) - - self._predictor = KronosPredictor( - model, tokenizer, device=device, max_context=512 - ) - logger.info("✅ Kronos Model loaded successfully.") - except Exception as e: - logger.error(f"❌ Failed to load Kronos Model: {e}") - self._predictor = None - self.has_news_model = False - - def get_base_forecast( - self, - df: pd.DataFrame, - lookback: int = 20, - pred_len: int = 5, - news_text: Optional[str] = None, - ) -> List[KLinePoint]: - """ - 生成原始模型预测 - """ - if self._predictor is None: - logger.error("Predictor not initialized.") - return [] - - if len(df) < lookback: - logger.warning( - f"Insufficient historical data ({len(df)}) for lookback ({lookback})." - ) - return [] - - # 获取最后 lookback 条数据 - x_df = df.iloc[-lookback:].copy() - x_timestamp = pd.to_datetime(x_df["date"]) # Ensure datetime - last_date = x_timestamp.iloc[-1] - - # 生成未来时间戳 - future_dates = pd.date_range( - start=last_date + BusinessDay(1), periods=pred_len, freq="B" - ) - y_timestamp = pd.Series(future_dates) - - # Embedding News if available - news_emb = None - if ( - news_text - and getattr(self, "has_news_model", False) - and hasattr(self, "embedder") - ): - try: - # Truncate to avoid too long text - emb = self.embedder.encode(news_text[:1000]) - news_emb = emb # KronosPredictor expects numpy array or tensor - except Exception as e: - logger.error(f"Failed to encode news: {e}") - - try: - # 预测所需的列 - cols = ["open", "high", "low", "close", "volume"] - pred_df = self._predictor.predict( - df=x_df[cols], - x_timestamp=x_timestamp, - y_timestamp=y_timestamp, - pred_len=pred_len, - T=1.0, - top_p=0.9, - sample_count=1, - verbose=False, - news_emb=news_emb, - ) - - # 转换为 KLinePoint - results = [] - for date, row in pred_df.iterrows(): - results.append( - KLinePoint( - date=date.strftime("%Y-%m-%d"), - open=float(row["open"]), - high=float(row["high"]), - low=float(row["low"]), - close=float(row["close"]), - volume=float(row["volume"]), - ) - ) - return results - except Exception as e: - logger.error(f"Forecast generation failed: {e}") - return [] - - -# Singleton instance for easy access -# Usage: predictor = KronosPredictorUtility() diff --git a/skills/alphaear-predictor/scripts/predictor/exports/models/kronos_news_v1_20260101_0015.pt b/skills/alphaear-predictor/scripts/predictor/exports/models/kronos_news_v1_20260101_0015.pt deleted file mode 100644 index 097a60b..0000000 Binary files a/skills/alphaear-predictor/scripts/predictor/exports/models/kronos_news_v1_20260101_0015.pt and /dev/null differ diff --git a/skills/alphaear-predictor/scripts/predictor/model/__init__.py b/skills/alphaear-predictor/scripts/predictor/model/__init__.py deleted file mode 100644 index d10e200..0000000 --- a/skills/alphaear-predictor/scripts/predictor/model/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from .kronos import KronosTokenizer, Kronos, KronosPredictor - -model_dict = { - 'kronos_tokenizer': KronosTokenizer, - 'kronos': Kronos, - 'kronos_predictor': KronosPredictor -} - - -def get_model_class(model_name): - if model_name in model_dict: - return model_dict[model_name] - else: - print(f"Model {model_name} not found in model_dict") - raise NotImplementedError - diff --git a/skills/alphaear-predictor/scripts/predictor/model/kronos.py b/skills/alphaear-predictor/scripts/predictor/model/kronos.py deleted file mode 100644 index cf8bece..0000000 --- a/skills/alphaear-predictor/scripts/predictor/model/kronos.py +++ /dev/null @@ -1,676 +0,0 @@ -import numpy as np -import pandas as pd -import torch -from huggingface_hub import PyTorchModelHubMixin -import sys - -from tqdm import trange - -sys.path.append("../") -from model.module import * - - -class KronosTokenizer(nn.Module, PyTorchModelHubMixin): - """ - KronosTokenizer module for tokenizing input data using a hybrid quantization approach. - - This tokenizer utilizes a combination of encoder and decoder Transformer blocks - along with the Binary Spherical Quantization (BSQuantizer) to compress and decompress input data. - - Args: - d_in (int): Input dimension. - d_model (int): Model dimension. - n_heads (int): Number of attention heads. - ff_dim (int): Feed-forward dimension. - n_enc_layers (int): Number of encoder layers. - n_dec_layers (int): Number of decoder layers. - ffn_dropout_p (float): Dropout probability for feed-forward networks. - attn_dropout_p (float): Dropout probability for attention mechanisms. - resid_dropout_p (float): Dropout probability for residual connections. - s1_bits (int): Number of bits for the pre token in BSQuantizer. - s2_bits (int): Number of bits for the post token in BSQuantizer. - beta (float): Beta parameter for BSQuantizer. - gamma0 (float): Gamma0 parameter for BSQuantizer. - gamma (float): Gamma parameter for BSQuantizer. - zeta (float): Zeta parameter for BSQuantizer. - group_size (int): Group size parameter for BSQuantizer. - - """ - - def __init__(self, d_in, d_model, n_heads, ff_dim, n_enc_layers, n_dec_layers, ffn_dropout_p, attn_dropout_p, resid_dropout_p, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size): - - super().__init__() - self.d_in = d_in - self.d_model = d_model - self.n_heads = n_heads - self.ff_dim = ff_dim - self.enc_layers = n_enc_layers - self.dec_layers = n_dec_layers - self.ffn_dropout_p = ffn_dropout_p - self.attn_dropout_p = attn_dropout_p - self.resid_dropout_p = resid_dropout_p - - self.s1_bits = s1_bits - self.s2_bits = s2_bits - self.codebook_dim = s1_bits + s2_bits # Total dimension of the codebook after quantization - self.embed = nn.Linear(self.d_in, self.d_model) - self.head = nn.Linear(self.d_model, self.d_in) - - # Encoder Transformer Blocks - self.encoder = nn.ModuleList([ - TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p) - for _ in range(self.enc_layers - 1) - ]) - # Decoder Transformer Blocks - self.decoder = nn.ModuleList([ - TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p) - for _ in range(self.dec_layers - 1) - ]) - self.quant_embed = nn.Linear(in_features=self.d_model, out_features=self.codebook_dim) # Linear layer before quantization - self.post_quant_embed_pre = nn.Linear(in_features=self.s1_bits, out_features=self.d_model) # Linear layer after quantization (pre part - s1 bits) - self.post_quant_embed = nn.Linear(in_features=self.codebook_dim, out_features=self.d_model) # Linear layer after quantization (full codebook) - self.tokenizer = BSQuantizer(self.s1_bits, self.s2_bits, beta, gamma0, gamma, zeta, group_size) # BSQuantizer module - - def forward(self, x): - """ - Forward pass of the KronosTokenizer. - - Args: - x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in). - - Returns: - tuple: A tuple containing: - - tuple: (z_pre, z) - Reconstructed outputs from decoder with s1_bits and full codebook respectively, - both of shape (batch_size, seq_len, d_in). - - torch.Tensor: bsq_loss - Loss from the BSQuantizer. - - torch.Tensor: quantized - Quantized representation from BSQuantizer. - - torch.Tensor: z_indices - Indices from the BSQuantizer. - """ - z = self.embed(x) - - for layer in self.encoder: - z = layer(z) - - z = self.quant_embed(z) # (B, T, codebook) - - bsq_loss, quantized, z_indices = self.tokenizer(z) - - quantized_pre = quantized[:, :, :self.s1_bits] # Extract the first part of quantized representation (s1_bits) - z_pre = self.post_quant_embed_pre(quantized_pre) - - z = self.post_quant_embed(quantized) - - # Decoder layers (for pre part - s1 bits) - for layer in self.decoder: - z_pre = layer(z_pre) - z_pre = self.head(z_pre) - - # Decoder layers (for full codebook) - for layer in self.decoder: - z = layer(z) - z = self.head(z) - - return (z_pre, z), bsq_loss, quantized, z_indices - - def indices_to_bits(self, x, half=False): - """ - Converts indices to bit representations and scales them. - - Args: - x (torch.Tensor): Indices tensor. - half (bool, optional): Whether to process only half of the codebook dimension. Defaults to False. - - Returns: - torch.Tensor: Bit representation tensor. - """ - if half: - x1 = x[0] # Assuming x is a tuple of indices if half is True - x2 = x[1] - mask = 2 ** torch.arange(self.codebook_dim//2, device=x1.device, dtype=torch.long) # Create a mask for bit extraction - x1 = (x1.unsqueeze(-1) & mask) != 0 # Extract bits for the first half - x2 = (x2.unsqueeze(-1) & mask) != 0 # Extract bits for the second half - x = torch.cat([x1, x2], dim=-1) # Concatenate the bit representations - else: - mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long) # Create a mask for bit extraction - x = (x.unsqueeze(-1) & mask) != 0 # Extract bits - - x = x.float() * 2 - 1 # Convert boolean to bipolar (-1, 1) - q_scale = 1. / (self.codebook_dim ** 0.5) # Scaling factor - x = x * q_scale - return x - - def encode(self, x, half=False): - """ - Encodes the input data into quantized indices. - - Args: - x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in). - half (bool, optional): Whether to use half quantization in BSQuantizer. Defaults to False. - - Returns: - torch.Tensor: Quantized indices from BSQuantizer. - """ - z = self.embed(x) - for layer in self.encoder: - z = layer(z) - z = self.quant_embed(z) - - bsq_loss, quantized, z_indices = self.tokenizer(z, half=half, collect_metrics=False) - return z_indices - - def decode(self, x, half=False): - """ - Decodes quantized indices back to the input data space. - - Args: - x (torch.Tensor): Quantized indices tensor. - half (bool, optional): Whether the indices were generated with half quantization. Defaults to False. - - Returns: - torch.Tensor: Reconstructed output tensor of shape (batch_size, seq_len, d_in). - """ - quantized = self.indices_to_bits(x, half) - z = self.post_quant_embed(quantized) - for layer in self.decoder: - z = layer(z) - z = self.head(z) - return z - - -class Kronos(nn.Module, PyTorchModelHubMixin): - """ - Kronos Model. - - Args: - s1_bits (int): Number of bits for pre tokens. - s2_bits (int): Number of bits for post tokens. - n_layers (int): Number of Transformer blocks. - d_model (int): Dimension of the model's embeddings and hidden states. - n_heads (int): Number of attention heads in the MultiheadAttention layers. - ff_dim (int): Dimension of the feedforward network in the Transformer blocks. - ffn_dropout_p (float): Dropout probability for the feedforward network. - attn_dropout_p (float): Dropout probability for the attention layers. - resid_dropout_p (float): Dropout probability for residual connections. - token_dropout_p (float): Dropout probability for token embeddings. - learn_te (bool): Whether to use learnable temporal embeddings. - """ - - def __init__(self, s1_bits, s2_bits, n_layers, d_model, n_heads, ff_dim, ffn_dropout_p, attn_dropout_p, resid_dropout_p, token_dropout_p, learn_te, news_dim=None): - super().__init__() - self.s1_bits = s1_bits - self.s2_bits = s2_bits - self.n_layers = n_layers - self.d_model = d_model - self.n_heads = n_heads - self.learn_te = learn_te - self.ff_dim = ff_dim - self.ffn_dropout_p = ffn_dropout_p - self.attn_dropout_p = attn_dropout_p - self.resid_dropout_p = resid_dropout_p - self.token_dropout_p = token_dropout_p - self.news_dim = news_dim - - self.s1_vocab_size = 2 ** self.s1_bits - self.token_drop = nn.Dropout(self.token_dropout_p) - self.embedding = HierarchicalEmbedding(self.s1_bits, self.s2_bits, self.d_model) - self.time_emb = TemporalEmbedding(self.d_model, self.learn_te) - self.transformer = nn.ModuleList([ - TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p) - for _ in range(self.n_layers) - ]) - self.norm = RMSNorm(self.d_model) - self.dep_layer = DependencyAwareLayer(self.d_model) - self.head = DualHead(self.s1_bits, self.s2_bits, self.d_model) - - if self.news_dim is not None: - self.news_proj = nn.Linear(self.news_dim, self.d_model) - else: - self.news_proj = None - - self.apply(self._init_weights) - - def _init_weights(self, module): - - if isinstance(module, nn.Linear): - nn.init.xavier_normal_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - nn.init.normal_(module.weight, mean=0, std=self.embedding.d_model ** -0.5) - elif isinstance(module, nn.LayerNorm): - nn.init.ones_(module.weight) - nn.init.zeros_(module.bias) - elif isinstance(module, RMSNorm): - nn.init.ones_(module.weight) - - def forward(self, s1_ids, s2_ids, stamp=None, padding_mask=None, use_teacher_forcing=False, s1_targets=None, news_emb=None): - """ - Args: - s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len] - s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len] - stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None. - padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. - use_teacher_forcing (bool, optional): Whether to use teacher forcing for s1 decoding. Defaults to False. - s1_targets (torch.Tensor, optional): Target s1 token IDs for teacher forcing. Shape: [batch_size, seq_len]. Defaults to None. - news_emb (torch.Tensor, optional): News embedding tensor. Shape: [batch_size, news_dim]. Defaults to None. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: - - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size] - - s2_logits: Logits for s2 token predictions, conditioned on s1. Shape: [batch_size, seq_len, s2_vocab_size] - """ - x = self.embedding([s1_ids, s2_ids]) - if stamp is not None: - time_embedding = self.time_emb(stamp) - x = x + time_embedding - x = self.token_drop(x) - - for layer in self.transformer: - x = layer(x, key_padding_mask=padding_mask) - - x = self.norm(x) - - if news_emb is not None and self.news_proj is not None: - news_bias = self.news_proj(news_emb).unsqueeze(1) # [B, 1, d_model] - x = x + news_bias - - s1_logits = self.head(x) - - if use_teacher_forcing: - sibling_embed = self.embedding.emb_s1(s1_targets) - else: - s1_probs = F.softmax(s1_logits.detach(), dim=-1) - sample_s1_ids = torch.multinomial(s1_probs.view(-1, self.s1_vocab_size), 1).view(s1_ids.shape) - sibling_embed = self.embedding.emb_s1(sample_s1_ids) - - x2 = self.dep_layer(x, sibling_embed, key_padding_mask=padding_mask) # Dependency Aware Layer: Condition on s1 embeddings - s2_logits = self.head.cond_forward(x2) - return s1_logits, s2_logits - - def decode_s1(self, s1_ids, s2_ids, stamp=None, padding_mask=None, news_emb=None): - """ - Decodes only the s1 tokens. - - This method performs a forward pass to predict only s1 tokens. It returns the s1 logits - and the context representation from the Transformer, which can be used for subsequent s2 decoding. - - Args: - s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len] - s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len] - stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None. - padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. - news_emb (torch.Tensor, optional): News embedding tensor. Shape: [batch_size, news_dim]. Defaults to None. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: - - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size] - - context: Context representation from the Transformer. Shape: [batch_size, seq_len, d_model] - """ - x = self.embedding([s1_ids, s2_ids]) - if stamp is not None: - time_embedding = self.time_emb(stamp) - x = x + time_embedding - x = self.token_drop(x) - - for layer in self.transformer: - x = layer(x, key_padding_mask=padding_mask) - - x = self.norm(x) - - if news_emb is not None and self.news_proj is not None: - news_bias = self.news_proj(news_emb).unsqueeze(1) # [B, 1, d_model] - x = x + news_bias - - s1_logits = self.head(x) - return s1_logits, x - - def decode_s2(self, context, s1_ids, padding_mask=None): - """ - Decodes the s2 tokens, conditioned on the context and s1 tokens. - - This method decodes s2 tokens based on a pre-computed context representation (typically from `decode_s1`) - and the s1 token IDs. It uses the dependency-aware layer and the conditional s2 head to predict s2 tokens. - - Args: - context (torch.Tensor): Context representation from the transformer (output of decode_s1). - Shape: [batch_size, seq_len, d_model] - s1_ids (torch.torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len] - padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. - - Returns: - torch.Tensor: s2 logits. Shape: [batch_size, seq_len, s2_vocab_size] - """ - sibling_embed = self.embedding.emb_s1(s1_ids) - x2 = self.dep_layer(context, sibling_embed, key_padding_mask=padding_mask) - return self.head.cond_forward(x2) - - -def top_k_top_p_filtering( - logits, - top_k: int = 0, - top_p: float = 1.0, - filter_value: float = -float("Inf"), - min_tokens_to_keep: int = 1, -): - """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering - Args: - logits: logits distribution shape (batch size, vocabulary size) - if top_k > 0: keep only top k tokens with highest probability (top-k filtering). - if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). - Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) - Make sure we keep at least min_tokens_to_keep per batch example in the output - From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 - """ - if top_k > 0: - top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check - # Remove all tokens with a probability less than the last token of the top-k - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits[indices_to_remove] = filter_value - return logits - - if top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold (token with 0 are kept) - sorted_indices_to_remove = cumulative_probs > top_p - if min_tokens_to_keep > 1: - # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) - sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 - # Shift the indices to the right to keep also the first token above the threshold - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - - # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - logits[indices_to_remove] = filter_value - return logits - - -def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_logits=True): - logits = logits / temperature - if top_k is not None or top_p is not None: - if top_k > 0 or top_p < 1.0: - logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) - - probs = F.softmax(logits, dim=-1) - - if not sample_logits: - _, x = top_k(probs, k=1, dim=-1) - else: - x = torch.multinomial(probs, num_samples=1) - - return x - - -def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False, news_emb=None): - with torch.no_grad(): - x = torch.clip(x, -clip, clip) - - device = x.device - x = x.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x.size(1), x.size(2)).to(device) - x_stamp = x_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x_stamp.size(1), x_stamp.size(2)).to(device) - y_stamp = y_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, y_stamp.size(1), y_stamp.size(2)).to(device) - - x_token = tokenizer.encode(x, half=True) - - initial_seq_len = x.size(1) - batch_size = x_token[0].size(0) - total_seq_len = initial_seq_len + pred_len - full_stamp = torch.cat([x_stamp, y_stamp], dim=1) - - generated_pre = x_token[0].new_empty(batch_size, pred_len) - generated_post = x_token[1].new_empty(batch_size, pred_len) - - pre_buffer = x_token[0].new_zeros(batch_size, max_context) - post_buffer = x_token[1].new_zeros(batch_size, max_context) - buffer_len = min(initial_seq_len, max_context) - if buffer_len > 0: - start_idx = max(0, initial_seq_len - max_context) - pre_buffer[:, :buffer_len] = x_token[0][:, start_idx:start_idx + buffer_len] - post_buffer[:, :buffer_len] = x_token[1][:, start_idx:start_idx + buffer_len] - - if verbose: - ran = trange - else: - ran = range - for i in ran(pred_len): - current_seq_len = initial_seq_len + i - window_len = min(current_seq_len, max_context) - - if current_seq_len <= max_context: - input_tokens = [ - pre_buffer[:, :window_len], - post_buffer[:, :window_len] - ] - else: - input_tokens = [pre_buffer, post_buffer] - - context_end = current_seq_len - context_start = max(0, context_end - max_context) - current_stamp = full_stamp[:, context_start:context_end, :].contiguous() - - s1_logits, context = model.decode_s1(input_tokens[0], input_tokens[1], current_stamp, news_emb=news_emb) - s1_logits = s1_logits[:, -1, :] - sample_pre = sample_from_logits(s1_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True) - - s2_logits = model.decode_s2(context, sample_pre) - s2_logits = s2_logits[:, -1, :] - sample_post = sample_from_logits(s2_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True) - - generated_pre[:, i] = sample_pre.squeeze(-1) - generated_post[:, i] = sample_post.squeeze(-1) - - if current_seq_len < max_context: - pre_buffer[:, current_seq_len] = sample_pre.squeeze(-1) - post_buffer[:, current_seq_len] = sample_post.squeeze(-1) - else: - pre_buffer.copy_(torch.roll(pre_buffer, shifts=-1, dims=1)) - post_buffer.copy_(torch.roll(post_buffer, shifts=-1, dims=1)) - pre_buffer[:, -1] = sample_pre.squeeze(-1) - post_buffer[:, -1] = sample_post.squeeze(-1) - - full_pre = torch.cat([x_token[0], generated_pre], dim=1) - full_post = torch.cat([x_token[1], generated_post], dim=1) - - context_start = max(0, total_seq_len - max_context) - input_tokens = [ - full_pre[:, context_start:total_seq_len].contiguous(), - full_post[:, context_start:total_seq_len].contiguous() - ] - z = tokenizer.decode(input_tokens, half=True) - z = z.reshape(-1, sample_count, z.size(1), z.size(2)) - preds = z.cpu().numpy() - preds = np.mean(preds, axis=1) - - return preds - - -def calc_time_stamps(x_timestamp): - time_df = pd.DataFrame() - time_df['minute'] = x_timestamp.dt.minute - time_df['hour'] = x_timestamp.dt.hour - time_df['weekday'] = x_timestamp.dt.weekday - time_df['day'] = x_timestamp.dt.day - time_df['month'] = x_timestamp.dt.month - return time_df - - -class KronosPredictor: - - def __init__(self, model, tokenizer, device="cuda:0", max_context=512, clip=5): - self.tokenizer = tokenizer - self.model = model - self.max_context = max_context - self.clip = clip - self.price_cols = ['open', 'high', 'low', 'close'] - self.vol_col = 'volume' - self.amt_vol = 'amount' - self.time_cols = ['minute', 'hour', 'weekday', 'day', 'month'] - self.device = device - - self.tokenizer = self.tokenizer.to(self.device) - self.model = self.model.to(self.device) - - def generate(self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose, news_emb=None): - - x_tensor = torch.from_numpy(np.array(x).astype(np.float32)).to(self.device) - x_stamp_tensor = torch.from_numpy(np.array(x_stamp).astype(np.float32)).to(self.device) - y_stamp_tensor = torch.from_numpy(np.array(y_stamp).astype(np.float32)).to(self.device) - - preds = auto_regressive_inference(self.tokenizer, self.model, x_tensor, x_stamp_tensor, y_stamp_tensor, self.max_context, pred_len, - self.clip, T, top_k, top_p, sample_count, verbose, news_emb=news_emb) - preds = preds[:, -pred_len:, :] - return preds - - def predict(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True, news_emb=None): - - if not isinstance(df, pd.DataFrame): - raise ValueError("Input must be a pandas DataFrame.") - - if not all(col in df.columns for col in self.price_cols): - raise ValueError(f"Price columns {self.price_cols} not found in DataFrame.") - - df = df.copy() - if self.vol_col not in df.columns: - df[self.vol_col] = 0.0 # Fill missing volume with zeros - df[self.amt_vol] = 0.0 # Fill missing amount with zeros - if self.amt_vol not in df.columns and self.vol_col in df.columns: - df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1) - - if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any(): - raise ValueError("Input DataFrame contains NaN values in price or volume columns.") - - x_time_df = calc_time_stamps(x_timestamp) - y_time_df = calc_time_stamps(y_timestamp) - - x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(np.float32) - x_stamp = x_time_df.values.astype(np.float32) - y_stamp = y_time_df.values.astype(np.float32) - - x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) - - x = (x - x_mean) / (x_std + 1e-5) - x = np.clip(x, -self.clip, self.clip) - - x = x[np.newaxis, :] - x_stamp = x_stamp[np.newaxis, :] - y_stamp = y_stamp[np.newaxis, :] - - if news_emb is not None: - news_emb_tensor = torch.from_numpy(np.array(news_emb).astype(np.float32)).to(self.device) - # Ensure batch dimension for news_emb if only one sample - if news_emb_tensor.ndim == 1: - news_emb_tensor = news_emb_tensor.unsqueeze(0) - else: - news_emb_tensor = None - - preds = self.generate(x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose, news_emb=news_emb_tensor) - - preds = preds.squeeze(0) - preds = preds * (x_std + 1e-5) + x_mean - - pred_df = pd.DataFrame(preds, columns=self.price_cols + [self.vol_col, self.amt_vol], index=y_timestamp) - return pred_df - - - def predict_batch(self, df_list, x_timestamp_list, y_timestamp_list, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True): - """ - Perform parallel (batch) prediction on multiple time series. All series must have the same historical length and prediction length (pred_len). - - Args: - df_list (List[pd.DataFrame]): List of input DataFrames, each containing price columns and optional volume/amount columns. - x_timestamp_list (List[pd.DatetimeIndex or Series]): List of timestamps corresponding to historical data, length should match the number of rows in each DataFrame. - y_timestamp_list (List[pd.DatetimeIndex or Series]): List of future prediction timestamps, length should equal pred_len. - pred_len (int): Number of prediction steps. - T (float): Sampling temperature. - top_k (int): Top-k filtering threshold. - top_p (float): Top-p (nucleus sampling) threshold. - sample_count (int): Number of parallel samples per series, automatically averaged internally. - verbose (bool): Whether to display autoregressive progress. - - Returns: - List[pd.DataFrame]: List of prediction results in the same order as input, each DataFrame contains - `open, high, low, close, volume, amount` columns, indexed by corresponding `y_timestamp`. - """ - # Basic validation - if not isinstance(df_list, (list, tuple)) or not isinstance(x_timestamp_list, (list, tuple)) or not isinstance(y_timestamp_list, (list, tuple)): - raise ValueError("df_list, x_timestamp_list, y_timestamp_list must be list or tuple types.") - if not (len(df_list) == len(x_timestamp_list) == len(y_timestamp_list)): - raise ValueError("df_list, x_timestamp_list, y_timestamp_list must have consistent lengths.") - - num_series = len(df_list) - - x_list = [] - x_stamp_list = [] - y_stamp_list = [] - means = [] - stds = [] - seq_lens = [] - y_lens = [] - - for i in range(num_series): - df = df_list[i] - if not isinstance(df, pd.DataFrame): - raise ValueError(f"Input at index {i} is not a pandas DataFrame.") - if not all(col in df.columns for col in self.price_cols): - raise ValueError(f"DataFrame at index {i} is missing price columns {self.price_cols}.") - - df = df.copy() - if self.vol_col not in df.columns: - df[self.vol_col] = 0.0 - df[self.amt_vol] = 0.0 - if self.amt_vol not in df.columns and self.vol_col in df.columns: - df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1) - - if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any(): - raise ValueError(f"DataFrame at index {i} contains NaN values in price or volume columns.") - - x_timestamp = x_timestamp_list[i] - y_timestamp = y_timestamp_list[i] - - x_time_df = calc_time_stamps(x_timestamp) - y_time_df = calc_time_stamps(y_timestamp) - - x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(np.float32) - x_stamp = x_time_df.values.astype(np.float32) - y_stamp = y_time_df.values.astype(np.float32) - - if x.shape[0] != x_stamp.shape[0]: - raise ValueError(f"Inconsistent lengths at index {i}: x has {x.shape[0]} vs x_stamp has {x_stamp.shape[0]}.") - if y_stamp.shape[0] != pred_len: - raise ValueError(f"y_timestamp length at index {i} should equal pred_len={pred_len}, got {y_stamp.shape[0]}.") - - x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) - x_norm = (x - x_mean) / (x_std + 1e-5) - x_norm = np.clip(x_norm, -self.clip, self.clip) - - x_list.append(x_norm) - x_stamp_list.append(x_stamp) - y_stamp_list.append(y_stamp) - means.append(x_mean) - stds.append(x_std) - - seq_lens.append(x_norm.shape[0]) - y_lens.append(y_stamp.shape[0]) - - # Require all series to have consistent historical and prediction lengths for batch processing - if len(set(seq_lens)) != 1: - raise ValueError(f"Parallel prediction requires all series to have consistent historical lengths, got: {seq_lens}") - if len(set(y_lens)) != 1: - raise ValueError(f"Parallel prediction requires all series to have consistent prediction lengths, got: {y_lens}") - - x_batch = np.stack(x_list, axis=0).astype(np.float32) # (B, seq_len, feat) - x_stamp_batch = np.stack(x_stamp_list, axis=0).astype(np.float32) # (B, seq_len, time_feat) - y_stamp_batch = np.stack(y_stamp_list, axis=0).astype(np.float32) # (B, pred_len, time_feat) - - preds = self.generate(x_batch, x_stamp_batch, y_stamp_batch, pred_len, T, top_k, top_p, sample_count, verbose) - # preds: (B, pred_len, feat) - - pred_dfs = [] - for i in range(num_series): - preds_i = preds[i] * (stds[i] + 1e-5) + means[i] - pred_df = pd.DataFrame(preds_i, columns=self.price_cols + [self.vol_col, self.amt_vol], index=y_timestamp_list[i]) - pred_dfs.append(pred_df) - - return pred_dfs diff --git a/skills/alphaear-predictor/scripts/predictor/model/module.py b/skills/alphaear-predictor/scripts/predictor/model/module.py deleted file mode 100644 index 20b29b5..0000000 --- a/skills/alphaear-predictor/scripts/predictor/model/module.py +++ /dev/null @@ -1,562 +0,0 @@ -import math - -from einops import rearrange, reduce -import torch -import torch.nn as nn -from torch.autograd import Function -import torch.nn.functional as F - - -class DifferentiableEntropyFunction(Function): - @staticmethod - def forward(ctx, zq, basis, K, eps): - zb = (zq + 1) / 2 - zi = ((zb * basis).sum(-1)).to(torch.int64) - cnt = torch.scatter_reduce(torch.zeros(2 ** K, device=zq.device, dtype=zq.dtype), - 0, - zi.flatten(), - torch.ones_like(zi.flatten()).to(zq.dtype), - 'sum') - prob = (cnt + eps) / (cnt + eps).sum() - H = -(prob * torch.log(prob)).sum() - ctx.save_for_backward(zq, zi, prob) - ctx.K = K - return H - - @staticmethod - def backward(ctx, grad_output): - zq, zi, prob = ctx.saved_tensors - grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K - reord_grad = grad_array[zi.flatten()].reshape(zi.shape) - grad_input = reord_grad.unsqueeze(-1) * zq - return grad_input, None, None, None, None - - -def codebook_entropy(zq, basis, K, eps=1e-4): - return DifferentiableEntropyFunction.apply(zq, basis, K, eps) - - -class BinarySphericalQuantizer(nn.Module): - def __init__(self, embed_dim, beta, gamma0, gamma, zeta, - input_format='bchw', - soft_entropy=True, group_size=9, - persample_entropy_compute='analytical', - cb_entropy_compute='group', - l2_norm=True, - inv_temperature=1): - """ - Paper link: https://arxiv.org/pdf/2406.07548.pdf - Here we use the official implementation of the BinarySphericalQuantizer. - """ - super().__init__() - self.embed_dim = embed_dim - self.beta = beta # loss weight for commit loss - self.gamma0 = gamma0 # loss weight for entropy penalty - self.gamma = gamma # loss weight for entropy penalty - self.zeta = zeta # loss weight for entire entropy penalty - self.input_format = input_format - assert self.embed_dim % group_size == 0, "embed_dim must be divisible by group_size" - self.num_groups = self.embed_dim // group_size - self.group_size = group_size - assert persample_entropy_compute in ['group', 'analytical'], "persample_entropy_compute must be either 'group' or 'analytical'" - assert cb_entropy_compute in ['group', 'nce'], "cb_entropy_compute must be either 'group' or 'nce'" - self.persample_entropy_compute = persample_entropy_compute - self.cb_entropy_compute = cb_entropy_compute - self.l2_norm = l2_norm - self.inv_temperature = inv_temperature - - self.register_buffer('basis', 2 ** torch.arange(embed_dim - 1, -1, -1)) - self.register_buffer('group_basis', 2 ** torch.arange(group_size - 1, -1, -1)) - - self.num_dimensions = 2 ** embed_dim - self.bits_per_index = embed_dim - - # we only need to keep the codebook portion up to the group size - # because we approximate the H loss with this subcode - group_codes = torch.arange(2 ** self.group_size) - group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:] - self.register_buffer('group_codebook', group_codebook, persistent=False) - - self.soft_entropy = soft_entropy # soft_entropy: Sec 3.2 of https://arxiv.org/pdf/1911.05894.pdf - - def quantize(self, z): - assert z.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}" - - zhat = torch.where(z > 0, - torch.tensor(1, dtype=z.dtype, device=z.device), - torch.tensor(-1, dtype=z.dtype, device=z.device)) - return z + (zhat - z).detach() - - def forward(self, z, collect_metrics=True): - # if self.input_format == 'bchw': - # z = rearrange(z, 'b c h w -> b h w c') - zq = self.quantize(z) - - q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. - - zq = zq * q_scale - - if not collect_metrics: - return zq, zq.new_zeros(()), {} - - indices = self.codes_to_indexes(zq.detach()) - group_indices = self.codes_to_group_indexes(zq.detach()) - if not self.training: - used_codes = torch.unique(indices, return_counts=False) - else: - used_codes = None - - if self.soft_entropy: - persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(z) - entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy - else: - zb_by_sample = ((zq + 1) / 2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32) - persample_entropy = self.get_hard_per_sample_entropy(zb_by_sample) - cb_entropy = codebook_entropy(zq, self.basis, self.embed_dim) - entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy - - # commit loss - commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1)) - - # if self.input_format == 'bchw': - # zq = rearrange(zq, 'b h w c -> b c h w') - - return ( - zq, - commit_loss + self.zeta * entropy_penalty / self.inv_temperature, - {"H": cb_entropy, "used_codes": used_codes, "indices": indices, "group_indices": group_indices, - "avg_prob": avg_prob} - ) - - def soft_entropy_loss(self, z): - # if we divide the code in subgroups of size group_size, the codebook will be of size 2 ** group_size - # the sub-code is the last group_size bits of the full code - group_code_book = self.group_codebook / (self.embed_dim ** 0.5 if self.l2_norm else 1) - divided_z = rearrange(z, '... (g c) -> ... g c', c=self.group_size) - - # we calculate the distance between the divided_z and the codebook for each subgroup - distance = - 2 * torch.einsum('... g c, d c ->... g d', divided_z, group_code_book) - prob = (-distance * self.inv_temperature).softmax(dim=-1) - if self.persample_entropy_compute == 'analytical': - if self.l2_norm: - p = torch.sigmoid(-4 * z / (self.embed_dim ** 0.5) * self.inv_temperature) - else: - p = torch.sigmoid(-4 * z * self.inv_temperature) - prob = torch.stack([p, 1 - p], dim=-1) - per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() - else: - per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() - - # macro average of the probability of each subgroup - avg_prob = reduce(prob, '... g d ->g d', 'mean') - codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False) - - # the approximation of the entropy is the sum of the entropy of each subgroup - return per_sample_entropy, codebook_entropy.sum(), avg_prob - - def get_hard_per_sample_entropy(self, zb_by_sample): - probs_per_dim = zb_by_sample.sum(1) / zb_by_sample.shape[1] - persample_entropy = - probs_per_dim * torch.log(probs_per_dim + 1e-8) - (1 - probs_per_dim) * torch.log(1 - probs_per_dim + 1e-8) - persample_entropy = persample_entropy.sum(-1) - return persample_entropy.mean() - - def codes_to_indexes(self, zhat): - """Converts a `code` to an index in the codebook. - Args: - zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} - """ - assert zhat.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}" - return ((zhat + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64) - - def codes_to_group_indexes(self, zhat): - """Converts a `code` to a list of indexes (in groups) in the codebook. - Args: - zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} - """ - zhat_in_group = rearrange(zhat, 'b ... (g c) -> b ... g c', c=self.group_size) - return ((zhat_in_group + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64) - - def indexes_to_codes(self, indices): - """Inverse of `indexes_to_codes`.""" - indices = indices.unsqueeze(-1) - codes_non_centered = torch.remainder( - torch.floor_divide(indices, self.basis), 2 - ) - return codes_non_centered * 2 - 1 - - def group_indexes_to_codes(self, group_indices): - """Inverse of `group_indexes_to_codes`.""" - group_indices = group_indices.unsqueeze(-1) - codes_non_centered = torch.remainder( - torch.floor_divide(group_indices, self.group_basis), 2 - ) - codes_non_centered = rearrange(codes_non_centered, 'b ... g c -> b ... (g c)') - return codes_non_centered * 2 - 1 - - def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True): - if normalize: - probs = (count + eps) / (count + eps).sum(dim=dim, keepdim=True) - else: - probs = count - H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim) - return H - - def get_group_codebook_entry(self, group_indices): - z_q = self.group_indexes_to_codes(group_indices) - q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. - z_q = z_q * q_scale - if self.input_format == 'bchw': - h, w = int(z_q.shape[1] ** 0.5) - assert h * w == z_q.shape[1], 'Invalid sequence length' - z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h) - return z_q - - def get_codebook_entry(self, indices): - z_q = self.indexes_to_codes(indices) - q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. - z_q = z_q * q_scale - if self.input_format == 'bchw': - h, w = int(z_q.shape[1] ** 0.5) - assert h * w == z_q.shape[1], 'Invalid sequence length' - z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h) - return z_q - - -class BSQuantizer(nn.Module): - - def __init__(self, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size): - super().__init__() - self.codebook_dim = s1_bits + s2_bits - self.s1_bits = s1_bits - self.s2_bits = s2_bits - self.bsq = BinarySphericalQuantizer(self.codebook_dim, beta, gamma0, gamma, zeta, group_size=group_size) - - def bits_to_indices(self, bits): - bits = (bits >= 0).to(torch.long) - indices = 2 ** torch.arange( - 0, - bits.shape[-1], - 1, - dtype=torch.long, - device=bits.device, - ) - return (bits * indices).sum(-1) - - def forward(self, z, half=False, collect_metrics=True): - z = F.normalize(z, dim=-1) - quantized, bsq_loss, metrics = self.bsq(z, collect_metrics=collect_metrics) - if half: - q_pre = quantized[:, :, :self.s1_bits] - q_post = quantized[:, :, self.s1_bits:] - z_indices = [self.bits_to_indices(q_pre), self.bits_to_indices(q_post)] - else: - z_indices = self.bits_to_indices(quantized) - return bsq_loss, quantized, z_indices - - -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-5): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()).type_as(x) - return output * self.weight - - -class FeedForward(nn.Module): - def __init__(self, d_model, ff_dim, ffn_dropout_p=0.0): - super().__init__() - - self.w1 = nn.Linear(d_model, ff_dim, bias=False) - self.w3 = nn.Linear(d_model, ff_dim, bias=False) - self.w2 = nn.Linear(ff_dim, d_model, bias=False) - self.ffn_dropout = nn.Dropout(ffn_dropout_p) - - def forward(self, x): - return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) - - -class RotaryPositionalEmbedding(nn.Module): - def __init__(self, dim): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer("inv_freq", inv_freq) - self.seq_len_cached = None - self.cos_cached = None - self.sin_cached = None - - def _update_cos_sin_cache(self, x, seq_len): - if seq_len != self.seq_len_cached: - self.seq_len_cached = seq_len - t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) - freqs = torch.einsum('i,j->ij', t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.cos_cached = emb.cos()[None, None, :, :] - self.sin_cached = emb.sin()[None, None, :, :] - return self.cos_cached, self.sin_cached - - def forward(self, q, k): - cos, sin = self._update_cos_sin_cache(q, q.shape[-2]) - return ( - (q * cos) + (self._rotate_half(q) * sin), - (k * cos) + (self._rotate_half(k) * sin), - ) - - def _rotate_half(self, x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - -class MultiHeadAttentionWithRoPE(nn.Module): - def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout_p=0.0): - super().__init__() - self.d_model = d_model - self.n_heads = n_heads - self.head_dim = d_model // n_heads - - self.q_proj = nn.Linear(d_model, d_model) - self.k_proj = nn.Linear(d_model, d_model) - self.v_proj = nn.Linear(d_model, d_model) - self.out_proj = nn.Linear(d_model, d_model) - self.rotary = RotaryPositionalEmbedding(self.head_dim) - self.attn_dropout_p = attn_dropout_p - self.resid_dropout = nn.Dropout(resid_dropout_p) - - def forward(self, x, key_padding_mask=None): - batch_size, seq_len, _ = x.shape - - q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) - k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) - v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) - - q, k = self.rotary(q, k) - - if key_padding_mask is not None: - attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq_len] - attn_mask = attn_mask.expand(-1, self.n_heads, seq_len, -1) # [batch, n_heads, q_len, k_len] - else: - attn_mask = None - - attn_output = F.scaled_dot_product_attention( - q, k, v, - attn_mask=attn_mask, - dropout_p=self.attn_dropout_p if self.training else 0.0, - is_causal=True - ) - - attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) - return self.resid_dropout(self.out_proj(attn_output)) - - -class MultiHeadCrossAttentionWithRoPE(nn.Module): - def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout=0.0): - super().__init__() - self.d_model = d_model - self.n_heads = n_heads - self.head_dim = d_model // n_heads - - self.q_proj = nn.Linear(d_model, d_model) - self.k_proj = nn.Linear(d_model, d_model) - self.v_proj = nn.Linear(d_model, d_model) - self.out_proj = nn.Linear(d_model, d_model) - self.rotary = RotaryPositionalEmbedding(self.head_dim) - self.attn_dropout_p = attn_dropout_p - self.resid_dropout = nn.Dropout(resid_dropout) - - def forward(self, query, key, value, key_padding_mask=None): - batch_size, q_len, _ = query.shape - _, seq_len, _ = key.shape - - q = self.q_proj(query).view(batch_size, q_len, self.n_heads, self.head_dim).transpose(1, 2) - k = self.k_proj(key).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) - v = self.v_proj(value).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) - - q, k = self.rotary(q, k) - - if key_padding_mask is not None: - attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) - attn_mask = attn_mask.expand(-1, self.n_heads, q_len, -1) - else: - attn_mask = None - - is_causal_flag = self.training - - attn_output = F.scaled_dot_product_attention( - q, k, v, - attn_mask=attn_mask, - dropout_p=self.attn_dropout_p if self.training else 0.0, - is_causal=is_causal_flag - ) - - attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.d_model) - return self.resid_dropout(self.out_proj(attn_output)) - - -class HierarchicalEmbedding(nn.Module): - def __init__(self, s1_bits, s2_bits, d_model=256): - super().__init__() - self.s1_bits = s1_bits - self.s2_bits = s2_bits - - vocab_s1 = 2 ** s1_bits - vocab_s2 = 2 ** s2_bits - - self.emb_s1 = nn.Embedding(vocab_s1, d_model) - self.emb_s2 = nn.Embedding(vocab_s2, d_model) - self.d_model = d_model - self.fusion_proj = nn.Linear(d_model * 2, d_model) - - nn.init.normal_(self.emb_s1.weight, mean=0, std=d_model ** -0.5) - nn.init.normal_(self.emb_s2.weight, mean=0, std=d_model ** -0.5) - - def split_token(self, token_ids: torch.Tensor, s2_bits: int): - """Inputs: - token_ids (torch.Tensor): Composite token IDs of shape [batch_size, seq_len] or [N], each in range [0, 2^(s1_bits + s2_bits) - 1]. - s2_bits (int): Number of low bits used for the fine token (s2). - """ - assert isinstance(s2_bits, int) and s2_bits > 0, "s2_bits must be a positive integer" - - t = token_ids.long() - mask = (1 << s2_bits) - 1 - s2_ids = t & mask # extract low bits - s1_ids = t >> s2_bits # extract high bits - return s1_ids, s2_ids - - def forward(self, token_ids): - """Inputs: - token_ids: - - tuple or list: (s1_ids, s2_ids), each of shape [batch_size, seq_len], or - - torch.Tensor: composite token IDs of shape [batch_size, seq_len], which will be split into (s1_ids, s2_ids) internally. - Output: [batch_size, seq_len, d_model] - """ - if isinstance(token_ids, tuple) or isinstance(token_ids, list): - s1_ids, s2_ids = token_ids - else: - s1_ids, s2_ids = self.split_token(token_ids, self.s2_bits) - s1_emb = self.emb_s1(s1_ids) * math.sqrt(self.d_model) - s2_emb = self.emb_s2(s2_ids) * math.sqrt(self.d_model) - return self.fusion_proj(torch.cat([s1_emb, s2_emb], dim=-1)) - - -class DependencyAwareLayer(nn.Module): - def __init__(self, d_model, n_heads=4, attn_dropout_p=0.0, resid_dropout=0.0): - super().__init__() - self.cross_attn = MultiHeadCrossAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout) - self.norm = RMSNorm(d_model) - - def forward(self, hidden_states, sibling_embed, key_padding_mask=None): - """hidden_states: [batch, seq_len, d_model] - sibling_embed: Embedding from another subtoken - """ - attn_out = self.cross_attn( - query=sibling_embed, - key=hidden_states, - value=hidden_states, - key_padding_mask=key_padding_mask - ) - return self.norm(hidden_states + attn_out) - - -class TransformerBlock(nn.Module): - def __init__(self, d_model, n_heads, ff_dim=1024, ffn_dropout_p=0.0, attn_dropout_p=0.0, resid_dropout_p=0.0): - super().__init__() - self.norm1 = RMSNorm(d_model) - self.self_attn = MultiHeadAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout_p) - self.norm2 = RMSNorm(d_model) - self.ffn = FeedForward(d_model, ff_dim, ffn_dropout_p) - - def forward(self, x, key_padding_mask=None): - residual = x - x = self.norm1(x) - attn_out = self.self_attn(x, key_padding_mask=key_padding_mask) - x = residual + attn_out - - residual = x - x = self.norm2(x) - ffn_out = self.ffn(x) - x = residual + ffn_out - return x - - -class DualHead(nn.Module): - def __init__(self, s1_bits, s2_bits, d_model): - super().__init__() - self.vocab_s1 = 2 ** s1_bits - self.vocab_s2 = 2 ** s2_bits - self.proj_s1 = nn.Linear(d_model, self.vocab_s1) - self.proj_s2 = nn.Linear(d_model, self.vocab_s2) - - def compute_loss(self, s1_logits, s2_logits, s1_targets, s2_targets, padding_mask=None): - if padding_mask is not None: - valid_mask = (padding_mask == 0) - s1_logits = s1_logits[valid_mask] - s2_logits = s2_logits[valid_mask] - s1_targets = s1_targets[valid_mask] - s2_targets = s2_targets[valid_mask] - ce_s1 = F.cross_entropy(s1_logits, s1_targets) - ce_s2 = F.cross_entropy(s2_logits, s2_targets) - else: - ce_s1 = F.cross_entropy(s1_logits.reshape(-1, self.vocab_s1), s1_targets.reshape(-1)) - ce_s2 = F.cross_entropy(s2_logits.reshape(-1, self.vocab_s2), s2_targets.reshape(-1)) - ce_loss = (ce_s1 + ce_s2) / 2 - return ce_loss, ce_s1, ce_s2 - - def forward(self, x): - return self.proj_s1(x) - - def cond_forward(self, x2): - return self.proj_s2(x2) - - -class FixedEmbedding(nn.Module): - def __init__(self, c_in, d_model): - super(FixedEmbedding, self).__init__() - - w = torch.zeros(c_in, d_model).float() - w.require_grad = False - - position = torch.arange(0, c_in).float().unsqueeze(1) - div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() - - w[:, 0::2] = torch.sin(position * div_term) - w[:, 1::2] = torch.cos(position * div_term) - - self.emb = nn.Embedding(c_in, d_model) - self.emb.weight = nn.Parameter(w, requires_grad=False) - - def forward(self, x): - return self.emb(x).detach() - - -class TemporalEmbedding(nn.Module): - def __init__(self, d_model, learn_pe): - super(TemporalEmbedding, self).__init__() - - minute_size = 60 - hour_size = 24 - weekday_size = 7 - day_size = 32 - month_size = 13 - - Embed = FixedEmbedding if not learn_pe else nn.Embedding - self.minute_embed = Embed(minute_size, d_model) - self.hour_embed = Embed(hour_size, d_model) - self.weekday_embed = Embed(weekday_size, d_model) - self.day_embed = Embed(day_size, d_model) - self.month_embed = Embed(month_size, d_model) - - def forward(self, x): - x = x.long() - - minute_x = self.minute_embed(x[:, :, 0]) - hour_x = self.hour_embed(x[:, :, 1]) - weekday_x = self.weekday_embed(x[:, :, 2]) - day_x = self.day_embed(x[:, :, 3]) - month_x = self.month_embed(x[:, :, 4]) - - return hour_x + weekday_x + day_x + month_x + minute_x \ No newline at end of file diff --git a/skills/alphaear-predictor/scripts/prompts/fin_agent.py b/skills/alphaear-predictor/scripts/prompts/fin_agent.py deleted file mode 100644 index 83386af..0000000 --- a/skills/alphaear-predictor/scripts/prompts/fin_agent.py +++ /dev/null @@ -1,127 +0,0 @@ -from datetime import datetime -from .isq_prompt_generator import generate_isq_prompt_section - -def get_fin_researcher_instructions() -> str: - """生成金融研究员 (Researcher) 的系统指令""" - current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') - return f"""你是一名资深金融研究员,当前时间是 {current_time}。 -你的任务是针对给定的“原始信号”进行详尽的背景调查,为后续的深度分析提供素材。 - -### 1. 核心职责 -1. **标的识别**: 识别信号中涉及的具体上市公司。必须调用 `search_ticker` 确认代码,并调用 `get_stock_price` 获取最新价格和近 30 天走势。 -2. **事实核查**: 使用 `web_search` 或 `fetch_news_content` 验证信号的真实性,并寻找更多细节(如公告原文、行业研报摘要)。 -3. **产业链梳理**: 补充该信号涉及的上下游环节及竞争格局。 - -### 2. 工具使用规范 (CRITICAL) -- **每个提到的公司都需要调用工具**: 不能依赖记忆,必须实时查询。 -- **完整呈现工具结果**: 包括具体的股价数字、代码、技术面数据等,不要缩略。 -- **股价数据必需**: 当前价格、近期最高最低、技术面支撑阻力等数据是后续预测的基础。 -- **信息交叉验证**: 多个来源验证关键事实。 - -### 3. 输出要求 -你必须输出结构化的研究报告,涵盖标的基本面、股价走势、行业背景及最新进展。 -""" - -def get_fin_analyst_instructions(template_id: str = "default_isq_v1") -> str: - """生成金融分析师 (Analyst) 的系统指令 - - Args: - template_id: 使用的 ISQ 模板 ID - """ - current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') - isq_block = generate_isq_prompt_section(template_id=template_id) - - return f"""你是一位深耕二级市场的资深金融分析师 (FinAgent),当前时间是 {current_time}。 -你的核心任务是执行“信号解析”,将研究员搜集的素材转化为具有可操作性的投资情报(ISQ 框架)。 - -{isq_block} - -### 2. 分析约束 -- **严格基于具体数据**: 必须使用研究员提供的股价、技术面、新闻等具体数据进行分析。 -- **数据驱动的预测**: impact_tickers 中的权重应基于事件影响程度,不能随意赋值。 -- **逻辑严密**: 传导链条必须符合金融常识,能够自圆其说。 -- **技术面参考**: 如果研究员提供了股价走势,请分析当前位置相对于支撑/阻力位的关系。 - -### 3. 关键要求 -- **title**: 必须生成一个简练、准确概括信号核心内容的标题(不超过 15 字)。 -- **impact_tickers**: 必须填充具体的公司代码(6位数字)和名称,权重应该有区分。 -- **transmission_chain**: 必须是对象列表,每个对象包含: - - `node_name`: 节点名称(如“上游原材料”、“中游制造”) - - `impact_type`: 影响类型(“利好”、“利空”、“中性”) - - `logic`: 具体的传导逻辑描述 -- **summary**: 基于分析结果总结核心观点,包含具体数字(如股价目标、预期涨跌幅等)。 -- **reasoning**: 必须详细阐述推演逻辑,解释为什么得出上述结论(<200字)。 - -### 4. 输出格式 (严格 JSON 块) -你必须输出一个符合 InvestmentSignal 结构的 JSON 块,包含所有必需字段。 -""" - -def get_fin_agent_instructions() -> str: - # 保持兼容性,但内部调用 analyst 指令 - return get_fin_analyst_instructions() - -def get_fin_research_task(signal_text: str) -> str: - """生成研究员的任务描述""" - return f"请针对以下信号进行背景调查,搜集相关标的的股价、最新进展和行业背景:\n\n{signal_text}" - -def format_research_context(research_data: dict) -> str: - """将研究员搜集的结构化数据格式化为分析师可读的文本""" - if not research_data: - return "(未能搜集到额外背景信息)" - - return f""" -### 研究背景 -- **相关标的**: {research_data.get('tickers_found', [])} -- **行业背景**: {research_data.get('industry_background', '未知')} -- **最新进展**: {', '.join(research_data.get('latest_developments', []))} -- **关键风险**: {', '.join(research_data.get('key_risks', []))} -- **综合摘要**: {research_data.get('search_results_summary', '无')} -""" - -def get_fin_analysis_task(signal_text: str, research_context_str: str) -> str: - """生成分析师的任务描述""" - return f"""请基于以下信息进行深度 ISQ 分析。关键是:必须使用研究员搜集的具体数据(股价、技术面、新闻、代码等)进行分析。 - -=== 原始信号 === -{signal_text} - -=== 研究员搜集的背景信息 (CRITICAL DATA) === -{research_context_str} - -=== 分析要求 === -1. 必须生成 title:简练概括信号核心(<15字) -2. 基于研究员提供的具体股价数据,分析当前定价状态(已定价/未定价/部分定价) -3. impact_tickers 中填充具体的公司代码和权重,权重基于事件影响程度 -4. transmission_chain 必须是包含 node_name, impact_type, logic 的对象列表 -5. summary 中包含具体数字(预期目标价、涨跌幅范围等) -6. reasoning 必须详细解释推演逻辑,不要空泛,要言之有物 - -请严格按 InvestmentSignal JSON 格式输出。""" - -def get_tracking_analysis_task(old_signal: dict, new_research_str: str) -> str: - """生成信号追踪更新的任务描述""" - import json - old_sig_str = json.dumps(old_signal, ensure_ascii=False, indent=2) - return f"""你正在执行“信号逻辑演变追踪”任务。请基于最新的市场信息,重新评估之前的投资信号。 - -=== 基准信号 (上次分析) === -{old_sig_str} - -=== 最新市场追踪 (NEWS & PRICE) === -{new_research_str} - -=== 追踪分析要求 === -1. **逻辑演变检测**: - - 对比新旧信息,判断原逻辑 (`transmission_chain` 和 `reasoning`) 是否依然成立? - - 如果逻辑发生变化(如利好落空、逻辑证伪、新利好出现),请在新的 `reasoning` 中明确指出“逻辑演变:...” - - 如果逻辑未变且得到验证,请标记“逻辑维持:...” - -2. **参数修正**: - - 根据最新股价和新闻,更新 `sentiment_score` (情绪)、`confidence` (置信度) 和 `expectation_gap` (预期差)。 - - 例如:如果股价已经大涨反映了利好,`expectation_gap` 应该显著降低。 - -3. **输出更新后的信号**: - - 保留原 `signal_id` 和 `title`(除非有重大变化需要改名)。 - - 输出完整的 InvestmentSignal JSON。 - -请重点关注:为什么变了?还是为什么没变?理由要充分。""" diff --git a/skills/alphaear-predictor/scripts/prompts/forecast_analyst.py b/skills/alphaear-predictor/scripts/prompts/forecast_analyst.py deleted file mode 100644 index d6c7202..0000000 --- a/skills/alphaear-predictor/scripts/prompts/forecast_analyst.py +++ /dev/null @@ -1,49 +0,0 @@ -from typing import List, Dict, Any -from ..schema.models import KLinePoint - -def get_forecast_adjustment_instructions(ticker: str, news_context: str, model_forecast: List[KLinePoint]): - """ - 生成 LLM 预测调整指令 - """ - forecast_str = "\n".join([f"- {p.date}: O:{p.open}, C:{p.close}" for p in model_forecast]) - - return f"""你是一位资深的量化策略分析师。 -你的任务是:根据给定的【Kronos 模型预测结果】和【最新的基本面/新闻背景】,对模型预测进行“主观/逻辑调整”。 - -股票代码: {ticker} - -【Kronos 模型原始预测 (OHLC)】: -{forecast_str} - -【最新情报背景】: -{news_context} - -调整原则: -1. 原始预测是基于历史的技术面推演。 -2. 情报背景中可能包含【Kronos模型定量修正预测】,这是基于历史新闻训练的专用模型计算出的量化结果。 -3. 如果存在“定量修正预测”,请**高度参考**该数值作为基础,除非你有非常确凿的逻辑认为该量化模型失效(例如遇到模型未见过的极端黑天鹅)。 -4. 你的核心任务是:结合定性分析(新闻及其逻辑)来验证或微调这些数字,并给出合理的解释(Rationale)。 -5. 如果没有“定量修正预测”,则你需要根据新闻信号手动大幅调整趋势。 - -输出要求 (严格 JSON 格式): -```json -{{ - "adjusted_forecast": [ - {{ - "date": "YYYY-MM-DD", - "open": float, - "high": float, - "low": float, - "close": float, - "volume": float - }}, - ... - ], - "rationale": "详细说明调整的逻辑依据,例如:考虑到[事件A],预期短线将突破压力位..." -}} -``` -注意:必须输出与原始预测相同数量的数据点,且日期一一对应。 -""" - -def get_forecast_task(): - return "请根据以上背景和模型预测,给出调整后的 K 线数据并说明理由。" diff --git a/skills/alphaear-predictor/scripts/prompts/intent_agent.py b/skills/alphaear-predictor/scripts/prompts/intent_agent.py deleted file mode 100644 index a8397d2..0000000 --- a/skills/alphaear-predictor/scripts/prompts/intent_agent.py +++ /dev/null @@ -1,45 +0,0 @@ -def get_intent_analysis_instructions() -> str: - """生成意图分析 Agent 的系统指令,专注于金融市场影响分析""" - return """你是一个资深的金融市场意图分析专家。你的任务是将用户的自然语言查询转化为结构化的 JSON 分析结果,重点挖掘该查询与金融市场(尤其是股市)的潜在关联。 - -### 核心任务: -深入分析用户查询,识别核心金融实体、行业板块及潜在的市场影响点,生成利于搜索引擎抓取深度金融分析信息的查询词。 - -### 输出格式(严格 JSON): -```json -{ - "keywords": ["实体/行业/事件"], - "search_queries": ["针对市场影响的搜索词1", "针对行业变动的搜索词2"], - "affected_sectors": ["相关板块1", "相关板块2"], - "is_market_moving": true/false, - "time_range": "recent/all/specific_date", - "intent_summary": "一句话描述其金融市场分析意图" -} -``` - -### 字段说明: -1. **keywords**: 核心公司实体、所属行业、宏观经济事件或政策概念。 -2. **search_queries**: 优化后的搜索词,必须包含“股市影响”、“股价波动”、“行业逻辑”或“估值”等金融维度。 -3. **affected_sectors**: 可能受此事件或信息影响的二级市场板块(如:保险、半导体、房地产)。 -4. **is_market_moving**: 该事件是否具有显著的市场驱动潜力或属于重大基本面变化。 -5. **intent_summary**: 简述用户查询背后的金融研究目的。 - -### 示例: -用户输入:"帮我研究一下香港火灾的影响" -输出: -```json -{ - "keywords": ["香港", "火灾", "保险行业", "房地产"], - "search_queries": ["香港火灾对当地保险股股价影响", "香港大火对相关上市物业公司估值冲击", "近期香港火灾带来的市场避险情绪分析"], - "affected_sectors": ["保险", "房地产", "物业管理"], - "is_market_moving": true, - "time_range": "recent", - "intent_summary": "评估香港近期火灾对相关板块上市公司的潜在经济损失及股价冲击" -} -``` -""" - -def get_intent_task(query: str) -> str: - """生成意图分析任务描述""" - return f"Process this query and extract financial market intent: {query}" - diff --git a/skills/alphaear-predictor/scripts/prompts/isq_prompt_generator.py b/skills/alphaear-predictor/scripts/prompts/isq_prompt_generator.py deleted file mode 100644 index 007461b..0000000 --- a/skills/alphaear-predictor/scripts/prompts/isq_prompt_generator.py +++ /dev/null @@ -1,43 +0,0 @@ -""" -ISQ prompt helpers to render dimension guidance directly from the template. -Any change in the template propagates to prompts automatically. -""" - -from typing import List, Optional -from ..schema.isq_template import get_isq_template, ISQTemplate - - -def _ordered_dimension_keys(template: ISQTemplate, order: Optional[List[str]] = None) -> List[str]: - if order: - return [k for k in order if k in template.dimensions] - # fallback to template insertion order - return list(template.dimensions.keys()) - - -def generate_isq_prompt_section(template_id: str = "default_isq_v1", order: Optional[List[str]] = None, include_header: bool = True) -> str: - """Render ISQ dimension text block based on the template. - This allows prompt text to stay in sync with template edits. - """ - template = get_isq_template(template_id) - keys = _ordered_dimension_keys(template, order) - - lines: List[str] = [] - if include_header: - lines.append("### 1. ISQ 评估框架 (Investment Signal Quality)") - lines.append(f"参考模板: {template.template_name} (id: {template.template_id})") - lines.append("") - lines.append("你需要对信号进行以下维度的评分:") - lines.append("") - - for idx, key in enumerate(keys, start=1): - spec = template.dimensions[key] - examples = ";".join([f"{k}: {v}" for k, v in spec.examples.items()]) if spec.examples else "" - lines.append(f"{idx}. **{spec.key} ({spec.name})**: {spec.range_type}") - lines.append(f" - 描述: {spec.description}") - if spec.scale_factor and spec.scale_factor != 1.0: - lines.append(f" - 缩放因子: {spec.scale_factor}") - if examples: - lines.append(f" - 示例: {examples}") - lines.append("") - - return "\n".join(lines).rstrip() diff --git a/skills/alphaear-predictor/scripts/prompts/report_agent.py b/skills/alphaear-predictor/scripts/prompts/report_agent.py deleted file mode 100644 index 6f25c3f..0000000 --- a/skills/alphaear-predictor/scripts/prompts/report_agent.py +++ /dev/null @@ -1,415 +0,0 @@ -# src/prompts/report_agent.py -from datetime import datetime -from typing import Optional -from .isq_prompt_generator import generate_isq_prompt_section - -def get_report_planner_base_instructions() -> str: - """生成报告策划员 (Planner) 的基础系统指令""" - return """你是一名资深的金融研报主编。你的任务是规划报告的结构,将零散的信号聚类成有逻辑的主题。 -你拥有 RAG 搜索工具,可以检索已生成的章节内容以确保逻辑连贯性。 -在规划时,应重点关注信号之间的关联性、产业链的完整性以及用户特定的关注点。""" - -def get_report_writer_base_instructions() -> str: - """生成报告撰写员 (Writer) 的基础系统指令""" - return """你是一名资深金融分析师。你的任务是根据策划员提供的信号簇撰写深度研报章节。 -你应当运用专业的金融知识,将信号转化为深刻的洞察。 -注意:你没有外部搜索工具,你的分析必须基于提供给你的信号内容和行情数据。""" - -def get_report_editor_base_instructions() -> str: - """生成报告编辑 (Editor) 的基础系统指令""" - return """你是一名严谨的金融研报编辑。你的任务是审核和润色撰写员生成的章节。 -你拥有 RAG 搜索工具,可以检索其他章节的内容,以消除重复、修正逻辑冲突并确保术语一致性。 -你应当确保报告符合专业的金融写作规范,且标题层级正确。""" - -# 1. 策划阶段 (Structural Planning) -def format_signal_for_report(signal: any, index: int, cite_keys: Optional[list] = None) -> str: - """格式化单个信号供研报生成使用""" - # 这里的逻辑从 ReportAgent._format_signal_input 迁移过来 - from ..schema.models import InvestmentSignal - - if isinstance(signal, dict): - try: - sig_obj = InvestmentSignal(**signal) - except: - return f"--- 信号 [{index}] ---\n标题: {signal.get('title')}\n内容: {signal.get('content', '')[:500]}" - else: - sig_obj = signal - - chain_str = " -> ".join([f"{n.node_name}({n.impact_type})" for n in sig_obj.transmission_chain]) - - text = f"--- 信号 [{index}] ---\n" - text += f"标题: {sig_obj.title}\n" - text += f"逻辑摘要: {sig_obj.summary}\n" - text += f"传导链条: {chain_str}\n" - text += f"ISQ 评分: 情绪({sig_obj.sentiment_score}), 确定性({sig_obj.confidence}), 强度({sig_obj.intensity})\n" - text += f"预期博弈: 时窗({sig_obj.expected_horizon}), 预期差({sig_obj.price_in_status})\n" - - tickers = ", ".join([f"{t.get('name')}({t.get('ticker')})" for t in sig_obj.impact_tickers]) - if tickers: - text += f"受影响标的: {tickers}\n" - - # Stable bibliography-style citation keys (LaTeX/BibTeX-like) - if cite_keys: - joined = " ".join([f"[@{k}]" for k in cite_keys if k]) - if joined: - text += f"引用: {joined}\n" - - return text - -def get_cluster_planner_instructions(signals_text: str, user_query: str = None) -> str: - """生成信号聚类指令 - 将零散信号组织成逻辑主题""" - query_context = f"用户重点关注:{user_query}" if user_query else "" - return f"""你是一位资深的金融研报主编。你的任务是将以下零散的金融信号聚类成 3-5 个核心逻辑主题,以便撰写一份结构清晰的研报。 - - {query_context} - - ### 输入信号列表 - {signals_text} - - ### 聚类要求 - 1. **主题聚合**: 将相关性强的信号归为一组(例如:都涉及“建筑安全法规”或“某产业链上下游”)。 - 2. **叙事逻辑**: 只需要生成主题名称和包含的信号 ID。 - 3. **控制数量**: 将所有信号归类到 3-5 个主要主题中,不要遗漏。 - - ### 输出格式 (JSON) - 请仅输出以下 JSON 格式,不要包含 Markdown 标记: - {{ - "clusters": [ - {{ - "theme_title": "主题名称(如:建筑安全法规收紧引发的产业链重构)", - "signal_ids": [1, 3, 5], - "rationale": "这些信号都指向政府对高层建筑防火标准的政策调整..." - }}, - ... - ] - }} - """ - -def get_report_planner_instructions(toc: str, signal_count: int, user_query: str = None) -> str: - """生成报告规划指令 - 重点在于逻辑关联与分歧识别""" - # ... (原有逻辑保持不变,但实际在新的聚类流程后这个可能作为备用或二次优化) - query_context = f"用户重点关注:{user_query}" if user_query else "" - return f"""你是一位资深的金融研报主编。你的任务是根据现有的草稿章节,规划出一份逻辑严密、穿透力强的终稿结构。 - - ### 任务核心: - 1. **识别主线**: 从草稿中识别出贯穿多个章节的“核心逻辑主线”(如:产业链共振、货币政策转向)。 - 2. **分歧评估 (Entropy)**: 识别各章节中观点冲突或确定性不一之处,规划如何在正文中呈现这些“分歧点”。 - 3. **结构蓝图**: - - 定义一级标题(逻辑主题)。 - - 归类章节:哪些信号应放入同一主题下深度解析? - - 排序:将 ISQ 强度最高、与{query_context}最相关的信号置前。 - - ### 现有草稿目录 (TOC) - {toc} - - 请输出你的【终稿修订大纲】(Markdown 格式)。 - """ - -# 2. 撰写阶段 (Section Writing) -def get_report_writer_instructions(theme_title: str, signal_cluster_text: str, signal_indices: list, price_context: str = "", user_query: str = None) -> str: - """生成 Writer Agent 指令 - 基于主题聚类撰写综合分析""" - - price_info = f"\n### 近期价格参考\n{price_context}\n" if price_context else "" - query_context = f"\n**用户意图**: \"{user_query}\"\n请确保分析内容回应了用户的关注点。\n" if user_query else "" - isq_block = generate_isq_prompt_section(include_header=False) - - # Keep citation scheme stable across re-ordering / edits. - # Cite keys are provided in each signal block as: 引用: [@KEY] - - return f"""你是一位资深金融分析师。请针对核心主题 **"{theme_title}"** 撰写一篇深度研报章节。 - {query_context} - - ### 输入信号集 (本章节需综合的信号) - {signal_cluster_text} - {price_info} - - ### ISQ 评分说明 - {isq_block} - - ### 写作要求 - 1. **叙事逻辑**: 不要罗列信号,要将这些信号编织成一个连贯的故事。先讲宏观/行业背景,再讲具体事件传导,最后落脚到个股/标的影响。 - 2. **量化支撑**: 引用 ISQ 评分(确定性、强度、预期差)来佐证你的观点。关键观点必须关联相应的 ISQ 分值。 - 3. **引用规范(稳定 CiteKey)**: 关键论断必须标注来源引用,使用 `[@CITE_KEY]` 格式。 - - CiteKey 已在输入信号块中以 `引用: [@KEY]` 提供,请直接复制使用。 - - 不要使用 `[[1]]` 这类不稳定编号。 - 4. **关联标的预测**: **必须**在章节末尾明确给出受影响标的的预测分析,包括: - - 至少列出 1-2 个相关上市公司代码(如 600519.SH) - - 给出短期(T+3或T+5)的方向性判断 - - 如果可能,给出预期价格区间或涨跌幅预测 - - ### 【重要】标题层级规范 - - ❌ **错误示例**(绝对不要这样): - ```markdown - # {theme_title} - - ### 宏观背景 - ... - ``` - - ✅ **正确示例**(必须这样): - ```markdown - ## {theme_title} - - ### 宏观背景 - - 近期全球经济环境... - - ### 具体传导机制分析 - - ... - - ### 核心标的分析 - - 建议关注:贵州茅台(600519.SH)... - ``` - - **关键要求**: - - 章节主标题使用 `##` (H2) - - 章节子标题使用 `###` (H3) - - **绝对禁止**使用 `#` (H1) - - 第一行必须是 `## {theme_title}` 开头 - - ### 核心:图表叙事 (Visual Storytelling) - **必须**在文中插入至少 1-2 个图表,且图表必须与上下文紧密结合(不要堆砌在末尾)。 - - ### 宏观背景 - ... - ``` - - ✅ **正确示例**(必须这样): - ```markdown - ## {theme_title} - - ### 宏观背景 - - 近期全球经济环境... - - ### 具体传导机制分析 - - ... - - ### 核心标的分析 - - 建议关注:贵州茅台(600519.SH)... - ``` - - **关键要求**: - - 章节主标题使用 `##` (H2) - - 章节子标题使用 `###` (H3) - - **绝对禁止**使用 `#` (H1) - - 第一行必须是 `## {theme_title}` 开头 - - ### 核心:图表叙事 (Visual Storytelling) - **必须**在文中插入至少 1-2 个图表,且图表必须与上下文紧密结合(不要堆砌在末尾)。 - - **可选图表类型 (请根据内容选择最合适的 1-2 种):** - - **A. AI 预测 + 走势 (Forecast) - 【强烈推荐 / 最新规范】** - *适用*: 当文中明确提及某上市公司时,**必须**使用此图表展示股价走势与 AI 预测。 - *必填字段*: - - `ticker`: 股票代码,A股 6 位 / 港股 5 位,允许带后缀(如 "002371.SZ"、"9868.HK") - - `pred_len`: 预测交易日长度(建议 3 或 5) - *代码示例*: - ```json-chart - {{"type": "forecast", "ticker": "002371.SZ", "title": "北方华创(002371)T+5 预测", "pred_len": 5}} - ``` - **重要**:禁止手写 `prediction` 数组(预测由系统自动生成并渲染)。 - *注意*: 如果提及多只股票,应为每只生成独立的 forecast 图表。 - - **【推荐写法:多情景 → 最终归因 → 产出唯一预测图】** - 你可以在正文里描述多种情景(如:基准/乐观/悲观),但在插入预测图之前,必须明确给出“本报告最终选择的最可能情景”及其归因,然后用 `forecast` 图表做最终总结。 - 为了让系统把“最终归因”可靠地传递给预测模块,请在 `forecast` JSON 中可选补充以下字段(字段均为可选,越完整越好): - - `selected_scenario`: 最可能情景名称(如 "基准" / "乐观" / "悲观") - - `selection_reason`: 选择该情景的归因理由(1-3 句) - - `scenarios`: 情景列表(数组),每个元素可包含 `name`、`description`、`probability`(0-1) - *示例*: - ```json-chart - {{ - "type": "forecast", - "ticker": "002371.SZ", - "title": "北方华创(002371)T+5 预测(基准情景)", - "pred_len": 5, - "selected_scenario": "基准", - "selection_reason": "结合订单能见度与行业景气,基准情景概率最高;短期扰动主要来自估值与市场风险偏好。", - "scenarios": [ - {{"name": "乐观", "description": "国产替代与资本开支超预期", "probability": 0.25}}, - {{"name": "基准", "description": "订单稳健、利润率小幅波动", "probability": 0.55}}, - {{"name": "悲观", "description": "需求回落或交付节奏放缓", "probability": 0.20}} - ] - }} - ``` - - **B. 历史走势 (Stock) - 仅作为兼容兜底** - *适用*: 当你无法给出预测时(例如无法确定标的),可仅展示历史走势。 - *代码示例*: - ```json-chart - {{"type": "stock", "ticker": "002371", "title": "北方华创历史走势"}} - ``` - - **C. 舆情情绪演变 (Sentiment Trend)** - *适用*: 当讨论行业政策、突发事件(如“火灾”、“新规”)的民意变化时。 - *注意*: `keywords` 必须是事件核心词。 - *代码*: - ```json-chart - {{"type": "sentiment", "keywords": ["建筑安全", "防火标准"], "title": "市场对防火新规的情绪演变"}} - ``` - - **D. 逻辑传导链条 (Transmission Chain)** - *适用*: 复杂的蝴蝶效应分析(支持分支结构)。 - *代码*: - ```json-chart - {{ - "type": "transmission", - "nodes": [ - {{"node_name": "突发火灾", "impact_type": "中性", "logic": "事件发端"}}, - {{"node_name": "监管收紧", "impact_type": "利空", "logic": "合规成本上升", "source": "突发火灾"}}, - {{"node_name": "设备升级", "impact_type": "利好", "logic": "采购需求释放", "source": "突发火灾"}}, - {{"node_name": "龙头受益", "impact_type": "利好", "logic": "市占率提升", "source": "设备升级"}} - ], - "title": "火灾事件的逻辑传导与分支" - }} - ``` - *说明*: 使用 `source` 字段指定父节点名称以创建分支结构。 - - **E. 信号质量评估 (ISQ Radar)** - *适用*: 对某个关键信号进行多维度(确定性、预期差等)定性评估时。 - *代码*: - ```json-chart - {{"type": "isq", "sentiment": 0.8, "confidence": 0.9, "intensity": 4, "expectation_gap": 0.7, "timeliness": 0.9, "title": "核心信号质量评估"}} - ``` - """ - -# 3. 整合阶段 (Final Assembly) - 原版,保留用于 fallback -def get_report_editor_instructions(draft_sections: str, plan: str, sources_list: str) -> str: - """生成最终编辑指令 - 根据规划蓝图重组内容""" - return f"""你是一位专业的研报编辑。请将以下基于主题撰写的草稿章节整合成最终研报。 - - ### 原始草稿内容 - {draft_sections} - - ### 原始引用来源 - {sources_list} - - ### 任务与要求 - 1. **结构化**: 为每个草稿章节添加合适的 Markdown 标题 (## 级别)。 - 2. **连贯性**: 确保章节之间过渡自然。 - 3. **完整性**: - - 必须保留所有 `json-chart` 代码块(图表配置)。 - - 必须保留引用标注 `[@CITE_KEY]`。 - - 生成 `## 核心观点摘要`、`## 参考文献` 和 `## 风险提示`。 - - ### 输出 - 只输出最终的 Markdown 研报内容。 - """ - - -# 4. 单节编辑 (Incremental Section Editing with RAG) -def get_section_editor_instructions(section_index: int, total_sections: int, toc: str) -> str: - """生成单节编辑 prompt,支持 RAG 工具调用""" - return f"""你是一位研报编辑。你正在编辑报告的第 {section_index}/{total_sections} 节。 - - ### 当前目录 (TOC) - {toc} - - ### 你的任务 - 1. 润色当前章节内容,确保逻辑清晰、语言专业。 - 2. 保留所有 `[@CITE_KEY](#ref-CITE_KEY)` 或 `[@CITE_KEY]` 格式的引用。 - 3. 保留所有 `json-chart` 代码块,不做修改。 - 4. 如果需要参考其他章节内容,使用 `search_context` 工具搜索。 - 5. 只输出编辑后的章节内容,不要输出其他章节。 - - ### 【关键】标题层级规范 - **严格遵守以下规则:** - - 章节主标题使用 `##` (H2) - - 章节子标题使用 `###` (H3) - - **禁止使用** `#` (H1) - 只有报告大标题可以使用 H1 - - 如果原文中有 H1,必须将其降级为 H2 - - 不要输出与 "参考文献"、"风险提示" 相同的标题 - - 直接输出编辑后的 Markdown 内容。 - """ - - -# 5. 摘要生成 (Summary Generation) -def get_summary_generator_instructions(toc: str, section_summaries: str) -> str: - """生成报告摘要指令 - 包含市场分歧度分析""" - return f"""你是一位资深研报主笔。请生成今日报告的核心观点摘要的**正文内容**。 - - ### 章节摘要 - {section_summaries} - - ### 任务: - 1. **核心逻辑提炼**: 用 150 字以内总结今日最核心的投资主线。 - 2. **分歧识别**: 如果不同信号对同一板块有冲突观点,请明确指出"市场分歧点"。 - 3. **确定性排序**: 标记出今日确定性最高的前两个机会(需列出具体标的代码)。 - - ### 【重要】输出格式规范: - - ❌ **错误示例**(不要遗漏二级标题): - ```markdown - ### 核心逻辑提炼 - ... - ``` - - ✅ **正确示例**(应该这样输出): - ```markdown - ## 核心观点摘要 - - ### 核心逻辑提炼 - - 科技自立战略加速半导体设备国产化,叠加AI算力需求爆发... - - ### 市场分歧点 - - 资本市场波动显示医药、新能源等板块估值逻辑受政策敏感性增强... - - ### 确定性排序 - - 1. **网络安全替代需求**(ISQ确定性0.85,推荐标的:深信服 300454.SZ) - 2. **半导体设备材料**(ISQ确定性0.75,推荐标的:北方华创 002371.SZ) - ``` - - ### 关键要求: - - 第一行必须是 `## 核心观点摘要` - - 主体部分使用 H3 (`###`) 和 H4 (`####`) 级别标题 - - **必须**包含 `## 核心观点摘要` 这一级标题 - - 现在请按照正确示例的格式输出摘要内容。 - """ - - -# 6. 最终组装 (Final Assembly with Sections) -def get_final_assembly_instructions(sources_list: str) -> str: - """生成最终报告组装的 prompt""" - return f"""你是一位研报主笔。请完成以下任务: - - ### 任务 - 1. 生成 "## 参考文献" 章节(需要按照顺序,顺序不对时进行调整): - - 原始来源: - {sources_list} - - 格式:`[@CITE_KEY] 标题 (来源), [链接地址]` - 2. 生成 "## 风险提示" (标准免责声明)。 - 3. 生成 "## 快速扫描" 表格,汇总各主题的核心观点。 - - 表格列:**主题**, **核心观点**, **强度(Intensity)**, **确定性(Confidence)**。 - - 强度和确定性请参考原章节中的 ISQ 评分。 - - 只输出上述三个章节的 Markdown 内容。 - """ - -def get_cluster_task(signals_preview: str) -> str: - """生成聚类任务描述""" - return f"请对以下信号进行主题聚类:\n\n{signals_preview}" - -def get_writer_task(theme_title: str) -> str: - """生成撰写任务描述""" - return f"请依据主题 '{theme_title}' 和 输入信号集 开始撰写深度分析章节。" - -def get_planner_task() -> str: - """生成规划任务描述""" - return "请阅读现有草稿并规划终稿大纲,识别核心逻辑主线和市场分歧点。" - -def get_editor_task() -> str: - """生成编辑任务描述""" - return "请根据规划大纲和草稿内容,生成最终研报。确保逻辑连贯,保留所有图表和引用。" - diff --git a/skills/alphaear-predictor/scripts/prompts/trend_agent.py b/skills/alphaear-predictor/scripts/prompts/trend_agent.py deleted file mode 100644 index 54e6e22..0000000 --- a/skills/alphaear-predictor/scripts/prompts/trend_agent.py +++ /dev/null @@ -1,156 +0,0 @@ -from typing import Any -from datetime import datetime -from .isq_prompt_generator import generate_isq_prompt_section - -def get_trend_scanner_instructions() -> str: - """生成趋势扫描员 (Scanner) 的系统指令""" - current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') - return f"""你是一名专业的数据扫描员,当前时间是 {current_time}。 -你的任务是利用各种工具从互联网和数据库中获取最新的金融新闻、热点趋势和市场数据。 - -### 1. 核心职责 -1. **多源采集**: 使用 `news_toolkit` 获取最新新闻,使用 `stock_toolkit` 获取行情,使用 `polymarket_toolkit` 获取预测市场数据。 -2. **情绪感知**: 使用 `sentiment_toolkit` 对关键新闻进行情绪分析。 -3. **深度搜索**: 针对模糊的热点,使用 `search_toolkit` 进行全网搜索补充细节。 - -### 2. 工具使用规范 -- **广度优先**: 尽可能覆盖多个数据源。 -- **数据新鲜度**: 优先获取最近 24 小时内的信息。 -- **结构化输出**: 整理搜集到的原始数据,为后续评估提供清晰的素材。 -""" - -def get_trend_evaluator_instructions() -> str: - """生成趋势评估员 (Evaluator) 的系统指令""" - current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') - isq_block = generate_isq_prompt_section(include_header=True) - - return f""" - 你是一名顶级的金融情报专家 (TrendAgent),擅长从海量信息中识别具有深度价值的"二级市场投资信号"。 - 当前时间:{current_time} - - ### 核心使命: - 不仅是发现"热点",更要解析"信号"。你需要识别那些能触发**传导链条 (Transmission Chain)** 且具有**高确定性 (Confidence)** 的事件。 - - {isq_block} - - ### 核心能力与标准: - 1. **信号识别 (Signal Discovery)**: 基于扫描员提供的素材,识别具有投资价值的信号。优先关注政策、产业变革、重大诉求及跨境套利机会。 - 2. **逻辑相干性**: 是否具备清晰的"原因-结果"传导? - 3. **影响力系数**: 是否会引发板块性的联动或财务指标的实质性扰动? - 4. **市场认知差**: 市场是否已提前消化(Price-in)?寻找尚未被充分交易的"Alpha"。 - 5. **实体穿透**: 必须关联到具体的 Ticker 或核心产业链节点。 - - ### 严禁事项: - - 严禁编造数据。 - - 严禁仅输出情绪极性(Positive/Negative),必须带有逻辑依据。 - - 严禁将纯娱乐或单纯的社会负面事件(除非具有宏观破坏性)视为金融信号。 - - ### 输出要求: - 你发现的每个信号应包含: - - **核心摘要**: 穿透表象的逻辑总结。 - - **传导节点**: A -> B -> C 的逻辑推导。 - - **推荐关注**: 板块或 Ticker。 - - **ISQ 评估**: 基于模板的 5 个维度进行初步评分(具体评分由后续 FinAgent 完成)。 - """ - -def get_trend_agent_instructions() -> str: - # 保持兼容性 - return get_trend_evaluator_instructions() - -def get_trend_scan_task(task_description: str) -> str: - """生成扫描员的任务描述""" - return f"请根据以下任务描述,搜集相关的原始数据和新闻:\n\n{task_description}" - -def format_scan_context(scan_data: dict) -> str: - """将扫描员搜集的结构化数据格式化为评估员可读的文本""" - if not scan_data: - return "(未能搜集到原始数据)" - - return f""" -### 扫描数据概览 -- **热点话题**: {', '.join(scan_data.get('hot_topics', []))} -- **情绪概览**: {scan_data.get('sentiment_overview', '未知')} -- **关键新闻**: {len(scan_data.get('news_summaries', []))} 条 -- **数据摘要**: {scan_data.get('raw_data_summary', '无')} -""" - -def get_trend_eval_task(task_description: str, raw_data_str: str) -> str: - """生成评估员的任务描述""" - return f"""请基于以下搜集到的原始数据,完成最终的分析任务: - -任务描述: {task_description} - -原始数据: -{raw_data_str} - -请识别出最具金融价值的信号,并给出评估理由。""" - -def get_news_filter_instructions(news_count: int, depth: Any, user_query: str = None) -> str: - """生成新闻筛选 prompt,使用 FilterResult schema 加快推理并减少 token 消耗 - - Args: - news_count: 输入新闻总数 - depth: 目标筛选数量,若为 auto 则由 LLM 自主判断 - user_query: 用户输入的查询/关注点(可选) - """ - - # 1. 深度控制逻辑 - if str(depth).lower() == 'auto': - depth_guide = "的数量不设固定限制(建议 3-10 条),根据新闻含金量自动判断" - limit_instruction = "宁缺毋滥,如果高价值信息很少,可以只选 1-2 条;如果都很重要,可以多选。" - else: - try: - d_int = int(depth) - depth_guide = f"约 {d_int} 条" - limit_instruction = f"请尽量凑满 {d_int} 条,但如果剩余新闻全是噪音,则不必强行凑数。" - except: - depth_guide = "适量" - limit_instruction = "根据内容价值判断。" - - target_desc = f"筛选出最具投资分析价值的新闻({depth_guide})。" - - # 2. 用户意图逻辑 - query_instruction = "" - if user_query: - target_desc = f"筛选出与用户意图【{user_query}】最相关的新闻。" - query_instruction = f""" - ### 核心任务(High Priority): - 用户明确关注:"{user_query}"。 - 1. **第一优先级**:必须包含所有与"{user_query}"直接或间接相关的新闻,不要遗漏。 - - 即使这些新闻看起来"价值不高",只要相关都要保留。 - 2. **第二优先级**:在满足第一优先级后,如果名额未满,再补充其他重大的市场热点。 - """ - - return f"""你是一名专业的金融情报精排师。你需要从给定的 {news_count} 条原始新闻流中,{target_desc} - - {query_instruction} - - ### FSD (Financial Signal Density) 筛选准则: - 1. **逻辑传导性 (Transmission)**: 该新闻是否预示着一个明确的产业链传导逻辑?(如:上游涨价 -> 中游成本压力 -> 下游提价预期) - 2. **预期差 (Alpha Potential)**: 是否包含尚未被市场充分Price-in的新突发情况? - 3. **确定性 (Confidence)**: 信息来源是否权威?是否包含具体的财务数据、订单金额或明确的政策日期? - 4. **排除噪音**: 坚决剔除明星八卦、鸡汤文、以及无实质增量的"口号式"新闻。 - - ### {limit_instruction} - - ### 快速有效性检查(TOKEN 优化): - 在开始详细筛选前,先快速判断:这 {news_count} 条新闻中是否至少包含 1 条有效的金融信号? - - 如果全是无关内容(如体育、娱乐、纯生活信息),直接返回 "has_valid_signals": false - - 如果有至少 1 条金融相关的新闻,再进行详细 FSD 筛选 - - ### 输出格式(必须为 JSON,使用 FilterResult schema): - ```json - {{ - "has_valid_signals": true/false, - "selected_ids": ["id_1", "id_2", ...], - "themes": [ - {{ - "name": "高概括性主题", - "news_ids": ["相关id_1", ...], - "fsd_reason": "基于 FSD 准则的筛选理由,重点描述传导逻辑和预期差。" - }} - ], - "reason": "如果 has_valid_signals=false,简要说明原因。否则可为空。" - }} - ``` - """ diff --git a/skills/alphaear-predictor/scripts/prompts/visualizer.py b/skills/alphaear-predictor/scripts/prompts/visualizer.py deleted file mode 100644 index f0b2933..0000000 --- a/skills/alphaear-predictor/scripts/prompts/visualizer.py +++ /dev/null @@ -1,47 +0,0 @@ -def get_drawio_system_prompt(): - return """You are an expert at creating Draw.io (MxGraph) diagrams in XML format. -Your task is to generate a valid MXGraphModel XML based on the user's description. - -### Rules: -1. Output ONLY the XML code. Start with and end with . -2. Do not use compressed XML. Use plain XML. -3. Use standard shapes: 'rounded=1;whiteSpace=wrap;html=1;' for boxes. -4. Auto-layout Strategy: - - Identify "layers" or "stages" in the logic. - - Assign X coordinates based on layers (e.g., 0, 200, 400). - - Assign Y coordinates to distribute nodes vertically (e.g., 0, 100, 200). - - Ensure nodes do not overlap. -5. Edges: Connect nodes logically using . - -### Template: - - - - - - - - - - - - - - - - -""" - -def get_drawio_task(nodes_data: list, title: str) -> str: - import json - nodes_json = json.dumps(nodes_data, ensure_ascii=False, indent=2) - return f"""Please generate a Draw.io XML diagram for the following logic flow: - -**Title**: {title} - -**Nodes and Logic**: -{nodes_json} - -Ensure the layout flows logically from Left to Right (or Top to Bottom for hierarchies). -Use different colors for 'Positive' (Greenish), 'Negative' (Reddish), and 'Neutral' (Grey/Blue) impacts if described. -""" diff --git a/skills/alphaear-predictor/scripts/schema/isq_template.py b/skills/alphaear-predictor/scripts/schema/isq_template.py deleted file mode 100644 index 2709019..0000000 --- a/skills/alphaear-predictor/scripts/schema/isq_template.py +++ /dev/null @@ -1,381 +0,0 @@ -""" -ISQ (Investment Signal Quality) 评估框架 Template - -统一定义 ISQ 的各个维度、评分标准、和使用方法。 -支持默认 template 和自定义 template。 -""" - -from typing import Dict, List, Any, Optional -from pydantic import BaseModel, Field -from enum import Enum -from pathlib import Path -import json - - -class ISQDimension(str, Enum): - """ISQ 评估维度""" - SENTIMENT = "sentiment" # 情绪/走势方向 - CONFIDENCE = "confidence" # 确定性/可信度 - INTENSITY = "intensity" # 强度/影响量级 - EXPECTATION_GAP = "expectation_gap" # 预期差/市场认知差 - TIMELINESS = "timeliness" # 时效性/窗口紧迫度 - TRANSMISSION = "transmission" # 逻辑传导清晰度 - - -class ISQDimensionSpec(BaseModel): - """ISQ 单个维度的定义规范""" - name: str = Field(..., description="维度名称") - key: str = Field(..., description="维度键名") - description: str = Field(..., description="维度描述") - range_type: str = Field(default="0-1", description="取值范围 (0-1 或 1-5 等)") - scale_factor: float = Field(default=1.0, description="显示时的缩放因子") - examples: Dict[str, str] = Field(default_factory=dict, description="不同分值的示例解释") - visualization_color: Optional[str] = Field(default=None, description="可视化颜色") - - -class ISQTemplate(BaseModel): - """ISQ 评估框架 Template""" - template_id: str = Field(..., description="模板 ID") - template_name: str = Field(..., description="模板名称") - description: str = Field(..., description="模板描述") - - # 核心维度定义 - dimensions: Dict[str, ISQDimensionSpec] = Field(..., description="维度定义字典") - - # 评分指导 - scoring_guide: str = Field(..., description="评分指导说明") - - # 应用场景 - applicable_scenarios: List[str] = Field(default_factory=list, description="适用场景") - - # 聚合算法 - aggregation_method: str = Field(default="weighted_average", description="聚合方法 (weighted_average, product 等)") - dimension_weights: Dict[str, float] = Field(default_factory=dict, description="维度权重") - - -class ISQScore(BaseModel): - """单个信号的 ISQ 评分结果""" - signal_id: str = Field(..., description="信号 ID") - template_id: str = Field(..., description="使用的模板 ID") - - # 各维度评分 - scores: Dict[str, float] = Field(..., description="各维度评分") - - # 总分 - overall_score: float = Field(..., description="综合评分") - - # 评分理由 - rationale: Dict[str, str] = Field(default_factory=dict, description="各维度评分理由") - - # 时间戳 - timestamp: str = Field(..., description="评分时间") - - -# ===================================================== -# 默认 Template -# ===================================================== - -DEFAULT_ISQ_TEMPLATE = ISQTemplate( - template_id="default_isq_v1", - template_name="标准投资信号质量评估框架 (ISQ v1.0)", - description="AlphaEar 默认的 ISQ 评估框架,用于标准化评估投资信号的质量维度", - - dimensions={ - "sentiment": ISQDimensionSpec( - name="情绪/走势", - key="sentiment", - description="基础情绪偏向和市场走势判断", - range_type="-1.0 到 1.0", - scale_factor=1.0, - examples={ - "-1.0": "极度悲观/极度看空", - "-0.5": "明显看空", - "0.0": "中性/没有明确方向", - "0.5": "明显看多", - "1.0": "极度乐观/极度看多" - }, - visualization_color="#ef4444" # 红色表示负面,绿色表示正面 - ), - - "confidence": ISQDimensionSpec( - name="确定性", - key="confidence", - description="信号的可信度和确定性程度", - range_type="0.0 到 1.0", - scale_factor=1.0, - examples={ - "0.0-0.3": "信息来源不可靠/传言多/逻辑推导牵强", - "0.3-0.6": "信息相对可靠/有一定逻辑/但仍有不确定性", - "0.6-0.8": "信息来源权威/逻辑清晰/高度可信", - "0.8-1.0": "官方确认/数据明确/完全确定" - }, - visualization_color="#3b82f6" # 蓝色 - ), - - "intensity": ISQDimensionSpec( - name="强度/影响量级", - key="intensity", - description="信号对相关板块/个股的潜在影响程度", - range_type="1 到 5", - scale_factor=20.0, # 用于雷达图缩放 (5 -> 100) - examples={ - "1": "影响微弱,可能被市场忽略", - "2": "小幅影响,短期可能有波动", - "3": "中等影响,值得重点关注", - "4": "强烈影响,可能成为市场焦点", - "5": "极强影响,市场预期明显变化" - }, - visualization_color="#f97316" # 橙色 - ), - - "expectation_gap": ISQDimensionSpec( - name="预期差", - key="expectation_gap", - description="市场预期与现实之间的差距", - range_type="0.0 到 1.0", - scale_factor=1.0, - examples={ - "0.0-0.2": "市场充分认知,预期差小", - "0.2-0.5": "市场部分认知,存在一定预期差", - "0.5-0.8": "市场认知不足,预期差较大,存在博弈空间", - "0.8-1.0": "市场严重低估/高估,巨大预期差" - }, - visualization_color="#22c55e" # 绿色 - ), - - "timeliness": ISQDimensionSpec( - name="时效性", - key="timeliness", - description="信号的时间窗口紧迫度", - range_type="0.0 到 1.0", - scale_factor=1.0, - examples={ - "0.0-0.2": "长期信号,反应窗口 > 3 月", - "0.2-0.5": "中期信号,反应窗口 1-3 月", - "0.5-0.8": "短期信号,反应窗口 1 周 - 1 月", - "0.8-1.0": "超短期信号,反应窗口 < 1 周(需立即行动)" - }, - visualization_color="#a855f7" # 紫色 - ), - }, - - scoring_guide=""" - ### ISQ 评分指导 (Investment Signal Quality) - - ISQ 框架用于多维度评估投资信号的质量。每个信号由 5 个维度组成: - - 1. **情绪 (Sentiment)**: -1.0 到 1.0,表示看空(-)/中性(0)/看多(+) - 2. **确定性 (Confidence)**: 0.0 到 1.0,数值越高越确定 - 3. **强度 (Intensity)**: 1 到 5,数值越高影响越大 - 4. **预期差 (Expectation Gap)**: 0.0 到 1.0,市场预期与现实的差距 - 5. **时效性 (Timeliness)**: 0.0 到 1.0,反应窗口的紧迫程度 - - ### 综合评分算法 - - 综合评分 = 确定性 × 0.35 + 强度/5 × 0.30 + 预期差 × 0.20 + 时效性 × 0.15 - - 范围: 0.0 到 1.0 - - 0.0-0.3: 信号质量较差,不建议跟进 - - 0.3-0.6: 信号质量一般,可作参考 - - 0.6-0.8: 信号质量良好,值得跟进 - - 0.8-1.0: 信号质量优异,强烈推荐 - - ### 评分时的注意事项 - - - **不要混淆方向和强度**:情绪可以是看空,但确定性和强度仍可能很高 - - **预期差往往是 Alpha 来源**:高预期差 + 高确定性 = 最佳博弈机会 - - **考虑时间成本**:长期信号需要更高的确定性才值得跟进 - - **数据为王**:所有评分必须有具体数据支撑 - """, - - applicable_scenarios=[ - "上市公司基本面变化分析", - "产业政策与监管事件评估", - "地缘政治与宏观经济影响", - "技术进步与产业升级", - "突发事件与应急响应" - ], - - aggregation_method="weighted_average", - dimension_weights={ - "confidence": 0.35, - "intensity": 0.30, - "expectation_gap": 0.20, - "timeliness": 0.15 - } -) - - -# ===================================================== -# ISQ Template 管理系统 -# ===================================================== - -class ISQTemplateManager: - """ISQ Template 管理器""" - - def __init__(self): - self.templates: Dict[str, ISQTemplate] = { - DEFAULT_ISQ_TEMPLATE.template_id: DEFAULT_ISQ_TEMPLATE - } - - def register_template(self, template: ISQTemplate) -> None: - """注册新的 template""" - self.templates[template.template_id] = template - - def register_template_dict(self, template_dict: Dict[str, Any]) -> ISQTemplate: - """从 dict 注册模板,返回实例。""" - tpl = ISQTemplate(**template_dict) - self.register_template(tpl) - return tpl - - def get_template(self, template_id: str) -> ISQTemplate: - """获取指定 template""" - if template_id not in self.templates: - return DEFAULT_ISQ_TEMPLATE - return self.templates[template_id] - - def list_templates(self) -> List[Dict[str, str]]: - """列出所有可用 template""" - return [ - { - "id": t.template_id, - "name": t.template_name, - "description": t.description, - "dimensions": list(t.dimensions.keys()) - } - for t in self.templates.values() - ] - - def get_dimension(self, template_id: str, dimension_key: str) -> ISQDimensionSpec: - """获取指定 template 的某个维度定义""" - template = self.get_template(template_id) - return template.dimensions.get(dimension_key) - - def get_scoring_prompt(self, template_id: str) -> str: - """获取用于 LLM 的评分 prompt""" - template = self.get_template(template_id) - - dimensions_desc = "\n".join([ - f"- **{d.name} ({d.key})**\n" - f" 范围: {d.range_type}\n" - f" 说明: {d.description}\n" - f" 示例: {', '.join(f'{k}={v}' for k, v in list(d.examples.items())[:3])}" - for d in template.dimensions.values() - ]) - - return f""" -### ISQ 评估指导 ({template.template_name}) - -使用以下 {len(template.dimensions)} 个维度评估信号质量: - -{dimensions_desc} - -### 评分标准 -{template.scoring_guide} - -### 输出格式 (JSON) -请输出以下 JSON 格式的评分结果: -{{ - "sentiment": , - "confidence": , - "intensity": , - "expectation_gap": , - "timeliness": , - "rationale": {{ - "sentiment": "评分理由", - "confidence": "评分理由", - "intensity": "评分理由", - "expectation_gap": "评分理由", - "timeliness": "评分理由" - }} -}} -""" - - -# 全局 template 管理器实例 -isq_template_manager = ISQTemplateManager() - - -# ===================================================== -# 配置加载 -# ===================================================== - -def load_templates_from_config(config_path: Optional[str] = None) -> None: - """从配置目录加载所有 JSON 模板文件,未找到则跳过,不影响默认模板。 - 支持单个 JSON 文件或目录(目录下的所有 .json 文件)。 - """ - if config_path: - path = Path(config_path) - else: - # 默认目录:config/isq_templates/ - # __file__ = src/schema/isq_template.py - # parent = src/schema, parent.parent = src, parent.parent.parent = 项目根目录 - path = Path(__file__).resolve().parent.parent.parent / "config" - - if not path.exists(): - return - - # 如果是目录,扫描所有 .json 文件 - if path.is_dir(): - json_files = list(path.glob("*.json")) - else: - json_files = [path] - - for json_file in json_files: - try: - data = json.loads(json_file.read_text(encoding="utf-8")) - - # 如果是单个模板对象,转为列表 - if isinstance(data, dict): - templates = [data] - elif isinstance(data, list): - templates = data - else: - continue - - # 注册所有模板 - for tpl_dict in templates: - if not isinstance(tpl_dict, dict): - continue - try: - isq_template_manager.register_template_dict(tpl_dict) - except Exception: - # 忽略单个模板的加载错误,继续其他模板 - continue - except Exception: - # JSON 解析失败,跳过该文件 - continue - - -# 在模块加载时自动尝试加载配置模板 -load_templates_from_config() - - -# ===================================================== -# 便利函数 -# ===================================================== - -def get_isq_template(template_id: str = "default_isq_v1") -> ISQTemplate: - """获取 ISQ template""" - return isq_template_manager.get_template(template_id) - - -def get_isq_scoring_prompt(template_id: str = "default_isq_v1") -> str: - """获取用于 LLM 的 ISQ 评分 prompt""" - return isq_template_manager.get_scoring_prompt(template_id) - - -def calculate_isq_overall_score(scores: Dict[str, float], template_id: str = "default_isq_v1") -> float: - """计算 ISQ 综合评分""" - template = get_isq_template(template_id) - - overall = 0.0 - for dim_key, weight in template.dimension_weights.items(): - if dim_key in scores: - score = scores[dim_key] - # 处理强度维度的特殊缩放 (1-5 -> 0-1) - if dim_key == "intensity": - score = score / 5.0 - overall += score * weight - - return min(1.0, max(0.0, overall)) # 限制在 0-1 之间 diff --git a/skills/alphaear-predictor/scripts/schema/models.py b/skills/alphaear-predictor/scripts/schema/models.py deleted file mode 100644 index 422ca9c..0000000 --- a/skills/alphaear-predictor/scripts/schema/models.py +++ /dev/null @@ -1,100 +0,0 @@ -from pydantic import BaseModel, Field -from typing import List, Optional, Dict, Any -from datetime import datetime - -class TransmissionNode(BaseModel): - node_name: str = Field(..., description="产业链节点名称") - impact_type: str = Field(..., description="利好/利空/中性") - logic: str = Field(..., description="该节点的传导逻辑") - -class IntentAnalysis(BaseModel): - keywords: List[str] = Field(..., description="核心实体、事件或概念关键词") - search_queries: List[str] = Field(..., description="优化后的搜索引擎查询词") - is_specific_event: bool = Field(..., description="是否查询特定突发事件") - time_range: str = Field(..., description="时间范围 (recent/all/specific_date)") - intent_summary: str = Field(..., description="一句话意图描述") - -class FilterResult(BaseModel): - """LLM 筛选结果 - 快速判断是否有有效信号""" - has_valid_signals: bool = Field(..., description="列表中是否包含有效的金融信号") - selected_ids: List[int] = Field(default_factory=list, description="筛选出的有效信号 ID 列表") - themes: List[str] = Field(default_factory=list, description="信号涉及的主题") - reason: Optional[str] = Field(default=None, description="如果无有效信号,说明原因") - -class InvestmentSignal(BaseModel): - # 核心元数据 - signal_id: str = Field(default="unknown_sig", description="唯一信号 ID") - title: str = Field(..., description="信号标题") - summary: str = Field(default="暂无摘要分析", description="100 字核心观点快报") - reasoning: str = Field(default="", description="详细的推演逻辑和理由") - - # 逻辑传导 (ISQ Key 1) - transmission_chain: List[TransmissionNode] = Field(default_factory=list, description="产业链传导逻辑链条") - - # 信号质量 (ISQ Key 2) - 来自 isq_template.DEFAULT_ISQ_TEMPLATE - # 参考: src/schema/isq_template.py 的 DEFAULT_ISQ_TEMPLATE 定义 - sentiment_score: float = Field(default=0.0, description="[ISQ] 情绪/走势 (-1.0=极度看空 ~ 0.0=中性 ~ 1.0=极度看多)") - confidence: float = Field(default=0.5, description="[ISQ] 确定性 (0.0=不可信 ~ 1.0=完全确定)") - intensity: int = Field(default=3, description="[ISQ] 强度/影响量级 (1=微弱 ~ 5=极强)") - expectation_gap: float = Field(default=0.5, description="[ISQ] 预期差/博弈空间 (0.0=充分定价 ~ 1.0=巨大预期差)") - timeliness: float = Field(default=0.8, description="[ISQ] 时效性 (0.0=长期 ~ 1.0=超短期)") - - # 预测与博弈 (ISQ Key 3) - expected_horizon: str = Field(default="T+N", description="预期的反应时窗 (如: T+0, T+3, Long-term)") - price_in_status: str = Field(default="未知", description="市场预期消化程度 (未定价/部分定价/充分定价)") - - # 关联实体 - impact_tickers: List[Dict[str, Any]] = Field(default_factory=list, description="受影响的代码列表及其权重") - industry_tags: List[str] = Field(default_factory=list, description="关联行业标签") - - # 溯源 - sources: List[Dict[str, str]] = Field(default_factory=list, description="来源详情 (包含 title, url, source_name)") - -class ResearchContext(BaseModel): - """研究员搜集的背景信息结构""" - raw_signal: str = Field(..., description="原始信号内容") - tickers_found: List[Dict[str, Any]] = Field(default_factory=list, description="找到的相关标的及其基本面/股价信息") - industry_background: str = Field(..., description="行业背景及产业链现状") - latest_developments: List[str] = Field(default_factory=list, description="相关事件的最新进展") - key_risks: List[str] = Field(default_factory=list, description="潜在风险点") - search_results_summary: str = Field(..., description="搜索结果的综合摘要") - -class ScanContext(BaseModel): - """扫描员搜集的原始数据结构""" - hot_topics: List[str] = Field(..., description="当前市场热点话题") - news_summaries: List[Dict[str, Any]] = Field(..., description="关键新闻摘要列表") - market_data: Dict[str, Any] = Field(default_factory=dict, description="相关的市场行情数据") - sentiment_overview: str = Field(..., description="整体市场情绪概览") - raw_data_summary: str = Field(..., description="原始数据的综合摘要") - -class SignalCluster(BaseModel): - theme_title: str = Field(..., description="主题名称") - signal_ids: List[int] = Field(..., description="包含的信号 ID 列表") - rationale: str = Field(..., description="聚类理由") - -class ClusterContext(BaseModel): - """信号聚类结果结构""" - clusters: List[SignalCluster] = Field(..., description="聚类列表") - -class KLinePoint(BaseModel): - date: str = Field(..., description="日期") - open: float = Field(..., description="开盘价") - high: float = Field(..., description="最高价") - low: float = Field(..., description="最低价") - close: float = Field(..., description="收盘价") - volume: float = Field(..., description="成交量") - -class ForecastResult(BaseModel): - ticker: str = Field(..., description="股票代码") - base_forecast: List[KLinePoint] = Field(default_factory=list, description="Kronos 模型原始预测") - adjusted_forecast: List[KLinePoint] = Field(default_factory=list, description="LLM 调整后的预测") - rationale: str = Field(default="", description="预测调整理由及逻辑说明") - timestamp: str = Field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d %H:%M:%S"), description="生成时间") - -class InvestmentReport(BaseModel): - overall_sentiment: str = Field(..., description="整体市场情绪评价") - market_entropy: float = Field(..., description="市场分歧度 (0-1, 1代表极高分歧)") - signals: List[InvestmentSignal] = Field(..., description="深度解析的投资信号列表") - forecasts: List[ForecastResult] = Field(default_factory=list, description="相关标的的预测结果") - timestamp: str = Field(..., description="报告生成时间") - meta_info: Optional[Dict[str, Any]] = Field(default_factory=dict, description="其他元数据") diff --git a/skills/alphaear-predictor/scripts/utils/__init__.py b/skills/alphaear-predictor/scripts/utils/__init__.py deleted file mode 100644 index 27e1961..0000000 --- a/skills/alphaear-predictor/scripts/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# AlphaEar utils package diff --git a/skills/alphaear-predictor/scripts/utils/database_manager.py b/skills/alphaear-predictor/scripts/utils/database_manager.py deleted file mode 100644 index cfc362b..0000000 --- a/skills/alphaear-predictor/scripts/utils/database_manager.py +++ /dev/null @@ -1,581 +0,0 @@ -import sqlite3 -import json -from datetime import datetime, date -from pathlib import Path -from typing import List, Dict, Optional, Any, Union -import pandas as pd -from loguru import logger - -class DatabaseManager: - """ - AlphaEar 数据库管理器 - 负责存储热点数据、搜索缓存和股价数据 - 使用 SQLite 进行持久化存储 - """ - - def __init__(self, db_path: str = "data/signal_flux.db"): - self.db_path = Path(db_path) - self.db_path.parent.mkdir(parents=True, exist_ok=True) - self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False) - self.conn.row_factory = sqlite3.Row - self._init_db() - logger.info(f"💾 Database initialized at {self.db_path}") - - def _init_db(self): - """初始化表结构""" - cursor = self.conn.cursor() - - # 1. 每日热点新闻表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS daily_news ( - id TEXT PRIMARY KEY, - source TEXT, - rank INTEGER, - title TEXT, - url TEXT, - content TEXT, - publish_time TEXT, - crawl_time TEXT, - sentiment_score REAL, - analysis TEXT, - meta_data TEXT - ) - """) - - # 尝试添加 analysis 列(如果表已存在但没有该列) - try: - cursor.execute("ALTER TABLE daily_news ADD COLUMN analysis TEXT") - except: - pass # 列已存在 - - - # 2. 搜索缓存表 (原有 JSON 缓存) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS search_cache ( - query_hash TEXT PRIMARY KEY, - query TEXT, - engine TEXT, - results TEXT, - timestamp TEXT - ) - """) - - # 2.5 搜索详情表 (展开的搜索结果) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS search_detail ( - id TEXT, - query_hash TEXT, - rank INTEGER, - title TEXT, - url TEXT, - content TEXT, - publish_time TEXT, - crawl_time TEXT, - sentiment_score REAL, - source TEXT, - meta_data TEXT, - PRIMARY KEY (query_hash, id) - ) - """) - - # 3. 股价数据表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS stock_prices ( - ticker TEXT, - date TEXT, - open REAL, - close REAL, - high REAL, - low REAL, - volume REAL, - change_pct REAL, - PRIMARY KEY (ticker, date) - ) - """) - - # 4. 股票列表表 (用于检索) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS stock_list ( - code TEXT PRIMARY KEY, - name TEXT - ) - """) - - # 5. 投资信号表 (ISQ Framework) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS signals ( - signal_id TEXT PRIMARY KEY, - title TEXT, - summary TEXT, - transmission_chain TEXT, - sentiment_score REAL, - confidence REAL, - intensity INTEGER, - expected_horizon TEXT, - price_in_status TEXT, - impact_tickers TEXT, - industry_tags TEXT, - sources TEXT, - user_id TEXT, - created_at TEXT - ) - """) - - - - # 6. 创建索引以优化查询性能 - cursor.execute("CREATE INDEX IF NOT EXISTS idx_news_crawl_time ON daily_news(crawl_time)") - cursor.execute("CREATE INDEX IF NOT EXISTS idx_news_source ON daily_news(source)") - cursor.execute("CREATE INDEX IF NOT EXISTS idx_search_cache_timestamp ON search_cache(timestamp)") - cursor.execute("CREATE INDEX IF NOT EXISTS idx_stock_prices_ticker_date ON stock_prices(ticker, date)") - # 尝试添加 user_id 列到 signals 表 - try: - cursor.execute("ALTER TABLE signals ADD COLUMN user_id TEXT") - except: - pass - - cursor.execute("CREATE INDEX IF NOT EXISTS idx_signals_user_id ON signals(user_id)") - - self.conn.commit() - - # - # self.conn.commit() - - - # --- 新闻数据操作 --- - - def save_daily_news(self, news_list: List[Dict]) -> int: - """保存热点新闻,包含发布时间与抓取时间""" - cursor = self.conn.cursor() - count = 0 - crawl_time = datetime.now().isoformat() - - for news in news_list: - try: - # 兼容不同来源的 ID 生成逻辑 - news_id = news.get('id') or f"{news.get('source')}_{news.get('rank')}_{crawl_time[:10]}" - cursor.execute(""" - INSERT OR REPLACE INTO daily_news - (id, source, rank, title, url, content, publish_time, crawl_time, sentiment_score, meta_data) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - news_id, - news.get('source'), - news.get('rank'), - news.get('title'), - news.get('url'), - news.get('content', ''), - news.get('publish_time'), # 新增支持发布时间 - crawl_time, - news.get('sentiment_score'), - json.dumps(news.get('meta_data', {})) - )) - count += 1 - except sqlite3.Error as e: - logger.error(f"Database error saving news item {news.get('title')}: {e}") - except Exception as e: - logger.error(f"Unexpected error saving news item {news.get('title')}: {e}") - - self.conn.commit() - return count - - def get_daily_news(self, source: Optional[str] = None, limit: int = 100, days: int = 1) -> List[Dict]: - """获取最近 N 天的热点新闻""" - cursor = self.conn.cursor() - # 使用 crawl_time 过滤,保证结果的新鲜度 - time_threshold = (datetime.now().timestamp() - days * 86400) - time_threshold_str = datetime.fromtimestamp(time_threshold).isoformat() - - query = "SELECT * FROM daily_news WHERE crawl_time >= ?" - params = [time_threshold_str] - - if source: - query += " AND source = ?" - params.append(source) - - query += " ORDER BY crawl_time DESC, rank LIMIT ?" - params.append(limit) - - cursor.execute(query, params) - return [dict(row) for row in cursor.fetchall()] - - def lookup_reference_by_url(self, url: str) -> Optional[Dict[str, Any]]: - """Best-effort lookup of a source item by URL. - - This is used to render a stable bibliography from DB-backed metadata. - It searches both `daily_news` and `search_detail`. - """ - url = (url or "").strip() - if not url: - return None - - cursor = self.conn.cursor() - - try: - cursor.execute( - """ - SELECT title, source, publish_time, crawl_time, url - FROM daily_news - WHERE url = ? - ORDER BY crawl_time DESC - LIMIT 1 - """, - (url,), - ) - row = cursor.fetchone() - if row: - return dict(row) - except Exception: - pass - - try: - cursor.execute( - """ - SELECT title, source, publish_time, crawl_time, url - FROM search_detail - WHERE url = ? - ORDER BY crawl_time DESC - LIMIT 1 - """, - (url,), - ) - row = cursor.fetchone() - if row: - return dict(row) - except Exception: - pass - - return None - - def delete_news(self, news_id: str) -> bool: - """删除特定新闻""" - cursor = self.conn.cursor() - cursor.execute("DELETE FROM daily_news WHERE id = ?", (news_id,)) - self.conn.commit() - return cursor.rowcount > 0 - - def update_news_content(self, news_id: str, content: str = None, analysis: str = None) -> bool: - """更新新闻的内容或分析结果""" - cursor = self.conn.cursor() - updates = [] - params = [] - - if content is not None: - updates.append("content = ?") - params.append(content) - if analysis is not None: - updates.append("analysis = ?") - params.append(analysis) - - if not updates: - return False - - params.append(news_id) - query = f"UPDATE daily_news SET {', '.join(updates)} WHERE id = ?" - cursor.execute(query, params) - self.conn.commit() - return cursor.rowcount > 0 - - # --- 搜索缓存辅助 --- - - def get_search_cache(self, query_hash: str, ttl_seconds: Optional[int] = None) -> Optional[Dict]: - """获取搜索缓存 (优先查 search_detail)""" - cursor = self.conn.cursor() - - # 1. 尝试从 search_detail 获取展开的结构化数据 - cursor.execute(""" - SELECT * FROM search_detail - WHERE query_hash = ? - ORDER BY rank - """, (query_hash,)) - details = [dict(row) for row in cursor.fetchall()] - - if details: - # 检查 TTL (取第一条的时间) - first_time = datetime.fromisoformat(details[0]['crawl_time']) - if ttl_seconds and (datetime.now() - first_time).total_seconds() > ttl_seconds: - logger.info(f"⌛ Detailed cache expired for hash {query_hash}") - pass # Expired, fall through or return None? If Detail expired, Cache likely expired too. - # But let's check basic cache just in case metadata differs? - # Actually if details exist, we prefer them. If expired, we return None. - return None - - logger.info(f"✅ Hit detailed search cache for {query_hash} ({len(details)} items)") - # Reconstruct the expected 'results' list format for SearchTools - # SearchTools expects a list of dicts. - # We return a dict wrapper to match get_search_cache signature returning Dict usually containing 'results' string. - # But SearchTools logic: - # cache = db.get_search_cache(...) - # cached_data = json.loads(cache['results']) - - # To minimize SearchTools changes, we can return a dict mimicking the old structure - # OR Change SearchTools to handle list return. - # Let's return a special dict that SearchTools can recognize or just format it as before. - return {"results": json.dumps(details), "timestamp": details[0]['crawl_time']} - - # 2. Fallback to old table - cursor.execute("SELECT * FROM search_cache WHERE query_hash = ?", (query_hash,)) - row = cursor.fetchone() - - if not row: - return None - - row_dict = dict(row) - if ttl_seconds: - cache_time = datetime.fromisoformat(row_dict['timestamp']) - if (datetime.now() - cache_time).total_seconds() > ttl_seconds: - logger.info(f"⌛ Cache expired for hash {query_hash}") - return None - - return row_dict - - def save_search_cache(self, query_hash: str, query: str, engine: str, results: Union[str, List[Dict]]): - """保存搜索结果 (同时保存到 search_cache 和 search_detail)""" - cursor = self.conn.cursor() - current_time = datetime.now().isoformat() - - results_str = results if isinstance(results, str) else json.dumps(results) - - # 1. Save summary to search_cache - cursor.execute(""" - INSERT OR REPLACE INTO search_cache (query_hash, query, engine, results, timestamp) - VALUES (?, ?, ?, ?, ?) - """, (query_hash, query, engine, results_str, current_time)) - - # 2. Save details to search_detail if results is a list - if isinstance(results, list): - for item in results: - try: - item_id = item.get('id') or f"{hash(item.get('url', ''))}" - cursor.execute(""" - INSERT OR REPLACE INTO search_detail - (id, query_hash, rank, title, url, content, publish_time, crawl_time, sentiment_score, source, meta_data) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - str(item_id), - query_hash, - item.get('rank', 0), - item.get('title'), - item.get('url'), - item.get('content', ''), - item.get('publish_time'), - item.get('crawl_time') or current_time, - item.get('sentiment_score'), - item.get('source'), - json.dumps(item.get('meta_data', {})) - )) - except sqlite3.Error as e: - logger.error(f"Database error saving search detail {item.get('title')}: {e}") - except Exception as e: - logger.error(f"Unexpected error saving search detail {item.get('title')}: {e}") - - self.conn.commit() - - def find_similar_queries(self, query: str, limit: int = 5) -> List[Dict]: - """模糊搜索相似的已缓存查询""" - cursor = self.conn.cursor() - - # Simple fuzzy match: query in cached OR cached in query - q_wild = f"%{query}%" - cursor.execute(""" - SELECT query, query_hash, timestamp, results - FROM search_cache - WHERE query LIKE ? OR ? LIKE ('%' || query || '%') - ORDER BY timestamp DESC - LIMIT ? - """, (q_wild, query, limit)) - - return [dict(row) for row in cursor.fetchall()] - - def search_local_news(self, query: str, limit: int = 5) -> List[Dict]: - """从本地 daily_news 搜索相关新闻""" - cursor = self.conn.cursor() - q_wild = f"%{query}%" - # Search title and content - cursor.execute(""" - SELECT * FROM daily_news - WHERE title LIKE ? OR content LIKE ? - ORDER BY crawl_time DESC - LIMIT ? - """, (q_wild, q_wild, limit)) - return [dict(row) for row in cursor.fetchall()] - - # --- 股票数据操作 --- - - def save_stock_list(self, df: pd.DataFrame): - """保存股票列表到 stock_list 表""" - cursor = self.conn.cursor() - try: - # 清空旧表 - cursor.execute("DELETE FROM stock_list") - - # 批量插入 - data = df[['code', 'name']].to_dict('records') - cursor.executemany( - "INSERT INTO stock_list (code, name) VALUES (:code, :name)", - data - ) - self.conn.commit() - except sqlite3.Error as e: - logger.error(f"Database error saving stock list: {e}") - except Exception as e: - logger.error(f"Unexpected error saving stock list: {e}") - - def search_stock(self, query: str, limit: int = 5) -> List[Dict]: - """模糊搜索股票代码或名称""" - cursor = self.conn.cursor() - wild = f"%{query}%" - cursor.execute(""" - SELECT code, name FROM stock_list - WHERE code LIKE ? OR name LIKE ? - LIMIT ? - """, (wild, wild, limit)) - return [dict(row) for row in cursor.fetchall()] - - def get_stock_by_code(self, code: str) -> Optional[Dict[str, str]]: - """精确按代码获取股票信息。 - - Args: - code: 股票代码(A股6位 / 港股5位),必须为纯数字字符串。 - - Returns: - dict: {"code": str, "name": str} 或 None。 - """ - if not code: - return None - clean = "".join([c for c in str(code).strip() if c.isdigit()]) - if not clean: - return None - - cursor = self.conn.cursor() - cursor.execute("SELECT code, name FROM stock_list WHERE code = ? LIMIT 1", (clean,)) - row = cursor.fetchone() - return dict(row) if row else None - - def save_stock_prices(self, ticker: str, df: pd.DataFrame): - """保存股价历史数据""" - if df.empty: - return - - cursor = self.conn.cursor() - - # 确保 DataFrame 有必要的列 - required_cols = ['date', 'open', 'close', 'high', 'low', 'volume', 'change_pct'] - for col in required_cols: - if col not in df.columns: - logger.warning(f"Missing column {col} in stock data for {ticker}") - return - - try: - for _, row in df.iterrows(): - cursor.execute(""" - INSERT OR REPLACE INTO stock_prices - (ticker, date, open, close, high, low, volume, change_pct) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, ( - ticker, - row['date'], - row['open'], - row['close'], - row['high'], - row['low'], - row['volume'], - row['change_pct'] - )) - self.conn.commit() - except sqlite3.Error as e: - logger.error(f"Database error saving stock prices for {ticker}: {e}") - except Exception as e: - logger.error(f"Unexpected error saving stock prices for {ticker}: {e}") - - def get_stock_prices(self, ticker: str, start_date: str, end_date: str) -> pd.DataFrame: - """获取指定日期范围的股价数据""" - cursor = self.conn.cursor() - - cursor.execute(""" - SELECT * FROM stock_prices - WHERE ticker = ? AND date >= ? AND date <= ? - ORDER BY date - """, (ticker, start_date, end_date)) - - rows = cursor.fetchall() - if not rows: - return pd.DataFrame() - - columns = ['ticker', 'date', 'open', 'close', 'high', 'low', 'volume', 'change_pct'] - return pd.DataFrame([dict(row) for row in rows], columns=columns) - - def execute_query(self, query: str, params: tuple = ()) -> List[Any]: - """执行自定义 SQL 查询""" - try: - cursor = self.conn.cursor() - cursor.execute(query, params) - if query.strip().upper().startswith("SELECT"): - return cursor.fetchall() - else: - self.conn.commit() - return [] - except sqlite3.Error as e: - logger.error(f"SQL execution failed (Database error): {e}") - return [] - except Exception as e: - logger.error(f"SQL execution failed (Unexpected error): {e}") - return [] - - # --- 投资信号操作 (ISQ Framework) --- - - def save_signal(self, signal: Dict[str, Any]): - """保存投资信号""" - cursor = self.conn.cursor() - created_at = datetime.now().isoformat() - - cursor.execute(""" - INSERT OR REPLACE INTO signals - (signal_id, title, summary, transmission_chain, sentiment_score, - confidence, intensity, expected_horizon, price_in_status, - impact_tickers, industry_tags, sources, user_id, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - signal.get('signal_id'), - signal.get('title'), - signal.get('summary'), - json.dumps(signal.get('transmission_chain', [])), - signal.get('sentiment_score', 0.0), - signal.get('confidence', 0.0), - signal.get('intensity', 1), - signal.get('expected_horizon', 'T+0'), - signal.get('price_in_status', '未知'), - json.dumps(signal.get('impact_tickers', [])), - json.dumps(signal.get('industry_tags', [])), - json.dumps(signal.get('sources', [])), - signal.get('user_id'), - created_at - )) - self.conn.commit() - - def get_recent_signals(self, limit: int = 20, user_id: Optional[str] = None) -> List[Dict]: - """获取最近的投资信号""" - cursor = self.conn.cursor() - if user_id: - cursor.execute("SELECT * FROM signals WHERE user_id = ? ORDER BY created_at DESC LIMIT ?", (user_id, limit)) - else: - cursor.execute("SELECT * FROM signals ORDER BY created_at DESC LIMIT ?", (limit,)) - rows = cursor.fetchall() - - signals = [] - for row in rows: - d = dict(row) - # 解析 JSON 字段 - for field in ['transmission_chain', 'impact_tickers', 'industry_tags', 'sources']: - if d.get(field): - try: - d[field] = json.loads(d[field]) - except: - pass - signals.append(d) - return signals - - def close(self): - if self.conn: - self.conn.close() - logger.info("Database connection closed.") - diff --git a/skills/alphaear-predictor/scripts/utils/json_utils.py b/skills/alphaear-predictor/scripts/utils/json_utils.py deleted file mode 100644 index c29aab2..0000000 --- a/skills/alphaear-predictor/scripts/utils/json_utils.py +++ /dev/null @@ -1,180 +0,0 @@ -import ast -import json -import re -from typing import Optional, Any -from loguru import logger - -def _strip_comments(text: str) -> str: - """ - Safely remove C-style comments (// and /* */) from JSON-like text, - preserving strings (including URLs like http://). - """ - result = [] - i = 0 - n = len(text) - in_string = False - escape = False - - while i < n: - char = text[i] - - if in_string: - if char == '\\': - escape = not escape - elif char == '"' and not escape: - in_string = False - else: - escape = False - result.append(char) - i += 1 - continue - - # Not in string - if char == '"': - in_string = True - result.append(char) - i += 1 - continue - - # Check for // comment - if i + 1 < n and text[i:i+2] == '//': - i += 2 - while i < n and text[i] != '\n': - i += 1 - continue - - # Check for /* comment - if i + 1 < n and text[i:i+2] == '/*': - i += 2 - while i + 1 < n and text[i:i+2] != '*/': - i += 1 - i += 2 - continue - - result.append(char) - i += 1 - - return ''.join(result) - -def extract_json(text: str) -> Optional[Any]: - """ - 更加鲁棒的 JSON 提取工具。 - 处理: - 1. Markdown 代码块 (```json ... ```) - 2. 首尾多余字符 - 3. 同一个文本中多个 JSON 对象 (仅提取第一个) - 4. 简单的 JSON 修复 (末尾逗号等) - 5. C 风格注释 (// 和 /* */) - """ - if not text: - return None - - # 1. 清理明显的 Markdown 包装 - text = text.strip() - - # 先尝试精确匹配 ```json ... ``` 或 ```...``` - md_match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', text, re.DOTALL) - if md_match: - text = md_match.group(1).strip() - elif text.startswith("```"): - # 回退:如果开头有 ``` 但没完整匹配 - text = re.sub(r'^```[a-z]*\n?', '', text) - text = re.sub(r'\n?```\s*$', '', text) - - # 2. 寻找第一个 JSON 起始符 { 或 [ - start_brace = text.find('{') - start_bracket = text.find('[') - - if start_brace == -1 and start_bracket == -1: - return None - - start_idx = start_brace if (start_bracket == -1 or (start_brace != -1 and start_brace < start_bracket)) else start_bracket - - # 2.5 预处理:修复一些极其常见的 LLM 错误 - potential_json = text[start_idx:].strip() - - # remove comments safely - potential_json = _strip_comments(potential_json) - - # b. 修复缺失开头引号的键: nodes": [ -> "nodes": [ - # 匹配模式: (空白或换行) 单词 紧跟引号和冒号 - potential_json = re.sub(r'([\{\,]\s*)([a-zA-Z_]\w*)\"\s*:', r'\1"\2":', potential_json) - - # c. 修复缺失末尾引号的键: "nodes: [ -> "nodes": [ - potential_json = re.sub(r'([\{\,]\s*)\"([a-zA-Z_]\w*)\s*:', r'\1"\2":', potential_json) - - # d. 修复完全缺失引号的键: nodes: [ -> "nodes": [ - # 注意避免匹配到像 http:// 这种内容,所以限定在 { 或 , 之后 - potential_json = re.sub(r'([\{\,]\s*)([a-zA-Z_]\w*)\s*:', r'\1"\2":', potential_json) - - # 3. 使用 raw_decode 尝试解析 - decoder = json.JSONDecoder() - - # 首先尝试直接解析(不做任何预处理) - try: - obj = json.loads(potential_json) - return obj - except json.JSONDecodeError: - pass - - # 简单预处理:移除对象/列表末位多余逗号 - processed_json = re.sub(r',\s*([\]}])', r'\1', potential_json) - - try: - obj, end_pos = decoder.raw_decode(processed_json) - return obj - except json.JSONDecodeError: - pass - - # e. 修复未终止的字符串字面量问题:移除值中的实际换行符 - # LLM 可能在字符串值中生成包含真实 newline 的内容,导致 JSON 非法 - def fix_multiline_strings(s): - # 简单策略:将字符串值内的换行替换为空格 - lines = s.split('\n') - result = [] - in_string = False - for line in lines: - # 计算未转义的引号数 - quote_count = line.count('"') - line.count('\\"') - if in_string: - result[-1] += ' ' + line.strip() - else: - result.append(line) - - if quote_count % 2 == 1: - in_string = not in_string - return '\n'.join(result) - - fixed_json = fix_multiline_strings(processed_json) - - try: - obj, end_pos = decoder.raw_decode(fixed_json) - return obj - except json.JSONDecodeError: - try: - # 4. 尝试处理单引号问题 (JSON 规范要求双引号,但 LLM 常输出单引号) - # 这是一个简单的替换技巧,仅针对像 {'key': 'value'} 这样的结构 - # 注意:这可能会破坏包含单引号的字符串值,所以作为较后的回退 - fix_quotes = re.sub(r"'(.*?)':", r'"\1":', processed_json) # 修复键 - fix_quotes = re.sub(r":\s*'(.*?)'", r': "\1"', fix_quotes) # 修复简单值 - obj, end_pos = decoder.raw_decode(fix_quotes) - return obj - except (json.JSONDecodeError, TypeError): - try: - # 5. 使用 ast.literal_eval 作为终极回退 (处理 Python 字典格式) - # 提取第一个匹配的括号对内容 - # 寻找匹配的 { } - stack = [] - for i, char in enumerate(potential_json): - if char == '{': stack.append('{') - elif char == '}': - if stack: stack.pop() - if not stack: - content = potential_json[:i+1] - return ast.literal_eval(content) - except (ValueError, SyntaxError, MemoryError) as e: - logger.warning(f"All JSON extraction attempts failed: {e}") - except Exception as e: - logger.error(f"Unexpected error during JSON extraction: {e}") - - return None diff --git a/skills/alphaear-predictor/scripts/utils/llm/capability.py b/skills/alphaear-predictor/scripts/utils/llm/capability.py deleted file mode 100644 index d07ca4f..0000000 --- a/skills/alphaear-predictor/scripts/utils/llm/capability.py +++ /dev/null @@ -1,85 +0,0 @@ -import os -from typing import Optional, List, Dict, Any -from agno.agent import Agent -from agno.models.base import Model -from loguru import logger -from ..llm.factory import get_model - - -def test_tool_call_support(model: Model) -> bool: - """ - 测试模型是否支持原生的 Tool Call (Function Calling)。 - 通过尝试执行一个简单的加法工具来验证。 - """ - - def get_current_weather(location: str): - """获取指定地点的天气""" - return f"{location} 的天气是晴天,25度。" - - test_agent = Agent( - model=model, - tools=[get_current_weather], - instructions="请调用工具查询北京的天气,并直接返回工具的输出结果。", - ) - - try: - # 运行一个简单的任务,观察是否触发了 tool_call - response = test_agent.run("北京天气怎么样?") - - # 检查 response 中是否包含 tool_calls - # Agno 的 RunResponse 对象通常包含 messages,我们可以检查最后几条消息 - has_tool_call = False - for msg in response.messages: - if hasattr(msg, "tool_calls") and msg.tool_calls: - has_tool_call = True - break - - if has_tool_call: - logger.info(f"✅ Model {model.id} supports native tool calling.") - return True - else: - # 如果没有 tool_calls 但返回了正确答案,可能是模型通过纯文本模拟了工具调用(ReAct) - # 或者根本没用工具。对于原生支持的判断,我们坚持要求有 tool_calls 结构。 - logger.warning( - f"⚠️ Model {model.id} did NOT use native tool calling structure." - ) - return False - - except Exception as e: - logger.error(f"❌ Error testing tool call for {model.id}: {e}") - return False - - -class ModelCapabilityRegistry: - """ - 模型能力注册表,用于缓存和管理不同模型的能力测试结果。 - """ - - _cache = {} - - @classmethod - def get_capabilities( - cls, provider: str, model_id: str, **kwargs - ) -> Dict[str, bool]: - key = f"{provider}:{model_id}" - if key not in cls._cache: - logger.info(f"🔍 Testing capabilities for {key}...") - model = get_model(provider, model_id, **kwargs) - supports_tool_call = test_tool_call_support(model) - cls._cache[key] = {"supports_tool_call": supports_tool_call} - return cls._cache[key] - - -if __name__ == "__main__": - import os - from skills._env_loader import load_unified_env - - load_unified_env() - - # 测试当前配置的模型 - p = os.getenv("LLM_PROVIDER", "minimax") - m = os.getenv("LLM_MODEL", "Qwen") - - print(f"Testing {p}/{m}...") - res = ModelCapabilityRegistry.get_capabilities(p, m) - print(f"Result: {res}") diff --git a/skills/alphaear-predictor/scripts/utils/llm/factory.py b/skills/alphaear-predictor/scripts/utils/llm/factory.py deleted file mode 100644 index 449e5b8..0000000 --- a/skills/alphaear-predictor/scripts/utils/llm/factory.py +++ /dev/null @@ -1,122 +0,0 @@ -import os -from agno.models.openai import OpenAIChat -from agno.models.ollama import Ollama -from agno.models.dashscope import DashScope -from agno.models.deepseek import DeepSeek -from agno.models.openrouter import OpenRouter - - -def get_model(model_provider: str, model_id: str, **kwargs): - """ - Factory to get the appropriate LLM model. - - Args: - model_provider: "openai", "ollama", "deepseek" - model_id: The specific model ID (e.g., "gpt-4o", "llama3", "deepseek-chat") - **kwargs: Additional arguments for the model constructor - """ - if model_provider == "openai": - return OpenAIChat(id=model_id, **kwargs) - - elif model_provider == "ollama": - return Ollama(id=model_id, **kwargs) - - elif model_provider == "minimax": - api_key = os.getenv("MINIMAX_API_KEY") - if not api_key: - print("Warning: MINIMAX_API_KEY not set.") - - return OpenAIChat( - id=model_id, - base_url=os.getenv("MINIMAX_API_BASE", "https://api.minimax.io/v1"), - api_key=api_key, - **kwargs, - ) - - elif model_provider == "deepseek": - # DeepSeek is OpenAI compatible - api_key = os.getenv("DEEPSEEK_API_KEY") - if not api_key: - print("Warning: DEEPSEEK_API_KEY not set.") - - return DeepSeek(id=model_id, api_key=api_key, **kwargs) - elif model_provider == "dashscope": - api_key = os.getenv("DASHSCOPE_API_KEY") - if not api_key: - print("Warning: DASHSCOPE_API_KEY not set.") - - return DashScope( - id=model_id, - base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", - api_key=api_key, - **kwargs, - ) - elif model_provider == "openrouter": - api_key = os.getenv("OPENROUTER_API_KEY") - if not api_key: - print("Warning: OPENROUTER_API_KEY not set.") - - return OpenRouter(id=model_id, api_key=api_key, **kwargs) - - elif model_provider == "zai": - api_key = os.getenv("ZAI_KEY_API") - if not api_key: - print("Warning: ZAI_KEY_API not set.") - - # role_map to ensure compatibility. - default_role_map = { - "system": "system", - "user": "user", - "assistant": "assistant", - "tool": "tool", - "model": "assistant", - } - - # Allow callers to override role_map via kwargs, otherwise use default - role_map = kwargs.pop("role_map", default_role_map) - - return OpenAIChat( - id=model_id, - base_url="https://api.z.ai/api/paas/v4", - api_key=api_key, - timeout=60, - role_map=role_map, - extra_body={ - "enable_thinking": False - }, # TODO: one more setting for thinking - **kwargs, - ) - - elif model_provider == "ust": - api_key = os.getenv("UST_KEY_API") - if not api_key: - print("Warning: UST_KEY_API not set.") - - # Some UST-compatible endpoints expect the standard OpenAI role names - # (e.g. "system", "user", "assistant") rather than Agno's default - # mapping which maps "system" -> "developer". Provide an explicit - # role_map to ensure compatibility. - default_role_map = { - "system": "system", - "user": "user", - "assistant": "assistant", - "tool": "tool", - "model": "assistant", - } - - # Allow callers to override role_map via kwargs, otherwise use default - role_map = kwargs.pop("role_map", default_role_map) - - return OpenAIChat( - id=model_id, - api_key=api_key, - base_url=os.getenv("UST_URL"), - role_map=role_map, - extra_body={ - "enable_thinking": False - }, # TODO: one more setting for thinking - **kwargs, - ) - - else: - raise ValueError(f"Unknown model provider: {model_provider}") diff --git a/skills/alphaear-predictor/scripts/utils/llm/router.py b/skills/alphaear-predictor/scripts/utils/llm/router.py deleted file mode 100644 index 8c69958..0000000 --- a/skills/alphaear-predictor/scripts/utils/llm/router.py +++ /dev/null @@ -1,80 +0,0 @@ -import os -from typing import Optional, List, Dict, Any, Union -from agno.models.base import Model -from loguru import logger -from ..llm.factory import get_model -from ..llm.capability import ModelCapabilityRegistry -from skills._env_loader import load_unified_env - -load_unified_env() - - -class ModelRouter: - """ - 模型路由管理器 - - 功能: - 1. 管理“推理/写作模型” (Reasoning Model) 和“工具调用模型” (Tool Model)。 - 2. 根据任务需求自动选择合适的模型。 - """ - - def __init__(self): - # 默认从环境变量读取 - self.reasoning_provider = os.getenv( - "REASONING_MODEL_PROVIDER", os.getenv("LLM_PROVIDER", "openai") - ) - self.reasoning_id = os.getenv( - "REASONING_MODEL_ID", os.getenv("LLM_MODEL", "gpt-4o") - ) - self.reasoning_host = os.getenv("REASONING_MODEL_HOST", os.getenv("LLM_HOST")) - - self.tool_provider = os.getenv("TOOL_MODEL_PROVIDER", self.reasoning_provider) - self.tool_id = os.getenv("TOOL_MODEL_ID", self.reasoning_id) - self.tool_host = os.getenv("TOOL_MODEL_HOST", self.reasoning_host) - - self._reasoning_model = None - self._tool_model = None - - logger.info( - f"🤖 ModelRouter initialized: Reasoning={self.reasoning_id} ({self.reasoning_host or 'default'}), Tool={self.tool_id} ({self.tool_host or 'default'})" - ) - - def get_reasoning_model(self, **kwargs) -> Model: - if not self._reasoning_model: - # 优先使用路由配置的 host - if self.reasoning_host and "host" not in kwargs: - kwargs["host"] = self.reasoning_host - self._reasoning_model = get_model( - self.reasoning_provider, self.reasoning_id, **kwargs - ) - return self._reasoning_model - - def get_tool_model(self, **kwargs) -> Model: - if not self._tool_model: - # 优先使用路由配置的 host - if self.tool_host and "host" not in kwargs: - kwargs["host"] = self.tool_host - - # 检查 tool_model 是否真的支持 tool call - caps = ModelCapabilityRegistry.get_capabilities( - self.tool_provider, self.tool_id, **kwargs - ) - if not caps["supports_tool_call"]: - logger.warning( - f"⚠️ Configured tool model {self.tool_id} might not support native tool calls! Consider using ReAct mode or a different model." - ) - - self._tool_model = get_model(self.tool_provider, self.tool_id, **kwargs) - return self._tool_model - - def get_model_for_agent(self, has_tools: bool = False, **kwargs) -> Model: - """ - 根据 Agent 是否包含工具来返回合适的模型。 - """ - if has_tools: - return self.get_tool_model(**kwargs) - return self.get_reasoning_model(**kwargs) - - -# 全局单例 -router = ModelRouter() diff --git a/skills/alphaear-predictor/scripts/utils/logging_setup.py b/skills/alphaear-predictor/scripts/utils/logging_setup.py deleted file mode 100644 index 9a2ca62..0000000 --- a/skills/alphaear-predictor/scripts/utils/logging_setup.py +++ /dev/null @@ -1,45 +0,0 @@ -import os -import sys -from datetime import datetime -from typing import Optional - -from loguru import logger - - -def setup_file_logging( - run_id: str, - log_dir: str = "logs", - level: str = "INFO", - retention: str = "10 days", - rotation: str = "20 MB", -) -> str: - """Configure Loguru to log to stderr + a per-run file. - - Returns the log file path. - """ - os.makedirs(log_dir, exist_ok=True) - - # Remove default handler to avoid duplicate logs. - logger.remove() - - # Console - logger.add(sys.stderr, level=level, backtrace=False, diagnose=False) - - # File (safe for multi-thread via enqueue) - log_path = os.path.join(log_dir, f"signalflux_{run_id}.log") - logger.add( - log_path, - level=level, - rotation=rotation, - retention=retention, - enqueue=True, - backtrace=True, - diagnose=False, - encoding="utf-8", - ) - return log_path - - -def make_run_id(prefix: Optional[str] = None) -> str: - ts = datetime.now().strftime("%Y%m%d_%H%M%S") - return f"{prefix}_{ts}" if prefix else ts diff --git a/skills/alphaear-predictor/scripts/utils/predictor/evaluation.py b/skills/alphaear-predictor/scripts/utils/predictor/evaluation.py deleted file mode 100644 index 26c5df7..0000000 --- a/skills/alphaear-predictor/scripts/utils/predictor/evaluation.py +++ /dev/null @@ -1,137 +0,0 @@ -import os -import sys -import torch -import pandas as pd -import numpy as np -import glob -from loguru import logger -from datetime import datetime, timedelta - -# Setup paths -KRONOS_DIR = os.path.dirname(os.path.abspath(__file__)) -SRC_DIR = os.path.dirname(os.path.dirname(KRONOS_DIR)) -if SRC_DIR not in sys.path: - sys.path.insert(0, SRC_DIR) - -from ..kronos.auto_synthesis_training import AutoSynthesisTrainer -from ..kronos.model import KronosPredictor -from ..visualizer import VisualizerTools -from ..schema.models import ForecastResult, KLinePoint - -class NewsModelEvaluator: - def __init__(self, model_path=None): - self.trainer = AutoSynthesisTrainer() - self.device = self.trainer.device - - if model_path is None: - # Try to find the latest model in exports/models - model_files = glob.glob(os.path.join(SRC_DIR, "exports/models/*.pt")) - if not model_files: - logger.warning("⚠️ No trained models found in exports/models/. Using base model (zero-init proj).") - else: - model_path = max(model_files, key=os.path.getctime) - - if model_path: - self.load_weights(model_path) - - def load_weights(self, path): - logger.info(f"🔄 Loading model weights from {path}...") - checkpoint = torch.load(path, map_location=self.device) - self.trainer.model.news_proj.load_state_dict(checkpoint['news_proj_state_dict']) - logger.success("✅ News projection layer loaded.") - - def evaluate_range(self, start_idx=100, end_idx=200, pred_len=5): - # 1. Fetch Tickers - res = self.trainer.db.execute_query("SELECT code FROM stock_list") - all_tickers = [row['code'] for row in res] - test_tickers = all_tickers[start_idx:end_idx] - - if not test_tickers: - logger.error(f"No tickers found in range {start_idx}-{end_idx}") - return - - logger.info(f"🚀 Evaluating News Model on stocks {start_idx} to {end_idx}...") - - # 2. Discover Shocks - shocks = self.trainer.discover_shocks(test_tickers, pred_len=pred_len) - - # 3. Associate News & Predict - self.trainer.model.eval() - predictor = KronosPredictor(self.trainer.model, self.trainer.tokenizer, device=self.device) - - save_dir = os.path.join(SRC_DIR, "exports/evaluation_results") - os.makedirs(save_dir, exist_ok=True) - - count = 0 - for shock in shocks: - summary = self.trainer.find_reason_and_verify(shock) - if not summary: - continue - - logger.info(f"📈 Testing shock: {shock['ticker']} on {shock['date']}") - - # Embedding news - news_emb = self.trainer.embedder.encode(summary) - - # Prediction - h = shock['history'] - t = shock['target'] - actuals = t['close'].values[:pred_len] - - x_ts = pd.to_datetime(h['date']) - future_dates = pd.date_range(start=x_ts.iloc[-1] + timedelta(days=1), periods=pred_len, freq='B') - y_ts = pd.Series(future_dates) - - # A. Base Prediction (No news) - p_base = predictor.predict(h, x_ts, y_ts, pred_len=pred_len, news_emb=None, verbose=False) - - # B. News-Aware Prediction - p_news = predictor.predict(h, x_ts, y_ts, pred_len=pred_len, news_emb=news_emb, verbose=False) - - # Calculate Improvement - b_preds = p_base['close'].values[:len(actuals)] - n_preds = p_news['close'].values[:len(actuals)] - b_mae = np.mean(np.abs(b_preds - actuals)) - n_mae = np.mean(np.abs(n_preds - actuals)) - improvement = (b_mae - n_mae) / (b_mae + 1e-6) * 100 - - # C. Visualize - try: - def to_kp_list(preds_df): - points = [] - for idx, row in preds_df.iterrows(): - points.append(KLinePoint( - date=str(idx)[:10], open=row['open'], high=row['high'], - low=row['low'], close=row['close'], volume=row.get('volume', 0) - )) - return points - - forecast_obj = ForecastResult( - ticker=shock['ticker'], - base_forecast=to_kp_list(p_base), - adjusted_forecast=to_kp_list(p_news), - rationale=summary - ) - - chart = VisualizerTools.generate_stock_chart( - df=h, ticker=shock['ticker'], - title=f"Test Eval: {shock['ticker']} ({shock['date']}) Imp: {improvement:.1f}%", - forecast=forecast_obj, - ground_truth=t[['date', 'open', 'high', 'low', 'close', 'volume']] - ) - - safe_date = shock['date'].replace("-", "") - filename = f"test_{shock['ticker']}_{safe_date}.html" - VisualizerTools.render_chart_to_file(chart, os.path.join(save_dir, filename)) - - logger.success(f"📊 Result for {shock['ticker']} saved. Base MAE: {b_mae:.4f}, News MAE: {n_mae:.4f}") - count += 1 - except Exception as e: - logger.error(f"Visualization failed: {e}") - - logger.info(f"🏁 Finished evaluation. {count} cases visualized in {save_dir}") - -if __name__ == "__main__": - # If you have a specific model, pass the path here. Otherwise it picks the latest. - evaluator = NewsModelEvaluator() - evaluator.evaluate_range(start_idx=100, end_idx=200, pred_len=1) diff --git a/skills/alphaear-predictor/scripts/utils/predictor/kline_generate.py b/skills/alphaear-predictor/scripts/utils/predictor/kline_generate.py deleted file mode 100644 index 3224c21..0000000 --- a/skills/alphaear-predictor/scripts/utils/predictor/kline_generate.py +++ /dev/null @@ -1,196 +0,0 @@ -# Ref: https://github.com/shiyu-coder/Kronos - -from model import Kronos, KronosTokenizer, KronosPredictor -import pandas as pd -import sqlite3 -import torch -import matplotlib.pyplot as plt -import matplotlib.gridspec as gridspec -from pandas.tseries.offsets import BusinessDay -import numpy as np - -def get_device(): - device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" - print(f"Using device: {device}") - return device - -def load_predictor(): - tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") - model = Kronos.from_pretrained("NeoQuasar/Kronos-base") - device = get_device() - tokenizer = tokenizer.to(device) - model = model.to(device) - return KronosPredictor(model, tokenizer, device=device, max_context=512) - -def load_data(ticker="002111", db_path="AlphaEar/data/signal_flux.db"): - with sqlite3.connect(db_path) as conn: - df = pd.read_sql_query(f"SELECT * FROM stock_prices WHERE ticker = '{ticker}'", conn) - df['date'] = pd.to_datetime(df['date']) - df = df.sort_values('date').reset_index(drop=True) - return df - -def plot_kline_matplotlib(ax, ax_vol, dates, df, label_suffix="", color_up='#ef4444', color_down='#22c55e', alpha=1.0, is_prediction=False): - """ - 绘制 K 线图和成交量 - """ - # X axis mapping to integers for consistent spacing - x = np.arange(len(dates)) - - # K-line data - opens = df['open'].values - closes = df['close'].values - highs = df['high'].values - lows = df['low'].values - volumes = df['volume'].values - - # Width of the candlestick - width = 0.6 - - for i in range(len(x)): - color = color_up if closes[i] >= opens[i] else color_down - linestyle = '--' if is_prediction else '-' - - # Wick - ax.vlines(x[i], lows[i], highs[i], color=color, linewidth=1, alpha=alpha, linestyle=linestyle) - - # Body - rect_bottom = min(opens[i], closes[i]) - rect_height = abs(opens[i] - closes[i]) - if rect_height == 0: rect_height = 0.001 # Visual hair - - ax.add_patch(plt.Rectangle((x[i] - width/2, rect_bottom), width, rect_height, - edgecolor=color, facecolor=color if not is_prediction else 'none', - alpha=alpha, linewidth=1, linestyle=linestyle)) - - # Volume - ax_vol.bar(x[i], volumes[i], color=color, alpha=alpha * 0.5, width=width) - -def render_comparison_chart(history_df, actual_df, pred_df, title): - """ - 渲染组合图:历史 K 线 + 真值 K 线 + 预测 K 线 - """ - # Combine all dates for X axis - all_dates = pd.concat([history_df['date'], actual_df['date'] if actual_df is not None else pred_df.index.to_series()]).unique() - all_dates = sorted(all_dates) - date_to_idx = {date: i for i, date in enumerate(all_dates)} - - fig = plt.figure(figsize=(14, 8), facecolor='white') - gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1], hspace=0.1) - ax_main = fig.add_subplot(gs[0]) - ax_vol = fig.add_subplot(gs[1], sharex=ax_main) - - # 1. Plot History - hist_indices = [date_to_idx[d] for d in history_df['date']] - # We use a custom x for plotting to ensure continuity - plot_kline_matplotlib(ax_main, ax_vol, history_df['date'], history_df, alpha=0.8) - - offset = len(history_df) - - # 2. Plot Actual if exists - if actual_df is not None: - # Shift indices - actual_x = np.arange(len(actual_df)) + offset - # Plotting manually to handle offset - for i in range(len(actual_df)): - idx = actual_x[i] - row = actual_df.iloc[i] - color = '#ef4444' if row['close'] >= row['open'] else '#22c55e' - ax_main.vlines(idx, row['low'], row['high'], color=color, linewidth=1, alpha=0.9) - ax_main.add_patch(plt.Rectangle((idx - 0.3, min(row['open'], row['close'])), 0.6, abs(row['open']-row['close']), - edgecolor=color, facecolor=color, alpha=0.9)) - ax_vol.bar(idx, row['volume'], color=color, alpha=0.4) - - # 3. Plot Prediction - pred_x = np.arange(len(pred_df)) + offset - for i in range(len(pred_df)): - idx = pred_x[i] - row = pred_df.iloc[i] - color = '#ff8c00' # Orange for prediction to distinguish - ax_main.vlines(idx, row['low'], row['high'], color=color, linewidth=1.5, linestyle='--') - ax_main.add_patch(plt.Rectangle((idx - 0.3, min(row['open'], row['close'])), 0.6, abs(row['open']-row['close']), - edgecolor=color, facecolor='none', linewidth=1.5, linestyle='--')) - # Plot secondary prediction line for close - if i == 0: - # Connect to history - ax_main.plot([offset-1, idx], [history_df['close'].iloc[-1], row['close']], color=color, linestyle='--', alpha=0.6) - elif i > 0: - ax_main.plot([idx-1, idx], [pred_df['close'].iloc[i-1], row['close']], color=color, linestyle='--', alpha=0.6) - - # Styling - ax_main.set_title(title, fontsize=14, fontweight='bold') - ax_main.grid(True, linestyle=':', alpha=0.6) - ax_vol.grid(True, linestyle=':', alpha=0.6) - ax_vol.set_ylabel('Volume') - ax_main.set_ylabel('Price') - - # Set X ticks - step = max(1, len(all_dates) // 10) - ax_vol.set_xticks(np.arange(0, len(all_dates), step)) - ax_vol.set_xticklabels([all_dates[i].strftime('%Y-%m-%d') for i in range(0, len(all_dates), step)], rotation=45) - - plt.tight_layout() - plt.show() - plt.close() - -def run_backtest(df, predictor, lookback, pred_len, start_index=0): - total_len = len(df) - history_start = start_index - history_end = start_index + lookback - pred_start = history_end - - available_pred_len = total_len - pred_start - if available_pred_len <= 0: return - actual_pred_len = min(pred_len, available_pred_len) - pred_end = pred_start + actual_pred_len - - x_df = df.iloc[history_start : history_end].copy() - y_true_df = df.iloc[pred_start : pred_end].copy() - y_timestamp = y_true_df['date'] - - print(f"Backtesting: {x_df['date'].iloc[0].date()} to {y_timestamp.iloc[-1].date()}") - - pred_df = predictor.predict( - df=x_df[['open', 'high', 'low', 'close', 'volume']], - x_timestamp=x_df['date'], - y_timestamp=y_timestamp, - pred_len=actual_pred_len, - T=1.0, top_p=0.9, sample_count=1 - ) - - render_comparison_chart(x_df, y_true_df, pred_df, f"Backtest: {TICKER} K-Line Comparison") - -def run_forecast(df, predictor, lookback, pred_len): - if len(df) < lookback: return - x_df = df.iloc[-lookback:].copy() - last_date = x_df['date'].iloc[-1] - future_dates = pd.date_range(start=last_date + BusinessDay(1), periods=pred_len, freq='B') - future_dates = pd.Series(future_dates) - - print(f"Forecasting: Starting from {future_dates.iloc[0].date()}") - - pred_df = predictor.predict( - df=x_df[['open', 'high', 'low', 'close', 'volume']], - x_timestamp=x_df['date'], - y_timestamp=future_dates, - pred_len=pred_len, - T=1.0, top_p=0.9, sample_count=1 - ) - - render_comparison_chart(x_df, None, pred_df, f"Forecast: {TICKER} Future K-Line") - -if __name__ == "__main__": - LOOKBACK = 20 - PRED_LEN = 10 - TICKER = '002111' - - pred_model = load_predictor() - stock_data = load_data(TICKER) - - total_rows = len(stock_data) - backtest_start = max(0, total_rows - LOOKBACK - PRED_LEN - 10) # Leave some space to see trend - - print("\n--- Running Backtest ---") - run_backtest(stock_data, pred_model, LOOKBACK, PRED_LEN, start_index=backtest_start) - - print("\n--- Running Forecast ---") - run_forecast(stock_data, pred_model, LOOKBACK, PRED_LEN) \ No newline at end of file diff --git a/skills/alphaear-predictor/scripts/utils/predictor/model/__init__.py b/skills/alphaear-predictor/scripts/utils/predictor/model/__init__.py deleted file mode 100644 index d10e200..0000000 --- a/skills/alphaear-predictor/scripts/utils/predictor/model/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from .kronos import KronosTokenizer, Kronos, KronosPredictor - -model_dict = { - 'kronos_tokenizer': KronosTokenizer, - 'kronos': Kronos, - 'kronos_predictor': KronosPredictor -} - - -def get_model_class(model_name): - if model_name in model_dict: - return model_dict[model_name] - else: - print(f"Model {model_name} not found in model_dict") - raise NotImplementedError - diff --git a/skills/alphaear-predictor/scripts/utils/predictor/model/kronos.py b/skills/alphaear-predictor/scripts/utils/predictor/model/kronos.py deleted file mode 100644 index cf8bece..0000000 --- a/skills/alphaear-predictor/scripts/utils/predictor/model/kronos.py +++ /dev/null @@ -1,676 +0,0 @@ -import numpy as np -import pandas as pd -import torch -from huggingface_hub import PyTorchModelHubMixin -import sys - -from tqdm import trange - -sys.path.append("../") -from model.module import * - - -class KronosTokenizer(nn.Module, PyTorchModelHubMixin): - """ - KronosTokenizer module for tokenizing input data using a hybrid quantization approach. - - This tokenizer utilizes a combination of encoder and decoder Transformer blocks - along with the Binary Spherical Quantization (BSQuantizer) to compress and decompress input data. - - Args: - d_in (int): Input dimension. - d_model (int): Model dimension. - n_heads (int): Number of attention heads. - ff_dim (int): Feed-forward dimension. - n_enc_layers (int): Number of encoder layers. - n_dec_layers (int): Number of decoder layers. - ffn_dropout_p (float): Dropout probability for feed-forward networks. - attn_dropout_p (float): Dropout probability for attention mechanisms. - resid_dropout_p (float): Dropout probability for residual connections. - s1_bits (int): Number of bits for the pre token in BSQuantizer. - s2_bits (int): Number of bits for the post token in BSQuantizer. - beta (float): Beta parameter for BSQuantizer. - gamma0 (float): Gamma0 parameter for BSQuantizer. - gamma (float): Gamma parameter for BSQuantizer. - zeta (float): Zeta parameter for BSQuantizer. - group_size (int): Group size parameter for BSQuantizer. - - """ - - def __init__(self, d_in, d_model, n_heads, ff_dim, n_enc_layers, n_dec_layers, ffn_dropout_p, attn_dropout_p, resid_dropout_p, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size): - - super().__init__() - self.d_in = d_in - self.d_model = d_model - self.n_heads = n_heads - self.ff_dim = ff_dim - self.enc_layers = n_enc_layers - self.dec_layers = n_dec_layers - self.ffn_dropout_p = ffn_dropout_p - self.attn_dropout_p = attn_dropout_p - self.resid_dropout_p = resid_dropout_p - - self.s1_bits = s1_bits - self.s2_bits = s2_bits - self.codebook_dim = s1_bits + s2_bits # Total dimension of the codebook after quantization - self.embed = nn.Linear(self.d_in, self.d_model) - self.head = nn.Linear(self.d_model, self.d_in) - - # Encoder Transformer Blocks - self.encoder = nn.ModuleList([ - TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p) - for _ in range(self.enc_layers - 1) - ]) - # Decoder Transformer Blocks - self.decoder = nn.ModuleList([ - TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p) - for _ in range(self.dec_layers - 1) - ]) - self.quant_embed = nn.Linear(in_features=self.d_model, out_features=self.codebook_dim) # Linear layer before quantization - self.post_quant_embed_pre = nn.Linear(in_features=self.s1_bits, out_features=self.d_model) # Linear layer after quantization (pre part - s1 bits) - self.post_quant_embed = nn.Linear(in_features=self.codebook_dim, out_features=self.d_model) # Linear layer after quantization (full codebook) - self.tokenizer = BSQuantizer(self.s1_bits, self.s2_bits, beta, gamma0, gamma, zeta, group_size) # BSQuantizer module - - def forward(self, x): - """ - Forward pass of the KronosTokenizer. - - Args: - x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in). - - Returns: - tuple: A tuple containing: - - tuple: (z_pre, z) - Reconstructed outputs from decoder with s1_bits and full codebook respectively, - both of shape (batch_size, seq_len, d_in). - - torch.Tensor: bsq_loss - Loss from the BSQuantizer. - - torch.Tensor: quantized - Quantized representation from BSQuantizer. - - torch.Tensor: z_indices - Indices from the BSQuantizer. - """ - z = self.embed(x) - - for layer in self.encoder: - z = layer(z) - - z = self.quant_embed(z) # (B, T, codebook) - - bsq_loss, quantized, z_indices = self.tokenizer(z) - - quantized_pre = quantized[:, :, :self.s1_bits] # Extract the first part of quantized representation (s1_bits) - z_pre = self.post_quant_embed_pre(quantized_pre) - - z = self.post_quant_embed(quantized) - - # Decoder layers (for pre part - s1 bits) - for layer in self.decoder: - z_pre = layer(z_pre) - z_pre = self.head(z_pre) - - # Decoder layers (for full codebook) - for layer in self.decoder: - z = layer(z) - z = self.head(z) - - return (z_pre, z), bsq_loss, quantized, z_indices - - def indices_to_bits(self, x, half=False): - """ - Converts indices to bit representations and scales them. - - Args: - x (torch.Tensor): Indices tensor. - half (bool, optional): Whether to process only half of the codebook dimension. Defaults to False. - - Returns: - torch.Tensor: Bit representation tensor. - """ - if half: - x1 = x[0] # Assuming x is a tuple of indices if half is True - x2 = x[1] - mask = 2 ** torch.arange(self.codebook_dim//2, device=x1.device, dtype=torch.long) # Create a mask for bit extraction - x1 = (x1.unsqueeze(-1) & mask) != 0 # Extract bits for the first half - x2 = (x2.unsqueeze(-1) & mask) != 0 # Extract bits for the second half - x = torch.cat([x1, x2], dim=-1) # Concatenate the bit representations - else: - mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long) # Create a mask for bit extraction - x = (x.unsqueeze(-1) & mask) != 0 # Extract bits - - x = x.float() * 2 - 1 # Convert boolean to bipolar (-1, 1) - q_scale = 1. / (self.codebook_dim ** 0.5) # Scaling factor - x = x * q_scale - return x - - def encode(self, x, half=False): - """ - Encodes the input data into quantized indices. - - Args: - x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in). - half (bool, optional): Whether to use half quantization in BSQuantizer. Defaults to False. - - Returns: - torch.Tensor: Quantized indices from BSQuantizer. - """ - z = self.embed(x) - for layer in self.encoder: - z = layer(z) - z = self.quant_embed(z) - - bsq_loss, quantized, z_indices = self.tokenizer(z, half=half, collect_metrics=False) - return z_indices - - def decode(self, x, half=False): - """ - Decodes quantized indices back to the input data space. - - Args: - x (torch.Tensor): Quantized indices tensor. - half (bool, optional): Whether the indices were generated with half quantization. Defaults to False. - - Returns: - torch.Tensor: Reconstructed output tensor of shape (batch_size, seq_len, d_in). - """ - quantized = self.indices_to_bits(x, half) - z = self.post_quant_embed(quantized) - for layer in self.decoder: - z = layer(z) - z = self.head(z) - return z - - -class Kronos(nn.Module, PyTorchModelHubMixin): - """ - Kronos Model. - - Args: - s1_bits (int): Number of bits for pre tokens. - s2_bits (int): Number of bits for post tokens. - n_layers (int): Number of Transformer blocks. - d_model (int): Dimension of the model's embeddings and hidden states. - n_heads (int): Number of attention heads in the MultiheadAttention layers. - ff_dim (int): Dimension of the feedforward network in the Transformer blocks. - ffn_dropout_p (float): Dropout probability for the feedforward network. - attn_dropout_p (float): Dropout probability for the attention layers. - resid_dropout_p (float): Dropout probability for residual connections. - token_dropout_p (float): Dropout probability for token embeddings. - learn_te (bool): Whether to use learnable temporal embeddings. - """ - - def __init__(self, s1_bits, s2_bits, n_layers, d_model, n_heads, ff_dim, ffn_dropout_p, attn_dropout_p, resid_dropout_p, token_dropout_p, learn_te, news_dim=None): - super().__init__() - self.s1_bits = s1_bits - self.s2_bits = s2_bits - self.n_layers = n_layers - self.d_model = d_model - self.n_heads = n_heads - self.learn_te = learn_te - self.ff_dim = ff_dim - self.ffn_dropout_p = ffn_dropout_p - self.attn_dropout_p = attn_dropout_p - self.resid_dropout_p = resid_dropout_p - self.token_dropout_p = token_dropout_p - self.news_dim = news_dim - - self.s1_vocab_size = 2 ** self.s1_bits - self.token_drop = nn.Dropout(self.token_dropout_p) - self.embedding = HierarchicalEmbedding(self.s1_bits, self.s2_bits, self.d_model) - self.time_emb = TemporalEmbedding(self.d_model, self.learn_te) - self.transformer = nn.ModuleList([ - TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p) - for _ in range(self.n_layers) - ]) - self.norm = RMSNorm(self.d_model) - self.dep_layer = DependencyAwareLayer(self.d_model) - self.head = DualHead(self.s1_bits, self.s2_bits, self.d_model) - - if self.news_dim is not None: - self.news_proj = nn.Linear(self.news_dim, self.d_model) - else: - self.news_proj = None - - self.apply(self._init_weights) - - def _init_weights(self, module): - - if isinstance(module, nn.Linear): - nn.init.xavier_normal_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - nn.init.normal_(module.weight, mean=0, std=self.embedding.d_model ** -0.5) - elif isinstance(module, nn.LayerNorm): - nn.init.ones_(module.weight) - nn.init.zeros_(module.bias) - elif isinstance(module, RMSNorm): - nn.init.ones_(module.weight) - - def forward(self, s1_ids, s2_ids, stamp=None, padding_mask=None, use_teacher_forcing=False, s1_targets=None, news_emb=None): - """ - Args: - s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len] - s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len] - stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None. - padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. - use_teacher_forcing (bool, optional): Whether to use teacher forcing for s1 decoding. Defaults to False. - s1_targets (torch.Tensor, optional): Target s1 token IDs for teacher forcing. Shape: [batch_size, seq_len]. Defaults to None. - news_emb (torch.Tensor, optional): News embedding tensor. Shape: [batch_size, news_dim]. Defaults to None. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: - - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size] - - s2_logits: Logits for s2 token predictions, conditioned on s1. Shape: [batch_size, seq_len, s2_vocab_size] - """ - x = self.embedding([s1_ids, s2_ids]) - if stamp is not None: - time_embedding = self.time_emb(stamp) - x = x + time_embedding - x = self.token_drop(x) - - for layer in self.transformer: - x = layer(x, key_padding_mask=padding_mask) - - x = self.norm(x) - - if news_emb is not None and self.news_proj is not None: - news_bias = self.news_proj(news_emb).unsqueeze(1) # [B, 1, d_model] - x = x + news_bias - - s1_logits = self.head(x) - - if use_teacher_forcing: - sibling_embed = self.embedding.emb_s1(s1_targets) - else: - s1_probs = F.softmax(s1_logits.detach(), dim=-1) - sample_s1_ids = torch.multinomial(s1_probs.view(-1, self.s1_vocab_size), 1).view(s1_ids.shape) - sibling_embed = self.embedding.emb_s1(sample_s1_ids) - - x2 = self.dep_layer(x, sibling_embed, key_padding_mask=padding_mask) # Dependency Aware Layer: Condition on s1 embeddings - s2_logits = self.head.cond_forward(x2) - return s1_logits, s2_logits - - def decode_s1(self, s1_ids, s2_ids, stamp=None, padding_mask=None, news_emb=None): - """ - Decodes only the s1 tokens. - - This method performs a forward pass to predict only s1 tokens. It returns the s1 logits - and the context representation from the Transformer, which can be used for subsequent s2 decoding. - - Args: - s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len] - s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len] - stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None. - padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. - news_emb (torch.Tensor, optional): News embedding tensor. Shape: [batch_size, news_dim]. Defaults to None. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: - - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size] - - context: Context representation from the Transformer. Shape: [batch_size, seq_len, d_model] - """ - x = self.embedding([s1_ids, s2_ids]) - if stamp is not None: - time_embedding = self.time_emb(stamp) - x = x + time_embedding - x = self.token_drop(x) - - for layer in self.transformer: - x = layer(x, key_padding_mask=padding_mask) - - x = self.norm(x) - - if news_emb is not None and self.news_proj is not None: - news_bias = self.news_proj(news_emb).unsqueeze(1) # [B, 1, d_model] - x = x + news_bias - - s1_logits = self.head(x) - return s1_logits, x - - def decode_s2(self, context, s1_ids, padding_mask=None): - """ - Decodes the s2 tokens, conditioned on the context and s1 tokens. - - This method decodes s2 tokens based on a pre-computed context representation (typically from `decode_s1`) - and the s1 token IDs. It uses the dependency-aware layer and the conditional s2 head to predict s2 tokens. - - Args: - context (torch.Tensor): Context representation from the transformer (output of decode_s1). - Shape: [batch_size, seq_len, d_model] - s1_ids (torch.torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len] - padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. - - Returns: - torch.Tensor: s2 logits. Shape: [batch_size, seq_len, s2_vocab_size] - """ - sibling_embed = self.embedding.emb_s1(s1_ids) - x2 = self.dep_layer(context, sibling_embed, key_padding_mask=padding_mask) - return self.head.cond_forward(x2) - - -def top_k_top_p_filtering( - logits, - top_k: int = 0, - top_p: float = 1.0, - filter_value: float = -float("Inf"), - min_tokens_to_keep: int = 1, -): - """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering - Args: - logits: logits distribution shape (batch size, vocabulary size) - if top_k > 0: keep only top k tokens with highest probability (top-k filtering). - if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). - Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) - Make sure we keep at least min_tokens_to_keep per batch example in the output - From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 - """ - if top_k > 0: - top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check - # Remove all tokens with a probability less than the last token of the top-k - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits[indices_to_remove] = filter_value - return logits - - if top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold (token with 0 are kept) - sorted_indices_to_remove = cumulative_probs > top_p - if min_tokens_to_keep > 1: - # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) - sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 - # Shift the indices to the right to keep also the first token above the threshold - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - - # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - logits[indices_to_remove] = filter_value - return logits - - -def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_logits=True): - logits = logits / temperature - if top_k is not None or top_p is not None: - if top_k > 0 or top_p < 1.0: - logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) - - probs = F.softmax(logits, dim=-1) - - if not sample_logits: - _, x = top_k(probs, k=1, dim=-1) - else: - x = torch.multinomial(probs, num_samples=1) - - return x - - -def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False, news_emb=None): - with torch.no_grad(): - x = torch.clip(x, -clip, clip) - - device = x.device - x = x.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x.size(1), x.size(2)).to(device) - x_stamp = x_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x_stamp.size(1), x_stamp.size(2)).to(device) - y_stamp = y_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, y_stamp.size(1), y_stamp.size(2)).to(device) - - x_token = tokenizer.encode(x, half=True) - - initial_seq_len = x.size(1) - batch_size = x_token[0].size(0) - total_seq_len = initial_seq_len + pred_len - full_stamp = torch.cat([x_stamp, y_stamp], dim=1) - - generated_pre = x_token[0].new_empty(batch_size, pred_len) - generated_post = x_token[1].new_empty(batch_size, pred_len) - - pre_buffer = x_token[0].new_zeros(batch_size, max_context) - post_buffer = x_token[1].new_zeros(batch_size, max_context) - buffer_len = min(initial_seq_len, max_context) - if buffer_len > 0: - start_idx = max(0, initial_seq_len - max_context) - pre_buffer[:, :buffer_len] = x_token[0][:, start_idx:start_idx + buffer_len] - post_buffer[:, :buffer_len] = x_token[1][:, start_idx:start_idx + buffer_len] - - if verbose: - ran = trange - else: - ran = range - for i in ran(pred_len): - current_seq_len = initial_seq_len + i - window_len = min(current_seq_len, max_context) - - if current_seq_len <= max_context: - input_tokens = [ - pre_buffer[:, :window_len], - post_buffer[:, :window_len] - ] - else: - input_tokens = [pre_buffer, post_buffer] - - context_end = current_seq_len - context_start = max(0, context_end - max_context) - current_stamp = full_stamp[:, context_start:context_end, :].contiguous() - - s1_logits, context = model.decode_s1(input_tokens[0], input_tokens[1], current_stamp, news_emb=news_emb) - s1_logits = s1_logits[:, -1, :] - sample_pre = sample_from_logits(s1_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True) - - s2_logits = model.decode_s2(context, sample_pre) - s2_logits = s2_logits[:, -1, :] - sample_post = sample_from_logits(s2_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True) - - generated_pre[:, i] = sample_pre.squeeze(-1) - generated_post[:, i] = sample_post.squeeze(-1) - - if current_seq_len < max_context: - pre_buffer[:, current_seq_len] = sample_pre.squeeze(-1) - post_buffer[:, current_seq_len] = sample_post.squeeze(-1) - else: - pre_buffer.copy_(torch.roll(pre_buffer, shifts=-1, dims=1)) - post_buffer.copy_(torch.roll(post_buffer, shifts=-1, dims=1)) - pre_buffer[:, -1] = sample_pre.squeeze(-1) - post_buffer[:, -1] = sample_post.squeeze(-1) - - full_pre = torch.cat([x_token[0], generated_pre], dim=1) - full_post = torch.cat([x_token[1], generated_post], dim=1) - - context_start = max(0, total_seq_len - max_context) - input_tokens = [ - full_pre[:, context_start:total_seq_len].contiguous(), - full_post[:, context_start:total_seq_len].contiguous() - ] - z = tokenizer.decode(input_tokens, half=True) - z = z.reshape(-1, sample_count, z.size(1), z.size(2)) - preds = z.cpu().numpy() - preds = np.mean(preds, axis=1) - - return preds - - -def calc_time_stamps(x_timestamp): - time_df = pd.DataFrame() - time_df['minute'] = x_timestamp.dt.minute - time_df['hour'] = x_timestamp.dt.hour - time_df['weekday'] = x_timestamp.dt.weekday - time_df['day'] = x_timestamp.dt.day - time_df['month'] = x_timestamp.dt.month - return time_df - - -class KronosPredictor: - - def __init__(self, model, tokenizer, device="cuda:0", max_context=512, clip=5): - self.tokenizer = tokenizer - self.model = model - self.max_context = max_context - self.clip = clip - self.price_cols = ['open', 'high', 'low', 'close'] - self.vol_col = 'volume' - self.amt_vol = 'amount' - self.time_cols = ['minute', 'hour', 'weekday', 'day', 'month'] - self.device = device - - self.tokenizer = self.tokenizer.to(self.device) - self.model = self.model.to(self.device) - - def generate(self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose, news_emb=None): - - x_tensor = torch.from_numpy(np.array(x).astype(np.float32)).to(self.device) - x_stamp_tensor = torch.from_numpy(np.array(x_stamp).astype(np.float32)).to(self.device) - y_stamp_tensor = torch.from_numpy(np.array(y_stamp).astype(np.float32)).to(self.device) - - preds = auto_regressive_inference(self.tokenizer, self.model, x_tensor, x_stamp_tensor, y_stamp_tensor, self.max_context, pred_len, - self.clip, T, top_k, top_p, sample_count, verbose, news_emb=news_emb) - preds = preds[:, -pred_len:, :] - return preds - - def predict(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True, news_emb=None): - - if not isinstance(df, pd.DataFrame): - raise ValueError("Input must be a pandas DataFrame.") - - if not all(col in df.columns for col in self.price_cols): - raise ValueError(f"Price columns {self.price_cols} not found in DataFrame.") - - df = df.copy() - if self.vol_col not in df.columns: - df[self.vol_col] = 0.0 # Fill missing volume with zeros - df[self.amt_vol] = 0.0 # Fill missing amount with zeros - if self.amt_vol not in df.columns and self.vol_col in df.columns: - df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1) - - if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any(): - raise ValueError("Input DataFrame contains NaN values in price or volume columns.") - - x_time_df = calc_time_stamps(x_timestamp) - y_time_df = calc_time_stamps(y_timestamp) - - x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(np.float32) - x_stamp = x_time_df.values.astype(np.float32) - y_stamp = y_time_df.values.astype(np.float32) - - x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) - - x = (x - x_mean) / (x_std + 1e-5) - x = np.clip(x, -self.clip, self.clip) - - x = x[np.newaxis, :] - x_stamp = x_stamp[np.newaxis, :] - y_stamp = y_stamp[np.newaxis, :] - - if news_emb is not None: - news_emb_tensor = torch.from_numpy(np.array(news_emb).astype(np.float32)).to(self.device) - # Ensure batch dimension for news_emb if only one sample - if news_emb_tensor.ndim == 1: - news_emb_tensor = news_emb_tensor.unsqueeze(0) - else: - news_emb_tensor = None - - preds = self.generate(x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose, news_emb=news_emb_tensor) - - preds = preds.squeeze(0) - preds = preds * (x_std + 1e-5) + x_mean - - pred_df = pd.DataFrame(preds, columns=self.price_cols + [self.vol_col, self.amt_vol], index=y_timestamp) - return pred_df - - - def predict_batch(self, df_list, x_timestamp_list, y_timestamp_list, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True): - """ - Perform parallel (batch) prediction on multiple time series. All series must have the same historical length and prediction length (pred_len). - - Args: - df_list (List[pd.DataFrame]): List of input DataFrames, each containing price columns and optional volume/amount columns. - x_timestamp_list (List[pd.DatetimeIndex or Series]): List of timestamps corresponding to historical data, length should match the number of rows in each DataFrame. - y_timestamp_list (List[pd.DatetimeIndex or Series]): List of future prediction timestamps, length should equal pred_len. - pred_len (int): Number of prediction steps. - T (float): Sampling temperature. - top_k (int): Top-k filtering threshold. - top_p (float): Top-p (nucleus sampling) threshold. - sample_count (int): Number of parallel samples per series, automatically averaged internally. - verbose (bool): Whether to display autoregressive progress. - - Returns: - List[pd.DataFrame]: List of prediction results in the same order as input, each DataFrame contains - `open, high, low, close, volume, amount` columns, indexed by corresponding `y_timestamp`. - """ - # Basic validation - if not isinstance(df_list, (list, tuple)) or not isinstance(x_timestamp_list, (list, tuple)) or not isinstance(y_timestamp_list, (list, tuple)): - raise ValueError("df_list, x_timestamp_list, y_timestamp_list must be list or tuple types.") - if not (len(df_list) == len(x_timestamp_list) == len(y_timestamp_list)): - raise ValueError("df_list, x_timestamp_list, y_timestamp_list must have consistent lengths.") - - num_series = len(df_list) - - x_list = [] - x_stamp_list = [] - y_stamp_list = [] - means = [] - stds = [] - seq_lens = [] - y_lens = [] - - for i in range(num_series): - df = df_list[i] - if not isinstance(df, pd.DataFrame): - raise ValueError(f"Input at index {i} is not a pandas DataFrame.") - if not all(col in df.columns for col in self.price_cols): - raise ValueError(f"DataFrame at index {i} is missing price columns {self.price_cols}.") - - df = df.copy() - if self.vol_col not in df.columns: - df[self.vol_col] = 0.0 - df[self.amt_vol] = 0.0 - if self.amt_vol not in df.columns and self.vol_col in df.columns: - df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1) - - if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any(): - raise ValueError(f"DataFrame at index {i} contains NaN values in price or volume columns.") - - x_timestamp = x_timestamp_list[i] - y_timestamp = y_timestamp_list[i] - - x_time_df = calc_time_stamps(x_timestamp) - y_time_df = calc_time_stamps(y_timestamp) - - x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(np.float32) - x_stamp = x_time_df.values.astype(np.float32) - y_stamp = y_time_df.values.astype(np.float32) - - if x.shape[0] != x_stamp.shape[0]: - raise ValueError(f"Inconsistent lengths at index {i}: x has {x.shape[0]} vs x_stamp has {x_stamp.shape[0]}.") - if y_stamp.shape[0] != pred_len: - raise ValueError(f"y_timestamp length at index {i} should equal pred_len={pred_len}, got {y_stamp.shape[0]}.") - - x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) - x_norm = (x - x_mean) / (x_std + 1e-5) - x_norm = np.clip(x_norm, -self.clip, self.clip) - - x_list.append(x_norm) - x_stamp_list.append(x_stamp) - y_stamp_list.append(y_stamp) - means.append(x_mean) - stds.append(x_std) - - seq_lens.append(x_norm.shape[0]) - y_lens.append(y_stamp.shape[0]) - - # Require all series to have consistent historical and prediction lengths for batch processing - if len(set(seq_lens)) != 1: - raise ValueError(f"Parallel prediction requires all series to have consistent historical lengths, got: {seq_lens}") - if len(set(y_lens)) != 1: - raise ValueError(f"Parallel prediction requires all series to have consistent prediction lengths, got: {y_lens}") - - x_batch = np.stack(x_list, axis=0).astype(np.float32) # (B, seq_len, feat) - x_stamp_batch = np.stack(x_stamp_list, axis=0).astype(np.float32) # (B, seq_len, time_feat) - y_stamp_batch = np.stack(y_stamp_list, axis=0).astype(np.float32) # (B, pred_len, time_feat) - - preds = self.generate(x_batch, x_stamp_batch, y_stamp_batch, pred_len, T, top_k, top_p, sample_count, verbose) - # preds: (B, pred_len, feat) - - pred_dfs = [] - for i in range(num_series): - preds_i = preds[i] * (stds[i] + 1e-5) + means[i] - pred_df = pd.DataFrame(preds_i, columns=self.price_cols + [self.vol_col, self.amt_vol], index=y_timestamp_list[i]) - pred_dfs.append(pred_df) - - return pred_dfs diff --git a/skills/alphaear-predictor/scripts/utils/predictor/model/module.py b/skills/alphaear-predictor/scripts/utils/predictor/model/module.py deleted file mode 100644 index 20b29b5..0000000 --- a/skills/alphaear-predictor/scripts/utils/predictor/model/module.py +++ /dev/null @@ -1,562 +0,0 @@ -import math - -from einops import rearrange, reduce -import torch -import torch.nn as nn -from torch.autograd import Function -import torch.nn.functional as F - - -class DifferentiableEntropyFunction(Function): - @staticmethod - def forward(ctx, zq, basis, K, eps): - zb = (zq + 1) / 2 - zi = ((zb * basis).sum(-1)).to(torch.int64) - cnt = torch.scatter_reduce(torch.zeros(2 ** K, device=zq.device, dtype=zq.dtype), - 0, - zi.flatten(), - torch.ones_like(zi.flatten()).to(zq.dtype), - 'sum') - prob = (cnt + eps) / (cnt + eps).sum() - H = -(prob * torch.log(prob)).sum() - ctx.save_for_backward(zq, zi, prob) - ctx.K = K - return H - - @staticmethod - def backward(ctx, grad_output): - zq, zi, prob = ctx.saved_tensors - grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K - reord_grad = grad_array[zi.flatten()].reshape(zi.shape) - grad_input = reord_grad.unsqueeze(-1) * zq - return grad_input, None, None, None, None - - -def codebook_entropy(zq, basis, K, eps=1e-4): - return DifferentiableEntropyFunction.apply(zq, basis, K, eps) - - -class BinarySphericalQuantizer(nn.Module): - def __init__(self, embed_dim, beta, gamma0, gamma, zeta, - input_format='bchw', - soft_entropy=True, group_size=9, - persample_entropy_compute='analytical', - cb_entropy_compute='group', - l2_norm=True, - inv_temperature=1): - """ - Paper link: https://arxiv.org/pdf/2406.07548.pdf - Here we use the official implementation of the BinarySphericalQuantizer. - """ - super().__init__() - self.embed_dim = embed_dim - self.beta = beta # loss weight for commit loss - self.gamma0 = gamma0 # loss weight for entropy penalty - self.gamma = gamma # loss weight for entropy penalty - self.zeta = zeta # loss weight for entire entropy penalty - self.input_format = input_format - assert self.embed_dim % group_size == 0, "embed_dim must be divisible by group_size" - self.num_groups = self.embed_dim // group_size - self.group_size = group_size - assert persample_entropy_compute in ['group', 'analytical'], "persample_entropy_compute must be either 'group' or 'analytical'" - assert cb_entropy_compute in ['group', 'nce'], "cb_entropy_compute must be either 'group' or 'nce'" - self.persample_entropy_compute = persample_entropy_compute - self.cb_entropy_compute = cb_entropy_compute - self.l2_norm = l2_norm - self.inv_temperature = inv_temperature - - self.register_buffer('basis', 2 ** torch.arange(embed_dim - 1, -1, -1)) - self.register_buffer('group_basis', 2 ** torch.arange(group_size - 1, -1, -1)) - - self.num_dimensions = 2 ** embed_dim - self.bits_per_index = embed_dim - - # we only need to keep the codebook portion up to the group size - # because we approximate the H loss with this subcode - group_codes = torch.arange(2 ** self.group_size) - group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:] - self.register_buffer('group_codebook', group_codebook, persistent=False) - - self.soft_entropy = soft_entropy # soft_entropy: Sec 3.2 of https://arxiv.org/pdf/1911.05894.pdf - - def quantize(self, z): - assert z.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}" - - zhat = torch.where(z > 0, - torch.tensor(1, dtype=z.dtype, device=z.device), - torch.tensor(-1, dtype=z.dtype, device=z.device)) - return z + (zhat - z).detach() - - def forward(self, z, collect_metrics=True): - # if self.input_format == 'bchw': - # z = rearrange(z, 'b c h w -> b h w c') - zq = self.quantize(z) - - q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. - - zq = zq * q_scale - - if not collect_metrics: - return zq, zq.new_zeros(()), {} - - indices = self.codes_to_indexes(zq.detach()) - group_indices = self.codes_to_group_indexes(zq.detach()) - if not self.training: - used_codes = torch.unique(indices, return_counts=False) - else: - used_codes = None - - if self.soft_entropy: - persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(z) - entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy - else: - zb_by_sample = ((zq + 1) / 2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32) - persample_entropy = self.get_hard_per_sample_entropy(zb_by_sample) - cb_entropy = codebook_entropy(zq, self.basis, self.embed_dim) - entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy - - # commit loss - commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1)) - - # if self.input_format == 'bchw': - # zq = rearrange(zq, 'b h w c -> b c h w') - - return ( - zq, - commit_loss + self.zeta * entropy_penalty / self.inv_temperature, - {"H": cb_entropy, "used_codes": used_codes, "indices": indices, "group_indices": group_indices, - "avg_prob": avg_prob} - ) - - def soft_entropy_loss(self, z): - # if we divide the code in subgroups of size group_size, the codebook will be of size 2 ** group_size - # the sub-code is the last group_size bits of the full code - group_code_book = self.group_codebook / (self.embed_dim ** 0.5 if self.l2_norm else 1) - divided_z = rearrange(z, '... (g c) -> ... g c', c=self.group_size) - - # we calculate the distance between the divided_z and the codebook for each subgroup - distance = - 2 * torch.einsum('... g c, d c ->... g d', divided_z, group_code_book) - prob = (-distance * self.inv_temperature).softmax(dim=-1) - if self.persample_entropy_compute == 'analytical': - if self.l2_norm: - p = torch.sigmoid(-4 * z / (self.embed_dim ** 0.5) * self.inv_temperature) - else: - p = torch.sigmoid(-4 * z * self.inv_temperature) - prob = torch.stack([p, 1 - p], dim=-1) - per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() - else: - per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() - - # macro average of the probability of each subgroup - avg_prob = reduce(prob, '... g d ->g d', 'mean') - codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False) - - # the approximation of the entropy is the sum of the entropy of each subgroup - return per_sample_entropy, codebook_entropy.sum(), avg_prob - - def get_hard_per_sample_entropy(self, zb_by_sample): - probs_per_dim = zb_by_sample.sum(1) / zb_by_sample.shape[1] - persample_entropy = - probs_per_dim * torch.log(probs_per_dim + 1e-8) - (1 - probs_per_dim) * torch.log(1 - probs_per_dim + 1e-8) - persample_entropy = persample_entropy.sum(-1) - return persample_entropy.mean() - - def codes_to_indexes(self, zhat): - """Converts a `code` to an index in the codebook. - Args: - zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} - """ - assert zhat.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}" - return ((zhat + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64) - - def codes_to_group_indexes(self, zhat): - """Converts a `code` to a list of indexes (in groups) in the codebook. - Args: - zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} - """ - zhat_in_group = rearrange(zhat, 'b ... (g c) -> b ... g c', c=self.group_size) - return ((zhat_in_group + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64) - - def indexes_to_codes(self, indices): - """Inverse of `indexes_to_codes`.""" - indices = indices.unsqueeze(-1) - codes_non_centered = torch.remainder( - torch.floor_divide(indices, self.basis), 2 - ) - return codes_non_centered * 2 - 1 - - def group_indexes_to_codes(self, group_indices): - """Inverse of `group_indexes_to_codes`.""" - group_indices = group_indices.unsqueeze(-1) - codes_non_centered = torch.remainder( - torch.floor_divide(group_indices, self.group_basis), 2 - ) - codes_non_centered = rearrange(codes_non_centered, 'b ... g c -> b ... (g c)') - return codes_non_centered * 2 - 1 - - def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True): - if normalize: - probs = (count + eps) / (count + eps).sum(dim=dim, keepdim=True) - else: - probs = count - H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim) - return H - - def get_group_codebook_entry(self, group_indices): - z_q = self.group_indexes_to_codes(group_indices) - q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. - z_q = z_q * q_scale - if self.input_format == 'bchw': - h, w = int(z_q.shape[1] ** 0.5) - assert h * w == z_q.shape[1], 'Invalid sequence length' - z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h) - return z_q - - def get_codebook_entry(self, indices): - z_q = self.indexes_to_codes(indices) - q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. - z_q = z_q * q_scale - if self.input_format == 'bchw': - h, w = int(z_q.shape[1] ** 0.5) - assert h * w == z_q.shape[1], 'Invalid sequence length' - z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h) - return z_q - - -class BSQuantizer(nn.Module): - - def __init__(self, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size): - super().__init__() - self.codebook_dim = s1_bits + s2_bits - self.s1_bits = s1_bits - self.s2_bits = s2_bits - self.bsq = BinarySphericalQuantizer(self.codebook_dim, beta, gamma0, gamma, zeta, group_size=group_size) - - def bits_to_indices(self, bits): - bits = (bits >= 0).to(torch.long) - indices = 2 ** torch.arange( - 0, - bits.shape[-1], - 1, - dtype=torch.long, - device=bits.device, - ) - return (bits * indices).sum(-1) - - def forward(self, z, half=False, collect_metrics=True): - z = F.normalize(z, dim=-1) - quantized, bsq_loss, metrics = self.bsq(z, collect_metrics=collect_metrics) - if half: - q_pre = quantized[:, :, :self.s1_bits] - q_post = quantized[:, :, self.s1_bits:] - z_indices = [self.bits_to_indices(q_pre), self.bits_to_indices(q_post)] - else: - z_indices = self.bits_to_indices(quantized) - return bsq_loss, quantized, z_indices - - -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-5): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()).type_as(x) - return output * self.weight - - -class FeedForward(nn.Module): - def __init__(self, d_model, ff_dim, ffn_dropout_p=0.0): - super().__init__() - - self.w1 = nn.Linear(d_model, ff_dim, bias=False) - self.w3 = nn.Linear(d_model, ff_dim, bias=False) - self.w2 = nn.Linear(ff_dim, d_model, bias=False) - self.ffn_dropout = nn.Dropout(ffn_dropout_p) - - def forward(self, x): - return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) - - -class RotaryPositionalEmbedding(nn.Module): - def __init__(self, dim): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer("inv_freq", inv_freq) - self.seq_len_cached = None - self.cos_cached = None - self.sin_cached = None - - def _update_cos_sin_cache(self, x, seq_len): - if seq_len != self.seq_len_cached: - self.seq_len_cached = seq_len - t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) - freqs = torch.einsum('i,j->ij', t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.cos_cached = emb.cos()[None, None, :, :] - self.sin_cached = emb.sin()[None, None, :, :] - return self.cos_cached, self.sin_cached - - def forward(self, q, k): - cos, sin = self._update_cos_sin_cache(q, q.shape[-2]) - return ( - (q * cos) + (self._rotate_half(q) * sin), - (k * cos) + (self._rotate_half(k) * sin), - ) - - def _rotate_half(self, x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - -class MultiHeadAttentionWithRoPE(nn.Module): - def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout_p=0.0): - super().__init__() - self.d_model = d_model - self.n_heads = n_heads - self.head_dim = d_model // n_heads - - self.q_proj = nn.Linear(d_model, d_model) - self.k_proj = nn.Linear(d_model, d_model) - self.v_proj = nn.Linear(d_model, d_model) - self.out_proj = nn.Linear(d_model, d_model) - self.rotary = RotaryPositionalEmbedding(self.head_dim) - self.attn_dropout_p = attn_dropout_p - self.resid_dropout = nn.Dropout(resid_dropout_p) - - def forward(self, x, key_padding_mask=None): - batch_size, seq_len, _ = x.shape - - q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) - k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) - v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) - - q, k = self.rotary(q, k) - - if key_padding_mask is not None: - attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq_len] - attn_mask = attn_mask.expand(-1, self.n_heads, seq_len, -1) # [batch, n_heads, q_len, k_len] - else: - attn_mask = None - - attn_output = F.scaled_dot_product_attention( - q, k, v, - attn_mask=attn_mask, - dropout_p=self.attn_dropout_p if self.training else 0.0, - is_causal=True - ) - - attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) - return self.resid_dropout(self.out_proj(attn_output)) - - -class MultiHeadCrossAttentionWithRoPE(nn.Module): - def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout=0.0): - super().__init__() - self.d_model = d_model - self.n_heads = n_heads - self.head_dim = d_model // n_heads - - self.q_proj = nn.Linear(d_model, d_model) - self.k_proj = nn.Linear(d_model, d_model) - self.v_proj = nn.Linear(d_model, d_model) - self.out_proj = nn.Linear(d_model, d_model) - self.rotary = RotaryPositionalEmbedding(self.head_dim) - self.attn_dropout_p = attn_dropout_p - self.resid_dropout = nn.Dropout(resid_dropout) - - def forward(self, query, key, value, key_padding_mask=None): - batch_size, q_len, _ = query.shape - _, seq_len, _ = key.shape - - q = self.q_proj(query).view(batch_size, q_len, self.n_heads, self.head_dim).transpose(1, 2) - k = self.k_proj(key).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) - v = self.v_proj(value).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) - - q, k = self.rotary(q, k) - - if key_padding_mask is not None: - attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) - attn_mask = attn_mask.expand(-1, self.n_heads, q_len, -1) - else: - attn_mask = None - - is_causal_flag = self.training - - attn_output = F.scaled_dot_product_attention( - q, k, v, - attn_mask=attn_mask, - dropout_p=self.attn_dropout_p if self.training else 0.0, - is_causal=is_causal_flag - ) - - attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.d_model) - return self.resid_dropout(self.out_proj(attn_output)) - - -class HierarchicalEmbedding(nn.Module): - def __init__(self, s1_bits, s2_bits, d_model=256): - super().__init__() - self.s1_bits = s1_bits - self.s2_bits = s2_bits - - vocab_s1 = 2 ** s1_bits - vocab_s2 = 2 ** s2_bits - - self.emb_s1 = nn.Embedding(vocab_s1, d_model) - self.emb_s2 = nn.Embedding(vocab_s2, d_model) - self.d_model = d_model - self.fusion_proj = nn.Linear(d_model * 2, d_model) - - nn.init.normal_(self.emb_s1.weight, mean=0, std=d_model ** -0.5) - nn.init.normal_(self.emb_s2.weight, mean=0, std=d_model ** -0.5) - - def split_token(self, token_ids: torch.Tensor, s2_bits: int): - """Inputs: - token_ids (torch.Tensor): Composite token IDs of shape [batch_size, seq_len] or [N], each in range [0, 2^(s1_bits + s2_bits) - 1]. - s2_bits (int): Number of low bits used for the fine token (s2). - """ - assert isinstance(s2_bits, int) and s2_bits > 0, "s2_bits must be a positive integer" - - t = token_ids.long() - mask = (1 << s2_bits) - 1 - s2_ids = t & mask # extract low bits - s1_ids = t >> s2_bits # extract high bits - return s1_ids, s2_ids - - def forward(self, token_ids): - """Inputs: - token_ids: - - tuple or list: (s1_ids, s2_ids), each of shape [batch_size, seq_len], or - - torch.Tensor: composite token IDs of shape [batch_size, seq_len], which will be split into (s1_ids, s2_ids) internally. - Output: [batch_size, seq_len, d_model] - """ - if isinstance(token_ids, tuple) or isinstance(token_ids, list): - s1_ids, s2_ids = token_ids - else: - s1_ids, s2_ids = self.split_token(token_ids, self.s2_bits) - s1_emb = self.emb_s1(s1_ids) * math.sqrt(self.d_model) - s2_emb = self.emb_s2(s2_ids) * math.sqrt(self.d_model) - return self.fusion_proj(torch.cat([s1_emb, s2_emb], dim=-1)) - - -class DependencyAwareLayer(nn.Module): - def __init__(self, d_model, n_heads=4, attn_dropout_p=0.0, resid_dropout=0.0): - super().__init__() - self.cross_attn = MultiHeadCrossAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout) - self.norm = RMSNorm(d_model) - - def forward(self, hidden_states, sibling_embed, key_padding_mask=None): - """hidden_states: [batch, seq_len, d_model] - sibling_embed: Embedding from another subtoken - """ - attn_out = self.cross_attn( - query=sibling_embed, - key=hidden_states, - value=hidden_states, - key_padding_mask=key_padding_mask - ) - return self.norm(hidden_states + attn_out) - - -class TransformerBlock(nn.Module): - def __init__(self, d_model, n_heads, ff_dim=1024, ffn_dropout_p=0.0, attn_dropout_p=0.0, resid_dropout_p=0.0): - super().__init__() - self.norm1 = RMSNorm(d_model) - self.self_attn = MultiHeadAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout_p) - self.norm2 = RMSNorm(d_model) - self.ffn = FeedForward(d_model, ff_dim, ffn_dropout_p) - - def forward(self, x, key_padding_mask=None): - residual = x - x = self.norm1(x) - attn_out = self.self_attn(x, key_padding_mask=key_padding_mask) - x = residual + attn_out - - residual = x - x = self.norm2(x) - ffn_out = self.ffn(x) - x = residual + ffn_out - return x - - -class DualHead(nn.Module): - def __init__(self, s1_bits, s2_bits, d_model): - super().__init__() - self.vocab_s1 = 2 ** s1_bits - self.vocab_s2 = 2 ** s2_bits - self.proj_s1 = nn.Linear(d_model, self.vocab_s1) - self.proj_s2 = nn.Linear(d_model, self.vocab_s2) - - def compute_loss(self, s1_logits, s2_logits, s1_targets, s2_targets, padding_mask=None): - if padding_mask is not None: - valid_mask = (padding_mask == 0) - s1_logits = s1_logits[valid_mask] - s2_logits = s2_logits[valid_mask] - s1_targets = s1_targets[valid_mask] - s2_targets = s2_targets[valid_mask] - ce_s1 = F.cross_entropy(s1_logits, s1_targets) - ce_s2 = F.cross_entropy(s2_logits, s2_targets) - else: - ce_s1 = F.cross_entropy(s1_logits.reshape(-1, self.vocab_s1), s1_targets.reshape(-1)) - ce_s2 = F.cross_entropy(s2_logits.reshape(-1, self.vocab_s2), s2_targets.reshape(-1)) - ce_loss = (ce_s1 + ce_s2) / 2 - return ce_loss, ce_s1, ce_s2 - - def forward(self, x): - return self.proj_s1(x) - - def cond_forward(self, x2): - return self.proj_s2(x2) - - -class FixedEmbedding(nn.Module): - def __init__(self, c_in, d_model): - super(FixedEmbedding, self).__init__() - - w = torch.zeros(c_in, d_model).float() - w.require_grad = False - - position = torch.arange(0, c_in).float().unsqueeze(1) - div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() - - w[:, 0::2] = torch.sin(position * div_term) - w[:, 1::2] = torch.cos(position * div_term) - - self.emb = nn.Embedding(c_in, d_model) - self.emb.weight = nn.Parameter(w, requires_grad=False) - - def forward(self, x): - return self.emb(x).detach() - - -class TemporalEmbedding(nn.Module): - def __init__(self, d_model, learn_pe): - super(TemporalEmbedding, self).__init__() - - minute_size = 60 - hour_size = 24 - weekday_size = 7 - day_size = 32 - month_size = 13 - - Embed = FixedEmbedding if not learn_pe else nn.Embedding - self.minute_embed = Embed(minute_size, d_model) - self.hour_embed = Embed(hour_size, d_model) - self.weekday_embed = Embed(weekday_size, d_model) - self.day_embed = Embed(day_size, d_model) - self.month_embed = Embed(month_size, d_model) - - def forward(self, x): - x = x.long() - - minute_x = self.minute_embed(x[:, :, 0]) - hour_x = self.hour_embed(x[:, :, 1]) - weekday_x = self.weekday_embed(x[:, :, 2]) - day_x = self.day_embed(x[:, :, 3]) - month_x = self.month_embed(x[:, :, 4]) - - return hour_x + weekday_x + day_x + month_x + minute_x \ No newline at end of file diff --git a/skills/alphaear-predictor/scripts/utils/predictor/training.py b/skills/alphaear-predictor/scripts/utils/predictor/training.py deleted file mode 100644 index 3b41724..0000000 --- a/skills/alphaear-predictor/scripts/utils/predictor/training.py +++ /dev/null @@ -1,539 +0,0 @@ -import os -import sys -import time -import torch -import torch.nn as nn -import pandas as pd -import numpy as np -import json -import random -from loguru import logger -from datetime import datetime, timedelta -from sentence_transformers import SentenceTransformer -from skills._env_loader import load_unified_env - -load_unified_env() - -# Setup paths -KRONOS_DIR = os.path.dirname(os.path.abspath(__file__)) -SRC_DIR = os.path.dirname(os.path.dirname(KRONOS_DIR)) -if SRC_DIR not in sys.path: - sys.path.insert(0, SRC_DIR) - -from ..kronos.model import Kronos, KronosTokenizer, KronosPredictor -from ..database_manager import DatabaseManager -from ..stock_tools import StockTools -from ..search_tools import SearchTools -from ..llm.factory import get_model -from ..visualizer import VisualizerTools -from ..schema.models import ForecastResult, KLinePoint -from agno.agent import Agent - - -class AutoSynthesisTrainer: - def __init__(self, news_dim=384): - self.device = ( - "cuda" - if torch.cuda.is_available() - else "mps" - if torch.backends.mps.is_available() - else "cpu" - ) - self.db = DatabaseManager() - self.tools = StockTools(self.db) - self.searcher = SearchTools(self.db) - # Try loading from local cache first to avoid network timeouts - model_name = os.getenv( - "EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2" - ) - try: - logger.info(f"🔄 Attempting to load {model_name} from local cache...") - self.embedder = SentenceTransformer( - model_name, device=self.device, local_files_only=True - ) - logger.success("✅ Model loaded from local cache.") - except Exception: - logger.warning( - "⚠️ Local cache not found or incomplete. Attempting to download..." - ) - self.embedder = SentenceTransformer(model_name, device=self.device) - self.news_dim = news_dim - - # Try loading from local cache first to avoid network timeouts - try: - logger.info( - "🔄 Attempting to load Kronos and Tokenizer from local cache..." - ) - self.tokenizer = KronosTokenizer.from_pretrained( - "NeoQuasar/Kronos-Tokenizer-base", local_files_only=True - ).to(self.device) - base_model = Kronos.from_pretrained( - "NeoQuasar/Kronos-base", local_files_only=True - ) - logger.success("✅ Kronos and Tokenizer loaded from local cache.") - except Exception: - logger.warning( - "⚠️ Local Kronos/Tokenizer not found or incomplete. Attempting to download..." - ) - self.tokenizer = KronosTokenizer.from_pretrained( - "NeoQuasar/Kronos-Tokenizer-base" - ).to(self.device) - base_model = Kronos.from_pretrained("NeoQuasar/Kronos-base") - - self.model = Kronos( - base_model.s1_bits, - base_model.s2_bits, - base_model.n_layers, - base_model.d_model, - base_model.n_heads, - base_model.ff_dim, - base_model.ffn_dropout_p, - base_model.attn_dropout_p, - base_model.resid_dropout_p, - base_model.token_dropout_p, - base_model.learn_te, - news_dim=self.news_dim, - ).to(self.device) - self.model.load_state_dict(base_model.state_dict(), strict=False) - - # LLM for causality verification - provider = os.getenv("LLM_PROVIDER", "minimax") - model_id = os.getenv("LLM_MODEL", "Qwen") - self.llm_agent = Agent(model=get_model(provider, model_id)) - - def discover_shocks( - self, ticker_list, threshold=2.0, limit_per_stock=5, days=365, pred_len=5 - ): - """1. Find days with significant price movements (Look back 1 year)""" - shocks = [] - end_date = datetime.now().strftime("%Y-%m-%d") - start_date = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d") - - for ticker in ticker_list: - df = self.tools.get_stock_price( - ticker, start_date=start_date, end_date=end_date - ) - if df.empty or len(df) < 60: - continue - - # Look for big moves - moves = df[df["change_pct"].abs() > threshold].copy() - if moves.empty: - continue - - count = 0 - for idx, row in moves.iterrows(): - # Ensure we have history before this day AND enough future days for eval - date_idx = df.index.get_loc(idx) - if date_idx < 50 or date_idx + pred_len > len(df): - continue - - shocks.append( - { - "ticker": ticker, - "date": row["date"], - "change": row["change_pct"], - "history": df.iloc[date_idx - 50 : date_idx], - "target": df.iloc[ - date_idx : date_idx + pred_len - ], # Now capturing pred_len days - } - ) - count += 1 - if count >= limit_per_stock: - break - - logger.info( - f"✨ Discovered {len(shocks)} potential price shocks over the last {days} days." - ) - return shocks - - def find_reason_and_verify(self, shock): - """2. Search for reasons and verify causality using LLM""" - ticker_info = self.db.get_stock_by_code(shock["ticker"]) - name = ticker_info["name"] if ticker_info else shock["ticker"] - date_str = shock["date"] - - # Try multiple query variations and engines - queries = [ - f"{name} ({shock['ticker']}) {date_str} 为什么涨跌 原因", - f"{name} {date_str} 异动 原因", - f"{shock['ticker']} {date_str} 新闻", - ] - - search_results = [] - for query in queries: - logger.info(f"🔍 Searching for reason: {query}") - # Try alternate engines - for engine in ["baidu"]: - try: - results = self.searcher.search_list( - query, engine=engine, max_results=3, enrich=False - ) - if results: - search_results = results - break - except Exception as e: - logger.warning(f"Search failed for {query} on {engine}: {e}") - - if search_results: - break - time.sleep(random.uniform(1.0, 2.0)) - - if not search_results: - logger.warning( - f"⚠️ No search results found for {name} on {date_str} after multiple attempts." - ) - return None - - context = "\n".join( - [f"- {r['title']}: {r.get('content', '')[:300]}" for r in search_results] - ) - - prompt = f""" - 任务:判断以下新闻是否解释了该股票在 {date_str} 的 {shock["change"]:.2f}% 价格变动。 - - 股票:{name} - 日期:{date_str} - 变动:{shock["change"]:.2f}% - - 搜索结果: - {context} - - 要求: - 1. 该新闻是否在该日期左右发生? - 2. 该新闻是否能逻辑上解释这种大幅波动(如财报、利好政策、重组、大环境暴跌等)? - 3. 如果是,请总结一段 100 字以内的“核心推动原因”。 - 4. 返回 JSON: {{"is_causal": true/false, "summary": "原因摘要"}} - """ - - try: - res = self.llm_agent.run(prompt) - data = json.loads( - res.content.replace("```json", "").replace("```", "").strip() - ) - if data.get("is_causal"): - logger.success( - f"✅ Verified cause for {name} on {date_str}: {data['summary']}" - ) - return data["summary"] - else: - logger.warning( - f"❌ Verified cause for {name} on {date_str}: {data['summary']}" - ) - return None - except Exception as e: - logger.warning(f"Verification failed: {e}") - return None - - def save_model(self, path=None): - """Save the news_proj weights""" - if path is None: - save_dir = os.path.join(SRC_DIR, "exports/models") - os.makedirs(save_dir, exist_ok=True) - path = os.path.join( - save_dir, f"kronos_news_v1_{datetime.now().strftime('%Y%m%d_%H%M')}.pt" - ) - - # We only really need to save the news_proj part as it's the only one we train - torch.save( - { - "news_proj_state_dict": self.model.news_proj.state_dict(), - "news_dim": self.news_dim, - "d_model": self.model.d_model, - }, - path, - ) - logger.success(f"💾 Model weights saved to {path}") - return path - - def run_synthesis_and_train(self, tickers, pred_len=5): - # 1. Discovery - shocks = self.discover_shocks(tickers, pred_len=pred_len) - print(f"find {len(shocks)} shocks") - - # 2. News Association & Verification - dataset = [] - max_news_items = 200 # Limit to 200 news items per session to avoid search bans - - logger.info( - f"🧬 Starting News Association for {len(shocks)} shocks (Max limit: {max_news_items})" - ) - - for i, shock in enumerate(shocks): - if len(dataset) >= max_news_items: - logger.info("Reached maximum news items limit for this session.") - break - - summary = self.find_reason_and_verify(shock) - if summary: - # 3. Embedding news - emb = self.embedder.encode(summary) - dataset.append( - { - "history": shock["history"], - "target": shock["target"], - "news_emb": emb, - "summary": summary, - } - ) - - # Add delay after search with randomness to avoid being blocked - if i < len(shocks) - 1: - delay = random.uniform(2.0, 4.0) - time.sleep(delay) - - if not dataset: - logger.error( - "❌ No verified news-price pairs found. Adjust threshold or check if news is available in that period." - ) - return - - # 4. Train/Val Split - random.seed(42) - random.shuffle(dataset) - - if len(dataset) < 2: - train_set = dataset - val_set = [] - logger.warning( - f"⚠️ Only {len(dataset)} sample(s) found. Training on all, skipping validation." - ) - else: - split_idx = max(1, int(len(dataset) * 0.8)) - if split_idx >= len(dataset): - split_idx = len(dataset) - 1 - - train_set = dataset[:split_idx] - val_set = dataset[split_idx:] - logger.info( - f"🏗️ Dataset Split: {len(train_set)} samples for training, {len(val_set)} for validation." - ) - - if not train_set: - logger.error("❌ No samples for training.") - return - - # 5. Training (Few-shot) - optimizer = torch.optim.Adam(self.model.news_proj.parameters(), lr=1e-3) - criterion = nn.CrossEntropyLoss() - self.model.train() - - loss_history = [] - logger.info(f"🚀 Training for 30 epochs...") - for epoch in range(30): - total_loss = 0 - for item in train_set: - optimizer.zero_grad() - - # Prep Data - hist_df = item["history"] - # For training, we still focus on the immediate next point (teacher forcing) - target_df = item["target"].iloc[:1] - - hist_raw = hist_df[ - ["open", "high", "low", "close", "volume"] - ].values.astype(np.float32) - hist_raw = np.column_stack([hist_raw, hist_raw[:, 3] * hist_raw[:, 4]]) - - mean, std = hist_raw.mean(axis=0), hist_raw.std(axis=0) + 1e-5 - hist_norm = ( - torch.from_numpy((hist_raw - mean) / std) - .unsqueeze(0) - .to(self.device) - ) - - target_raw = target_df[ - ["open", "high", "low", "close", "volume"] - ].values.astype(np.float32) - target_raw = np.column_stack( - [target_raw, target_raw[:, 3] * target_raw[:, 4]] - ) - target_norm = ( - torch.from_numpy((target_raw - mean) / std) - .unsqueeze(0) - .to(self.device) - ) - - with torch.no_grad(): - z_indices = self.tokenizer.encode(hist_norm, half=True) - t_indices = self.tokenizer.encode(target_norm, half=True) - s1_ids, s2_ids = z_indices[0], z_indices[1] - t_s1, t_s2 = t_indices[0], t_indices[1] - - news_t = torch.from_numpy(item["news_emb"]).unsqueeze(0).to(self.device) - s1_logits, s2_logits = self.model( - s1_ids, - s2_ids, - news_emb=news_t, - use_teacher_forcing=True, - s1_targets=t_s1, - ) - - loss = ( - criterion(s1_logits[:, -1, :], t_s1[:, 0]) - + criterion(s2_logits[:, -1, :], t_s2[:, 0]) - ) / 2 - loss.backward() - optimizer.step() - total_loss += loss.item() - - avg_epoch_loss = total_loss / max(1, len(train_set)) - loss_history.append(avg_epoch_loss) - - if (epoch + 1) % 10 == 0: - logger.info(f"Epoch {epoch + 1} Loss: {avg_epoch_loss:.4f}") - - # 5.1 Visualize Loss Curve - loss_chart = VisualizerTools.generate_loss_chart(loss_history) - VisualizerTools.render_chart_to_file( - loss_chart, - os.path.join(SRC_DIR, "exports/training_results/loss_curve.html"), - ) - - # 5.2 Save final model - self.save_model() - - # 6. Final Evaluation on Validation Set - if not val_set: - logger.warning("⚠️ Validation set is empty. Skipping statistical analysis.") - return - - logger.info( - f"🧪 Final Evaluation: Base vs News-Integrated ({pred_len}-day Window)" - ) - self.model.eval() - predictor = KronosPredictor(self.model, self.tokenizer, device=self.device) - - base_maes = [] - news_maes = [] - - print("\n" + "=" * 90) - print( - f"{'Date':<12} | {'Ticker':<8} | {'Base MAE':<15} | {'News MAE':<15} | {'Improvement'}" - ) - print("-" * 90) - - for item in val_set: - h = item["history"] - t = item["target"] - actuals = t["close"].values[:pred_len] - - x_ts = pd.to_datetime(h["date"]) - # Future timestamps: handle business days if possible, or just simple offset - future_dates = pd.date_range( - start=x_ts.iloc[-1] + timedelta(days=1), periods=pred_len, freq="B" - ) - y_ts = pd.Series(future_dates) - - # A. Base Prediction - p_base = predictor.predict( - h, x_ts, y_ts, pred_len=pred_len, news_emb=None, verbose=False - ) - b_preds = p_base["close"].values[: len(actuals)] - - # B. News-Aware Prediction - p_news = predictor.predict( - h, - x_ts, - y_ts, - pred_len=pred_len, - news_emb=item["news_emb"], - verbose=False, - ) - n_preds = p_news["close"].values[: len(actuals)] - - # Calculate MAE over the window - b_mae = np.mean(np.abs(b_preds - actuals)) - n_mae = np.mean(np.abs(n_preds - actuals)) - - base_maes.append(b_mae) - news_maes.append(n_mae) - - improvement = (b_mae - n_mae) / (b_mae + 1e-6) * 100 - - date_str = str(t["date"].values[0])[:10] - ticker = h.iloc[-1]["ticker"] if "ticker" in h.columns else "Stock" - print( - f"{date_str:<12} | {ticker:<8} | {b_mae:<15.4f} | {n_mae:<15.4f} | {improvement:>+7.1f}%" - ) - - # C. Generate Visualization for this case - try: - # Helper to convert DF to KLinePoints - def to_kp_list(preds_df): - points = [] - for idx, row in preds_df.iterrows(): - points.append( - KLinePoint( - date=str(idx)[:10], - open=row["open"], - high=row["high"], - low=row["low"], - close=row["close"], - volume=row["volume"] if "volume" in row else 0, - ) - ) - return points - - forecast_obj = ForecastResult( - ticker=ticker, - base_forecast=to_kp_list(p_base), - adjusted_forecast=to_kp_list(p_news), - rationale=item["summary"], - ) - - # Ground truth for visualizer expects a DataFrame with 'date' and 'close' - gt_df = t[["date", "open", "high", "low", "close", "volume"]] - - chart = VisualizerTools.generate_stock_chart( - df=h, - ticker=ticker, - title=f"Training Eval: {ticker} ({date_str}) Improvement: {improvement:.1f}%", - forecast=forecast_obj, - ground_truth=gt_df, - ) - - safe_date = date_str.replace("-", "") - filename = f"eval_{ticker}_{safe_date}.html" - VisualizerTools.render_chart_to_file( - chart, os.path.join(SRC_DIR, f"exports/training_results/{filename}") - ) - except Exception as e: - logger.error(f"Failed to generate eval chart for {ticker}: {e}") - - # Summary Statistics - avg_base_err = sum(base_maes) / max(1, len(base_maes)) - avg_news_err = sum(news_maes) / max(1, len(news_maes)) - overall_imp = (avg_base_err - avg_news_err) / (avg_base_err + 1e-6) * 100 - - print("-" * 90) - print( - f"{'AVERAGE':<12} | {'-':<8} | {avg_base_err:<15.4f} | {avg_news_err:<15.4f} | {overall_imp:>+7.1f}%" - ) - print("=" * 90 + "\n") - - logger.success( - f"🏁 Statistical Analysis Complete. Avg Error Reduction ({pred_len}-day): {overall_imp:.2f}%" - ) - logger.info( - f"📊 Visualization results saved to: {os.path.join(SRC_DIR, 'exports/training_results/')}" - ) - - -if __name__ == "__main__": - trainer = AutoSynthesisTrainer() - - logger.info("📂 Fetching all stock codes from database...") - res = trainer.db.execute_query("SELECT code FROM stock_list") - all_tickers = [row["code"] for row in res] - - if not all_tickers: - logger.warning("⚠️ No tickers found in stock_list table. Trying to sync...") - trainer.tools._check_and_update_stock_list(force=True) - res = trainer.db.execute_query("SELECT code FROM stock_list") - all_tickers = [row["code"] for row in res] - - logger.info(f"🚀 Starting training on potential stocks (1-year scan)...") - # 为了演示,我们扫描前 100 个股票,寻找最近一年的冲击点 - trainer.run_synthesis_and_train(all_tickers[:100], pred_len=1) diff --git a/skills/alphaear-predictor/scripts/utils/search_tools.py b/skills/alphaear-predictor/scripts/utils/search_tools.py deleted file mode 100644 index 50b08f3..0000000 --- a/skills/alphaear-predictor/scripts/utils/search_tools.py +++ /dev/null @@ -1,611 +0,0 @@ -import os -import hashlib -import json -import re -import requests -import time -import threading -from typing import List, Dict, Optional, Any -from agno.tools.duckduckgo import DuckDuckGoTools -from agno.tools.baidusearch import BaiduSearchTools -from agno.agent import Agent -from loguru import logger -from datetime import datetime -from .database_manager import DatabaseManager -from .content_extractor import ContentExtractor -from .llm.factory import get_model -from .hybrid_search import LocalNewsSearch - -# 默认搜索缓存 TTL(秒),可通过环境变量覆盖 -DEFAULT_SEARCH_TTL = int(os.getenv("SEARCH_CACHE_TTL", "3600")) # 默认 1 小时 - - -class JinaSearchEngine: - """Jina Search API 封装 - 使用 s.jina.ai 进行网络搜索""" - - JINA_SEARCH_URL = "https://s.jina.ai/" - - # 速率限制配置 - _rate_limit_no_key = 10 # 无 key 时每分钟最大请求数 - _rate_window = 60.0 - _min_interval = 2.0 - _request_times = [] - _last_request_time = 0.0 - _lock = threading.Lock() - - def __init__(self): - self.api_key = os.getenv("JINA_API_KEY", "").strip() - self.has_api_key = bool(self.api_key) - if self.has_api_key: - logger.info("✅ Jina Search API key configured") - - @classmethod - def _wait_for_rate_limit(cls, has_api_key: bool) -> None: - """等待以满足速率限制""" - if has_api_key: - time.sleep(0.3) - return - - with cls._lock: - current_time = time.time() - cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window] - - if len(cls._request_times) >= cls._rate_limit_no_key: - oldest = cls._request_times[0] - wait_time = cls._rate_window - (current_time - oldest) + 1.0 - if wait_time > 0: - logger.warning(f"⏳ Jina Search rate limit, waiting {wait_time:.1f}s...") - time.sleep(wait_time) - current_time = time.time() - cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window] - - time_since_last = current_time - cls._last_request_time - if time_since_last < cls._min_interval: - time.sleep(cls._min_interval - time_since_last) - - cls._request_times.append(time.time()) - cls._last_request_time = time.time() - - def search(self, query: str, max_results: int = 5) -> List[Dict]: - """ - 使用 Jina Search API 执行搜索 - - Args: - query: 搜索关键词 - max_results: 返回结果数量 - - Returns: - 搜索结果列表,每个结果包含 title, url, content - """ - if not query: - return [] - - logger.info(f"🔍 Jina Search: {query}") - - # 等待速率限制 - self._wait_for_rate_limit(self.has_api_key) - - headers = { - "Accept": "application/json", - "X-Retain-Images": "none", - } - - if self.has_api_key: - headers["Authorization"] = f"Bearer {self.api_key}" - - try: - # Jina Search API: https://s.jina.ai/{query} - import urllib.parse - encoded_query = urllib.parse.quote(query) - url = f"{self.JINA_SEARCH_URL}{encoded_query}" - - response = requests.get(url, headers=headers, timeout=30) - - if response.status_code == 429: - logger.warning("⚠️ Jina Search rate limited (429), waiting 30s...") - time.sleep(30) - return self.search(query, max_results) - - if response.status_code != 200: - logger.warning(f"Jina Search failed (Status {response.status_code})") - return [] - - # 解析响应 - try: - data = response.json() - except json.JSONDecodeError: - # 如果返回纯文本,尝试解析 - data = {"data": [{"title": "Search Result", "url": "", "content": response.text}]} - - results = [] - - # Jina 返回格式可能是 {"data": [...]} 或直接是列表 - items = data.get("data", []) if isinstance(data, dict) else data - if not isinstance(items, list): - items = [items] if items else [] - - for i, item in enumerate(items[:max_results]): - if isinstance(item, dict): - results.append({ - "title": item.get("title", f"Result {i+1}"), - "url": item.get("url", ""), - "href": item.get("url", ""), # 兼容性 - "content": item.get("content", item.get("description", "")), - "body": item.get("content", item.get("description", "")), # 兼容性 - }) - elif isinstance(item, str): - results.append({ - "title": f"Result {i+1}", - "url": "", - "content": item - }) - - logger.info(f"✅ Jina Search returned {len(results)} results") - return results - - except requests.exceptions.Timeout: - logger.error("Jina Search timeout") - return [] - except requests.exceptions.RequestException as e: - logger.error(f"Jina Search request error: {e}") - return [] - except Exception as e: - logger.error(f"Jina Search unexpected error: {e}") - return [] - -class SearchTools: - """扩展性搜索工具库 - 支持多引擎聚合与内容缓存""" - - def __init__(self, db: DatabaseManager): - self.db = db - - # 检查 Jina API Key 是否配置 - jina_api_key = os.getenv("JINA_API_KEY", "").strip() - self._jina_enabled = bool(jina_api_key) - - self._engines = { - "ddg": DuckDuckGoTools(), - "baidu": BaiduSearchTools(), - "local": LocalNewsSearch(db) - } - - # 如果配置了 Jina API Key,添加 Jina 引擎 - if self._jina_enabled: - self._engines["jina"] = JinaSearchEngine() - logger.info("🚀 Jina Search engine enabled (JINA_API_KEY configured)") - - # 确定默认搜索引擎 - self._default_engine = "jina" if self._jina_enabled else "ddg" - - def _generate_hash(self, query: str, engine: str, max_results: int) -> str: - return hashlib.md5(f"{engine}:{query}:{max_results}".encode()).hexdigest() - - def search(self, query: str, engine: str = None, max_results: int = 5, ttl: Optional[int] = None) -> str: - """ - 使用指定搜索引擎执行网络搜索,结果会被缓存以提高效率。 - - Args: - query: 搜索关键词,如 "英伟达财报" 或 "光伏行业政策"。 - engine: 搜索引擎选择。可选值: - "jina" (Jina Search,需配置 JINA_API_KEY,LLM友好输出), - "ddg" (DuckDuckGo,推荐英文/国际搜索), - "baidu" (百度,推荐中文/国内搜索), - "local" (本地历史新闻搜索,基于向量+BM25)。 - 默认: 若配置了 JINA_API_KEY 则使用 "jina",否则 "ddg"。 - max_results: 期望返回的结果数量,默认 5 条。 - ttl: 缓存有效期(秒)。如果缓存超过此时间会重新搜索。 - 默认使用环境变量 SEARCH_CACHE_TTL 或 3600 秒。 - 设为 0 可强制刷新。 - - Returns: - 搜索结果的文本描述,包含标题、摘要和链接。 - """ - # 使用默认引擎(如果配置了 Jina 则优先使用 Jina) - if engine is None: - engine = self._default_engine - - if engine not in self._engines: - return f"Error: Unsupported engine '{engine}'. Available: {list(self._engines.keys())}" - - query_hash = self._generate_hash(query, engine, max_results) - effective_ttl = ttl if ttl is not None else DEFAULT_SEARCH_TTL - - # 1. 尝试从缓存读取 (local 引擎不缓存,因为它本身就是查库) - if engine != "local": - cache = self.db.get_search_cache(query_hash, ttl_seconds=effective_ttl if effective_ttl > 0 else None) - if cache and effective_ttl != 0: - logger.info(f"ℹ️ Found search results in cache for: {query} ({engine})") - return cache['results'] - - # 2. 执行真实搜索 - logger.info(f"📡 Searching {engine} for: {query}") - try: - tool = self._engines[engine] - if engine == "jina": - # Jina Search 返回 List[Dict] - jina_results = tool.search(query, max_results=max_results) - results = [] - for r in jina_results: - results.append({ - "title": r.get("title", ""), - "href": r.get("url", ""), - "body": r.get("content", "") - }) - elif engine == "ddg": - results = tool.duckduckgo_search(query, max_results=max_results) - elif engine == "baidu": - results = tool.baidu_search(query, max_results=max_results) - elif engine == "local": - # LocalNewsSearch 返回的是 List[Dict] - local_results = tool.search(query, top_n=max_results) - results = [] - for r in local_results: - results.append({ - "title": r.get("title"), - "href": r.get("url", "local"), - "body": r.get("content", "") - }) - else: - results = "Search not implemented for this engine." - - results_str = str(results) - if engine != "local": - self.db.save_search_cache(query_hash, query, engine, results_str) - return results_str - - except Exception as e: - # 搜索失败时的降级策略 - if engine == "jina": - logger.warning(f"⚠️ Jina search failed, falling back to ddg: {query} ({e})") - try: - return self.search(query, engine="ddg", max_results=max_results, ttl=ttl) - except Exception as e2: - logger.error(f"❌ DDG fallback also failed for {query}: {e2}") - elif engine == "ddg": - logger.warning(f"⚠️ DDG search failed, falling back to baidu: {query} ({e})") - try: - return self.search(query, engine="baidu", max_results=max_results, ttl=ttl) - except Exception as e2: - logger.error(f"❌ Baidu fallback also failed for {query}: {e2}") - - logger.error(f"❌ Search failed for {query}: {e}") - return f"Error occurred during search: {str(e)}" - - def search_list(self, query: str, engine: str = None, max_results: int = 5, ttl: Optional[int] = None, enrich: bool = True) -> List[Dict]: - """ - 执行搜索并返回结构化列表 (List[Dict])。 - Dict 包含: title, href (or url), body (or snippet) - - Args: - engine: 搜索引擎,默认使用配置的默认引擎(Jina 优先) - enrich: 是否抓取正文内容 (默认 True) - """ - # 使用默认引擎 - if engine is None: - engine = self._default_engine - - if engine not in self._engines: - logger.error(f"Unsupported engine {engine}") - return [] - - # 不同的 hash 以区分是否 enrichment - enrich_suffix = ":enriched" if enrich else "" - query_hash = self._generate_hash(query, engine + enrich_suffix, max_results) - effective_ttl = ttl if ttl is not None else DEFAULT_SEARCH_TTL - - # 1. 尝试从缓存读取 - cache = self.db.get_search_cache(query_hash, ttl_seconds=effective_ttl if effective_ttl > 0 else None) - if cache and effective_ttl != 0: - try: - cached_data = json.loads(cache['results']) - if isinstance(cached_data, list): - logger.info(f"ℹ️ Found structured search cache for: {query}") - return cached_data - except: - pass - - # 1.5 Smart Cache (Fuzzy + LLM) - if effective_ttl != 0: - try: - # 1. Similar cached queries - similar_queries = self.db.find_similar_queries(query, limit=3) - # Filter by TTL - valid_candidates = [] - for q in similar_queries: - if q['query'] == query: continue - q_time = datetime.fromisoformat(q['timestamp']) - if effective_ttl and (datetime.now() - q_time).total_seconds() > effective_ttl: - continue - q['type'] = 'cached_search' - valid_candidates.append(q) - - # 2. Relevant local news (as search results) - local_news = self.db.search_local_news(query, limit=3) - if local_news: - # Group local news as a single "candidate" source? Or individual? - # Better to treat "Local News Database" as one candidate source that contains X items. - # Or just add them to candidates list? - # Let's package strictly relevant news as a "local_news_bundle" - valid_candidates.append({ - 'type': 'local_news', - 'query': 'Local Database News', - 'items': local_news, - 'timestamp': datetime.now().isoformat() - }) - - if valid_candidates: - logger.info(f"🤔 Found {len(valid_candidates)} smart cache candidates (Queries/News). Asking LLM...") - evaluation = self._evaluate_cache_relevance(query, valid_candidates) - - if evaluation and evaluation.get('reuse', False): - idx = evaluation.get('index', -1) - if 0 <= idx < len(valid_candidates): - chosen = valid_candidates[idx] - logger.info(f"🤖 LLM suggested reusing: '{chosen.get('query')}' ({chosen['type']})") - - if chosen['type'] == 'cached_search': - # Load the chosen cache - cache = self.db.get_search_cache(chosen['query_hash']) - if cache: - try: - cached_data = json.loads(cache['results']) - if isinstance(cached_data, list): - return cached_data - except: - pass - elif chosen['type'] == 'local_news': - # Convert local news items to search result format - news_results = [] - for i, news in enumerate(chosen['items'], 1): - news_results.append({ - "id": news.get('id'), - "rank": i, - "title": news.get('title'), - "url": news.get('url'), - "content": news.get('content'), - "original_snippet": news.get('content')[:200] if news.get('content') else '', - "source": f"Local News ({news.get('source')})", - "publish_time": news.get('publish_time'), - "crawl_time": news.get('crawl_time'), - "sentiment_score": news.get('sentiment_score', 0), - "meta_data": {"origin": "local_db"} - }) - return news_results - - except Exception as e: - logger.warning(f"Smart cache check failed: {e}") - - # 2. 执行搜索 - logger.info(f"📡 Searching {engine} (structured) for: {query}") - try: - tool = self._engines[engine] - results = [] - if engine == "jina": - # Jina Search 直接返回结构化数据 - jina_results = tool.search(query, max_results=max_results) - for r in jina_results: - results.append({ - "title": r.get("title", ""), - "url": r.get("url", ""), - "href": r.get("url", ""), - "body": r.get("content", ""), - "content": r.get("content", ""), - "source": "Jina Search" - }) - elif engine == "ddg": - results = tool.duckduckgo_search(query, max_results=max_results) - elif engine == "baidu": - results = tool.baidu_search(query, max_results=max_results) - elif engine == "local": - # LocalNewsSearch 返回的是 List[Dict] - local_results = tool.search(query, top_n=max_results) - results = [] - for r in local_results: - results.append({ - "title": r.get("title"), - "url": r.get("url", "local"), - "body": r.get("content", "")[:500], - "source": f"Local ({r.get('source', 'db')})", - "publish_time": r.get("publish_time") - }) - - # 处理字符串类型的 JSON 返回 (Baidu 常返 JSON 字符串) - if isinstance(results, str) and engine not in ["local", "jina"]: - try: - results = json.loads(results) - except: - pass - - # 转为统一格式 - normalized_results = [] - if isinstance(results, list): - - for i, r in enumerate(results, 1): - title = r.get('title', '') - url = r.get('href') or r.get('url') or r.get('link', '') - content = r.get('body') or r.get('snippet') or r.get('abstract', '') - - if title and url: - normalized_results.append({ - "id": self._generate_hash(url + query, "search_item", i), - "rank": i, - "title": title, - "url": url, - "content": content, - "original_snippet": content, # 保留摘要 - "source": f"Search ({engine})", - "publish_time": datetime.now().isoformat(), # 暂用当前时间 - "crawl_time": datetime.now().isoformat(), - "meta_data": {"query": query, "engine": engine} - }) - - # Fallback if still string and failed to parse - elif isinstance(results, str) and results: - normalized_results.append({"title": query, "url": "", "content": results, "source": engine}) - - # 3. 抓取正文 & 计算情绪 (Enrichment) - # 注意:如果使用 Jina Search,内容已经是 LLM 友好格式,可选择跳过 enrichment - skip_content_enrichment = (engine == "jina") - - if enrich and normalized_results: - logger.info(f"🕸️ Enriching {len(normalized_results)} search results with Jina & Sentiment...") - extractor = ContentExtractor() - - # Lazy load sentiment tool - if not hasattr(self, 'sentiment_tool') or self.sentiment_tool is None: - from ..sentiment_tools import SentimentTools - self.sentiment_tool = SentimentTools(self.db) - - for item in normalized_results: - if item.get("url"): - try: - # 如果是 Jina Search,内容已经足够好,跳过额外抓取 - if skip_content_enrichment and item.get("content") and len(item.get("content", "")) > 100: - full_content = item["content"] - else: - # Use Jina Reader to get full content - full_content = extractor.extract_with_jina(item["url"], timeout=60) - - if full_content and len(full_content) > 100: - item["content"] = full_content - - # Calculate sentiment - # Use title + snippet of content for efficiency - text_to_analyze = f"{item['title']} {full_content[:500]}" - sent_result = self.sentiment_tool.analyze_sentiment(text_to_analyze) # Using self.sentiment_tool - score = sent_result.get('score', 0.0) - item["sentiment_score"] = float(score) - - logger.info(f" ✅ Enriched: {item['title'][:20]}... (Sentiment: {score:.2f})") - else: - # Fallback: Use snippet for sentiment - logger.info(f" ⚠️ Content short/failed for {item['url']}, using snippet for sentiment.") - text_to_analyze = f"{item['title']} {item['content']}" # content is snippet here - sent_result = self.sentiment_tool.analyze_sentiment(text_to_analyze) - score = sent_result.get('score', 0.0) - item["sentiment_score"] = float(score) - - except Exception as e: - # Fallback: Use snippet for sentiment on error - logger.warning(f"Failed to enrich {item['url']}: {e}. Using snippet.") - text_to_analyze = f"{item['title']} {item['content']}" - sent_result = self.sentiment_tool.analyze_sentiment(text_to_analyze) - score = sent_result.get('score', 0.0) - item["sentiment_score"] = float(score) - - # 缓存结果 list - if normalized_results: - # Pass list directly, DB manager will handle JSON dump for main cache and populate search_details - # Only cache if NOT from local news reuse (though this logic path is for fresh search) - self.db.save_search_cache(query_hash, query, engine, normalized_results) - - return normalized_results - - except Exception as e: - # 搜索失败时的降级策略 - if engine == "jina": - logger.warning(f"⚠️ Jina search_list failed, falling back to ddg: {query} ({e})") - try: - return self.search_list(query, engine="ddg", max_results=max_results, ttl=ttl, enrich=enrich) - except Exception as e2: - logger.error(f"❌ DDG fallback (search_list) also failed for {query}: {e2}") - elif engine == "ddg": - logger.warning(f"⚠️ DDG search_list failed, falling back to baidu: {query} ({e})") - try: - return self.search_list(query, engine="baidu", max_results=max_results, ttl=ttl, enrich=enrich) - except Exception as e2: - logger.error(f"❌ Baidu fallback (search_list) also failed for {query}: {e2}") - - logger.error(f"❌ Structured search failed for {query}: {e}") - return [] - - def _evaluate_cache_relevance(self, current_query: str, candidates: List[Dict]) -> Dict: - """ - 使用 LLM 评估缓存候选是否足以回答当前问题。 - """ - try: - # Prepare candidates text - candidates_desc = [] - for i, c in enumerate(candidates): - if c['type'] == 'cached_search': - # Preview cached results if available? - # Maybe just use the query string as a proxy for what's in there. - # Or peek at 'results' snippet. - preview = "" - try: - # Attempt to peek first result title from JSON string - # Note: c.get('results') might be a stringified JSON list - res_list = json.loads(c.get('results', '[]')) - if res_list and isinstance(res_list, list) and len(res_list) > 0: - first_item = res_list[0] - if isinstance(first_item, dict) and 'title' in first_item: - preview = f" (Contains: {first_item.get('title', '')[:50]}...)" - except: - pass - candidates_desc.append(f"[{i}] Old Search Query: '{c['query']}' {preview} (Time: {c['timestamp']})") - elif c['type'] == 'local_news': - # List titles of local news - titles = [item['title'] for item in c['items'][:3]] - candidates_desc.append(f"[{i}] Local Database News: {', '.join(titles)}... (Time: {c['timestamp']})") - - prompt = f""" - Task: Decide if existing information is sufficient for the new search query. - - New Query: "{current_query}" - - Available Information Candidates: - {chr(10).join(candidates_desc)} - - Instructions: - 1. Analyze if any candidate provides ENOUGH up-to-date info for the "New Query". - 2. If yes, choose the best one. - 3. If the query implies needing LATEST real-time info and candidates are old, choose none. - 4. Return strictly JSON: {{"reuse": true/false, "index": , "reason": "short explanation"}} - """ - # 初始化模型 - provider = os.getenv("LLM_PROVIDER", "minimax") - model_id = os.getenv("LLM_MODEL", "Qwen") - host = os.getenv("LLM_HOST") - if host: - model = get_model(provider, model_id, host=host) - else: - model = get_model(provider, model_id) - - agent = Agent(model=model, markdown=True) - - response = agent.run(prompt) - content = response.content - - # Parse JSON - json_match = re.search(r'```json\s*(.*?)\s*```', content, re.DOTALL) - if json_match: - return json.loads(json_match.group(1)) - elif '{' in content: - # Fallback for cases where LLM doesn't wrap in ```json - return json.loads(content[content.find('{'):content.rfind('}')+1]) - return {"reuse": False} - - except Exception as e: - logger.warning(f"LLM evaluation failed: {e}") - return {"reuse": False} - - def aggregate_search(self, query: str, engines: Optional[List[str]] = None, max_results: int = 5) -> str: - """ - 使用多个搜索引擎同时搜索并聚合结果,获得更全面的信息覆盖。 - - Args: - query: 搜索关键词。 - engines: 要使用的搜索引擎列表。可选值: ["ddg", "baidu"]。 - 默认同时使用 ddg 和 baidu。 - max_results: 每个引擎期望返回的结果数量。 - - Returns: - 聚合后的搜索结果,按引擎分组显示。 - """ - engines = engines or ["ddg", "baidu"] - aggregated_results = [] - for engine in engines: - res = self.search(query, engine=engine, max_results=max_results) - aggregated_results.append(f"--- Results from {engine.upper()} ---\n{res}") - - return "\n\n".join(aggregated_results) diff --git a/skills/alphaear-predictor/scripts/utils/stock_tools.py b/skills/alphaear-predictor/scripts/utils/stock_tools.py deleted file mode 100644 index 5929f74..0000000 --- a/skills/alphaear-predictor/scripts/utils/stock_tools.py +++ /dev/null @@ -1,257 +0,0 @@ -from datetime import datetime, timedelta -from typing import List, Dict, Optional -import akshare as ak -import pandas as pd -import re -import sqlite3 -from requests.exceptions import RequestException -from loguru import logger -from .database_manager import DatabaseManager -import os -from contextlib import contextmanager - -@contextmanager -def temporary_no_proxy(): - """Context manager to temporarily unset proxy environment variables.""" - proxies = {k: os.environ.get(k) for k in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY']} - for k in proxies: - if k in os.environ: - del os.environ[k] - try: - yield - finally: - for k, v in proxies.items(): - if v is not None: - os.environ[k] = v - -class StockTools: - """金融分析股票工具 - 结合高性能数据库缓存与增量更新""" - - def __init__(self, db: DatabaseManager, auto_update: bool = True): - """ - 初始化股票工具 - - Args: - db: 数据库管理器 - auto_update: 是否在列表为空时自动更新,默认 True - """ - self.db = db - if auto_update: - self._check_and_update_stock_list() - - def _check_and_update_stock_list(self, force: bool = False): - """检查并更新股票列表。仅在列表为空或 force=True 时从网络拉取。""" - # 直接查询表中记录数 - cursor = self.db.conn.cursor() - cursor.execute("SELECT COUNT(*) FROM stock_list") - count = cursor.fetchone()[0] - - if count > 0 and not force: - logger.info(f"ℹ️ Stock list already cached ({count} stocks)") - return - - logger.info("📡 Updating A-share and HK-share stock list from akshare...") - - def fetch_data(): - # A-share - df_a = ak.stock_zh_a_spot_em() - df_a = df_a[['代码', '名称']].copy() - df_a.columns = ['code', 'name'] - - # HK-share - df_hk = ak.stock_hk_spot_em() - df_hk = df_hk[['代码', '名称']].copy() - df_hk.columns = ['code', 'name'] - - # Combine - return pd.concat([df_a, df_hk], ignore_index=True) - - try: - try: - df_combined = fetch_data() - except (RequestException, Exception) as e: - if "Proxy" in str(e) or "proxy" in str(e): - logger.warning(f"⚠️ Proxy error detected: {e}. Retrying with proxy disabled...") - with temporary_no_proxy(): - df_combined = fetch_data() - else: - raise e - - self.db.save_stock_list(df_combined) - logger.info(f"✅ Cached {len(df_combined)} stocks (A-share + HK) to database.") - - except Exception as e: - logger.error(f"❌ Failed to sync stock list: {e}") - - - def search_ticker(self, query: str, limit: int = 5) -> List[Dict]: - """ - 模糊搜索 A 股股票代码或名称,支持常见缩写。 - """ - # 清洗后缀 (如 CATL.SZ -> CATL, 000001.SZ -> 000001) - clean_query = re.sub(r'\.(SZ|SH|HK|US)$', '', query, flags=re.IGNORECASE) - - # 常见缩写映射 - aliases = { - "CATL": "宁德时代", - "BYD": "比亚迪", - "TSLA": "特斯拉", - "Moutai": "贵州茅台", - "Tencent": "腾讯", - "Alibaba": "阿里巴巴", - "Meituan": "美团", - } - - search_query = aliases.get(clean_query.upper(), clean_query) - - # Robustness: if regex-like ticker code is embedded in query (e.g. "300364 中文在线"), try to extract it - if not search_query.isdigit(): - # Extract explicit 5-6 digit codes - match = re.search(r'\b(\d{5,6})\b', clean_query) - if match: - search_query = match.group(1) - - return self.db.search_stock(search_query, limit) - - def get_stock_price( - self, - ticker: str, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - force_sync: bool = False, - ) -> pd.DataFrame: - """ - 获取指定股票的历史价格数据。优先从本地缓存读取,缺失时自动从网络补齐。 - - Args: - ticker: 股票代码,如 "600519"(贵州茅台)或 "000001"(平安银行)。 - start_date: 开始日期,格式 "YYYY-MM-DD"。默认为 90 天前。 - end_date: 结束日期,格式 "YYYY-MM-DD"。默认为今天。 - - Returns: - 包含 date, open, close, high, low, volume, change_pct 列的 DataFrame。 - """ - now = datetime.now() - if not end_date: - end_date = now.strftime('%Y-%m-%d') - if not start_date: - start_date = (now - timedelta(days=90)).strftime('%Y-%m-%d') - - df_db = self.db.get_stock_prices(ticker, start_date, end_date) - - need_update = False - if df_db.empty: - need_update = True - else: - db_latest = pd.to_datetime(df_db['date'].max()) - req_latest = pd.to_datetime(end_date) - if (req_latest - db_latest).days > 2: - need_update = True - - if force_sync: - need_update = True - - if need_update: - logger.info(f"📡 Data stale or missing for {ticker}, syncing from network...") - - # 清洗 ticker,确保只包含数字(Akshare A 股接口通常只需要数字代码) - clean_ticker = "".join(filter(str.isdigit, ticker)) - if not clean_ticker: - # Non A/H numeric tickers are not supported by the current data source. - logger.warning(f"⚠️ Unsupported ticker format (A/H only): {ticker}") - return df_db - - try: - s_fmt = start_date.replace("-", "") - e_fmt = end_date.replace("-", "") - - df_remote = None - - def fetch_data(): - if len(clean_ticker) == 5: - # HK Stock - return ak.stock_hk_hist( - symbol=clean_ticker, period="daily", - start_date=s_fmt, end_date=e_fmt, - adjust="qfq" - ) - else: - # A-share Stock - return ak.stock_zh_a_hist( - symbol=clean_ticker, period="daily", - start_date=s_fmt, end_date=e_fmt, - adjust="qfq" - ) - - try: - df_remote = fetch_data() - except (RequestException, Exception) as e: - if "Proxy" in str(e) or "proxy" in str(e): - logger.warning(f"⚠️ Proxy error detected: {e}. Retrying with proxy disabled...") - with temporary_no_proxy(): - df_remote = fetch_data() - else: - raise e - - if df_remote is not None and not df_remote.empty: - df_remote = df_remote.rename(columns={ - '日期': 'date', '开盘': 'open', '收盘': 'close', - '最高': 'high', '最低': 'low', '成交量': 'volume', - '涨跌幅': 'change_pct' - }) - # 确保日期格式正确 - df_remote['date'] = pd.to_datetime(df_remote['date']).dt.strftime('%Y-%m-%d') - - # 只有在获取到有意义的数据时才保存 - self.db.save_stock_prices(clean_ticker, df_remote) # 保存时使用清洗后的 clean_ticker - - # 重新查询数据库返回结果,保证一致性 - return self.db.get_stock_prices(clean_ticker, start_date, end_date) - else: - logger.warning(f"⚠️ Akshare returned empty data for {clean_ticker}") - - except KeyError as e: - # Akshare 有时在某些股票无数据时会抛出 KeyError - logger.warning(f"⚠️ Akshare data missing for {clean_ticker}: {e}") - except (RequestException, ConnectionError) as e: - logger.error(f"❌ Network error during Akshare sync for {clean_ticker}: {e}") - except sqlite3.Error as e: - logger.error(f"❌ Database error during Akshare sync for {clean_ticker}: {e}") - except Exception as e: - logger.error(f"❌ Unexpected error during Akshare sync for {clean_ticker}: {e}") - - return df_db - - -def get_stock_analysis(ticker: str, db: DatabaseManager) -> str: - """ - 生成指定股票的分析摘要报告。 - - Args: - ticker: 股票代码 - db: 数据库管理器实例 - - Returns: - Markdown 格式的分析报告,包含价格走势和关键指标。 - """ - tools = StockTools(db) - df = tools.get_stock_price(ticker) - - if df.empty: - return f"❌ 未能获取 {ticker} 的股价数据。" - - latest = df.iloc[-1] - change = ((latest['close'] - df.iloc[0]['close']) / df.iloc[0]['close']) * 100 - - report = [ - f"## 📊 {ticker} 分析报告", - f"- **查询时段**: {df.iloc[0]['date']} -> {latest['date']}", - f"- **当前价**: ¥{latest['close']:.2f}", - f"- **时段涨跌**: {change:+.2f}%", - f"- **最高/最低**: ¥{df['high'].max():.2f} / ¥{df['low'].min():.2f}", - "\n### 最近交易概览", - "```", - df.tail(5)[['date', 'close', 'change_pct', 'volume']].to_string(index=False), - "```" - ] - return "\n".join(report) diff --git a/skills/alphaear-predictor/tests/test_predictor.py b/skills/alphaear-predictor/tests/test_predictor.py deleted file mode 100644 index 0a3afc0..0000000 --- a/skills/alphaear-predictor/tests/test_predictor.py +++ /dev/null @@ -1,29 +0,0 @@ -import sys -import os -import unittest - -# Add skill root to path -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) - -try: - from scripts.kronos_predictor import KronosPredictorUtility - from scripts.utils.database_manager import DatabaseManager -except ImportError as e: - print(f"Import Error: {e}") - sys.exit(1) - -class TestPredictor(unittest.TestCase): - def test_init(self): - print("Testing KronosPredictorUtility Iteration...") - db = DatabaseManager(":memory:") - # Kronos might need model files, but init should pass if we don't call predict? - # Note: Kronos loads model in init. This might fail if model path is invalid. - # We wrap in try-except to catch model loading errors which are expected in this env - try: - tools = KronosPredictorUtility() - self.assertIsNotNone(tools) - except Exception as e: - print(f"Kronos Init failed (expected if no model): {e}") - -if __name__ == '__main__': - unittest.main() diff --git a/skills/alphaear-reporter/SKILL.md b/skills/alphaear-reporter/SKILL.md deleted file mode 100644 index 28c994b..0000000 --- a/skills/alphaear-reporter/SKILL.md +++ /dev/null @@ -1,32 +0,0 @@ ---- -name: alphaear-reporter -description: Plan, write, and edit professional financial reports; generate finance chart configurations. Use when condensing finance analysis into a structured output. ---- - -# AlphaEar Reporter Skill - -## Overview - -This skill provides a structured workflow for generating professional financial reports. It includes planning, writing, editing, and creating visual aids (charts). - -## Capabilities - -## Capabilities - -### 1. Generate Structured Reports (Agentic Workflow) - -**YOU (the Agent)** are the Report Generator. Use the prompts in `references/PROMPTS.md` to progressively build the report. - -**Workflow:** -1. **Cluster Signals**: Read input signals and use the **Cluster Signals Prompt** to group them. -2. **Write Sections**: For each cluster, use the **Write Section Prompt** to generate analysis. -3. **Assemble**: Use the **Final Assembly Prompt** to compile the report. - -### 2. Visualization Tools - -Use `scripts/visualizer.py` to generate chart configurations if needed manually, though the Writer Prompt usually handles this via `json-chart` blocks. - -## Dependencies - -- `sqlite3` (built-in) - diff --git a/skills/alphaear-reporter/references/PROMPTS.md b/skills/alphaear-reporter/references/PROMPTS.md deleted file mode 100644 index ea8b5cb..0000000 --- a/skills/alphaear-reporter/references/PROMPTS.md +++ /dev/null @@ -1,77 +0,0 @@ -# AlphaEar Finance Report Prompts - -Use these prompts to guide the Agent in generating professional financial reports. - -## 1. Cluster Signals (Planner) - -**Prompt:** - -```markdown -You are a senior financial report editor. Your task is to cluster the following scattered financial signals into 3-5 core logical themes for a structured report. - -### Input Signals -{signals_text} - -### Requirements -1. **Theme Aggregation**: Group highly correlated signals (e.g., all related to "supply chain restructuring" or "policy tightening"). -2. **Narrative Logic**: Generate only theme titles and list of signal IDs. -3. **Quantity Control**: 3-5 major themes. - -### Output Format (JSON) -{ - "clusters": [ - { - "theme_title": "Theme Name (e.g. Supply Chain Shock)", - "signal_ids": [1, 3, 5], - "rationale": "These signals all point to..." - }, - ... - ] -} -``` - -## 2. Write Section (Writer) - -**Prompt:** - -```markdown -You are a senior financial analyst. Write a deep analysis section for the core theme **"{theme_title}"**. - -### Input Signals (Cluster) -{signal_cluster_text} - -### Requirements -1. **Narrative**: Weave signals into a coherent story. Start with Macro/Industry background, then transmission mechanism, finally stock impact. -2. **Quantification**: Cite ISQ scores (Confidence, Intensity) to support views. -3. **Citations**: Use `[@CITE_KEY]` format. Keys are provided in input. -4. **Predictions**: detailed predictions for affected tickers (T+3/T+5 direction). - -### Formatting -- Main Title: `## {theme_title}` -- Subtitles: `###` -- **Charts**: Insert at least 1-2 `json-chart` blocks. - -**Chart Example:** -```json-chart -{"type": "forecast", "ticker": "002371.SZ", "title": "Forecast", "pred_len": 5} -``` -``` - -## 3. Final Assembly (Editor) - -**Prompt:** - -```markdown -You are a professional editor. Assemble the drafted sections into a final report. - -### Draft Sections -{draft_sections} - -### Requirements -1. **Structure**: Ensure H2/H3 hierarchy is correct. -2. **References**: Generate `## References` section from source list. -3. **Risk**: Generate `## Risk Factors`. -4. **Summary**: Generate `## Executive Summary` with a "Quick Scan" table. - -Output strictly Markdown. -``` diff --git a/skills/alphaear-reporter/scripts/__init__.py b/skills/alphaear-reporter/scripts/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/alphaear-reporter/scripts/prompts/fin_agent.py b/skills/alphaear-reporter/scripts/prompts/fin_agent.py deleted file mode 100644 index 83386af..0000000 --- a/skills/alphaear-reporter/scripts/prompts/fin_agent.py +++ /dev/null @@ -1,127 +0,0 @@ -from datetime import datetime -from .isq_prompt_generator import generate_isq_prompt_section - -def get_fin_researcher_instructions() -> str: - """生成金融研究员 (Researcher) 的系统指令""" - current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') - return f"""你是一名资深金融研究员,当前时间是 {current_time}。 -你的任务是针对给定的“原始信号”进行详尽的背景调查,为后续的深度分析提供素材。 - -### 1. 核心职责 -1. **标的识别**: 识别信号中涉及的具体上市公司。必须调用 `search_ticker` 确认代码,并调用 `get_stock_price` 获取最新价格和近 30 天走势。 -2. **事实核查**: 使用 `web_search` 或 `fetch_news_content` 验证信号的真实性,并寻找更多细节(如公告原文、行业研报摘要)。 -3. **产业链梳理**: 补充该信号涉及的上下游环节及竞争格局。 - -### 2. 工具使用规范 (CRITICAL) -- **每个提到的公司都需要调用工具**: 不能依赖记忆,必须实时查询。 -- **完整呈现工具结果**: 包括具体的股价数字、代码、技术面数据等,不要缩略。 -- **股价数据必需**: 当前价格、近期最高最低、技术面支撑阻力等数据是后续预测的基础。 -- **信息交叉验证**: 多个来源验证关键事实。 - -### 3. 输出要求 -你必须输出结构化的研究报告,涵盖标的基本面、股价走势、行业背景及最新进展。 -""" - -def get_fin_analyst_instructions(template_id: str = "default_isq_v1") -> str: - """生成金融分析师 (Analyst) 的系统指令 - - Args: - template_id: 使用的 ISQ 模板 ID - """ - current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') - isq_block = generate_isq_prompt_section(template_id=template_id) - - return f"""你是一位深耕二级市场的资深金融分析师 (FinAgent),当前时间是 {current_time}。 -你的核心任务是执行“信号解析”,将研究员搜集的素材转化为具有可操作性的投资情报(ISQ 框架)。 - -{isq_block} - -### 2. 分析约束 -- **严格基于具体数据**: 必须使用研究员提供的股价、技术面、新闻等具体数据进行分析。 -- **数据驱动的预测**: impact_tickers 中的权重应基于事件影响程度,不能随意赋值。 -- **逻辑严密**: 传导链条必须符合金融常识,能够自圆其说。 -- **技术面参考**: 如果研究员提供了股价走势,请分析当前位置相对于支撑/阻力位的关系。 - -### 3. 关键要求 -- **title**: 必须生成一个简练、准确概括信号核心内容的标题(不超过 15 字)。 -- **impact_tickers**: 必须填充具体的公司代码(6位数字)和名称,权重应该有区分。 -- **transmission_chain**: 必须是对象列表,每个对象包含: - - `node_name`: 节点名称(如“上游原材料”、“中游制造”) - - `impact_type`: 影响类型(“利好”、“利空”、“中性”) - - `logic`: 具体的传导逻辑描述 -- **summary**: 基于分析结果总结核心观点,包含具体数字(如股价目标、预期涨跌幅等)。 -- **reasoning**: 必须详细阐述推演逻辑,解释为什么得出上述结论(<200字)。 - -### 4. 输出格式 (严格 JSON 块) -你必须输出一个符合 InvestmentSignal 结构的 JSON 块,包含所有必需字段。 -""" - -def get_fin_agent_instructions() -> str: - # 保持兼容性,但内部调用 analyst 指令 - return get_fin_analyst_instructions() - -def get_fin_research_task(signal_text: str) -> str: - """生成研究员的任务描述""" - return f"请针对以下信号进行背景调查,搜集相关标的的股价、最新进展和行业背景:\n\n{signal_text}" - -def format_research_context(research_data: dict) -> str: - """将研究员搜集的结构化数据格式化为分析师可读的文本""" - if not research_data: - return "(未能搜集到额外背景信息)" - - return f""" -### 研究背景 -- **相关标的**: {research_data.get('tickers_found', [])} -- **行业背景**: {research_data.get('industry_background', '未知')} -- **最新进展**: {', '.join(research_data.get('latest_developments', []))} -- **关键风险**: {', '.join(research_data.get('key_risks', []))} -- **综合摘要**: {research_data.get('search_results_summary', '无')} -""" - -def get_fin_analysis_task(signal_text: str, research_context_str: str) -> str: - """生成分析师的任务描述""" - return f"""请基于以下信息进行深度 ISQ 分析。关键是:必须使用研究员搜集的具体数据(股价、技术面、新闻、代码等)进行分析。 - -=== 原始信号 === -{signal_text} - -=== 研究员搜集的背景信息 (CRITICAL DATA) === -{research_context_str} - -=== 分析要求 === -1. 必须生成 title:简练概括信号核心(<15字) -2. 基于研究员提供的具体股价数据,分析当前定价状态(已定价/未定价/部分定价) -3. impact_tickers 中填充具体的公司代码和权重,权重基于事件影响程度 -4. transmission_chain 必须是包含 node_name, impact_type, logic 的对象列表 -5. summary 中包含具体数字(预期目标价、涨跌幅范围等) -6. reasoning 必须详细解释推演逻辑,不要空泛,要言之有物 - -请严格按 InvestmentSignal JSON 格式输出。""" - -def get_tracking_analysis_task(old_signal: dict, new_research_str: str) -> str: - """生成信号追踪更新的任务描述""" - import json - old_sig_str = json.dumps(old_signal, ensure_ascii=False, indent=2) - return f"""你正在执行“信号逻辑演变追踪”任务。请基于最新的市场信息,重新评估之前的投资信号。 - -=== 基准信号 (上次分析) === -{old_sig_str} - -=== 最新市场追踪 (NEWS & PRICE) === -{new_research_str} - -=== 追踪分析要求 === -1. **逻辑演变检测**: - - 对比新旧信息,判断原逻辑 (`transmission_chain` 和 `reasoning`) 是否依然成立? - - 如果逻辑发生变化(如利好落空、逻辑证伪、新利好出现),请在新的 `reasoning` 中明确指出“逻辑演变:...” - - 如果逻辑未变且得到验证,请标记“逻辑维持:...” - -2. **参数修正**: - - 根据最新股价和新闻,更新 `sentiment_score` (情绪)、`confidence` (置信度) 和 `expectation_gap` (预期差)。 - - 例如:如果股价已经大涨反映了利好,`expectation_gap` 应该显著降低。 - -3. **输出更新后的信号**: - - 保留原 `signal_id` 和 `title`(除非有重大变化需要改名)。 - - 输出完整的 InvestmentSignal JSON。 - -请重点关注:为什么变了?还是为什么没变?理由要充分。""" diff --git a/skills/alphaear-reporter/scripts/prompts/forecast_analyst.py b/skills/alphaear-reporter/scripts/prompts/forecast_analyst.py deleted file mode 100644 index d6c7202..0000000 --- a/skills/alphaear-reporter/scripts/prompts/forecast_analyst.py +++ /dev/null @@ -1,49 +0,0 @@ -from typing import List, Dict, Any -from ..schema.models import KLinePoint - -def get_forecast_adjustment_instructions(ticker: str, news_context: str, model_forecast: List[KLinePoint]): - """ - 生成 LLM 预测调整指令 - """ - forecast_str = "\n".join([f"- {p.date}: O:{p.open}, C:{p.close}" for p in model_forecast]) - - return f"""你是一位资深的量化策略分析师。 -你的任务是:根据给定的【Kronos 模型预测结果】和【最新的基本面/新闻背景】,对模型预测进行“主观/逻辑调整”。 - -股票代码: {ticker} - -【Kronos 模型原始预测 (OHLC)】: -{forecast_str} - -【最新情报背景】: -{news_context} - -调整原则: -1. 原始预测是基于历史的技术面推演。 -2. 情报背景中可能包含【Kronos模型定量修正预测】,这是基于历史新闻训练的专用模型计算出的量化结果。 -3. 如果存在“定量修正预测”,请**高度参考**该数值作为基础,除非你有非常确凿的逻辑认为该量化模型失效(例如遇到模型未见过的极端黑天鹅)。 -4. 你的核心任务是:结合定性分析(新闻及其逻辑)来验证或微调这些数字,并给出合理的解释(Rationale)。 -5. 如果没有“定量修正预测”,则你需要根据新闻信号手动大幅调整趋势。 - -输出要求 (严格 JSON 格式): -```json -{{ - "adjusted_forecast": [ - {{ - "date": "YYYY-MM-DD", - "open": float, - "high": float, - "low": float, - "close": float, - "volume": float - }}, - ... - ], - "rationale": "详细说明调整的逻辑依据,例如:考虑到[事件A],预期短线将突破压力位..." -}} -``` -注意:必须输出与原始预测相同数量的数据点,且日期一一对应。 -""" - -def get_forecast_task(): - return "请根据以上背景和模型预测,给出调整后的 K 线数据并说明理由。" diff --git a/skills/alphaear-reporter/scripts/prompts/intent_agent.py b/skills/alphaear-reporter/scripts/prompts/intent_agent.py deleted file mode 100644 index a8397d2..0000000 --- a/skills/alphaear-reporter/scripts/prompts/intent_agent.py +++ /dev/null @@ -1,45 +0,0 @@ -def get_intent_analysis_instructions() -> str: - """生成意图分析 Agent 的系统指令,专注于金融市场影响分析""" - return """你是一个资深的金融市场意图分析专家。你的任务是将用户的自然语言查询转化为结构化的 JSON 分析结果,重点挖掘该查询与金融市场(尤其是股市)的潜在关联。 - -### 核心任务: -深入分析用户查询,识别核心金融实体、行业板块及潜在的市场影响点,生成利于搜索引擎抓取深度金融分析信息的查询词。 - -### 输出格式(严格 JSON): -```json -{ - "keywords": ["实体/行业/事件"], - "search_queries": ["针对市场影响的搜索词1", "针对行业变动的搜索词2"], - "affected_sectors": ["相关板块1", "相关板块2"], - "is_market_moving": true/false, - "time_range": "recent/all/specific_date", - "intent_summary": "一句话描述其金融市场分析意图" -} -``` - -### 字段说明: -1. **keywords**: 核心公司实体、所属行业、宏观经济事件或政策概念。 -2. **search_queries**: 优化后的搜索词,必须包含“股市影响”、“股价波动”、“行业逻辑”或“估值”等金融维度。 -3. **affected_sectors**: 可能受此事件或信息影响的二级市场板块(如:保险、半导体、房地产)。 -4. **is_market_moving**: 该事件是否具有显著的市场驱动潜力或属于重大基本面变化。 -5. **intent_summary**: 简述用户查询背后的金融研究目的。 - -### 示例: -用户输入:"帮我研究一下香港火灾的影响" -输出: -```json -{ - "keywords": ["香港", "火灾", "保险行业", "房地产"], - "search_queries": ["香港火灾对当地保险股股价影响", "香港大火对相关上市物业公司估值冲击", "近期香港火灾带来的市场避险情绪分析"], - "affected_sectors": ["保险", "房地产", "物业管理"], - "is_market_moving": true, - "time_range": "recent", - "intent_summary": "评估香港近期火灾对相关板块上市公司的潜在经济损失及股价冲击" -} -``` -""" - -def get_intent_task(query: str) -> str: - """生成意图分析任务描述""" - return f"Process this query and extract financial market intent: {query}" - diff --git a/skills/alphaear-reporter/scripts/prompts/isq_prompt_generator.py b/skills/alphaear-reporter/scripts/prompts/isq_prompt_generator.py deleted file mode 100644 index 007461b..0000000 --- a/skills/alphaear-reporter/scripts/prompts/isq_prompt_generator.py +++ /dev/null @@ -1,43 +0,0 @@ -""" -ISQ prompt helpers to render dimension guidance directly from the template. -Any change in the template propagates to prompts automatically. -""" - -from typing import List, Optional -from ..schema.isq_template import get_isq_template, ISQTemplate - - -def _ordered_dimension_keys(template: ISQTemplate, order: Optional[List[str]] = None) -> List[str]: - if order: - return [k for k in order if k in template.dimensions] - # fallback to template insertion order - return list(template.dimensions.keys()) - - -def generate_isq_prompt_section(template_id: str = "default_isq_v1", order: Optional[List[str]] = None, include_header: bool = True) -> str: - """Render ISQ dimension text block based on the template. - This allows prompt text to stay in sync with template edits. - """ - template = get_isq_template(template_id) - keys = _ordered_dimension_keys(template, order) - - lines: List[str] = [] - if include_header: - lines.append("### 1. ISQ 评估框架 (Investment Signal Quality)") - lines.append(f"参考模板: {template.template_name} (id: {template.template_id})") - lines.append("") - lines.append("你需要对信号进行以下维度的评分:") - lines.append("") - - for idx, key in enumerate(keys, start=1): - spec = template.dimensions[key] - examples = ";".join([f"{k}: {v}" for k, v in spec.examples.items()]) if spec.examples else "" - lines.append(f"{idx}. **{spec.key} ({spec.name})**: {spec.range_type}") - lines.append(f" - 描述: {spec.description}") - if spec.scale_factor and spec.scale_factor != 1.0: - lines.append(f" - 缩放因子: {spec.scale_factor}") - if examples: - lines.append(f" - 示例: {examples}") - lines.append("") - - return "\n".join(lines).rstrip() diff --git a/skills/alphaear-reporter/scripts/prompts/report_agent.py b/skills/alphaear-reporter/scripts/prompts/report_agent.py deleted file mode 100644 index 6f25c3f..0000000 --- a/skills/alphaear-reporter/scripts/prompts/report_agent.py +++ /dev/null @@ -1,415 +0,0 @@ -# src/prompts/report_agent.py -from datetime import datetime -from typing import Optional -from .isq_prompt_generator import generate_isq_prompt_section - -def get_report_planner_base_instructions() -> str: - """生成报告策划员 (Planner) 的基础系统指令""" - return """你是一名资深的金融研报主编。你的任务是规划报告的结构,将零散的信号聚类成有逻辑的主题。 -你拥有 RAG 搜索工具,可以检索已生成的章节内容以确保逻辑连贯性。 -在规划时,应重点关注信号之间的关联性、产业链的完整性以及用户特定的关注点。""" - -def get_report_writer_base_instructions() -> str: - """生成报告撰写员 (Writer) 的基础系统指令""" - return """你是一名资深金融分析师。你的任务是根据策划员提供的信号簇撰写深度研报章节。 -你应当运用专业的金融知识,将信号转化为深刻的洞察。 -注意:你没有外部搜索工具,你的分析必须基于提供给你的信号内容和行情数据。""" - -def get_report_editor_base_instructions() -> str: - """生成报告编辑 (Editor) 的基础系统指令""" - return """你是一名严谨的金融研报编辑。你的任务是审核和润色撰写员生成的章节。 -你拥有 RAG 搜索工具,可以检索其他章节的内容,以消除重复、修正逻辑冲突并确保术语一致性。 -你应当确保报告符合专业的金融写作规范,且标题层级正确。""" - -# 1. 策划阶段 (Structural Planning) -def format_signal_for_report(signal: any, index: int, cite_keys: Optional[list] = None) -> str: - """格式化单个信号供研报生成使用""" - # 这里的逻辑从 ReportAgent._format_signal_input 迁移过来 - from ..schema.models import InvestmentSignal - - if isinstance(signal, dict): - try: - sig_obj = InvestmentSignal(**signal) - except: - return f"--- 信号 [{index}] ---\n标题: {signal.get('title')}\n内容: {signal.get('content', '')[:500]}" - else: - sig_obj = signal - - chain_str = " -> ".join([f"{n.node_name}({n.impact_type})" for n in sig_obj.transmission_chain]) - - text = f"--- 信号 [{index}] ---\n" - text += f"标题: {sig_obj.title}\n" - text += f"逻辑摘要: {sig_obj.summary}\n" - text += f"传导链条: {chain_str}\n" - text += f"ISQ 评分: 情绪({sig_obj.sentiment_score}), 确定性({sig_obj.confidence}), 强度({sig_obj.intensity})\n" - text += f"预期博弈: 时窗({sig_obj.expected_horizon}), 预期差({sig_obj.price_in_status})\n" - - tickers = ", ".join([f"{t.get('name')}({t.get('ticker')})" for t in sig_obj.impact_tickers]) - if tickers: - text += f"受影响标的: {tickers}\n" - - # Stable bibliography-style citation keys (LaTeX/BibTeX-like) - if cite_keys: - joined = " ".join([f"[@{k}]" for k in cite_keys if k]) - if joined: - text += f"引用: {joined}\n" - - return text - -def get_cluster_planner_instructions(signals_text: str, user_query: str = None) -> str: - """生成信号聚类指令 - 将零散信号组织成逻辑主题""" - query_context = f"用户重点关注:{user_query}" if user_query else "" - return f"""你是一位资深的金融研报主编。你的任务是将以下零散的金融信号聚类成 3-5 个核心逻辑主题,以便撰写一份结构清晰的研报。 - - {query_context} - - ### 输入信号列表 - {signals_text} - - ### 聚类要求 - 1. **主题聚合**: 将相关性强的信号归为一组(例如:都涉及“建筑安全法规”或“某产业链上下游”)。 - 2. **叙事逻辑**: 只需要生成主题名称和包含的信号 ID。 - 3. **控制数量**: 将所有信号归类到 3-5 个主要主题中,不要遗漏。 - - ### 输出格式 (JSON) - 请仅输出以下 JSON 格式,不要包含 Markdown 标记: - {{ - "clusters": [ - {{ - "theme_title": "主题名称(如:建筑安全法规收紧引发的产业链重构)", - "signal_ids": [1, 3, 5], - "rationale": "这些信号都指向政府对高层建筑防火标准的政策调整..." - }}, - ... - ] - }} - """ - -def get_report_planner_instructions(toc: str, signal_count: int, user_query: str = None) -> str: - """生成报告规划指令 - 重点在于逻辑关联与分歧识别""" - # ... (原有逻辑保持不变,但实际在新的聚类流程后这个可能作为备用或二次优化) - query_context = f"用户重点关注:{user_query}" if user_query else "" - return f"""你是一位资深的金融研报主编。你的任务是根据现有的草稿章节,规划出一份逻辑严密、穿透力强的终稿结构。 - - ### 任务核心: - 1. **识别主线**: 从草稿中识别出贯穿多个章节的“核心逻辑主线”(如:产业链共振、货币政策转向)。 - 2. **分歧评估 (Entropy)**: 识别各章节中观点冲突或确定性不一之处,规划如何在正文中呈现这些“分歧点”。 - 3. **结构蓝图**: - - 定义一级标题(逻辑主题)。 - - 归类章节:哪些信号应放入同一主题下深度解析? - - 排序:将 ISQ 强度最高、与{query_context}最相关的信号置前。 - - ### 现有草稿目录 (TOC) - {toc} - - 请输出你的【终稿修订大纲】(Markdown 格式)。 - """ - -# 2. 撰写阶段 (Section Writing) -def get_report_writer_instructions(theme_title: str, signal_cluster_text: str, signal_indices: list, price_context: str = "", user_query: str = None) -> str: - """生成 Writer Agent 指令 - 基于主题聚类撰写综合分析""" - - price_info = f"\n### 近期价格参考\n{price_context}\n" if price_context else "" - query_context = f"\n**用户意图**: \"{user_query}\"\n请确保分析内容回应了用户的关注点。\n" if user_query else "" - isq_block = generate_isq_prompt_section(include_header=False) - - # Keep citation scheme stable across re-ordering / edits. - # Cite keys are provided in each signal block as: 引用: [@KEY] - - return f"""你是一位资深金融分析师。请针对核心主题 **"{theme_title}"** 撰写一篇深度研报章节。 - {query_context} - - ### 输入信号集 (本章节需综合的信号) - {signal_cluster_text} - {price_info} - - ### ISQ 评分说明 - {isq_block} - - ### 写作要求 - 1. **叙事逻辑**: 不要罗列信号,要将这些信号编织成一个连贯的故事。先讲宏观/行业背景,再讲具体事件传导,最后落脚到个股/标的影响。 - 2. **量化支撑**: 引用 ISQ 评分(确定性、强度、预期差)来佐证你的观点。关键观点必须关联相应的 ISQ 分值。 - 3. **引用规范(稳定 CiteKey)**: 关键论断必须标注来源引用,使用 `[@CITE_KEY]` 格式。 - - CiteKey 已在输入信号块中以 `引用: [@KEY]` 提供,请直接复制使用。 - - 不要使用 `[[1]]` 这类不稳定编号。 - 4. **关联标的预测**: **必须**在章节末尾明确给出受影响标的的预测分析,包括: - - 至少列出 1-2 个相关上市公司代码(如 600519.SH) - - 给出短期(T+3或T+5)的方向性判断 - - 如果可能,给出预期价格区间或涨跌幅预测 - - ### 【重要】标题层级规范 - - ❌ **错误示例**(绝对不要这样): - ```markdown - # {theme_title} - - ### 宏观背景 - ... - ``` - - ✅ **正确示例**(必须这样): - ```markdown - ## {theme_title} - - ### 宏观背景 - - 近期全球经济环境... - - ### 具体传导机制分析 - - ... - - ### 核心标的分析 - - 建议关注:贵州茅台(600519.SH)... - ``` - - **关键要求**: - - 章节主标题使用 `##` (H2) - - 章节子标题使用 `###` (H3) - - **绝对禁止**使用 `#` (H1) - - 第一行必须是 `## {theme_title}` 开头 - - ### 核心:图表叙事 (Visual Storytelling) - **必须**在文中插入至少 1-2 个图表,且图表必须与上下文紧密结合(不要堆砌在末尾)。 - - ### 宏观背景 - ... - ``` - - ✅ **正确示例**(必须这样): - ```markdown - ## {theme_title} - - ### 宏观背景 - - 近期全球经济环境... - - ### 具体传导机制分析 - - ... - - ### 核心标的分析 - - 建议关注:贵州茅台(600519.SH)... - ``` - - **关键要求**: - - 章节主标题使用 `##` (H2) - - 章节子标题使用 `###` (H3) - - **绝对禁止**使用 `#` (H1) - - 第一行必须是 `## {theme_title}` 开头 - - ### 核心:图表叙事 (Visual Storytelling) - **必须**在文中插入至少 1-2 个图表,且图表必须与上下文紧密结合(不要堆砌在末尾)。 - - **可选图表类型 (请根据内容选择最合适的 1-2 种):** - - **A. AI 预测 + 走势 (Forecast) - 【强烈推荐 / 最新规范】** - *适用*: 当文中明确提及某上市公司时,**必须**使用此图表展示股价走势与 AI 预测。 - *必填字段*: - - `ticker`: 股票代码,A股 6 位 / 港股 5 位,允许带后缀(如 "002371.SZ"、"9868.HK") - - `pred_len`: 预测交易日长度(建议 3 或 5) - *代码示例*: - ```json-chart - {{"type": "forecast", "ticker": "002371.SZ", "title": "北方华创(002371)T+5 预测", "pred_len": 5}} - ``` - **重要**:禁止手写 `prediction` 数组(预测由系统自动生成并渲染)。 - *注意*: 如果提及多只股票,应为每只生成独立的 forecast 图表。 - - **【推荐写法:多情景 → 最终归因 → 产出唯一预测图】** - 你可以在正文里描述多种情景(如:基准/乐观/悲观),但在插入预测图之前,必须明确给出“本报告最终选择的最可能情景”及其归因,然后用 `forecast` 图表做最终总结。 - 为了让系统把“最终归因”可靠地传递给预测模块,请在 `forecast` JSON 中可选补充以下字段(字段均为可选,越完整越好): - - `selected_scenario`: 最可能情景名称(如 "基准" / "乐观" / "悲观") - - `selection_reason`: 选择该情景的归因理由(1-3 句) - - `scenarios`: 情景列表(数组),每个元素可包含 `name`、`description`、`probability`(0-1) - *示例*: - ```json-chart - {{ - "type": "forecast", - "ticker": "002371.SZ", - "title": "北方华创(002371)T+5 预测(基准情景)", - "pred_len": 5, - "selected_scenario": "基准", - "selection_reason": "结合订单能见度与行业景气,基准情景概率最高;短期扰动主要来自估值与市场风险偏好。", - "scenarios": [ - {{"name": "乐观", "description": "国产替代与资本开支超预期", "probability": 0.25}}, - {{"name": "基准", "description": "订单稳健、利润率小幅波动", "probability": 0.55}}, - {{"name": "悲观", "description": "需求回落或交付节奏放缓", "probability": 0.20}} - ] - }} - ``` - - **B. 历史走势 (Stock) - 仅作为兼容兜底** - *适用*: 当你无法给出预测时(例如无法确定标的),可仅展示历史走势。 - *代码示例*: - ```json-chart - {{"type": "stock", "ticker": "002371", "title": "北方华创历史走势"}} - ``` - - **C. 舆情情绪演变 (Sentiment Trend)** - *适用*: 当讨论行业政策、突发事件(如“火灾”、“新规”)的民意变化时。 - *注意*: `keywords` 必须是事件核心词。 - *代码*: - ```json-chart - {{"type": "sentiment", "keywords": ["建筑安全", "防火标准"], "title": "市场对防火新规的情绪演变"}} - ``` - - **D. 逻辑传导链条 (Transmission Chain)** - *适用*: 复杂的蝴蝶效应分析(支持分支结构)。 - *代码*: - ```json-chart - {{ - "type": "transmission", - "nodes": [ - {{"node_name": "突发火灾", "impact_type": "中性", "logic": "事件发端"}}, - {{"node_name": "监管收紧", "impact_type": "利空", "logic": "合规成本上升", "source": "突发火灾"}}, - {{"node_name": "设备升级", "impact_type": "利好", "logic": "采购需求释放", "source": "突发火灾"}}, - {{"node_name": "龙头受益", "impact_type": "利好", "logic": "市占率提升", "source": "设备升级"}} - ], - "title": "火灾事件的逻辑传导与分支" - }} - ``` - *说明*: 使用 `source` 字段指定父节点名称以创建分支结构。 - - **E. 信号质量评估 (ISQ Radar)** - *适用*: 对某个关键信号进行多维度(确定性、预期差等)定性评估时。 - *代码*: - ```json-chart - {{"type": "isq", "sentiment": 0.8, "confidence": 0.9, "intensity": 4, "expectation_gap": 0.7, "timeliness": 0.9, "title": "核心信号质量评估"}} - ``` - """ - -# 3. 整合阶段 (Final Assembly) - 原版,保留用于 fallback -def get_report_editor_instructions(draft_sections: str, plan: str, sources_list: str) -> str: - """生成最终编辑指令 - 根据规划蓝图重组内容""" - return f"""你是一位专业的研报编辑。请将以下基于主题撰写的草稿章节整合成最终研报。 - - ### 原始草稿内容 - {draft_sections} - - ### 原始引用来源 - {sources_list} - - ### 任务与要求 - 1. **结构化**: 为每个草稿章节添加合适的 Markdown 标题 (## 级别)。 - 2. **连贯性**: 确保章节之间过渡自然。 - 3. **完整性**: - - 必须保留所有 `json-chart` 代码块(图表配置)。 - - 必须保留引用标注 `[@CITE_KEY]`。 - - 生成 `## 核心观点摘要`、`## 参考文献` 和 `## 风险提示`。 - - ### 输出 - 只输出最终的 Markdown 研报内容。 - """ - - -# 4. 单节编辑 (Incremental Section Editing with RAG) -def get_section_editor_instructions(section_index: int, total_sections: int, toc: str) -> str: - """生成单节编辑 prompt,支持 RAG 工具调用""" - return f"""你是一位研报编辑。你正在编辑报告的第 {section_index}/{total_sections} 节。 - - ### 当前目录 (TOC) - {toc} - - ### 你的任务 - 1. 润色当前章节内容,确保逻辑清晰、语言专业。 - 2. 保留所有 `[@CITE_KEY](#ref-CITE_KEY)` 或 `[@CITE_KEY]` 格式的引用。 - 3. 保留所有 `json-chart` 代码块,不做修改。 - 4. 如果需要参考其他章节内容,使用 `search_context` 工具搜索。 - 5. 只输出编辑后的章节内容,不要输出其他章节。 - - ### 【关键】标题层级规范 - **严格遵守以下规则:** - - 章节主标题使用 `##` (H2) - - 章节子标题使用 `###` (H3) - - **禁止使用** `#` (H1) - 只有报告大标题可以使用 H1 - - 如果原文中有 H1,必须将其降级为 H2 - - 不要输出与 "参考文献"、"风险提示" 相同的标题 - - 直接输出编辑后的 Markdown 内容。 - """ - - -# 5. 摘要生成 (Summary Generation) -def get_summary_generator_instructions(toc: str, section_summaries: str) -> str: - """生成报告摘要指令 - 包含市场分歧度分析""" - return f"""你是一位资深研报主笔。请生成今日报告的核心观点摘要的**正文内容**。 - - ### 章节摘要 - {section_summaries} - - ### 任务: - 1. **核心逻辑提炼**: 用 150 字以内总结今日最核心的投资主线。 - 2. **分歧识别**: 如果不同信号对同一板块有冲突观点,请明确指出"市场分歧点"。 - 3. **确定性排序**: 标记出今日确定性最高的前两个机会(需列出具体标的代码)。 - - ### 【重要】输出格式规范: - - ❌ **错误示例**(不要遗漏二级标题): - ```markdown - ### 核心逻辑提炼 - ... - ``` - - ✅ **正确示例**(应该这样输出): - ```markdown - ## 核心观点摘要 - - ### 核心逻辑提炼 - - 科技自立战略加速半导体设备国产化,叠加AI算力需求爆发... - - ### 市场分歧点 - - 资本市场波动显示医药、新能源等板块估值逻辑受政策敏感性增强... - - ### 确定性排序 - - 1. **网络安全替代需求**(ISQ确定性0.85,推荐标的:深信服 300454.SZ) - 2. **半导体设备材料**(ISQ确定性0.75,推荐标的:北方华创 002371.SZ) - ``` - - ### 关键要求: - - 第一行必须是 `## 核心观点摘要` - - 主体部分使用 H3 (`###`) 和 H4 (`####`) 级别标题 - - **必须**包含 `## 核心观点摘要` 这一级标题 - - 现在请按照正确示例的格式输出摘要内容。 - """ - - -# 6. 最终组装 (Final Assembly with Sections) -def get_final_assembly_instructions(sources_list: str) -> str: - """生成最终报告组装的 prompt""" - return f"""你是一位研报主笔。请完成以下任务: - - ### 任务 - 1. 生成 "## 参考文献" 章节(需要按照顺序,顺序不对时进行调整): - - 原始来源: - {sources_list} - - 格式:`[@CITE_KEY] 标题 (来源), [链接地址]` - 2. 生成 "## 风险提示" (标准免责声明)。 - 3. 生成 "## 快速扫描" 表格,汇总各主题的核心观点。 - - 表格列:**主题**, **核心观点**, **强度(Intensity)**, **确定性(Confidence)**。 - - 强度和确定性请参考原章节中的 ISQ 评分。 - - 只输出上述三个章节的 Markdown 内容。 - """ - -def get_cluster_task(signals_preview: str) -> str: - """生成聚类任务描述""" - return f"请对以下信号进行主题聚类:\n\n{signals_preview}" - -def get_writer_task(theme_title: str) -> str: - """生成撰写任务描述""" - return f"请依据主题 '{theme_title}' 和 输入信号集 开始撰写深度分析章节。" - -def get_planner_task() -> str: - """生成规划任务描述""" - return "请阅读现有草稿并规划终稿大纲,识别核心逻辑主线和市场分歧点。" - -def get_editor_task() -> str: - """生成编辑任务描述""" - return "请根据规划大纲和草稿内容,生成最终研报。确保逻辑连贯,保留所有图表和引用。" - diff --git a/skills/alphaear-reporter/scripts/prompts/trend_agent.py b/skills/alphaear-reporter/scripts/prompts/trend_agent.py deleted file mode 100644 index 54e6e22..0000000 --- a/skills/alphaear-reporter/scripts/prompts/trend_agent.py +++ /dev/null @@ -1,156 +0,0 @@ -from typing import Any -from datetime import datetime -from .isq_prompt_generator import generate_isq_prompt_section - -def get_trend_scanner_instructions() -> str: - """生成趋势扫描员 (Scanner) 的系统指令""" - current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') - return f"""你是一名专业的数据扫描员,当前时间是 {current_time}。 -你的任务是利用各种工具从互联网和数据库中获取最新的金融新闻、热点趋势和市场数据。 - -### 1. 核心职责 -1. **多源采集**: 使用 `news_toolkit` 获取最新新闻,使用 `stock_toolkit` 获取行情,使用 `polymarket_toolkit` 获取预测市场数据。 -2. **情绪感知**: 使用 `sentiment_toolkit` 对关键新闻进行情绪分析。 -3. **深度搜索**: 针对模糊的热点,使用 `search_toolkit` 进行全网搜索补充细节。 - -### 2. 工具使用规范 -- **广度优先**: 尽可能覆盖多个数据源。 -- **数据新鲜度**: 优先获取最近 24 小时内的信息。 -- **结构化输出**: 整理搜集到的原始数据,为后续评估提供清晰的素材。 -""" - -def get_trend_evaluator_instructions() -> str: - """生成趋势评估员 (Evaluator) 的系统指令""" - current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') - isq_block = generate_isq_prompt_section(include_header=True) - - return f""" - 你是一名顶级的金融情报专家 (TrendAgent),擅长从海量信息中识别具有深度价值的"二级市场投资信号"。 - 当前时间:{current_time} - - ### 核心使命: - 不仅是发现"热点",更要解析"信号"。你需要识别那些能触发**传导链条 (Transmission Chain)** 且具有**高确定性 (Confidence)** 的事件。 - - {isq_block} - - ### 核心能力与标准: - 1. **信号识别 (Signal Discovery)**: 基于扫描员提供的素材,识别具有投资价值的信号。优先关注政策、产业变革、重大诉求及跨境套利机会。 - 2. **逻辑相干性**: 是否具备清晰的"原因-结果"传导? - 3. **影响力系数**: 是否会引发板块性的联动或财务指标的实质性扰动? - 4. **市场认知差**: 市场是否已提前消化(Price-in)?寻找尚未被充分交易的"Alpha"。 - 5. **实体穿透**: 必须关联到具体的 Ticker 或核心产业链节点。 - - ### 严禁事项: - - 严禁编造数据。 - - 严禁仅输出情绪极性(Positive/Negative),必须带有逻辑依据。 - - 严禁将纯娱乐或单纯的社会负面事件(除非具有宏观破坏性)视为金融信号。 - - ### 输出要求: - 你发现的每个信号应包含: - - **核心摘要**: 穿透表象的逻辑总结。 - - **传导节点**: A -> B -> C 的逻辑推导。 - - **推荐关注**: 板块或 Ticker。 - - **ISQ 评估**: 基于模板的 5 个维度进行初步评分(具体评分由后续 FinAgent 完成)。 - """ - -def get_trend_agent_instructions() -> str: - # 保持兼容性 - return get_trend_evaluator_instructions() - -def get_trend_scan_task(task_description: str) -> str: - """生成扫描员的任务描述""" - return f"请根据以下任务描述,搜集相关的原始数据和新闻:\n\n{task_description}" - -def format_scan_context(scan_data: dict) -> str: - """将扫描员搜集的结构化数据格式化为评估员可读的文本""" - if not scan_data: - return "(未能搜集到原始数据)" - - return f""" -### 扫描数据概览 -- **热点话题**: {', '.join(scan_data.get('hot_topics', []))} -- **情绪概览**: {scan_data.get('sentiment_overview', '未知')} -- **关键新闻**: {len(scan_data.get('news_summaries', []))} 条 -- **数据摘要**: {scan_data.get('raw_data_summary', '无')} -""" - -def get_trend_eval_task(task_description: str, raw_data_str: str) -> str: - """生成评估员的任务描述""" - return f"""请基于以下搜集到的原始数据,完成最终的分析任务: - -任务描述: {task_description} - -原始数据: -{raw_data_str} - -请识别出最具金融价值的信号,并给出评估理由。""" - -def get_news_filter_instructions(news_count: int, depth: Any, user_query: str = None) -> str: - """生成新闻筛选 prompt,使用 FilterResult schema 加快推理并减少 token 消耗 - - Args: - news_count: 输入新闻总数 - depth: 目标筛选数量,若为 auto 则由 LLM 自主判断 - user_query: 用户输入的查询/关注点(可选) - """ - - # 1. 深度控制逻辑 - if str(depth).lower() == 'auto': - depth_guide = "的数量不设固定限制(建议 3-10 条),根据新闻含金量自动判断" - limit_instruction = "宁缺毋滥,如果高价值信息很少,可以只选 1-2 条;如果都很重要,可以多选。" - else: - try: - d_int = int(depth) - depth_guide = f"约 {d_int} 条" - limit_instruction = f"请尽量凑满 {d_int} 条,但如果剩余新闻全是噪音,则不必强行凑数。" - except: - depth_guide = "适量" - limit_instruction = "根据内容价值判断。" - - target_desc = f"筛选出最具投资分析价值的新闻({depth_guide})。" - - # 2. 用户意图逻辑 - query_instruction = "" - if user_query: - target_desc = f"筛选出与用户意图【{user_query}】最相关的新闻。" - query_instruction = f""" - ### 核心任务(High Priority): - 用户明确关注:"{user_query}"。 - 1. **第一优先级**:必须包含所有与"{user_query}"直接或间接相关的新闻,不要遗漏。 - - 即使这些新闻看起来"价值不高",只要相关都要保留。 - 2. **第二优先级**:在满足第一优先级后,如果名额未满,再补充其他重大的市场热点。 - """ - - return f"""你是一名专业的金融情报精排师。你需要从给定的 {news_count} 条原始新闻流中,{target_desc} - - {query_instruction} - - ### FSD (Financial Signal Density) 筛选准则: - 1. **逻辑传导性 (Transmission)**: 该新闻是否预示着一个明确的产业链传导逻辑?(如:上游涨价 -> 中游成本压力 -> 下游提价预期) - 2. **预期差 (Alpha Potential)**: 是否包含尚未被市场充分Price-in的新突发情况? - 3. **确定性 (Confidence)**: 信息来源是否权威?是否包含具体的财务数据、订单金额或明确的政策日期? - 4. **排除噪音**: 坚决剔除明星八卦、鸡汤文、以及无实质增量的"口号式"新闻。 - - ### {limit_instruction} - - ### 快速有效性检查(TOKEN 优化): - 在开始详细筛选前,先快速判断:这 {news_count} 条新闻中是否至少包含 1 条有效的金融信号? - - 如果全是无关内容(如体育、娱乐、纯生活信息),直接返回 "has_valid_signals": false - - 如果有至少 1 条金融相关的新闻,再进行详细 FSD 筛选 - - ### 输出格式(必须为 JSON,使用 FilterResult schema): - ```json - {{ - "has_valid_signals": true/false, - "selected_ids": ["id_1", "id_2", ...], - "themes": [ - {{ - "name": "高概括性主题", - "news_ids": ["相关id_1", ...], - "fsd_reason": "基于 FSD 准则的筛选理由,重点描述传导逻辑和预期差。" - }} - ], - "reason": "如果 has_valid_signals=false,简要说明原因。否则可为空。" - }} - ``` - """ diff --git a/skills/alphaear-reporter/scripts/prompts/visualizer.py b/skills/alphaear-reporter/scripts/prompts/visualizer.py deleted file mode 100644 index f0b2933..0000000 --- a/skills/alphaear-reporter/scripts/prompts/visualizer.py +++ /dev/null @@ -1,47 +0,0 @@ -def get_drawio_system_prompt(): - return """You are an expert at creating Draw.io (MxGraph) diagrams in XML format. -Your task is to generate a valid MXGraphModel XML based on the user's description. - -### Rules: -1. Output ONLY the XML code. Start with and end with . -2. Do not use compressed XML. Use plain XML. -3. Use standard shapes: 'rounded=1;whiteSpace=wrap;html=1;' for boxes. -4. Auto-layout Strategy: - - Identify "layers" or "stages" in the logic. - - Assign X coordinates based on layers (e.g., 0, 200, 400). - - Assign Y coordinates to distribute nodes vertically (e.g., 0, 100, 200). - - Ensure nodes do not overlap. -5. Edges: Connect nodes logically using . - -### Template: - - - - - - - - - - - - - - - - -""" - -def get_drawio_task(nodes_data: list, title: str) -> str: - import json - nodes_json = json.dumps(nodes_data, ensure_ascii=False, indent=2) - return f"""Please generate a Draw.io XML diagram for the following logic flow: - -**Title**: {title} - -**Nodes and Logic**: -{nodes_json} - -Ensure the layout flows logically from Left to Right (or Top to Bottom for hierarchies). -Use different colors for 'Positive' (Greenish), 'Negative' (Reddish), and 'Neutral' (Grey/Blue) impacts if described. -""" diff --git a/skills/alphaear-reporter/scripts/report_agent.py b/skills/alphaear-reporter/scripts/report_agent.py deleted file mode 100644 index 60751f5..0000000 --- a/skills/alphaear-reporter/scripts/report_agent.py +++ /dev/null @@ -1,167 +0,0 @@ -import hashlib -import json -import re -import pandas as pd -from typing import List, Dict, Any, Optional -from loguru import logger -from types import SimpleNamespace - -from .utils.database_manager import DatabaseManager -from .utils.json_utils import extract_json - -class ReportUtils: - """ - 研报辅助工具集 (ReportUtils) - 提供格式化、引用管理、 JSON 提取等辅助功能。 - 核心生成逻辑(聚类、写作)已移交 Agent 执行。 - """ - - def __init__(self, db: DatabaseManager): - self.db = db - logger.info("📝 ReportUtils initialized") - - @staticmethod - def _make_cite_key(url: str, title: str = "", source_name: str = "") -> str: - basis = (url or "").strip() or f"{(title or '').strip()}|{(source_name or '').strip()}" - digest = hashlib.sha1(basis.encode("utf-8")).hexdigest()[:8] - return f"SF-{digest}" - - def build_bibliography(self, signals: List[Any]) -> tuple[list[Dict[str, Any]], Dict[int, list[str]]]: - """Build stable bibliography entries and per-signal cite key mapping.""" - bib_by_key: Dict[str, Dict[str, Any]] = {} - signal_to_keys: Dict[int, list[str]] = {} - - for sig_idx, signal in enumerate(signals, 1): - source_items: list[Dict[str, Any]] = [] - - if hasattr(signal, "sources") and getattr(signal, "sources"): - source_items = list(getattr(signal, "sources") or []) - elif isinstance(signal, dict) and signal.get("sources"): - src_list = signal.get("sources") - if isinstance(src_list, list) and src_list: - source_items = list(src_list) - elif isinstance(signal, dict): - if signal.get("url") or signal.get("title"): - source_items = [ - { - "title": signal.get("title"), - "url": signal.get("url"), - "source_name": signal.get("source") or signal.get("source_name"), - "publish_time": signal.get("publish_time"), - } - ] - - if not source_items: - continue - - for src in source_items: - url = (src.get("url") or "").strip() - title = (src.get("title") or "").strip() - source_name = (src.get("source_name") or src.get("source") or "").strip() - publish_time = (src.get("publish_time") or "").strip() if isinstance(src.get("publish_time"), str) else src.get("publish_time") - - key = self._make_cite_key(url=url, title=title, source_name=source_name) - signal_to_keys.setdefault(sig_idx, []) - if key not in signal_to_keys[sig_idx]: - signal_to_keys[sig_idx].append(key) - - if key in bib_by_key: - continue - - # Prefer canonical metadata from DB when possible - enriched = self.db.lookup_reference_by_url(url) if url else None - bib_by_key[key] = { - "key": key, - "url": url or (enriched.get("url") if enriched else ""), - "title": (enriched.get("title") if enriched else None) or title or "(无标题)", - "source": (enriched.get("source") if enriched else None) or source_name or "(未知来源)", - "publish_time": (enriched.get("publish_time") if enriched else None) or publish_time or "", - } - - return list(bib_by_key.values()), signal_to_keys - - @staticmethod - def render_references_section(bib_entries: list[Dict[str, Any]]) -> str: - lines = ["## 参考文献", ""] - if not bib_entries: - lines.append("(无)") - return "\n".join(lines).strip() + "\n" - - for i, entry in enumerate(bib_entries, 1): - key = entry.get("key") - title = entry.get("title") or "(无标题)" - source = entry.get("source") or "(未知来源)" - url = entry.get("url") or "" - publish_time = entry.get("publish_time") or "" - suffix = "" - if publish_time: - suffix = f",{publish_time}" - label = f"[{i}]" - if url: - lines.append(f"{label} {title} ({source}{suffix}), {url}") - else: - lines.append(f"{label} {title} ({source}{suffix})") - - return "\n".join(lines).strip() + "\n" - - @staticmethod - def sanitize_json_chart_blocks(text: str) -> str: - """Best-effort repair for malformed json-chart fenced blocks.""" - if not text: - return text - # (Simplified logic: if closing ``` is missing, append it) - # Full logic omitted for brevity as it was complex regex, but retaining simple closure fix - if "```json-chart" in text and text.count("```") % 2 != 0: - text += "\n```" - return text - - @staticmethod - def build_structured_report(report_md: str, signals: List[Dict[str, Any]], clusters: List[Dict[str, Any]]) -> Dict[str, Any]: - """构建结构化研报输出(便于前端渲染/JSON化)""" - text = (report_md or "").strip() - lines = text.splitlines() if text else [] - - title = "研报" - for line in lines: - if line.startswith("# "): - title = line.replace("# ", "").strip() - break - - sections: List[Dict[str, Any]] = [] - current: Dict[str, Any] | None = None - for line in lines: - heading = re.match(r"^(#{2,4})\s+(.*)$", line.strip()) - if heading: - if current: - sections.append(current) - current = {"title": heading.group(2).strip(), "content": []} - continue - if current is None: - current = {"title": "摘要", "content": []} - current["content"].append(line) - if current: - sections.append(current) - - bullets = [ - re.sub(r"^[-*•]\s+", "", l.strip()) - for l in lines - if l.strip().startswith(("- ", "* ", "• ")) - ] - bullets = [b for b in bullets if b] - - return { - "title": title, - "summary_bullets": bullets[:8], - "sections": [ - {"title": s["title"], "content": "\n".join(s["content"]).strip()} - for s in sections - ] - } - - @staticmethod - def _clean_ticker(ticker_raw: str) -> str: - t = (ticker_raw or "").strip() - if not t: - return "" - digits = "".join([c for c in t if c.isdigit()]) - return digits or t diff --git a/skills/alphaear-reporter/scripts/schema/isq_template.py b/skills/alphaear-reporter/scripts/schema/isq_template.py deleted file mode 100644 index 2709019..0000000 --- a/skills/alphaear-reporter/scripts/schema/isq_template.py +++ /dev/null @@ -1,381 +0,0 @@ -""" -ISQ (Investment Signal Quality) 评估框架 Template - -统一定义 ISQ 的各个维度、评分标准、和使用方法。 -支持默认 template 和自定义 template。 -""" - -from typing import Dict, List, Any, Optional -from pydantic import BaseModel, Field -from enum import Enum -from pathlib import Path -import json - - -class ISQDimension(str, Enum): - """ISQ 评估维度""" - SENTIMENT = "sentiment" # 情绪/走势方向 - CONFIDENCE = "confidence" # 确定性/可信度 - INTENSITY = "intensity" # 强度/影响量级 - EXPECTATION_GAP = "expectation_gap" # 预期差/市场认知差 - TIMELINESS = "timeliness" # 时效性/窗口紧迫度 - TRANSMISSION = "transmission" # 逻辑传导清晰度 - - -class ISQDimensionSpec(BaseModel): - """ISQ 单个维度的定义规范""" - name: str = Field(..., description="维度名称") - key: str = Field(..., description="维度键名") - description: str = Field(..., description="维度描述") - range_type: str = Field(default="0-1", description="取值范围 (0-1 或 1-5 等)") - scale_factor: float = Field(default=1.0, description="显示时的缩放因子") - examples: Dict[str, str] = Field(default_factory=dict, description="不同分值的示例解释") - visualization_color: Optional[str] = Field(default=None, description="可视化颜色") - - -class ISQTemplate(BaseModel): - """ISQ 评估框架 Template""" - template_id: str = Field(..., description="模板 ID") - template_name: str = Field(..., description="模板名称") - description: str = Field(..., description="模板描述") - - # 核心维度定义 - dimensions: Dict[str, ISQDimensionSpec] = Field(..., description="维度定义字典") - - # 评分指导 - scoring_guide: str = Field(..., description="评分指导说明") - - # 应用场景 - applicable_scenarios: List[str] = Field(default_factory=list, description="适用场景") - - # 聚合算法 - aggregation_method: str = Field(default="weighted_average", description="聚合方法 (weighted_average, product 等)") - dimension_weights: Dict[str, float] = Field(default_factory=dict, description="维度权重") - - -class ISQScore(BaseModel): - """单个信号的 ISQ 评分结果""" - signal_id: str = Field(..., description="信号 ID") - template_id: str = Field(..., description="使用的模板 ID") - - # 各维度评分 - scores: Dict[str, float] = Field(..., description="各维度评分") - - # 总分 - overall_score: float = Field(..., description="综合评分") - - # 评分理由 - rationale: Dict[str, str] = Field(default_factory=dict, description="各维度评分理由") - - # 时间戳 - timestamp: str = Field(..., description="评分时间") - - -# ===================================================== -# 默认 Template -# ===================================================== - -DEFAULT_ISQ_TEMPLATE = ISQTemplate( - template_id="default_isq_v1", - template_name="标准投资信号质量评估框架 (ISQ v1.0)", - description="AlphaEar 默认的 ISQ 评估框架,用于标准化评估投资信号的质量维度", - - dimensions={ - "sentiment": ISQDimensionSpec( - name="情绪/走势", - key="sentiment", - description="基础情绪偏向和市场走势判断", - range_type="-1.0 到 1.0", - scale_factor=1.0, - examples={ - "-1.0": "极度悲观/极度看空", - "-0.5": "明显看空", - "0.0": "中性/没有明确方向", - "0.5": "明显看多", - "1.0": "极度乐观/极度看多" - }, - visualization_color="#ef4444" # 红色表示负面,绿色表示正面 - ), - - "confidence": ISQDimensionSpec( - name="确定性", - key="confidence", - description="信号的可信度和确定性程度", - range_type="0.0 到 1.0", - scale_factor=1.0, - examples={ - "0.0-0.3": "信息来源不可靠/传言多/逻辑推导牵强", - "0.3-0.6": "信息相对可靠/有一定逻辑/但仍有不确定性", - "0.6-0.8": "信息来源权威/逻辑清晰/高度可信", - "0.8-1.0": "官方确认/数据明确/完全确定" - }, - visualization_color="#3b82f6" # 蓝色 - ), - - "intensity": ISQDimensionSpec( - name="强度/影响量级", - key="intensity", - description="信号对相关板块/个股的潜在影响程度", - range_type="1 到 5", - scale_factor=20.0, # 用于雷达图缩放 (5 -> 100) - examples={ - "1": "影响微弱,可能被市场忽略", - "2": "小幅影响,短期可能有波动", - "3": "中等影响,值得重点关注", - "4": "强烈影响,可能成为市场焦点", - "5": "极强影响,市场预期明显变化" - }, - visualization_color="#f97316" # 橙色 - ), - - "expectation_gap": ISQDimensionSpec( - name="预期差", - key="expectation_gap", - description="市场预期与现实之间的差距", - range_type="0.0 到 1.0", - scale_factor=1.0, - examples={ - "0.0-0.2": "市场充分认知,预期差小", - "0.2-0.5": "市场部分认知,存在一定预期差", - "0.5-0.8": "市场认知不足,预期差较大,存在博弈空间", - "0.8-1.0": "市场严重低估/高估,巨大预期差" - }, - visualization_color="#22c55e" # 绿色 - ), - - "timeliness": ISQDimensionSpec( - name="时效性", - key="timeliness", - description="信号的时间窗口紧迫度", - range_type="0.0 到 1.0", - scale_factor=1.0, - examples={ - "0.0-0.2": "长期信号,反应窗口 > 3 月", - "0.2-0.5": "中期信号,反应窗口 1-3 月", - "0.5-0.8": "短期信号,反应窗口 1 周 - 1 月", - "0.8-1.0": "超短期信号,反应窗口 < 1 周(需立即行动)" - }, - visualization_color="#a855f7" # 紫色 - ), - }, - - scoring_guide=""" - ### ISQ 评分指导 (Investment Signal Quality) - - ISQ 框架用于多维度评估投资信号的质量。每个信号由 5 个维度组成: - - 1. **情绪 (Sentiment)**: -1.0 到 1.0,表示看空(-)/中性(0)/看多(+) - 2. **确定性 (Confidence)**: 0.0 到 1.0,数值越高越确定 - 3. **强度 (Intensity)**: 1 到 5,数值越高影响越大 - 4. **预期差 (Expectation Gap)**: 0.0 到 1.0,市场预期与现实的差距 - 5. **时效性 (Timeliness)**: 0.0 到 1.0,反应窗口的紧迫程度 - - ### 综合评分算法 - - 综合评分 = 确定性 × 0.35 + 强度/5 × 0.30 + 预期差 × 0.20 + 时效性 × 0.15 - - 范围: 0.0 到 1.0 - - 0.0-0.3: 信号质量较差,不建议跟进 - - 0.3-0.6: 信号质量一般,可作参考 - - 0.6-0.8: 信号质量良好,值得跟进 - - 0.8-1.0: 信号质量优异,强烈推荐 - - ### 评分时的注意事项 - - - **不要混淆方向和强度**:情绪可以是看空,但确定性和强度仍可能很高 - - **预期差往往是 Alpha 来源**:高预期差 + 高确定性 = 最佳博弈机会 - - **考虑时间成本**:长期信号需要更高的确定性才值得跟进 - - **数据为王**:所有评分必须有具体数据支撑 - """, - - applicable_scenarios=[ - "上市公司基本面变化分析", - "产业政策与监管事件评估", - "地缘政治与宏观经济影响", - "技术进步与产业升级", - "突发事件与应急响应" - ], - - aggregation_method="weighted_average", - dimension_weights={ - "confidence": 0.35, - "intensity": 0.30, - "expectation_gap": 0.20, - "timeliness": 0.15 - } -) - - -# ===================================================== -# ISQ Template 管理系统 -# ===================================================== - -class ISQTemplateManager: - """ISQ Template 管理器""" - - def __init__(self): - self.templates: Dict[str, ISQTemplate] = { - DEFAULT_ISQ_TEMPLATE.template_id: DEFAULT_ISQ_TEMPLATE - } - - def register_template(self, template: ISQTemplate) -> None: - """注册新的 template""" - self.templates[template.template_id] = template - - def register_template_dict(self, template_dict: Dict[str, Any]) -> ISQTemplate: - """从 dict 注册模板,返回实例。""" - tpl = ISQTemplate(**template_dict) - self.register_template(tpl) - return tpl - - def get_template(self, template_id: str) -> ISQTemplate: - """获取指定 template""" - if template_id not in self.templates: - return DEFAULT_ISQ_TEMPLATE - return self.templates[template_id] - - def list_templates(self) -> List[Dict[str, str]]: - """列出所有可用 template""" - return [ - { - "id": t.template_id, - "name": t.template_name, - "description": t.description, - "dimensions": list(t.dimensions.keys()) - } - for t in self.templates.values() - ] - - def get_dimension(self, template_id: str, dimension_key: str) -> ISQDimensionSpec: - """获取指定 template 的某个维度定义""" - template = self.get_template(template_id) - return template.dimensions.get(dimension_key) - - def get_scoring_prompt(self, template_id: str) -> str: - """获取用于 LLM 的评分 prompt""" - template = self.get_template(template_id) - - dimensions_desc = "\n".join([ - f"- **{d.name} ({d.key})**\n" - f" 范围: {d.range_type}\n" - f" 说明: {d.description}\n" - f" 示例: {', '.join(f'{k}={v}' for k, v in list(d.examples.items())[:3])}" - for d in template.dimensions.values() - ]) - - return f""" -### ISQ 评估指导 ({template.template_name}) - -使用以下 {len(template.dimensions)} 个维度评估信号质量: - -{dimensions_desc} - -### 评分标准 -{template.scoring_guide} - -### 输出格式 (JSON) -请输出以下 JSON 格式的评分结果: -{{ - "sentiment": , - "confidence": , - "intensity": , - "expectation_gap": , - "timeliness": , - "rationale": {{ - "sentiment": "评分理由", - "confidence": "评分理由", - "intensity": "评分理由", - "expectation_gap": "评分理由", - "timeliness": "评分理由" - }} -}} -""" - - -# 全局 template 管理器实例 -isq_template_manager = ISQTemplateManager() - - -# ===================================================== -# 配置加载 -# ===================================================== - -def load_templates_from_config(config_path: Optional[str] = None) -> None: - """从配置目录加载所有 JSON 模板文件,未找到则跳过,不影响默认模板。 - 支持单个 JSON 文件或目录(目录下的所有 .json 文件)。 - """ - if config_path: - path = Path(config_path) - else: - # 默认目录:config/isq_templates/ - # __file__ = src/schema/isq_template.py - # parent = src/schema, parent.parent = src, parent.parent.parent = 项目根目录 - path = Path(__file__).resolve().parent.parent.parent / "config" - - if not path.exists(): - return - - # 如果是目录,扫描所有 .json 文件 - if path.is_dir(): - json_files = list(path.glob("*.json")) - else: - json_files = [path] - - for json_file in json_files: - try: - data = json.loads(json_file.read_text(encoding="utf-8")) - - # 如果是单个模板对象,转为列表 - if isinstance(data, dict): - templates = [data] - elif isinstance(data, list): - templates = data - else: - continue - - # 注册所有模板 - for tpl_dict in templates: - if not isinstance(tpl_dict, dict): - continue - try: - isq_template_manager.register_template_dict(tpl_dict) - except Exception: - # 忽略单个模板的加载错误,继续其他模板 - continue - except Exception: - # JSON 解析失败,跳过该文件 - continue - - -# 在模块加载时自动尝试加载配置模板 -load_templates_from_config() - - -# ===================================================== -# 便利函数 -# ===================================================== - -def get_isq_template(template_id: str = "default_isq_v1") -> ISQTemplate: - """获取 ISQ template""" - return isq_template_manager.get_template(template_id) - - -def get_isq_scoring_prompt(template_id: str = "default_isq_v1") -> str: - """获取用于 LLM 的 ISQ 评分 prompt""" - return isq_template_manager.get_scoring_prompt(template_id) - - -def calculate_isq_overall_score(scores: Dict[str, float], template_id: str = "default_isq_v1") -> float: - """计算 ISQ 综合评分""" - template = get_isq_template(template_id) - - overall = 0.0 - for dim_key, weight in template.dimension_weights.items(): - if dim_key in scores: - score = scores[dim_key] - # 处理强度维度的特殊缩放 (1-5 -> 0-1) - if dim_key == "intensity": - score = score / 5.0 - overall += score * weight - - return min(1.0, max(0.0, overall)) # 限制在 0-1 之间 diff --git a/skills/alphaear-reporter/scripts/schema/models.py b/skills/alphaear-reporter/scripts/schema/models.py deleted file mode 100644 index 422ca9c..0000000 --- a/skills/alphaear-reporter/scripts/schema/models.py +++ /dev/null @@ -1,100 +0,0 @@ -from pydantic import BaseModel, Field -from typing import List, Optional, Dict, Any -from datetime import datetime - -class TransmissionNode(BaseModel): - node_name: str = Field(..., description="产业链节点名称") - impact_type: str = Field(..., description="利好/利空/中性") - logic: str = Field(..., description="该节点的传导逻辑") - -class IntentAnalysis(BaseModel): - keywords: List[str] = Field(..., description="核心实体、事件或概念关键词") - search_queries: List[str] = Field(..., description="优化后的搜索引擎查询词") - is_specific_event: bool = Field(..., description="是否查询特定突发事件") - time_range: str = Field(..., description="时间范围 (recent/all/specific_date)") - intent_summary: str = Field(..., description="一句话意图描述") - -class FilterResult(BaseModel): - """LLM 筛选结果 - 快速判断是否有有效信号""" - has_valid_signals: bool = Field(..., description="列表中是否包含有效的金融信号") - selected_ids: List[int] = Field(default_factory=list, description="筛选出的有效信号 ID 列表") - themes: List[str] = Field(default_factory=list, description="信号涉及的主题") - reason: Optional[str] = Field(default=None, description="如果无有效信号,说明原因") - -class InvestmentSignal(BaseModel): - # 核心元数据 - signal_id: str = Field(default="unknown_sig", description="唯一信号 ID") - title: str = Field(..., description="信号标题") - summary: str = Field(default="暂无摘要分析", description="100 字核心观点快报") - reasoning: str = Field(default="", description="详细的推演逻辑和理由") - - # 逻辑传导 (ISQ Key 1) - transmission_chain: List[TransmissionNode] = Field(default_factory=list, description="产业链传导逻辑链条") - - # 信号质量 (ISQ Key 2) - 来自 isq_template.DEFAULT_ISQ_TEMPLATE - # 参考: src/schema/isq_template.py 的 DEFAULT_ISQ_TEMPLATE 定义 - sentiment_score: float = Field(default=0.0, description="[ISQ] 情绪/走势 (-1.0=极度看空 ~ 0.0=中性 ~ 1.0=极度看多)") - confidence: float = Field(default=0.5, description="[ISQ] 确定性 (0.0=不可信 ~ 1.0=完全确定)") - intensity: int = Field(default=3, description="[ISQ] 强度/影响量级 (1=微弱 ~ 5=极强)") - expectation_gap: float = Field(default=0.5, description="[ISQ] 预期差/博弈空间 (0.0=充分定价 ~ 1.0=巨大预期差)") - timeliness: float = Field(default=0.8, description="[ISQ] 时效性 (0.0=长期 ~ 1.0=超短期)") - - # 预测与博弈 (ISQ Key 3) - expected_horizon: str = Field(default="T+N", description="预期的反应时窗 (如: T+0, T+3, Long-term)") - price_in_status: str = Field(default="未知", description="市场预期消化程度 (未定价/部分定价/充分定价)") - - # 关联实体 - impact_tickers: List[Dict[str, Any]] = Field(default_factory=list, description="受影响的代码列表及其权重") - industry_tags: List[str] = Field(default_factory=list, description="关联行业标签") - - # 溯源 - sources: List[Dict[str, str]] = Field(default_factory=list, description="来源详情 (包含 title, url, source_name)") - -class ResearchContext(BaseModel): - """研究员搜集的背景信息结构""" - raw_signal: str = Field(..., description="原始信号内容") - tickers_found: List[Dict[str, Any]] = Field(default_factory=list, description="找到的相关标的及其基本面/股价信息") - industry_background: str = Field(..., description="行业背景及产业链现状") - latest_developments: List[str] = Field(default_factory=list, description="相关事件的最新进展") - key_risks: List[str] = Field(default_factory=list, description="潜在风险点") - search_results_summary: str = Field(..., description="搜索结果的综合摘要") - -class ScanContext(BaseModel): - """扫描员搜集的原始数据结构""" - hot_topics: List[str] = Field(..., description="当前市场热点话题") - news_summaries: List[Dict[str, Any]] = Field(..., description="关键新闻摘要列表") - market_data: Dict[str, Any] = Field(default_factory=dict, description="相关的市场行情数据") - sentiment_overview: str = Field(..., description="整体市场情绪概览") - raw_data_summary: str = Field(..., description="原始数据的综合摘要") - -class SignalCluster(BaseModel): - theme_title: str = Field(..., description="主题名称") - signal_ids: List[int] = Field(..., description="包含的信号 ID 列表") - rationale: str = Field(..., description="聚类理由") - -class ClusterContext(BaseModel): - """信号聚类结果结构""" - clusters: List[SignalCluster] = Field(..., description="聚类列表") - -class KLinePoint(BaseModel): - date: str = Field(..., description="日期") - open: float = Field(..., description="开盘价") - high: float = Field(..., description="最高价") - low: float = Field(..., description="最低价") - close: float = Field(..., description="收盘价") - volume: float = Field(..., description="成交量") - -class ForecastResult(BaseModel): - ticker: str = Field(..., description="股票代码") - base_forecast: List[KLinePoint] = Field(default_factory=list, description="Kronos 模型原始预测") - adjusted_forecast: List[KLinePoint] = Field(default_factory=list, description="LLM 调整后的预测") - rationale: str = Field(default="", description="预测调整理由及逻辑说明") - timestamp: str = Field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d %H:%M:%S"), description="生成时间") - -class InvestmentReport(BaseModel): - overall_sentiment: str = Field(..., description="整体市场情绪评价") - market_entropy: float = Field(..., description="市场分歧度 (0-1, 1代表极高分歧)") - signals: List[InvestmentSignal] = Field(..., description="深度解析的投资信号列表") - forecasts: List[ForecastResult] = Field(default_factory=list, description="相关标的的预测结果") - timestamp: str = Field(..., description="报告生成时间") - meta_info: Optional[Dict[str, Any]] = Field(default_factory=dict, description="其他元数据") diff --git a/skills/alphaear-reporter/scripts/tools/__init__.py b/skills/alphaear-reporter/scripts/tools/__init__.py deleted file mode 100644 index 97fbb5d..0000000 --- a/skills/alphaear-reporter/scripts/tools/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# src/tools/__init__.py -""" -AlphaEar 工具包层 - Agno Toolkit 适配器 - -提供的 Toolkit 类: -- NewsToolkit: 热点新闻获取 -- StockToolkit: 股票搜索与价格查询 -- SentimentToolkit: 情绪分析 -- SearchToolkit: 网络搜索 -""" - -from .toolkits import ( - NewsToolkit, - StockToolkit, - SentimentToolkit, - SearchToolkit, -) - -__all__ = [ - "NewsToolkit", - "StockToolkit", - "SentimentToolkit", - "SearchToolkit", -] diff --git a/skills/alphaear-reporter/scripts/tools/toolkits.py b/skills/alphaear-reporter/scripts/tools/toolkits.py deleted file mode 100644 index ebd0b69..0000000 --- a/skills/alphaear-reporter/scripts/tools/toolkits.py +++ /dev/null @@ -1,526 +0,0 @@ -""" -AlphaEar 工具包层 - Agno Toolkit 适配器 -复用 utils 中的底层工具实现,提供 Agno Agent 兼容的 Toolkit 接口 -""" -from datetime import datetime -from typing import Optional -from agno.tools import Toolkit -from loguru import logger - -from ..utils.database_manager import DatabaseManager -from ..utils.news_tools import NewsNowTools, PolymarketTools -from ..utils.stock_tools import StockTools -from ..utils.search_tools import SearchTools -from ..utils.sentiment_tools import SentimentTools - - -class NewsToolkit(Toolkit): - """ - 新闻工具包 - 包装 NewsNowTools 为 Agno Toolkit - - 提供热点新闻获取、内容提取等功能 - """ - - def __init__(self, db: DatabaseManager, **kwargs): - self._news_tools = NewsNowTools(db) - self._sources = self._news_tools.SOURCES - - tools = [ - self.fetch_hot_news, - self.fetch_news_content, - self.get_unified_trends, - self.enrich_news_content, - ] - super().__init__(name="news_toolkit", tools=tools, **kwargs) - - - def fetch_hot_news(self, source_id: str, count: int = 10) -> str: - """ - 从指定新闻源获取热点新闻列表。 - - Args: - source_id: 新闻源标识符。可选值按类别: - **金融类**: "cls" (财联社), "wallstreetcn" (华尔街见闻), "xueqiu" (雪球) - **综合类**: "weibo" (微博热搜), "zhihu" (知乎热榜), "baidu" (百度热搜), - "toutiao" (今日头条), "douyin" (抖音), "thepaper" (澎湃新闻) - **科技类**: "36kr" (36氪), "ithome" (IT之家), "v2ex", "juejin" (掘金), - "hackernews" (Hacker News) - 推荐金融分析使用 "cls", "wallstreetcn", "xueqiu"。 - count: 获取的新闻数量,默认 10 条。 - - Returns: - 热点新闻列表的文本描述,包含排名、标题和链接。如果源不可用则返回错误信息。 - """ - logger.info(f"🔧 [TOOL CALLED] fetch_hot_news(source_id={source_id}, count={count})") - - items = self._news_tools.fetch_hot_news(source_id, count=count, fetch_content=False) - - if not items: - return f"获取 {source_id} 热点失败" - - source_name = self._sources.get(source_id, source_id) - result = f"## {source_name} 热点 (获取时间: {datetime.now().strftime('%Y-%m-%d %H:%M')})\n\n" - - for item in items: - result += f"{item['rank']}. {item['title']}\n 链接: {item['url']}\n\n" - - logger.info(f"✅ [TOOL SUCCESS] Got {len(items)} news items from {source_id}") - return result - - def fetch_news_content(self, url: str) -> str: - """ - 使用 Jina Reader 抓取指定 URL 的网页正文内容。 - - Args: - url: 需要抓取内容的完整网页 URL,必须以 http:// 或 https:// 开头。 - - Returns: - 提取的网页正文内容,如果失败则返回错误信息。 - """ - content = self._news_tools.fetch_news_content(url) - if content: - return content[:5000] # 限制长度 - return "内容抓取失败" - - def get_unified_trends(self, sources: str = "wallstreetcn,cls") -> str: - """ - 获取多平台综合热点报告。 - - Args: - sources: 要扫描的新闻源,用逗号分隔。 - 可选值: weibo, zhihu, baidu, toutiao, wallstreetcn, cls - 默认: "wallstreetcn,cls" (金融资讯) - - Returns: - 格式化的热点汇总报告。 - """ - source_list = [s.strip() for s in sources.split(",")] - report = self._news_tools.get_unified_trends(source_list) - return report - - def enrich_news_content(self, source: str = None, limit: int = 5) -> str: - """ - 为数据库中缺少正文内容的新闻补充内容。 - - Args: - source: 筛选特定新闻源(如 "cls"),为空则处理所有。 - limit: 最多处理的新闻数量,默认 5 条。 - - Returns: - 处理结果的描述。 - """ - logger.info(f"🔧 [TOOL CALLED] enrich_news_content(source={source}, limit={limit})") - - # 获取需要补充内容的新闻 - news_items = self._news_tools.db.get_daily_news(source=source, limit=limit) - items_without_content = [n for n in news_items if not n.get('content')] - - if not items_without_content: - return "没有需要补充内容的新闻" - - updated_count = 0 - cursor = self._news_tools.db.conn.cursor() - - for item in items_without_content[:limit]: - url = item.get('url') - if url: - content = self._news_tools.fetch_news_content(url) - if content: - cursor.execute( - "UPDATE daily_news SET content = ? WHERE id = ?", - (content[:10000], item['id']) - ) - updated_count += 1 - - self._news_tools.db.conn.commit() - logger.info(f"✅ [TOOL SUCCESS] Enriched {updated_count} news items with content") - - return f"✅ 已为 {updated_count} 条新闻补充正文内容" - - -class PolymarketToolkit(Toolkit): - """ - Polymarket 预测市场工具包 - 获取热门预测市场数据 - - 预测市场数据可反映公众情绪、预期和关注度 - """ - - def __init__(self, db: DatabaseManager, **kwargs): - self._poly_tools = PolymarketTools(db) - - tools = [ - self.get_prediction_markets, - self.get_market_summary, - ] - super().__init__(name="polymarket_toolkit", tools=tools, **kwargs) - - def get_prediction_markets(self, limit: int = 20) -> str: - """ - 获取 Polymarket 活跃预测市场的关键数据。 - - 预测市场反映公众对重大事件的概率预期,可用于: - - 分析市场情绪和风险偏好 - - 了解热门话题的关注度 - - 获取重大事件的概率预期 - - Args: - limit: 获取的市场数量,默认 20 个。 - - Returns: - 预测市场数据列表,包含问题、结果概率和交易量。 - 如果获取失败返回错误信息。 - """ - logger.info(f"🔧 [TOOL CALLED] get_prediction_markets(limit={limit})") - - markets = self._poly_tools.get_active_markets(limit) - if not markets: - return "❌ 无法获取 Polymarket 数据(可能是网络问题)" - - result = f"## 🔮 Polymarket 热门预测 (共 {len(markets)} 个)\n\n" - for i, m in enumerate(markets[:limit], 1): - question = m.get("question", "Unknown") - prices = m.get("outcomePrices", []) - volume = m.get("volume", 0) - - result += f"{i}. **{question}**\n" - if prices: - result += f" 概率: {prices}\n" - if volume: - try: - result += f" 交易量: ${float(volume):,.0f}\n" - except: - result += f" 交易量: {volume}\n" - result += "\n" - - logger.info(f"✅ [TOOL SUCCESS] Got {len(markets)} prediction markets") - return result - - def get_market_summary(self, limit: int = 10) -> str: - """ - 获取预测市场摘要报告,了解当前热门话题和公众预期。 - - Args: - limit: 获取的市场数量,默认 10 个。 - - Returns: - 格式化的预测市场报告。 - """ - return self._poly_tools.get_market_summary(limit) - - -class StockToolkit(Toolkit): - - """ - 股票工具包 - 包装 StockTools 为 Agno Toolkit - - 提供股票搜索、价格查询等功能 - """ - - def __init__(self, db: DatabaseManager, **kwargs): - self._stock_tools = StockTools(db) - - tools = [ - self.search_ticker, - self.get_stock_price, - ] - super().__init__(name="stock_toolkit", tools=tools, **kwargs) - - def search_ticker(self, query: str) -> str: - """ - 模糊搜索 A 股股票代码或名称。 - - Args: - query: 搜索关键词,可以是股票代码(如 "600519")或名称关键词(如 "茅台"、"宁德"、"比亚迪")。 - - Returns: - 匹配的股票列表,包含代码和名称。 - """ - q = (query or "").strip() - # Guardrails: prevent overly generic queries that tend to return arbitrary "...股份" matches. - generic_terms = { - "股份", - "有限公司", - "概念股", - "受益股", - "龙头", - "标的", - "相关股票", - "合作概念股", - } - if not q: - return "查询为空,无法搜索股票" - if q in generic_terms: - return f"查询过于泛化({q}),为避免误匹配已拒绝。请提供更具体的公司名或6位代码。" - # If it's not a numeric code, require at least 2 non-space chars. - if not any(ch.isdigit() for ch in q) and len(q.replace(" ", "")) < 2: - return "查询过短,无法搜索股票。请提供更具体的公司名或6位代码。" - - results = self._stock_tools.search_ticker(query) - - if not results: - return f"未找到匹配 '{query}' 的股票" - - output = f"## 股票搜索结果 (关键词: {query})\n\n" - for r in results: - output += f"- {r['code']} - {r['name']}\n" - return output - - def get_stock_price(self, ticker: str, days: int = 30) -> str: - """ - 获取指定股票的近期价格走势。 - - Args: - ticker: 股票代码,如 "600519"(贵州茅台)或 "000001"(平安银行)。 - days: 查询天数,默认 30 天。 - - Returns: - 价格走势的文本摘要。 - """ - from datetime import timedelta - end_date = datetime.now().strftime('%Y-%m-%d') - start_date = (datetime.now() - timedelta(days=days)).strftime('%Y-%m-%d') - - df = self._stock_tools.get_stock_price(ticker, start_date, end_date) - - if df.empty: - return f"未能获取 {ticker} 的股价数据" - - - latest = df.iloc[-1] - change = ((latest['close'] - df.iloc[0]['close']) / df.iloc[0]['close']) * 100 - - # 格式化历史数据供 LLM 分析 (取最近 15 天) - history_df = df.tail(15).copy() - history_df['date'] = history_df['date'].astype(str) - # 简化列名以节省 token - history_cols = ['date', 'open', 'close', 'high', 'low', 'volume'] - - # 尝试使用 markdown 格式,如果失败退回到 string - try: - history_str = history_df[history_cols].to_markdown(index=False, numalign="left", stralign="left") - except ImportError: - history_str = history_df[history_cols].to_string(index=False) - except Exception: - history_str = history_df[history_cols].to_string(index=False) - - return f"""## {ticker} 价格走势 ({days}天) -- 当前价: ¥{latest['close']:.2f} -- 期间涨跌: {change:+.2f}% -- 最高/最低: ¥{df['high'].max():.2f} / ¥{df['low'].min():.2f} -- 数据范围: {df.iloc[0]['date']} -> {latest['date']} - -### 最近 15 个交易日详细数据 (OHLCV): -{history_str} -""" - - - -class SentimentToolkit(Toolkit): - """ - 情绪分析工具包 - 包装 SentimentTools 为 Agno Toolkit - - 提供文本情绪分析功能(支持 BERT 和 LLM 模式) - """ - - def __init__(self, db: DatabaseManager, mode: str = "auto", **kwargs): - self._sentiment_tools = SentimentTools(db, mode=mode) - self._db = db - - tools = [ - self.analyze_sentiment, - self.batch_update_sentiment, - ] - super().__init__(name="sentiment_toolkit", tools=tools, **kwargs) - - def analyze_sentiment(self, text: str) -> str: - """ - 分析文本的情绪极性。 - - Args: - text: 需要分析的文本内容,如新闻标题或摘要。 - - Returns: - 情绪分析结果,包含分值(-1.0到1.0)和标签(positive/negative/neutral)。 - """ - result = self._sentiment_tools.analyze_sentiment(text) - - score = result.get('score', 0.0) - label = result.get('label', 'neutral') - reason = result.get('reason', '') - - return f"""情绪分析结果: -- 文本: {text[:100]}{'...' if len(text) > 100 else ''} -- 分值: {score:.2f} -- 标签: {label} -- 分析: {reason}""" - - def batch_update_sentiment(self, source: str = None, limit: int = 20) -> str: - """ - 批量更新数据库中新闻的情绪分数。 - - Args: - source: 筛选特定新闻源(如 "cls", "wallstreetcn"),为空则处理所有。 - limit: 最多处理的新闻数量,默认 20 条。 - - Returns: - 更新结果的描述。 - """ - logger.info(f"🔧 [TOOL CALLED] batch_update_sentiment(source={source}, limit={limit})") - - count = self._sentiment_tools.batch_update_news_sentiment(source=source, limit=limit) - - return f"✅ 已更新 {count} 条新闻的情绪分数" - - - -class SearchToolkit(Toolkit): - """ - 搜索工具包 - 包装 SearchTools 为 Agno Toolkit - - 提供网络搜索功能(支持 Jina、DuckDuckGo 和百度) - - 当环境变量 JINA_API_KEY 设置时,默认使用 Jina Search, - 提供 LLM 友好的搜索结果。 - """ - - def __init__(self, db: DatabaseManager, **kwargs): - self._search_tools = SearchTools(db) - - tools = [ - self.web_search, - self.aggregate_search, - ] - super().__init__(name="search_toolkit", tools=tools, **kwargs) - - def web_search(self, query: str, engine: str = None, max_results: int = 5) -> str: - """ - 使用指定搜索引擎执行网络搜索。 - - Args: - query: 搜索关键词,如 "英伟达财报" 或 "光伏行业政策"。 - engine: 搜索引擎选择。可选值: - "jina" (Jina Search,需配置 JINA_API_KEY,LLM友好输出), - "ddg" (DuckDuckGo,推荐英文/国际搜索), - "baidu" (百度,推荐中文/国内搜索)。 - 默认: 若配置了 JINA_API_KEY 则使用 "jina",否则 "ddg"。 - max_results: 返回结果数量。默认 5。 - - Returns: - 搜索结果的文本描述。 - """ - return self._search_tools.search(query, engine=engine, max_results=max_results) - - def aggregate_search(self, query: str, max_results: int = 5) -> str: - """ - 同时使用多个搜索引擎搜索并聚合结果。 - - Args: - query: 搜索关键词。 - max_results: 每个引擎返回的最大结果数。默认 5。 - - Returns: - 聚合后的搜索结果。 - """ - return self._search_tools.aggregate_search(query, max_results=max_results) - - -class ContextSearchToolkit(Toolkit): - """ - 上下文搜索工具包 - 用于 RAG 场景的文档片段检索 - - 支持在内存中存储文档片段,并通过关键词搜索相关内容。 - 适用于 ReportAgent 的分段编辑场景。 - """ - - def __init__(self, **kwargs): - self._store = {} # {doc_id: {"title": str, "content": str, "summary": str}} - - tools = [ - self.search_context, - self.get_toc, - ] - super().__init__(name="context_search_toolkit", tools=tools, **kwargs) - - def add_document(self, doc_id: str, title: str, content: str, summary: str = ""): - """添加文档到存储(供外部调用,非 LLM 工具)""" - self._store[doc_id] = { - "title": title, - "content": content, - "summary": summary or content[:200] + "..." - } - logger.info(f"📄 Added document to context store: {doc_id} - {title[:30]}...") - - def clear(self): - """清空文档存储""" - self._store.clear() - logger.info("🗑️ Context store cleared") - - def search_context(self, query: str, max_results: int = 3) -> str: - """ - 在已存储的文档中搜索与查询相关的内容片段。 - - Args: - query: 搜索关键词,如 "消费板块" 或 "茅台 预测"。 - max_results: 返回的最大结果数,默认 3。 - - Returns: - 匹配的文档片段,按相关性排序。 - """ - logger.info(f"🔍 [TOOL CALLED] search_context(query={query}, max_results={max_results})") - - if not self._store: - return "⚠️ 上下文存储为空,无可搜索内容。" - - # 简单的关键词匹配 + 计分 - query_terms = query.lower().split() - results = [] - - for doc_id, doc in self._store.items(): - score = 0 - content_lower = doc["content"].lower() - title_lower = doc["title"].lower() - - for term in query_terms: - # 标题匹配权重更高 - if term in title_lower: - score += 3 - if term in content_lower: - score += content_lower.count(term) - - if score > 0: - results.append((score, doc_id, doc)) - - # 按分数排序 - results.sort(key=lambda x: x[0], reverse=True) - results = results[:max_results] - - if not results: - return f"未找到与 '{query}' 相关的内容。" - - output = f"## 搜索结果 (查询: {query})\n\n" - for score, doc_id, doc in results: - output += f"### [{doc_id}] {doc['title']}\n" - # 返回摘要而非全文,节省 token - output += f"{doc['summary']}\n\n" - - logger.info(f"✅ [TOOL SUCCESS] Found {len(results)} matching documents") - return output - - def get_toc(self) -> str: - """ - 获取当前存储的所有文档的目录(TOC)。 - - Returns: - 文档目录列表,包含 ID 和标题。 - """ - logger.info("🔍 [TOOL CALLED] get_toc()") - - if not self._store: - return "⚠️ 上下文存储为空。" - - output = "## 文档目录 (TOC)\n\n" - for doc_id, doc in self._store.items(): - output += f"- **[{doc_id}]** {doc['title']}\n" - - return output - diff --git a/skills/alphaear-reporter/scripts/utils/__init__.py b/skills/alphaear-reporter/scripts/utils/__init__.py deleted file mode 100644 index 27e1961..0000000 --- a/skills/alphaear-reporter/scripts/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# AlphaEar utils package diff --git a/skills/alphaear-reporter/scripts/utils/content_extractor.py b/skills/alphaear-reporter/scripts/utils/content_extractor.py deleted file mode 100644 index 133207a..0000000 --- a/skills/alphaear-reporter/scripts/utils/content_extractor.py +++ /dev/null @@ -1,122 +0,0 @@ -import requests -from requests.exceptions import RequestException, Timeout, ConnectionError -import os -import time -import json -import threading -from typing import Optional -from loguru import logger - - -class ContentExtractor: - """内容提取工具 - 主要接入 Jina Reader API""" - - JINA_BASE_URL = "https://r.jina.ai/" - - # 速率限制配置 (无 API Key 时:20 次/分钟) - _rate_limit_no_key = 20 # 每分钟最大请求数 - _rate_window = 60.0 # 时间窗口(秒) - _min_interval = 3.0 # 请求最小间隔(秒) - - # 类级别的速率限制状态 - _request_times = [] - _last_request_time = 0.0 - _lock = threading.Lock() - - @classmethod - def _wait_for_rate_limit(cls, has_api_key: bool) -> None: - """等待以满足速率限制要求""" - if has_api_key: - # 有 API Key 时,只需保持最小间隔 - time.sleep(0.5) - return - - with cls._lock: - current_time = time.time() - - # 1. 清理过期的请求记录 - cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window] - - # 2. 检查是否达到速率限制 - if len(cls._request_times) >= cls._rate_limit_no_key: - # 需要等待最旧的请求过期 - oldest = cls._request_times[0] - wait_time = cls._rate_window - (current_time - oldest) + 1.0 - if wait_time > 0: - logger.warning(f"⏳ Jina rate limit reached, waiting {wait_time:.1f}s...") - time.sleep(wait_time) - current_time = time.time() - cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window] - - # 3. 确保请求间隔不太快 - time_since_last = current_time - cls._last_request_time - if time_since_last < cls._min_interval: - sleep_time = cls._min_interval - time_since_last - time.sleep(sleep_time) - - # 4. 记录本次请求 - cls._request_times.append(time.time()) - cls._last_request_time = time.time() - - @classmethod - def extract_with_jina(cls, url: str, timeout: int = 30) -> Optional[str]: - """ - 使用 Jina Reader 提取网页正文内容 (Markdown 格式) - - 无 API Key 时自动限速:每分钟最多 20 次请求,每次间隔至少 3 秒 - """ - if not url or not url.startswith("http"): - return None - - logger.info(f"🕸️ Extracting content from: {url} via Jina...") - - headers = { - "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", - "Accept": "application/json" - } - - # 使用统一的 JINA_API_KEY - api_key = os.getenv("JINA_API_KEY") - has_api_key = bool(api_key and api_key.strip()) - - if has_api_key: - headers["Authorization"] = f"Bearer {api_key}" - - # 等待速率限制 - cls._wait_for_rate_limit(has_api_key) - - try: - # Jina Reader API - full_url = f"{cls.JINA_BASE_URL}{url}" - response = requests.get(full_url, headers=headers, timeout=timeout) - - if response.status_code == 200: - try: - data = response.json() - # Jina JSON 响应格式通常在 data.content - if isinstance(data, dict) and "data" in data: - return data["data"].get("content", "") - return data.get("content", response.text) - except (json.JSONDecodeError, TypeError): - return response.text - elif response.status_code == 429: - # 触发速率限制,等待后重试一次 - logger.warning(f"⚠️ Jina rate limit (429), waiting 60s before retry...") - time.sleep(60) - return cls.extract_with_jina(url, timeout) - else: - logger.warning(f"Jina extraction failed (Status {response.status_code}) for {url}") - return None - - except Timeout: - logger.error(f"Timeout during Jina extraction for {url}") - return None - except ConnectionError: - logger.error(f"Connection error during Jina extraction for {url}") - return None - except RequestException as e: - logger.error(f"Request error during Jina extraction: {e}") - return None - except Exception as e: - logger.error(f"Unexpected error during Jina extraction: {e}") - return None diff --git a/skills/alphaear-reporter/scripts/utils/database_manager.py b/skills/alphaear-reporter/scripts/utils/database_manager.py deleted file mode 100644 index cfc362b..0000000 --- a/skills/alphaear-reporter/scripts/utils/database_manager.py +++ /dev/null @@ -1,581 +0,0 @@ -import sqlite3 -import json -from datetime import datetime, date -from pathlib import Path -from typing import List, Dict, Optional, Any, Union -import pandas as pd -from loguru import logger - -class DatabaseManager: - """ - AlphaEar 数据库管理器 - 负责存储热点数据、搜索缓存和股价数据 - 使用 SQLite 进行持久化存储 - """ - - def __init__(self, db_path: str = "data/signal_flux.db"): - self.db_path = Path(db_path) - self.db_path.parent.mkdir(parents=True, exist_ok=True) - self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False) - self.conn.row_factory = sqlite3.Row - self._init_db() - logger.info(f"💾 Database initialized at {self.db_path}") - - def _init_db(self): - """初始化表结构""" - cursor = self.conn.cursor() - - # 1. 每日热点新闻表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS daily_news ( - id TEXT PRIMARY KEY, - source TEXT, - rank INTEGER, - title TEXT, - url TEXT, - content TEXT, - publish_time TEXT, - crawl_time TEXT, - sentiment_score REAL, - analysis TEXT, - meta_data TEXT - ) - """) - - # 尝试添加 analysis 列(如果表已存在但没有该列) - try: - cursor.execute("ALTER TABLE daily_news ADD COLUMN analysis TEXT") - except: - pass # 列已存在 - - - # 2. 搜索缓存表 (原有 JSON 缓存) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS search_cache ( - query_hash TEXT PRIMARY KEY, - query TEXT, - engine TEXT, - results TEXT, - timestamp TEXT - ) - """) - - # 2.5 搜索详情表 (展开的搜索结果) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS search_detail ( - id TEXT, - query_hash TEXT, - rank INTEGER, - title TEXT, - url TEXT, - content TEXT, - publish_time TEXT, - crawl_time TEXT, - sentiment_score REAL, - source TEXT, - meta_data TEXT, - PRIMARY KEY (query_hash, id) - ) - """) - - # 3. 股价数据表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS stock_prices ( - ticker TEXT, - date TEXT, - open REAL, - close REAL, - high REAL, - low REAL, - volume REAL, - change_pct REAL, - PRIMARY KEY (ticker, date) - ) - """) - - # 4. 股票列表表 (用于检索) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS stock_list ( - code TEXT PRIMARY KEY, - name TEXT - ) - """) - - # 5. 投资信号表 (ISQ Framework) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS signals ( - signal_id TEXT PRIMARY KEY, - title TEXT, - summary TEXT, - transmission_chain TEXT, - sentiment_score REAL, - confidence REAL, - intensity INTEGER, - expected_horizon TEXT, - price_in_status TEXT, - impact_tickers TEXT, - industry_tags TEXT, - sources TEXT, - user_id TEXT, - created_at TEXT - ) - """) - - - - # 6. 创建索引以优化查询性能 - cursor.execute("CREATE INDEX IF NOT EXISTS idx_news_crawl_time ON daily_news(crawl_time)") - cursor.execute("CREATE INDEX IF NOT EXISTS idx_news_source ON daily_news(source)") - cursor.execute("CREATE INDEX IF NOT EXISTS idx_search_cache_timestamp ON search_cache(timestamp)") - cursor.execute("CREATE INDEX IF NOT EXISTS idx_stock_prices_ticker_date ON stock_prices(ticker, date)") - # 尝试添加 user_id 列到 signals 表 - try: - cursor.execute("ALTER TABLE signals ADD COLUMN user_id TEXT") - except: - pass - - cursor.execute("CREATE INDEX IF NOT EXISTS idx_signals_user_id ON signals(user_id)") - - self.conn.commit() - - # - # self.conn.commit() - - - # --- 新闻数据操作 --- - - def save_daily_news(self, news_list: List[Dict]) -> int: - """保存热点新闻,包含发布时间与抓取时间""" - cursor = self.conn.cursor() - count = 0 - crawl_time = datetime.now().isoformat() - - for news in news_list: - try: - # 兼容不同来源的 ID 生成逻辑 - news_id = news.get('id') or f"{news.get('source')}_{news.get('rank')}_{crawl_time[:10]}" - cursor.execute(""" - INSERT OR REPLACE INTO daily_news - (id, source, rank, title, url, content, publish_time, crawl_time, sentiment_score, meta_data) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - news_id, - news.get('source'), - news.get('rank'), - news.get('title'), - news.get('url'), - news.get('content', ''), - news.get('publish_time'), # 新增支持发布时间 - crawl_time, - news.get('sentiment_score'), - json.dumps(news.get('meta_data', {})) - )) - count += 1 - except sqlite3.Error as e: - logger.error(f"Database error saving news item {news.get('title')}: {e}") - except Exception as e: - logger.error(f"Unexpected error saving news item {news.get('title')}: {e}") - - self.conn.commit() - return count - - def get_daily_news(self, source: Optional[str] = None, limit: int = 100, days: int = 1) -> List[Dict]: - """获取最近 N 天的热点新闻""" - cursor = self.conn.cursor() - # 使用 crawl_time 过滤,保证结果的新鲜度 - time_threshold = (datetime.now().timestamp() - days * 86400) - time_threshold_str = datetime.fromtimestamp(time_threshold).isoformat() - - query = "SELECT * FROM daily_news WHERE crawl_time >= ?" - params = [time_threshold_str] - - if source: - query += " AND source = ?" - params.append(source) - - query += " ORDER BY crawl_time DESC, rank LIMIT ?" - params.append(limit) - - cursor.execute(query, params) - return [dict(row) for row in cursor.fetchall()] - - def lookup_reference_by_url(self, url: str) -> Optional[Dict[str, Any]]: - """Best-effort lookup of a source item by URL. - - This is used to render a stable bibliography from DB-backed metadata. - It searches both `daily_news` and `search_detail`. - """ - url = (url or "").strip() - if not url: - return None - - cursor = self.conn.cursor() - - try: - cursor.execute( - """ - SELECT title, source, publish_time, crawl_time, url - FROM daily_news - WHERE url = ? - ORDER BY crawl_time DESC - LIMIT 1 - """, - (url,), - ) - row = cursor.fetchone() - if row: - return dict(row) - except Exception: - pass - - try: - cursor.execute( - """ - SELECT title, source, publish_time, crawl_time, url - FROM search_detail - WHERE url = ? - ORDER BY crawl_time DESC - LIMIT 1 - """, - (url,), - ) - row = cursor.fetchone() - if row: - return dict(row) - except Exception: - pass - - return None - - def delete_news(self, news_id: str) -> bool: - """删除特定新闻""" - cursor = self.conn.cursor() - cursor.execute("DELETE FROM daily_news WHERE id = ?", (news_id,)) - self.conn.commit() - return cursor.rowcount > 0 - - def update_news_content(self, news_id: str, content: str = None, analysis: str = None) -> bool: - """更新新闻的内容或分析结果""" - cursor = self.conn.cursor() - updates = [] - params = [] - - if content is not None: - updates.append("content = ?") - params.append(content) - if analysis is not None: - updates.append("analysis = ?") - params.append(analysis) - - if not updates: - return False - - params.append(news_id) - query = f"UPDATE daily_news SET {', '.join(updates)} WHERE id = ?" - cursor.execute(query, params) - self.conn.commit() - return cursor.rowcount > 0 - - # --- 搜索缓存辅助 --- - - def get_search_cache(self, query_hash: str, ttl_seconds: Optional[int] = None) -> Optional[Dict]: - """获取搜索缓存 (优先查 search_detail)""" - cursor = self.conn.cursor() - - # 1. 尝试从 search_detail 获取展开的结构化数据 - cursor.execute(""" - SELECT * FROM search_detail - WHERE query_hash = ? - ORDER BY rank - """, (query_hash,)) - details = [dict(row) for row in cursor.fetchall()] - - if details: - # 检查 TTL (取第一条的时间) - first_time = datetime.fromisoformat(details[0]['crawl_time']) - if ttl_seconds and (datetime.now() - first_time).total_seconds() > ttl_seconds: - logger.info(f"⌛ Detailed cache expired for hash {query_hash}") - pass # Expired, fall through or return None? If Detail expired, Cache likely expired too. - # But let's check basic cache just in case metadata differs? - # Actually if details exist, we prefer them. If expired, we return None. - return None - - logger.info(f"✅ Hit detailed search cache for {query_hash} ({len(details)} items)") - # Reconstruct the expected 'results' list format for SearchTools - # SearchTools expects a list of dicts. - # We return a dict wrapper to match get_search_cache signature returning Dict usually containing 'results' string. - # But SearchTools logic: - # cache = db.get_search_cache(...) - # cached_data = json.loads(cache['results']) - - # To minimize SearchTools changes, we can return a dict mimicking the old structure - # OR Change SearchTools to handle list return. - # Let's return a special dict that SearchTools can recognize or just format it as before. - return {"results": json.dumps(details), "timestamp": details[0]['crawl_time']} - - # 2. Fallback to old table - cursor.execute("SELECT * FROM search_cache WHERE query_hash = ?", (query_hash,)) - row = cursor.fetchone() - - if not row: - return None - - row_dict = dict(row) - if ttl_seconds: - cache_time = datetime.fromisoformat(row_dict['timestamp']) - if (datetime.now() - cache_time).total_seconds() > ttl_seconds: - logger.info(f"⌛ Cache expired for hash {query_hash}") - return None - - return row_dict - - def save_search_cache(self, query_hash: str, query: str, engine: str, results: Union[str, List[Dict]]): - """保存搜索结果 (同时保存到 search_cache 和 search_detail)""" - cursor = self.conn.cursor() - current_time = datetime.now().isoformat() - - results_str = results if isinstance(results, str) else json.dumps(results) - - # 1. Save summary to search_cache - cursor.execute(""" - INSERT OR REPLACE INTO search_cache (query_hash, query, engine, results, timestamp) - VALUES (?, ?, ?, ?, ?) - """, (query_hash, query, engine, results_str, current_time)) - - # 2. Save details to search_detail if results is a list - if isinstance(results, list): - for item in results: - try: - item_id = item.get('id') or f"{hash(item.get('url', ''))}" - cursor.execute(""" - INSERT OR REPLACE INTO search_detail - (id, query_hash, rank, title, url, content, publish_time, crawl_time, sentiment_score, source, meta_data) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - str(item_id), - query_hash, - item.get('rank', 0), - item.get('title'), - item.get('url'), - item.get('content', ''), - item.get('publish_time'), - item.get('crawl_time') or current_time, - item.get('sentiment_score'), - item.get('source'), - json.dumps(item.get('meta_data', {})) - )) - except sqlite3.Error as e: - logger.error(f"Database error saving search detail {item.get('title')}: {e}") - except Exception as e: - logger.error(f"Unexpected error saving search detail {item.get('title')}: {e}") - - self.conn.commit() - - def find_similar_queries(self, query: str, limit: int = 5) -> List[Dict]: - """模糊搜索相似的已缓存查询""" - cursor = self.conn.cursor() - - # Simple fuzzy match: query in cached OR cached in query - q_wild = f"%{query}%" - cursor.execute(""" - SELECT query, query_hash, timestamp, results - FROM search_cache - WHERE query LIKE ? OR ? LIKE ('%' || query || '%') - ORDER BY timestamp DESC - LIMIT ? - """, (q_wild, query, limit)) - - return [dict(row) for row in cursor.fetchall()] - - def search_local_news(self, query: str, limit: int = 5) -> List[Dict]: - """从本地 daily_news 搜索相关新闻""" - cursor = self.conn.cursor() - q_wild = f"%{query}%" - # Search title and content - cursor.execute(""" - SELECT * FROM daily_news - WHERE title LIKE ? OR content LIKE ? - ORDER BY crawl_time DESC - LIMIT ? - """, (q_wild, q_wild, limit)) - return [dict(row) for row in cursor.fetchall()] - - # --- 股票数据操作 --- - - def save_stock_list(self, df: pd.DataFrame): - """保存股票列表到 stock_list 表""" - cursor = self.conn.cursor() - try: - # 清空旧表 - cursor.execute("DELETE FROM stock_list") - - # 批量插入 - data = df[['code', 'name']].to_dict('records') - cursor.executemany( - "INSERT INTO stock_list (code, name) VALUES (:code, :name)", - data - ) - self.conn.commit() - except sqlite3.Error as e: - logger.error(f"Database error saving stock list: {e}") - except Exception as e: - logger.error(f"Unexpected error saving stock list: {e}") - - def search_stock(self, query: str, limit: int = 5) -> List[Dict]: - """模糊搜索股票代码或名称""" - cursor = self.conn.cursor() - wild = f"%{query}%" - cursor.execute(""" - SELECT code, name FROM stock_list - WHERE code LIKE ? OR name LIKE ? - LIMIT ? - """, (wild, wild, limit)) - return [dict(row) for row in cursor.fetchall()] - - def get_stock_by_code(self, code: str) -> Optional[Dict[str, str]]: - """精确按代码获取股票信息。 - - Args: - code: 股票代码(A股6位 / 港股5位),必须为纯数字字符串。 - - Returns: - dict: {"code": str, "name": str} 或 None。 - """ - if not code: - return None - clean = "".join([c for c in str(code).strip() if c.isdigit()]) - if not clean: - return None - - cursor = self.conn.cursor() - cursor.execute("SELECT code, name FROM stock_list WHERE code = ? LIMIT 1", (clean,)) - row = cursor.fetchone() - return dict(row) if row else None - - def save_stock_prices(self, ticker: str, df: pd.DataFrame): - """保存股价历史数据""" - if df.empty: - return - - cursor = self.conn.cursor() - - # 确保 DataFrame 有必要的列 - required_cols = ['date', 'open', 'close', 'high', 'low', 'volume', 'change_pct'] - for col in required_cols: - if col not in df.columns: - logger.warning(f"Missing column {col} in stock data for {ticker}") - return - - try: - for _, row in df.iterrows(): - cursor.execute(""" - INSERT OR REPLACE INTO stock_prices - (ticker, date, open, close, high, low, volume, change_pct) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, ( - ticker, - row['date'], - row['open'], - row['close'], - row['high'], - row['low'], - row['volume'], - row['change_pct'] - )) - self.conn.commit() - except sqlite3.Error as e: - logger.error(f"Database error saving stock prices for {ticker}: {e}") - except Exception as e: - logger.error(f"Unexpected error saving stock prices for {ticker}: {e}") - - def get_stock_prices(self, ticker: str, start_date: str, end_date: str) -> pd.DataFrame: - """获取指定日期范围的股价数据""" - cursor = self.conn.cursor() - - cursor.execute(""" - SELECT * FROM stock_prices - WHERE ticker = ? AND date >= ? AND date <= ? - ORDER BY date - """, (ticker, start_date, end_date)) - - rows = cursor.fetchall() - if not rows: - return pd.DataFrame() - - columns = ['ticker', 'date', 'open', 'close', 'high', 'low', 'volume', 'change_pct'] - return pd.DataFrame([dict(row) for row in rows], columns=columns) - - def execute_query(self, query: str, params: tuple = ()) -> List[Any]: - """执行自定义 SQL 查询""" - try: - cursor = self.conn.cursor() - cursor.execute(query, params) - if query.strip().upper().startswith("SELECT"): - return cursor.fetchall() - else: - self.conn.commit() - return [] - except sqlite3.Error as e: - logger.error(f"SQL execution failed (Database error): {e}") - return [] - except Exception as e: - logger.error(f"SQL execution failed (Unexpected error): {e}") - return [] - - # --- 投资信号操作 (ISQ Framework) --- - - def save_signal(self, signal: Dict[str, Any]): - """保存投资信号""" - cursor = self.conn.cursor() - created_at = datetime.now().isoformat() - - cursor.execute(""" - INSERT OR REPLACE INTO signals - (signal_id, title, summary, transmission_chain, sentiment_score, - confidence, intensity, expected_horizon, price_in_status, - impact_tickers, industry_tags, sources, user_id, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - signal.get('signal_id'), - signal.get('title'), - signal.get('summary'), - json.dumps(signal.get('transmission_chain', [])), - signal.get('sentiment_score', 0.0), - signal.get('confidence', 0.0), - signal.get('intensity', 1), - signal.get('expected_horizon', 'T+0'), - signal.get('price_in_status', '未知'), - json.dumps(signal.get('impact_tickers', [])), - json.dumps(signal.get('industry_tags', [])), - json.dumps(signal.get('sources', [])), - signal.get('user_id'), - created_at - )) - self.conn.commit() - - def get_recent_signals(self, limit: int = 20, user_id: Optional[str] = None) -> List[Dict]: - """获取最近的投资信号""" - cursor = self.conn.cursor() - if user_id: - cursor.execute("SELECT * FROM signals WHERE user_id = ? ORDER BY created_at DESC LIMIT ?", (user_id, limit)) - else: - cursor.execute("SELECT * FROM signals ORDER BY created_at DESC LIMIT ?", (limit,)) - rows = cursor.fetchall() - - signals = [] - for row in rows: - d = dict(row) - # 解析 JSON 字段 - for field in ['transmission_chain', 'impact_tickers', 'industry_tags', 'sources']: - if d.get(field): - try: - d[field] = json.loads(d[field]) - except: - pass - signals.append(d) - return signals - - def close(self): - if self.conn: - self.conn.close() - logger.info("Database connection closed.") - diff --git a/skills/alphaear-reporter/scripts/utils/hybrid_search.py b/skills/alphaear-reporter/scripts/utils/hybrid_search.py deleted file mode 100644 index c597fee..0000000 --- a/skills/alphaear-reporter/scripts/utils/hybrid_search.py +++ /dev/null @@ -1,216 +0,0 @@ -import numpy as np -import os -from typing import List, Dict, Any, Optional, Union -from rank_bm25 import BM25Okapi -from loguru import logger -from sentence_transformers import SentenceTransformer -from sklearn.metrics.pairwise import cosine_similarity - -class HybridSearcher: - """ - 统一混合检索引擎 (Hybrid RAG) - 实现 BM25 (文本) + 向量 (语义) 的融合搜索 (RRF) - """ - - def __init__(self, data: List[Dict[str, Any]], text_fields: List[str] = ["title", "content"], model_name: str = None): - """ - 初始化搜索器 - - Args: - data: 数据列表,每个元素为 Dict - text_fields: 用于建立索引的文本字段 - model_name: 向量模型名称,默认使用 paraphrase-multilingual-MiniLM-L12-v2 - """ - self.data = data - self.text_fields = text_fields - self._corpus = [] - self._bm25 = None - self._vector_model = None - self._embeddings = None - self._fitted = False - self._vector_fitted = False - - # 默认模型 - self.model_name = model_name or os.getenv("EMBEDDING_MODEL", "paraphrase-multilingual-MiniLM-L12-v2") - - if data: - self._prepare_corpus() - self._fit_bm25() - # 延迟加载向量模型,仅在需要时或初始化时显式调用 - # self._fit_vector() - - def _prepare_corpus(self): - """准备语料库用于分词""" - import jieba # 使用 jieba 进行中文分词 - - self._corpus = [] - self._full_texts = [] - for item in self.data: - text = " ".join([str(item.get(field, "")) for field in self.text_fields]) - self._full_texts.append(text) - # 中文分词优化 - tokens = list(jieba.cut(text)) - self._corpus.append(tokens) - - def _fit_bm25(self): - """训练 BM25 模型""" - if self._corpus: - self._bm25 = BM25Okapi(self._corpus) - self._fitted = True - logger.info(f"✅ BM25 index fitted with {len(self.data)} documents") - - def _fit_vector(self): - """训练向量模型并生成 Embeddings""" - if not self.data: - return - - try: - logger.info(f"📡 Loading embedding model: {self.model_name}...") - self._vector_model = SentenceTransformer(self.model_name) - logger.info(f"🧠 Encoding {len(self._full_texts)} documents...") - self._embeddings = self._vector_model.encode(self._full_texts, show_progress_bar=False) - self._vector_fitted = True - logger.info("✅ Vector index fitted successfully") - except Exception as e: - logger.error(f"❌ Failed to fit vector index: {e}") - self._vector_fitted = False - - def _compute_rrf(self, rank_lists: List[List[int]], k: int = 60) -> List[tuple]: - """ - 计算 Reciprocal Rank Fusion (RRF) - - Args: - rank_lists: 多个排序后的索引列表 - k: RRF 常数,默认 60 - """ - scores = {} - for rank_list in rank_lists: - for rank, idx in enumerate(rank_list): - if idx not in scores: - scores[idx] = 0 - scores[idx] += 1.0 / (k + rank + 1) - - # 按分数排序 - sorted_indices = sorted(scores.items(), key=lambda x: x[1], reverse=True) - return sorted_indices - - def search(self, query: str, top_n: int = 5, use_vector: bool = False) -> List[Dict[str, Any]]: - """ - 执行混合搜索 - - Args: - query: 搜索关键词 - top_n: 返回结果数量 - use_vector: 是否启用向量搜索 - """ - if not self._fitted or not query: - return [] - - import jieba - query_tokens = list(jieba.cut(query)) - - # 1. BM25 搜索结果 - bm25_scores = self._bm25.get_scores(query_tokens) - bm25_rank = np.argsort(bm25_scores)[::-1].tolist() - - rank_lists = [bm25_rank] - - # 2. 向量搜索逻辑 - if use_vector: - if not self._vector_fitted: - self._fit_vector() - - if self._vector_fitted: - query_embedding = self._vector_model.encode([query], show_progress_bar=False) - similarities = cosine_similarity(query_embedding, self._embeddings)[0] - vector_rank = np.argsort(similarities)[::-1].tolist() - rank_lists.append(vector_rank) - else: - logger.warning("Vector search requested but model not fitted, falling back to BM25") - - # 3. 融合排序 (RRF) - if len(rank_lists) > 1: - rrf_results = self._compute_rrf(rank_lists) - # RRF 返回 (idx, score) 列表 - final_rank = [idx for idx, score in rrf_results] - else: - final_rank = bm25_rank - - # 返回前 top_n 条结果 - results = [self.data[idx].copy() for idx in final_rank[:top_n]] - - # 为每个结果注入相关性评分 - for i, res in enumerate(results): - try: - original_idx = final_rank[i] - res["_search_score"] = bm25_scores[original_idx] - if use_vector and self._vector_fitted: - res["_vector_score"] = float(similarities[original_idx]) - except: - res["_search_score"] = 0 - - return results - -class InMemoryRAG(HybridSearcher): - """专门用于 ReportAgent 跨章节检索的内存态 RAG""" - - def search(self, query: str, top_n: int = 3, use_vector: bool = True) -> List[Dict[str, Any]]: - """默认开启向量搜索的内存检索""" - return super().search(query, top_n=top_n, use_vector=use_vector) - - def update_data(self, new_data: List[Dict[str, Any]]): - """动态更新数据并重新训练索引""" - self.data = new_data - self._prepare_corpus() - self._fit_bm25() - # 如果之前已经加载过向量模型,则更新向量索引 - if self._vector_model: - self._fit_vector() - logger.info(f"🔄 InMemoryRAG updated with {len(new_data)} items") - -class LocalNewsSearch(HybridSearcher): - """持久态 RAG:检索数据库中的历史新闻""" - - def __init__(self, db_manager): - """ - Args: - db_manager: DatabaseManager 实例 - """ - self.db = db_manager - # 初始时不加载数据,需调用 load_history - super().__init__([], ["title", "content"]) - - def load_history(self, days: int = 30, limit: int = 1000): - """从数据库加载最近 N 天的新闻构建索引""" - try: - # 假设 db_manager 有 execute_query - query = f"SELECT title, content, publish_time, source FROM daily_news ORDER BY publish_time DESC LIMIT ?" - results = self.db.execute_query(query, (limit,)) - - data = [] - for row in results: - # 转换 Row 为 Dict - if hasattr(row, 'keys'): - item = dict(row) - else: - item = { - "title": row[0], - "content": row[1], - "publish_time": row[2], - "source": row[3] - } - data.append(item) - - self.data = data - self._prepare_corpus() - self._fit_bm25() - # 默认不立即训练向量,等到第一次搜索时按需训练 - logger.info(f"📚 LocalNewsSearch loaded {len(data)} items from history") - except Exception as e: - logger.error(f"Failed to load history for search: {e}") - - def search(self, query: str, top_n: int = 5, use_vector: bool = True) -> List[Dict[str, Any]]: - """执行本地历史搜索,默认开启向量搜索""" - if not self.data: - self.load_history() - return super().search(query, top_n=top_n, use_vector=use_vector) diff --git a/skills/alphaear-reporter/scripts/utils/json_utils.py b/skills/alphaear-reporter/scripts/utils/json_utils.py deleted file mode 100644 index c29aab2..0000000 --- a/skills/alphaear-reporter/scripts/utils/json_utils.py +++ /dev/null @@ -1,180 +0,0 @@ -import ast -import json -import re -from typing import Optional, Any -from loguru import logger - -def _strip_comments(text: str) -> str: - """ - Safely remove C-style comments (// and /* */) from JSON-like text, - preserving strings (including URLs like http://). - """ - result = [] - i = 0 - n = len(text) - in_string = False - escape = False - - while i < n: - char = text[i] - - if in_string: - if char == '\\': - escape = not escape - elif char == '"' and not escape: - in_string = False - else: - escape = False - result.append(char) - i += 1 - continue - - # Not in string - if char == '"': - in_string = True - result.append(char) - i += 1 - continue - - # Check for // comment - if i + 1 < n and text[i:i+2] == '//': - i += 2 - while i < n and text[i] != '\n': - i += 1 - continue - - # Check for /* comment - if i + 1 < n and text[i:i+2] == '/*': - i += 2 - while i + 1 < n and text[i:i+2] != '*/': - i += 1 - i += 2 - continue - - result.append(char) - i += 1 - - return ''.join(result) - -def extract_json(text: str) -> Optional[Any]: - """ - 更加鲁棒的 JSON 提取工具。 - 处理: - 1. Markdown 代码块 (```json ... ```) - 2. 首尾多余字符 - 3. 同一个文本中多个 JSON 对象 (仅提取第一个) - 4. 简单的 JSON 修复 (末尾逗号等) - 5. C 风格注释 (// 和 /* */) - """ - if not text: - return None - - # 1. 清理明显的 Markdown 包装 - text = text.strip() - - # 先尝试精确匹配 ```json ... ``` 或 ```...``` - md_match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', text, re.DOTALL) - if md_match: - text = md_match.group(1).strip() - elif text.startswith("```"): - # 回退:如果开头有 ``` 但没完整匹配 - text = re.sub(r'^```[a-z]*\n?', '', text) - text = re.sub(r'\n?```\s*$', '', text) - - # 2. 寻找第一个 JSON 起始符 { 或 [ - start_brace = text.find('{') - start_bracket = text.find('[') - - if start_brace == -1 and start_bracket == -1: - return None - - start_idx = start_brace if (start_bracket == -1 or (start_brace != -1 and start_brace < start_bracket)) else start_bracket - - # 2.5 预处理:修复一些极其常见的 LLM 错误 - potential_json = text[start_idx:].strip() - - # remove comments safely - potential_json = _strip_comments(potential_json) - - # b. 修复缺失开头引号的键: nodes": [ -> "nodes": [ - # 匹配模式: (空白或换行) 单词 紧跟引号和冒号 - potential_json = re.sub(r'([\{\,]\s*)([a-zA-Z_]\w*)\"\s*:', r'\1"\2":', potential_json) - - # c. 修复缺失末尾引号的键: "nodes: [ -> "nodes": [ - potential_json = re.sub(r'([\{\,]\s*)\"([a-zA-Z_]\w*)\s*:', r'\1"\2":', potential_json) - - # d. 修复完全缺失引号的键: nodes: [ -> "nodes": [ - # 注意避免匹配到像 http:// 这种内容,所以限定在 { 或 , 之后 - potential_json = re.sub(r'([\{\,]\s*)([a-zA-Z_]\w*)\s*:', r'\1"\2":', potential_json) - - # 3. 使用 raw_decode 尝试解析 - decoder = json.JSONDecoder() - - # 首先尝试直接解析(不做任何预处理) - try: - obj = json.loads(potential_json) - return obj - except json.JSONDecodeError: - pass - - # 简单预处理:移除对象/列表末位多余逗号 - processed_json = re.sub(r',\s*([\]}])', r'\1', potential_json) - - try: - obj, end_pos = decoder.raw_decode(processed_json) - return obj - except json.JSONDecodeError: - pass - - # e. 修复未终止的字符串字面量问题:移除值中的实际换行符 - # LLM 可能在字符串值中生成包含真实 newline 的内容,导致 JSON 非法 - def fix_multiline_strings(s): - # 简单策略:将字符串值内的换行替换为空格 - lines = s.split('\n') - result = [] - in_string = False - for line in lines: - # 计算未转义的引号数 - quote_count = line.count('"') - line.count('\\"') - if in_string: - result[-1] += ' ' + line.strip() - else: - result.append(line) - - if quote_count % 2 == 1: - in_string = not in_string - return '\n'.join(result) - - fixed_json = fix_multiline_strings(processed_json) - - try: - obj, end_pos = decoder.raw_decode(fixed_json) - return obj - except json.JSONDecodeError: - try: - # 4. 尝试处理单引号问题 (JSON 规范要求双引号,但 LLM 常输出单引号) - # 这是一个简单的替换技巧,仅针对像 {'key': 'value'} 这样的结构 - # 注意:这可能会破坏包含单引号的字符串值,所以作为较后的回退 - fix_quotes = re.sub(r"'(.*?)':", r'"\1":', processed_json) # 修复键 - fix_quotes = re.sub(r":\s*'(.*?)'", r': "\1"', fix_quotes) # 修复简单值 - obj, end_pos = decoder.raw_decode(fix_quotes) - return obj - except (json.JSONDecodeError, TypeError): - try: - # 5. 使用 ast.literal_eval 作为终极回退 (处理 Python 字典格式) - # 提取第一个匹配的括号对内容 - # 寻找匹配的 { } - stack = [] - for i, char in enumerate(potential_json): - if char == '{': stack.append('{') - elif char == '}': - if stack: stack.pop() - if not stack: - content = potential_json[:i+1] - return ast.literal_eval(content) - except (ValueError, SyntaxError, MemoryError) as e: - logger.warning(f"All JSON extraction attempts failed: {e}") - except Exception as e: - logger.error(f"Unexpected error during JSON extraction: {e}") - - return None diff --git a/skills/alphaear-reporter/scripts/utils/llm/capability.py b/skills/alphaear-reporter/scripts/utils/llm/capability.py deleted file mode 100644 index d07ca4f..0000000 --- a/skills/alphaear-reporter/scripts/utils/llm/capability.py +++ /dev/null @@ -1,85 +0,0 @@ -import os -from typing import Optional, List, Dict, Any -from agno.agent import Agent -from agno.models.base import Model -from loguru import logger -from ..llm.factory import get_model - - -def test_tool_call_support(model: Model) -> bool: - """ - 测试模型是否支持原生的 Tool Call (Function Calling)。 - 通过尝试执行一个简单的加法工具来验证。 - """ - - def get_current_weather(location: str): - """获取指定地点的天气""" - return f"{location} 的天气是晴天,25度。" - - test_agent = Agent( - model=model, - tools=[get_current_weather], - instructions="请调用工具查询北京的天气,并直接返回工具的输出结果。", - ) - - try: - # 运行一个简单的任务,观察是否触发了 tool_call - response = test_agent.run("北京天气怎么样?") - - # 检查 response 中是否包含 tool_calls - # Agno 的 RunResponse 对象通常包含 messages,我们可以检查最后几条消息 - has_tool_call = False - for msg in response.messages: - if hasattr(msg, "tool_calls") and msg.tool_calls: - has_tool_call = True - break - - if has_tool_call: - logger.info(f"✅ Model {model.id} supports native tool calling.") - return True - else: - # 如果没有 tool_calls 但返回了正确答案,可能是模型通过纯文本模拟了工具调用(ReAct) - # 或者根本没用工具。对于原生支持的判断,我们坚持要求有 tool_calls 结构。 - logger.warning( - f"⚠️ Model {model.id} did NOT use native tool calling structure." - ) - return False - - except Exception as e: - logger.error(f"❌ Error testing tool call for {model.id}: {e}") - return False - - -class ModelCapabilityRegistry: - """ - 模型能力注册表,用于缓存和管理不同模型的能力测试结果。 - """ - - _cache = {} - - @classmethod - def get_capabilities( - cls, provider: str, model_id: str, **kwargs - ) -> Dict[str, bool]: - key = f"{provider}:{model_id}" - if key not in cls._cache: - logger.info(f"🔍 Testing capabilities for {key}...") - model = get_model(provider, model_id, **kwargs) - supports_tool_call = test_tool_call_support(model) - cls._cache[key] = {"supports_tool_call": supports_tool_call} - return cls._cache[key] - - -if __name__ == "__main__": - import os - from skills._env_loader import load_unified_env - - load_unified_env() - - # 测试当前配置的模型 - p = os.getenv("LLM_PROVIDER", "minimax") - m = os.getenv("LLM_MODEL", "Qwen") - - print(f"Testing {p}/{m}...") - res = ModelCapabilityRegistry.get_capabilities(p, m) - print(f"Result: {res}") diff --git a/skills/alphaear-reporter/scripts/utils/llm/factory.py b/skills/alphaear-reporter/scripts/utils/llm/factory.py deleted file mode 100644 index 09b6ea5..0000000 --- a/skills/alphaear-reporter/scripts/utils/llm/factory.py +++ /dev/null @@ -1,114 +0,0 @@ -import os -from agno.models.openai import OpenAIChat -from agno.models.ollama import Ollama -from agno.models.dashscope import DashScope -from agno.models.deepseek import DeepSeek -from agno.models.openrouter import OpenRouter - -def get_model(model_provider: str, model_id: str, **kwargs): - """ - Factory to get the appropriate LLM model. - - Args: - model_provider: "openai", "ollama", "deepseek" - model_id: The specific model ID (e.g., "gpt-4o", "llama3", "deepseek-chat") - **kwargs: Additional arguments for the model constructor - """ - if model_provider == "openai": - return OpenAIChat(id=model_id, **kwargs) - - elif model_provider == "ollama": - return Ollama(id=model_id, **kwargs) - - elif model_provider == "deepseek": - # DeepSeek is OpenAI compatible - api_key = os.getenv("DEEPSEEK_API_KEY") - if not api_key: - print("Warning: DEEPSEEK_API_KEY not set.") - - return DeepSeek( - id=model_id, - api_key=api_key, - **kwargs - ) - elif model_provider == "dashscope": - api_key = os.getenv("DASHSCOPE_API_KEY") - if not api_key: - print("Warning: DASHSCOPE_API_KEY not set.") - - return DashScope( - id=model_id, - base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", - api_key=api_key, - **kwargs - ) - elif model_provider == 'openrouter': - api_key = os.getenv("OPENROUTER_API_KEY") - if not api_key: - print('Warning: OPENROUTER_API_KEY not set.') - - return OpenRouter( - id=model_id, - api_key=api_key, - **kwargs - ) - - elif model_provider == 'zai': - api_key = os.getenv("ZAI_KEY_API") - if not api_key: - print('Warning: ZAI_KEY_API not set.') - - # role_map to ensure compatibility. - default_role_map = { - "system": "system", - "user": "user", - "assistant": "assistant", - "tool": "tool", - "model": "assistant", - } - - # Allow callers to override role_map via kwargs, otherwise use default - role_map = kwargs.pop("role_map", default_role_map) - - return OpenAIChat( - id=model_id, - base_url="https://api.z.ai/api/paas/v4", - api_key=api_key, - timeout=60, - role_map=role_map, - extra_body={"enable_thinking": False}, # TODO: one more setting for thinking - **kwargs - ) - - elif model_provider == 'ust': - api_key = os.getenv("UST_KEY_API") - if not api_key: - print('Warning: UST_KEY_API not set.') - - # Some UST-compatible endpoints expect the standard OpenAI role names - # (e.g. "system", "user", "assistant") rather than Agno's default - # mapping which maps "system" -> "developer". Provide an explicit - # role_map to ensure compatibility. - default_role_map = { - "system": "system", - "user": "user", - "assistant": "assistant", - "tool": "tool", - "model": "assistant", - } - - # Allow callers to override role_map via kwargs, otherwise use default - role_map = kwargs.pop("role_map", default_role_map) - - return OpenAIChat( - id=model_id, - api_key=api_key, - base_url=os.getenv("UST_URL"), - role_map=role_map, - extra_body={"enable_thinking": False}, # TODO: one more setting for thinking - **kwargs - ) - - else: - raise ValueError(f"Unknown model provider: {model_provider}") - diff --git a/skills/alphaear-reporter/scripts/utils/llm/router.py b/skills/alphaear-reporter/scripts/utils/llm/router.py deleted file mode 100644 index 8c69958..0000000 --- a/skills/alphaear-reporter/scripts/utils/llm/router.py +++ /dev/null @@ -1,80 +0,0 @@ -import os -from typing import Optional, List, Dict, Any, Union -from agno.models.base import Model -from loguru import logger -from ..llm.factory import get_model -from ..llm.capability import ModelCapabilityRegistry -from skills._env_loader import load_unified_env - -load_unified_env() - - -class ModelRouter: - """ - 模型路由管理器 - - 功能: - 1. 管理“推理/写作模型” (Reasoning Model) 和“工具调用模型” (Tool Model)。 - 2. 根据任务需求自动选择合适的模型。 - """ - - def __init__(self): - # 默认从环境变量读取 - self.reasoning_provider = os.getenv( - "REASONING_MODEL_PROVIDER", os.getenv("LLM_PROVIDER", "openai") - ) - self.reasoning_id = os.getenv( - "REASONING_MODEL_ID", os.getenv("LLM_MODEL", "gpt-4o") - ) - self.reasoning_host = os.getenv("REASONING_MODEL_HOST", os.getenv("LLM_HOST")) - - self.tool_provider = os.getenv("TOOL_MODEL_PROVIDER", self.reasoning_provider) - self.tool_id = os.getenv("TOOL_MODEL_ID", self.reasoning_id) - self.tool_host = os.getenv("TOOL_MODEL_HOST", self.reasoning_host) - - self._reasoning_model = None - self._tool_model = None - - logger.info( - f"🤖 ModelRouter initialized: Reasoning={self.reasoning_id} ({self.reasoning_host or 'default'}), Tool={self.tool_id} ({self.tool_host or 'default'})" - ) - - def get_reasoning_model(self, **kwargs) -> Model: - if not self._reasoning_model: - # 优先使用路由配置的 host - if self.reasoning_host and "host" not in kwargs: - kwargs["host"] = self.reasoning_host - self._reasoning_model = get_model( - self.reasoning_provider, self.reasoning_id, **kwargs - ) - return self._reasoning_model - - def get_tool_model(self, **kwargs) -> Model: - if not self._tool_model: - # 优先使用路由配置的 host - if self.tool_host and "host" not in kwargs: - kwargs["host"] = self.tool_host - - # 检查 tool_model 是否真的支持 tool call - caps = ModelCapabilityRegistry.get_capabilities( - self.tool_provider, self.tool_id, **kwargs - ) - if not caps["supports_tool_call"]: - logger.warning( - f"⚠️ Configured tool model {self.tool_id} might not support native tool calls! Consider using ReAct mode or a different model." - ) - - self._tool_model = get_model(self.tool_provider, self.tool_id, **kwargs) - return self._tool_model - - def get_model_for_agent(self, has_tools: bool = False, **kwargs) -> Model: - """ - 根据 Agent 是否包含工具来返回合适的模型。 - """ - if has_tools: - return self.get_tool_model(**kwargs) - return self.get_reasoning_model(**kwargs) - - -# 全局单例 -router = ModelRouter() diff --git a/skills/alphaear-reporter/scripts/utils/logging_setup.py b/skills/alphaear-reporter/scripts/utils/logging_setup.py deleted file mode 100644 index 9a2ca62..0000000 --- a/skills/alphaear-reporter/scripts/utils/logging_setup.py +++ /dev/null @@ -1,45 +0,0 @@ -import os -import sys -from datetime import datetime -from typing import Optional - -from loguru import logger - - -def setup_file_logging( - run_id: str, - log_dir: str = "logs", - level: str = "INFO", - retention: str = "10 days", - rotation: str = "20 MB", -) -> str: - """Configure Loguru to log to stderr + a per-run file. - - Returns the log file path. - """ - os.makedirs(log_dir, exist_ok=True) - - # Remove default handler to avoid duplicate logs. - logger.remove() - - # Console - logger.add(sys.stderr, level=level, backtrace=False, diagnose=False) - - # File (safe for multi-thread via enqueue) - log_path = os.path.join(log_dir, f"signalflux_{run_id}.log") - logger.add( - log_path, - level=level, - rotation=rotation, - retention=retention, - enqueue=True, - backtrace=True, - diagnose=False, - encoding="utf-8", - ) - return log_path - - -def make_run_id(prefix: Optional[str] = None) -> str: - ts = datetime.now().strftime("%Y%m%d_%H%M%S") - return f"{prefix}_{ts}" if prefix else ts diff --git a/skills/alphaear-reporter/scripts/utils/news_tools.py b/skills/alphaear-reporter/scripts/utils/news_tools.py deleted file mode 100644 index e833e2e..0000000 --- a/skills/alphaear-reporter/scripts/utils/news_tools.py +++ /dev/null @@ -1,256 +0,0 @@ -import requests -from requests.exceptions import RequestException, Timeout -import json -import time -from datetime import datetime -from typing import List, Dict, Optional -from loguru import logger -from .database_manager import DatabaseManager -from .content_extractor import ContentExtractor - -class NewsNowTools: - """热点新闻获取工具 - 接入 NewsNow API 与 Jina 内容提取""" - - BASE_URL = "https://newsnow.busiyi.world" - SOURCES = { - # 金融类 - "cls": "财联社", - "wallstreetcn": "华尔街见闻", - "xueqiu": "雪球热榜", - # 综合/社交 - "weibo": "微博热搜", - "zhihu": "知乎热榜", - "baidu": "百度热搜", - "toutiao": "今日头条", - "douyin": "抖音热榜", - "thepaper": "澎湃新闻", - # 科技类 - "36kr": "36氪", - "ithome": "IT之家", - "v2ex": "V2EX", - "juejin": "掘金", - "hackernews": "Hacker News", - } - - - def __init__(self, db: DatabaseManager): - self.db = db - self.user_agent = ( - "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " - "AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36" - ) - self.extractor = ContentExtractor() - # Simple in-memory cache: source_id -> {"time": timestamp, "data": []} - self._cache = {} - - def fetch_hot_news(self, source_id: str, count: int = 15, fetch_content: bool = False) -> List[Dict]: - """ - 从指定新闻源获取热点新闻列表(支持5分钟缓存)。 - """ - # 1. Check cache validity (5 minutes) - cache_key = f"{source_id}_{count}" - cached = self._cache.get(cache_key) - now = time.time() - - if cached and (now - cached["time"] < 300): - logger.info(f"⚡ Using cached news for {source_id} (Age: {int(now - cached['time'])}s)") - return cached["data"] - - try: - url = f"{self.BASE_URL}/api/s?id={source_id}" - response = requests.get(url, headers={"User-Agent": self.user_agent}, timeout=30) - if response.status_code == 200: - data = response.json() - items = data.get("items", [])[:count] - processed_items = [] - for i, item in enumerate(items, 1): - item_url = item.get("url", "") - content = "" - if fetch_content and item_url: - content = self.extractor.extract_with_jina(item_url) or "" - - processed_items.append({ - "id": item.get("id") or f"{source_id}_{int(time.time())}_{i}", - "source": source_id, - "rank": i, - "title": item.get("title", ""), - "url": item_url, - "content": content, - "publish_time": item.get("publish_time"), - "meta_data": item.get("extra", {}) - }) - - # Update Cache - self._cache[cache_key] = {"time": now, "data": processed_items} - logger.info(f"✅ Fetched and cached news for {source_id}") - - self.db.save_daily_news(processed_items) - return processed_items - else: - logger.error(f"NewsNow API Error: {response.status_code}") - # Fallback to stale cache if available - if cached: - logger.warning(f"⚠️ API failed, using stale cache for {source_id}") - return cached["data"] - return [] - except Timeout: - logger.error(f"Timeout fetching hot news from {source_id}") - if cached: - logger.warning(f"⚠️ Timeout, using stale cache for {source_id}") - return cached["data"] - return [] - except RequestException as e: - logger.error(f"Network error fetching hot news from {source_id}: {e}") - if cached: - logger.warning(f"⚠️ Network check failed, using stale cache for {source_id}") - return cached["data"] - return [] - except json.JSONDecodeError: - logger.error(f"Failed to parse JSON response from NewsNow for {source_id}") - return [] - except Exception as e: - logger.error(f"Unexpected error fetching hot news from {source_id}: {e}") - return [] - - def fetch_news_content(self, url: str) -> Optional[str]: - """ - 使用 Jina Reader 抓取指定 URL 的网页正文内容。 - - Args: - url: 需要抓取内容的完整网页 URL,必须以 http:// 或 https:// 开头。 - - Returns: - 提取的网页正文内容 (Markdown 格式),如果失败则返回 None。 - """ - return self.extractor.extract_with_jina(url) - - def get_unified_trends(self, sources: Optional[List[str]] = None) -> str: - """ - 获取多平台综合热点报告,自动聚合多个新闻源的热门内容。 - - Args: - sources: 要扫描的新闻源列表。可选值按类别: - **金融类**: "cls", "wallstreetcn", "xueqiu" - **综合类**: "weibo", "zhihu", "baidu", "toutiao", "douyin", "thepaper" - **科技类**: "36kr", "ithome", "v2ex", "juejin", "hackernews" - - Returns: - 格式化的 Markdown 热点汇总报告,包含各平台 Top 10 热点标题和链接。 - """ - sources = sources or ["weibo", "zhihu", "wallstreetcn"] - all_news = [] - for src in sources: - all_news.extend(self.fetch_hot_news(src)) - time.sleep(0.2) - - if not all_news: - return "❌ 未能获取到热点数据" - - report = f"# 实时全网热点汇总 ({datetime.now().strftime('%Y-%m-%d %H:%M')})\n\n" - for src in sources: - - src_name = self.SOURCES.get(src, src) - report += f"### 🔥 {src_name}\n" - src_news = [n for n in all_news if n['source'] == src] - for n in src_news[:10]: - report += f"- {n['title']} ([链接]({n['url']}))\n" - report += "\n" - - return report - - -class PolymarketTools: - """Polymarket 预测市场数据工具 - 获取热门预测市场反映公众情绪和预期""" - - BASE_URL = "https://gamma-api.polymarket.com" - - def __init__(self, db: DatabaseManager): - self.db = db - self.user_agent = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36" - - def get_active_markets(self, limit: int = 20) -> List[Dict]: - """ - 获取活跃的预测市场,用于分析公众情绪和预期。 - - 预测市场数据可以反映: - - 公众对重大事件的预期概率 - - 市场情绪和风险偏好 - - 热门话题的关注度 - - Args: - limit: 获取的市场数量,默认 20 个。 - - Returns: - 包含预测市场信息的列表,每个市场包含: - - question: 预测问题 - - outcomes: 可能的结果 - - outcomePrices: 各结果的概率价格 - - volume: 交易量 - """ - try: - response = requests.get( - f"{self.BASE_URL}/markets", - params={"active": "true", "closed": "false", "limit": limit}, - headers={"User-Agent": self.user_agent, "Accept": "application/json"}, - timeout=30 - ) - - if response.status_code == 200: - markets = response.json() - result = [] - for m in markets: - result.append({ - "id": m.get("id"), - "question": m.get("question"), - "slug": m.get("slug"), - "outcomes": m.get("outcomes"), - "outcomePrices": m.get("outcomePrices"), - "volume": m.get("volume"), - "liquidity": m.get("liquidity"), - }) - logger.info(f"✅ 获取 {len(result)} 个预测市场") - return result - else: - logger.warning(f"⚠️ Polymarket API 返回 {response.status_code}") - return [] - except Timeout: - logger.error("Timeout fetching Polymarket markets") - return [] - except RequestException as e: - logger.error(f"Network error fetching Polymarket markets: {e}") - return [] - except json.JSONDecodeError: - logger.error("Failed to parse JSON response from Polymarket") - return [] - except Exception as e: - logger.error(f"Unexpected error fetching Polymarket markets: {e}") - return [] - - def get_market_summary(self, limit: int = 10) -> str: - """ - 获取预测市场摘要报告,用于了解当前热门话题和公众预期。 - - Args: - limit: 获取的市场数量 - - Returns: - 格式化的预测市场报告 - """ - markets = self.get_active_markets(limit) - if not markets: - return "❌ 无法获取 Polymarket 数据" - - report = f"# 🔮 Polymarket 热门预测 ({datetime.now().strftime('%Y-%m-%d %H:%M')})\n\n" - for i, m in enumerate(markets, 1): - question = m.get("question", "Unknown") - prices = m.get("outcomePrices", []) - volume = m.get("volume", 0) - - report += f"**{i}. {question}**\n" - if prices: - report += f" 概率: {prices}\n" - if volume: - report += f" 交易量: ${float(volume):,.0f}\n" - report += "\n" - - return report diff --git a/skills/alphaear-reporter/scripts/utils/predictor/evaluation.py b/skills/alphaear-reporter/scripts/utils/predictor/evaluation.py deleted file mode 100644 index 26c5df7..0000000 --- a/skills/alphaear-reporter/scripts/utils/predictor/evaluation.py +++ /dev/null @@ -1,137 +0,0 @@ -import os -import sys -import torch -import pandas as pd -import numpy as np -import glob -from loguru import logger -from datetime import datetime, timedelta - -# Setup paths -KRONOS_DIR = os.path.dirname(os.path.abspath(__file__)) -SRC_DIR = os.path.dirname(os.path.dirname(KRONOS_DIR)) -if SRC_DIR not in sys.path: - sys.path.insert(0, SRC_DIR) - -from ..kronos.auto_synthesis_training import AutoSynthesisTrainer -from ..kronos.model import KronosPredictor -from ..visualizer import VisualizerTools -from ..schema.models import ForecastResult, KLinePoint - -class NewsModelEvaluator: - def __init__(self, model_path=None): - self.trainer = AutoSynthesisTrainer() - self.device = self.trainer.device - - if model_path is None: - # Try to find the latest model in exports/models - model_files = glob.glob(os.path.join(SRC_DIR, "exports/models/*.pt")) - if not model_files: - logger.warning("⚠️ No trained models found in exports/models/. Using base model (zero-init proj).") - else: - model_path = max(model_files, key=os.path.getctime) - - if model_path: - self.load_weights(model_path) - - def load_weights(self, path): - logger.info(f"🔄 Loading model weights from {path}...") - checkpoint = torch.load(path, map_location=self.device) - self.trainer.model.news_proj.load_state_dict(checkpoint['news_proj_state_dict']) - logger.success("✅ News projection layer loaded.") - - def evaluate_range(self, start_idx=100, end_idx=200, pred_len=5): - # 1. Fetch Tickers - res = self.trainer.db.execute_query("SELECT code FROM stock_list") - all_tickers = [row['code'] for row in res] - test_tickers = all_tickers[start_idx:end_idx] - - if not test_tickers: - logger.error(f"No tickers found in range {start_idx}-{end_idx}") - return - - logger.info(f"🚀 Evaluating News Model on stocks {start_idx} to {end_idx}...") - - # 2. Discover Shocks - shocks = self.trainer.discover_shocks(test_tickers, pred_len=pred_len) - - # 3. Associate News & Predict - self.trainer.model.eval() - predictor = KronosPredictor(self.trainer.model, self.trainer.tokenizer, device=self.device) - - save_dir = os.path.join(SRC_DIR, "exports/evaluation_results") - os.makedirs(save_dir, exist_ok=True) - - count = 0 - for shock in shocks: - summary = self.trainer.find_reason_and_verify(shock) - if not summary: - continue - - logger.info(f"📈 Testing shock: {shock['ticker']} on {shock['date']}") - - # Embedding news - news_emb = self.trainer.embedder.encode(summary) - - # Prediction - h = shock['history'] - t = shock['target'] - actuals = t['close'].values[:pred_len] - - x_ts = pd.to_datetime(h['date']) - future_dates = pd.date_range(start=x_ts.iloc[-1] + timedelta(days=1), periods=pred_len, freq='B') - y_ts = pd.Series(future_dates) - - # A. Base Prediction (No news) - p_base = predictor.predict(h, x_ts, y_ts, pred_len=pred_len, news_emb=None, verbose=False) - - # B. News-Aware Prediction - p_news = predictor.predict(h, x_ts, y_ts, pred_len=pred_len, news_emb=news_emb, verbose=False) - - # Calculate Improvement - b_preds = p_base['close'].values[:len(actuals)] - n_preds = p_news['close'].values[:len(actuals)] - b_mae = np.mean(np.abs(b_preds - actuals)) - n_mae = np.mean(np.abs(n_preds - actuals)) - improvement = (b_mae - n_mae) / (b_mae + 1e-6) * 100 - - # C. Visualize - try: - def to_kp_list(preds_df): - points = [] - for idx, row in preds_df.iterrows(): - points.append(KLinePoint( - date=str(idx)[:10], open=row['open'], high=row['high'], - low=row['low'], close=row['close'], volume=row.get('volume', 0) - )) - return points - - forecast_obj = ForecastResult( - ticker=shock['ticker'], - base_forecast=to_kp_list(p_base), - adjusted_forecast=to_kp_list(p_news), - rationale=summary - ) - - chart = VisualizerTools.generate_stock_chart( - df=h, ticker=shock['ticker'], - title=f"Test Eval: {shock['ticker']} ({shock['date']}) Imp: {improvement:.1f}%", - forecast=forecast_obj, - ground_truth=t[['date', 'open', 'high', 'low', 'close', 'volume']] - ) - - safe_date = shock['date'].replace("-", "") - filename = f"test_{shock['ticker']}_{safe_date}.html" - VisualizerTools.render_chart_to_file(chart, os.path.join(save_dir, filename)) - - logger.success(f"📊 Result for {shock['ticker']} saved. Base MAE: {b_mae:.4f}, News MAE: {n_mae:.4f}") - count += 1 - except Exception as e: - logger.error(f"Visualization failed: {e}") - - logger.info(f"🏁 Finished evaluation. {count} cases visualized in {save_dir}") - -if __name__ == "__main__": - # If you have a specific model, pass the path here. Otherwise it picks the latest. - evaluator = NewsModelEvaluator() - evaluator.evaluate_range(start_idx=100, end_idx=200, pred_len=1) diff --git a/skills/alphaear-reporter/scripts/utils/predictor/kline_generate.py b/skills/alphaear-reporter/scripts/utils/predictor/kline_generate.py deleted file mode 100644 index 3224c21..0000000 --- a/skills/alphaear-reporter/scripts/utils/predictor/kline_generate.py +++ /dev/null @@ -1,196 +0,0 @@ -# Ref: https://github.com/shiyu-coder/Kronos - -from model import Kronos, KronosTokenizer, KronosPredictor -import pandas as pd -import sqlite3 -import torch -import matplotlib.pyplot as plt -import matplotlib.gridspec as gridspec -from pandas.tseries.offsets import BusinessDay -import numpy as np - -def get_device(): - device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" - print(f"Using device: {device}") - return device - -def load_predictor(): - tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") - model = Kronos.from_pretrained("NeoQuasar/Kronos-base") - device = get_device() - tokenizer = tokenizer.to(device) - model = model.to(device) - return KronosPredictor(model, tokenizer, device=device, max_context=512) - -def load_data(ticker="002111", db_path="AlphaEar/data/signal_flux.db"): - with sqlite3.connect(db_path) as conn: - df = pd.read_sql_query(f"SELECT * FROM stock_prices WHERE ticker = '{ticker}'", conn) - df['date'] = pd.to_datetime(df['date']) - df = df.sort_values('date').reset_index(drop=True) - return df - -def plot_kline_matplotlib(ax, ax_vol, dates, df, label_suffix="", color_up='#ef4444', color_down='#22c55e', alpha=1.0, is_prediction=False): - """ - 绘制 K 线图和成交量 - """ - # X axis mapping to integers for consistent spacing - x = np.arange(len(dates)) - - # K-line data - opens = df['open'].values - closes = df['close'].values - highs = df['high'].values - lows = df['low'].values - volumes = df['volume'].values - - # Width of the candlestick - width = 0.6 - - for i in range(len(x)): - color = color_up if closes[i] >= opens[i] else color_down - linestyle = '--' if is_prediction else '-' - - # Wick - ax.vlines(x[i], lows[i], highs[i], color=color, linewidth=1, alpha=alpha, linestyle=linestyle) - - # Body - rect_bottom = min(opens[i], closes[i]) - rect_height = abs(opens[i] - closes[i]) - if rect_height == 0: rect_height = 0.001 # Visual hair - - ax.add_patch(plt.Rectangle((x[i] - width/2, rect_bottom), width, rect_height, - edgecolor=color, facecolor=color if not is_prediction else 'none', - alpha=alpha, linewidth=1, linestyle=linestyle)) - - # Volume - ax_vol.bar(x[i], volumes[i], color=color, alpha=alpha * 0.5, width=width) - -def render_comparison_chart(history_df, actual_df, pred_df, title): - """ - 渲染组合图:历史 K 线 + 真值 K 线 + 预测 K 线 - """ - # Combine all dates for X axis - all_dates = pd.concat([history_df['date'], actual_df['date'] if actual_df is not None else pred_df.index.to_series()]).unique() - all_dates = sorted(all_dates) - date_to_idx = {date: i for i, date in enumerate(all_dates)} - - fig = plt.figure(figsize=(14, 8), facecolor='white') - gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1], hspace=0.1) - ax_main = fig.add_subplot(gs[0]) - ax_vol = fig.add_subplot(gs[1], sharex=ax_main) - - # 1. Plot History - hist_indices = [date_to_idx[d] for d in history_df['date']] - # We use a custom x for plotting to ensure continuity - plot_kline_matplotlib(ax_main, ax_vol, history_df['date'], history_df, alpha=0.8) - - offset = len(history_df) - - # 2. Plot Actual if exists - if actual_df is not None: - # Shift indices - actual_x = np.arange(len(actual_df)) + offset - # Plotting manually to handle offset - for i in range(len(actual_df)): - idx = actual_x[i] - row = actual_df.iloc[i] - color = '#ef4444' if row['close'] >= row['open'] else '#22c55e' - ax_main.vlines(idx, row['low'], row['high'], color=color, linewidth=1, alpha=0.9) - ax_main.add_patch(plt.Rectangle((idx - 0.3, min(row['open'], row['close'])), 0.6, abs(row['open']-row['close']), - edgecolor=color, facecolor=color, alpha=0.9)) - ax_vol.bar(idx, row['volume'], color=color, alpha=0.4) - - # 3. Plot Prediction - pred_x = np.arange(len(pred_df)) + offset - for i in range(len(pred_df)): - idx = pred_x[i] - row = pred_df.iloc[i] - color = '#ff8c00' # Orange for prediction to distinguish - ax_main.vlines(idx, row['low'], row['high'], color=color, linewidth=1.5, linestyle='--') - ax_main.add_patch(plt.Rectangle((idx - 0.3, min(row['open'], row['close'])), 0.6, abs(row['open']-row['close']), - edgecolor=color, facecolor='none', linewidth=1.5, linestyle='--')) - # Plot secondary prediction line for close - if i == 0: - # Connect to history - ax_main.plot([offset-1, idx], [history_df['close'].iloc[-1], row['close']], color=color, linestyle='--', alpha=0.6) - elif i > 0: - ax_main.plot([idx-1, idx], [pred_df['close'].iloc[i-1], row['close']], color=color, linestyle='--', alpha=0.6) - - # Styling - ax_main.set_title(title, fontsize=14, fontweight='bold') - ax_main.grid(True, linestyle=':', alpha=0.6) - ax_vol.grid(True, linestyle=':', alpha=0.6) - ax_vol.set_ylabel('Volume') - ax_main.set_ylabel('Price') - - # Set X ticks - step = max(1, len(all_dates) // 10) - ax_vol.set_xticks(np.arange(0, len(all_dates), step)) - ax_vol.set_xticklabels([all_dates[i].strftime('%Y-%m-%d') for i in range(0, len(all_dates), step)], rotation=45) - - plt.tight_layout() - plt.show() - plt.close() - -def run_backtest(df, predictor, lookback, pred_len, start_index=0): - total_len = len(df) - history_start = start_index - history_end = start_index + lookback - pred_start = history_end - - available_pred_len = total_len - pred_start - if available_pred_len <= 0: return - actual_pred_len = min(pred_len, available_pred_len) - pred_end = pred_start + actual_pred_len - - x_df = df.iloc[history_start : history_end].copy() - y_true_df = df.iloc[pred_start : pred_end].copy() - y_timestamp = y_true_df['date'] - - print(f"Backtesting: {x_df['date'].iloc[0].date()} to {y_timestamp.iloc[-1].date()}") - - pred_df = predictor.predict( - df=x_df[['open', 'high', 'low', 'close', 'volume']], - x_timestamp=x_df['date'], - y_timestamp=y_timestamp, - pred_len=actual_pred_len, - T=1.0, top_p=0.9, sample_count=1 - ) - - render_comparison_chart(x_df, y_true_df, pred_df, f"Backtest: {TICKER} K-Line Comparison") - -def run_forecast(df, predictor, lookback, pred_len): - if len(df) < lookback: return - x_df = df.iloc[-lookback:].copy() - last_date = x_df['date'].iloc[-1] - future_dates = pd.date_range(start=last_date + BusinessDay(1), periods=pred_len, freq='B') - future_dates = pd.Series(future_dates) - - print(f"Forecasting: Starting from {future_dates.iloc[0].date()}") - - pred_df = predictor.predict( - df=x_df[['open', 'high', 'low', 'close', 'volume']], - x_timestamp=x_df['date'], - y_timestamp=future_dates, - pred_len=pred_len, - T=1.0, top_p=0.9, sample_count=1 - ) - - render_comparison_chart(x_df, None, pred_df, f"Forecast: {TICKER} Future K-Line") - -if __name__ == "__main__": - LOOKBACK = 20 - PRED_LEN = 10 - TICKER = '002111' - - pred_model = load_predictor() - stock_data = load_data(TICKER) - - total_rows = len(stock_data) - backtest_start = max(0, total_rows - LOOKBACK - PRED_LEN - 10) # Leave some space to see trend - - print("\n--- Running Backtest ---") - run_backtest(stock_data, pred_model, LOOKBACK, PRED_LEN, start_index=backtest_start) - - print("\n--- Running Forecast ---") - run_forecast(stock_data, pred_model, LOOKBACK, PRED_LEN) \ No newline at end of file diff --git a/skills/alphaear-reporter/scripts/utils/predictor/model/__init__.py b/skills/alphaear-reporter/scripts/utils/predictor/model/__init__.py deleted file mode 100644 index d10e200..0000000 --- a/skills/alphaear-reporter/scripts/utils/predictor/model/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from .kronos import KronosTokenizer, Kronos, KronosPredictor - -model_dict = { - 'kronos_tokenizer': KronosTokenizer, - 'kronos': Kronos, - 'kronos_predictor': KronosPredictor -} - - -def get_model_class(model_name): - if model_name in model_dict: - return model_dict[model_name] - else: - print(f"Model {model_name} not found in model_dict") - raise NotImplementedError - diff --git a/skills/alphaear-reporter/scripts/utils/predictor/model/kronos.py b/skills/alphaear-reporter/scripts/utils/predictor/model/kronos.py deleted file mode 100644 index cf8bece..0000000 --- a/skills/alphaear-reporter/scripts/utils/predictor/model/kronos.py +++ /dev/null @@ -1,676 +0,0 @@ -import numpy as np -import pandas as pd -import torch -from huggingface_hub import PyTorchModelHubMixin -import sys - -from tqdm import trange - -sys.path.append("../") -from model.module import * - - -class KronosTokenizer(nn.Module, PyTorchModelHubMixin): - """ - KronosTokenizer module for tokenizing input data using a hybrid quantization approach. - - This tokenizer utilizes a combination of encoder and decoder Transformer blocks - along with the Binary Spherical Quantization (BSQuantizer) to compress and decompress input data. - - Args: - d_in (int): Input dimension. - d_model (int): Model dimension. - n_heads (int): Number of attention heads. - ff_dim (int): Feed-forward dimension. - n_enc_layers (int): Number of encoder layers. - n_dec_layers (int): Number of decoder layers. - ffn_dropout_p (float): Dropout probability for feed-forward networks. - attn_dropout_p (float): Dropout probability for attention mechanisms. - resid_dropout_p (float): Dropout probability for residual connections. - s1_bits (int): Number of bits for the pre token in BSQuantizer. - s2_bits (int): Number of bits for the post token in BSQuantizer. - beta (float): Beta parameter for BSQuantizer. - gamma0 (float): Gamma0 parameter for BSQuantizer. - gamma (float): Gamma parameter for BSQuantizer. - zeta (float): Zeta parameter for BSQuantizer. - group_size (int): Group size parameter for BSQuantizer. - - """ - - def __init__(self, d_in, d_model, n_heads, ff_dim, n_enc_layers, n_dec_layers, ffn_dropout_p, attn_dropout_p, resid_dropout_p, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size): - - super().__init__() - self.d_in = d_in - self.d_model = d_model - self.n_heads = n_heads - self.ff_dim = ff_dim - self.enc_layers = n_enc_layers - self.dec_layers = n_dec_layers - self.ffn_dropout_p = ffn_dropout_p - self.attn_dropout_p = attn_dropout_p - self.resid_dropout_p = resid_dropout_p - - self.s1_bits = s1_bits - self.s2_bits = s2_bits - self.codebook_dim = s1_bits + s2_bits # Total dimension of the codebook after quantization - self.embed = nn.Linear(self.d_in, self.d_model) - self.head = nn.Linear(self.d_model, self.d_in) - - # Encoder Transformer Blocks - self.encoder = nn.ModuleList([ - TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p) - for _ in range(self.enc_layers - 1) - ]) - # Decoder Transformer Blocks - self.decoder = nn.ModuleList([ - TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p) - for _ in range(self.dec_layers - 1) - ]) - self.quant_embed = nn.Linear(in_features=self.d_model, out_features=self.codebook_dim) # Linear layer before quantization - self.post_quant_embed_pre = nn.Linear(in_features=self.s1_bits, out_features=self.d_model) # Linear layer after quantization (pre part - s1 bits) - self.post_quant_embed = nn.Linear(in_features=self.codebook_dim, out_features=self.d_model) # Linear layer after quantization (full codebook) - self.tokenizer = BSQuantizer(self.s1_bits, self.s2_bits, beta, gamma0, gamma, zeta, group_size) # BSQuantizer module - - def forward(self, x): - """ - Forward pass of the KronosTokenizer. - - Args: - x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in). - - Returns: - tuple: A tuple containing: - - tuple: (z_pre, z) - Reconstructed outputs from decoder with s1_bits and full codebook respectively, - both of shape (batch_size, seq_len, d_in). - - torch.Tensor: bsq_loss - Loss from the BSQuantizer. - - torch.Tensor: quantized - Quantized representation from BSQuantizer. - - torch.Tensor: z_indices - Indices from the BSQuantizer. - """ - z = self.embed(x) - - for layer in self.encoder: - z = layer(z) - - z = self.quant_embed(z) # (B, T, codebook) - - bsq_loss, quantized, z_indices = self.tokenizer(z) - - quantized_pre = quantized[:, :, :self.s1_bits] # Extract the first part of quantized representation (s1_bits) - z_pre = self.post_quant_embed_pre(quantized_pre) - - z = self.post_quant_embed(quantized) - - # Decoder layers (for pre part - s1 bits) - for layer in self.decoder: - z_pre = layer(z_pre) - z_pre = self.head(z_pre) - - # Decoder layers (for full codebook) - for layer in self.decoder: - z = layer(z) - z = self.head(z) - - return (z_pre, z), bsq_loss, quantized, z_indices - - def indices_to_bits(self, x, half=False): - """ - Converts indices to bit representations and scales them. - - Args: - x (torch.Tensor): Indices tensor. - half (bool, optional): Whether to process only half of the codebook dimension. Defaults to False. - - Returns: - torch.Tensor: Bit representation tensor. - """ - if half: - x1 = x[0] # Assuming x is a tuple of indices if half is True - x2 = x[1] - mask = 2 ** torch.arange(self.codebook_dim//2, device=x1.device, dtype=torch.long) # Create a mask for bit extraction - x1 = (x1.unsqueeze(-1) & mask) != 0 # Extract bits for the first half - x2 = (x2.unsqueeze(-1) & mask) != 0 # Extract bits for the second half - x = torch.cat([x1, x2], dim=-1) # Concatenate the bit representations - else: - mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long) # Create a mask for bit extraction - x = (x.unsqueeze(-1) & mask) != 0 # Extract bits - - x = x.float() * 2 - 1 # Convert boolean to bipolar (-1, 1) - q_scale = 1. / (self.codebook_dim ** 0.5) # Scaling factor - x = x * q_scale - return x - - def encode(self, x, half=False): - """ - Encodes the input data into quantized indices. - - Args: - x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in). - half (bool, optional): Whether to use half quantization in BSQuantizer. Defaults to False. - - Returns: - torch.Tensor: Quantized indices from BSQuantizer. - """ - z = self.embed(x) - for layer in self.encoder: - z = layer(z) - z = self.quant_embed(z) - - bsq_loss, quantized, z_indices = self.tokenizer(z, half=half, collect_metrics=False) - return z_indices - - def decode(self, x, half=False): - """ - Decodes quantized indices back to the input data space. - - Args: - x (torch.Tensor): Quantized indices tensor. - half (bool, optional): Whether the indices were generated with half quantization. Defaults to False. - - Returns: - torch.Tensor: Reconstructed output tensor of shape (batch_size, seq_len, d_in). - """ - quantized = self.indices_to_bits(x, half) - z = self.post_quant_embed(quantized) - for layer in self.decoder: - z = layer(z) - z = self.head(z) - return z - - -class Kronos(nn.Module, PyTorchModelHubMixin): - """ - Kronos Model. - - Args: - s1_bits (int): Number of bits for pre tokens. - s2_bits (int): Number of bits for post tokens. - n_layers (int): Number of Transformer blocks. - d_model (int): Dimension of the model's embeddings and hidden states. - n_heads (int): Number of attention heads in the MultiheadAttention layers. - ff_dim (int): Dimension of the feedforward network in the Transformer blocks. - ffn_dropout_p (float): Dropout probability for the feedforward network. - attn_dropout_p (float): Dropout probability for the attention layers. - resid_dropout_p (float): Dropout probability for residual connections. - token_dropout_p (float): Dropout probability for token embeddings. - learn_te (bool): Whether to use learnable temporal embeddings. - """ - - def __init__(self, s1_bits, s2_bits, n_layers, d_model, n_heads, ff_dim, ffn_dropout_p, attn_dropout_p, resid_dropout_p, token_dropout_p, learn_te, news_dim=None): - super().__init__() - self.s1_bits = s1_bits - self.s2_bits = s2_bits - self.n_layers = n_layers - self.d_model = d_model - self.n_heads = n_heads - self.learn_te = learn_te - self.ff_dim = ff_dim - self.ffn_dropout_p = ffn_dropout_p - self.attn_dropout_p = attn_dropout_p - self.resid_dropout_p = resid_dropout_p - self.token_dropout_p = token_dropout_p - self.news_dim = news_dim - - self.s1_vocab_size = 2 ** self.s1_bits - self.token_drop = nn.Dropout(self.token_dropout_p) - self.embedding = HierarchicalEmbedding(self.s1_bits, self.s2_bits, self.d_model) - self.time_emb = TemporalEmbedding(self.d_model, self.learn_te) - self.transformer = nn.ModuleList([ - TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p) - for _ in range(self.n_layers) - ]) - self.norm = RMSNorm(self.d_model) - self.dep_layer = DependencyAwareLayer(self.d_model) - self.head = DualHead(self.s1_bits, self.s2_bits, self.d_model) - - if self.news_dim is not None: - self.news_proj = nn.Linear(self.news_dim, self.d_model) - else: - self.news_proj = None - - self.apply(self._init_weights) - - def _init_weights(self, module): - - if isinstance(module, nn.Linear): - nn.init.xavier_normal_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - nn.init.normal_(module.weight, mean=0, std=self.embedding.d_model ** -0.5) - elif isinstance(module, nn.LayerNorm): - nn.init.ones_(module.weight) - nn.init.zeros_(module.bias) - elif isinstance(module, RMSNorm): - nn.init.ones_(module.weight) - - def forward(self, s1_ids, s2_ids, stamp=None, padding_mask=None, use_teacher_forcing=False, s1_targets=None, news_emb=None): - """ - Args: - s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len] - s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len] - stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None. - padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. - use_teacher_forcing (bool, optional): Whether to use teacher forcing for s1 decoding. Defaults to False. - s1_targets (torch.Tensor, optional): Target s1 token IDs for teacher forcing. Shape: [batch_size, seq_len]. Defaults to None. - news_emb (torch.Tensor, optional): News embedding tensor. Shape: [batch_size, news_dim]. Defaults to None. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: - - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size] - - s2_logits: Logits for s2 token predictions, conditioned on s1. Shape: [batch_size, seq_len, s2_vocab_size] - """ - x = self.embedding([s1_ids, s2_ids]) - if stamp is not None: - time_embedding = self.time_emb(stamp) - x = x + time_embedding - x = self.token_drop(x) - - for layer in self.transformer: - x = layer(x, key_padding_mask=padding_mask) - - x = self.norm(x) - - if news_emb is not None and self.news_proj is not None: - news_bias = self.news_proj(news_emb).unsqueeze(1) # [B, 1, d_model] - x = x + news_bias - - s1_logits = self.head(x) - - if use_teacher_forcing: - sibling_embed = self.embedding.emb_s1(s1_targets) - else: - s1_probs = F.softmax(s1_logits.detach(), dim=-1) - sample_s1_ids = torch.multinomial(s1_probs.view(-1, self.s1_vocab_size), 1).view(s1_ids.shape) - sibling_embed = self.embedding.emb_s1(sample_s1_ids) - - x2 = self.dep_layer(x, sibling_embed, key_padding_mask=padding_mask) # Dependency Aware Layer: Condition on s1 embeddings - s2_logits = self.head.cond_forward(x2) - return s1_logits, s2_logits - - def decode_s1(self, s1_ids, s2_ids, stamp=None, padding_mask=None, news_emb=None): - """ - Decodes only the s1 tokens. - - This method performs a forward pass to predict only s1 tokens. It returns the s1 logits - and the context representation from the Transformer, which can be used for subsequent s2 decoding. - - Args: - s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len] - s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len] - stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None. - padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. - news_emb (torch.Tensor, optional): News embedding tensor. Shape: [batch_size, news_dim]. Defaults to None. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: - - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size] - - context: Context representation from the Transformer. Shape: [batch_size, seq_len, d_model] - """ - x = self.embedding([s1_ids, s2_ids]) - if stamp is not None: - time_embedding = self.time_emb(stamp) - x = x + time_embedding - x = self.token_drop(x) - - for layer in self.transformer: - x = layer(x, key_padding_mask=padding_mask) - - x = self.norm(x) - - if news_emb is not None and self.news_proj is not None: - news_bias = self.news_proj(news_emb).unsqueeze(1) # [B, 1, d_model] - x = x + news_bias - - s1_logits = self.head(x) - return s1_logits, x - - def decode_s2(self, context, s1_ids, padding_mask=None): - """ - Decodes the s2 tokens, conditioned on the context and s1 tokens. - - This method decodes s2 tokens based on a pre-computed context representation (typically from `decode_s1`) - and the s1 token IDs. It uses the dependency-aware layer and the conditional s2 head to predict s2 tokens. - - Args: - context (torch.Tensor): Context representation from the transformer (output of decode_s1). - Shape: [batch_size, seq_len, d_model] - s1_ids (torch.torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len] - padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. - - Returns: - torch.Tensor: s2 logits. Shape: [batch_size, seq_len, s2_vocab_size] - """ - sibling_embed = self.embedding.emb_s1(s1_ids) - x2 = self.dep_layer(context, sibling_embed, key_padding_mask=padding_mask) - return self.head.cond_forward(x2) - - -def top_k_top_p_filtering( - logits, - top_k: int = 0, - top_p: float = 1.0, - filter_value: float = -float("Inf"), - min_tokens_to_keep: int = 1, -): - """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering - Args: - logits: logits distribution shape (batch size, vocabulary size) - if top_k > 0: keep only top k tokens with highest probability (top-k filtering). - if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). - Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) - Make sure we keep at least min_tokens_to_keep per batch example in the output - From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 - """ - if top_k > 0: - top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check - # Remove all tokens with a probability less than the last token of the top-k - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits[indices_to_remove] = filter_value - return logits - - if top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold (token with 0 are kept) - sorted_indices_to_remove = cumulative_probs > top_p - if min_tokens_to_keep > 1: - # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) - sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 - # Shift the indices to the right to keep also the first token above the threshold - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - - # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - logits[indices_to_remove] = filter_value - return logits - - -def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_logits=True): - logits = logits / temperature - if top_k is not None or top_p is not None: - if top_k > 0 or top_p < 1.0: - logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) - - probs = F.softmax(logits, dim=-1) - - if not sample_logits: - _, x = top_k(probs, k=1, dim=-1) - else: - x = torch.multinomial(probs, num_samples=1) - - return x - - -def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False, news_emb=None): - with torch.no_grad(): - x = torch.clip(x, -clip, clip) - - device = x.device - x = x.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x.size(1), x.size(2)).to(device) - x_stamp = x_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x_stamp.size(1), x_stamp.size(2)).to(device) - y_stamp = y_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, y_stamp.size(1), y_stamp.size(2)).to(device) - - x_token = tokenizer.encode(x, half=True) - - initial_seq_len = x.size(1) - batch_size = x_token[0].size(0) - total_seq_len = initial_seq_len + pred_len - full_stamp = torch.cat([x_stamp, y_stamp], dim=1) - - generated_pre = x_token[0].new_empty(batch_size, pred_len) - generated_post = x_token[1].new_empty(batch_size, pred_len) - - pre_buffer = x_token[0].new_zeros(batch_size, max_context) - post_buffer = x_token[1].new_zeros(batch_size, max_context) - buffer_len = min(initial_seq_len, max_context) - if buffer_len > 0: - start_idx = max(0, initial_seq_len - max_context) - pre_buffer[:, :buffer_len] = x_token[0][:, start_idx:start_idx + buffer_len] - post_buffer[:, :buffer_len] = x_token[1][:, start_idx:start_idx + buffer_len] - - if verbose: - ran = trange - else: - ran = range - for i in ran(pred_len): - current_seq_len = initial_seq_len + i - window_len = min(current_seq_len, max_context) - - if current_seq_len <= max_context: - input_tokens = [ - pre_buffer[:, :window_len], - post_buffer[:, :window_len] - ] - else: - input_tokens = [pre_buffer, post_buffer] - - context_end = current_seq_len - context_start = max(0, context_end - max_context) - current_stamp = full_stamp[:, context_start:context_end, :].contiguous() - - s1_logits, context = model.decode_s1(input_tokens[0], input_tokens[1], current_stamp, news_emb=news_emb) - s1_logits = s1_logits[:, -1, :] - sample_pre = sample_from_logits(s1_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True) - - s2_logits = model.decode_s2(context, sample_pre) - s2_logits = s2_logits[:, -1, :] - sample_post = sample_from_logits(s2_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True) - - generated_pre[:, i] = sample_pre.squeeze(-1) - generated_post[:, i] = sample_post.squeeze(-1) - - if current_seq_len < max_context: - pre_buffer[:, current_seq_len] = sample_pre.squeeze(-1) - post_buffer[:, current_seq_len] = sample_post.squeeze(-1) - else: - pre_buffer.copy_(torch.roll(pre_buffer, shifts=-1, dims=1)) - post_buffer.copy_(torch.roll(post_buffer, shifts=-1, dims=1)) - pre_buffer[:, -1] = sample_pre.squeeze(-1) - post_buffer[:, -1] = sample_post.squeeze(-1) - - full_pre = torch.cat([x_token[0], generated_pre], dim=1) - full_post = torch.cat([x_token[1], generated_post], dim=1) - - context_start = max(0, total_seq_len - max_context) - input_tokens = [ - full_pre[:, context_start:total_seq_len].contiguous(), - full_post[:, context_start:total_seq_len].contiguous() - ] - z = tokenizer.decode(input_tokens, half=True) - z = z.reshape(-1, sample_count, z.size(1), z.size(2)) - preds = z.cpu().numpy() - preds = np.mean(preds, axis=1) - - return preds - - -def calc_time_stamps(x_timestamp): - time_df = pd.DataFrame() - time_df['minute'] = x_timestamp.dt.minute - time_df['hour'] = x_timestamp.dt.hour - time_df['weekday'] = x_timestamp.dt.weekday - time_df['day'] = x_timestamp.dt.day - time_df['month'] = x_timestamp.dt.month - return time_df - - -class KronosPredictor: - - def __init__(self, model, tokenizer, device="cuda:0", max_context=512, clip=5): - self.tokenizer = tokenizer - self.model = model - self.max_context = max_context - self.clip = clip - self.price_cols = ['open', 'high', 'low', 'close'] - self.vol_col = 'volume' - self.amt_vol = 'amount' - self.time_cols = ['minute', 'hour', 'weekday', 'day', 'month'] - self.device = device - - self.tokenizer = self.tokenizer.to(self.device) - self.model = self.model.to(self.device) - - def generate(self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose, news_emb=None): - - x_tensor = torch.from_numpy(np.array(x).astype(np.float32)).to(self.device) - x_stamp_tensor = torch.from_numpy(np.array(x_stamp).astype(np.float32)).to(self.device) - y_stamp_tensor = torch.from_numpy(np.array(y_stamp).astype(np.float32)).to(self.device) - - preds = auto_regressive_inference(self.tokenizer, self.model, x_tensor, x_stamp_tensor, y_stamp_tensor, self.max_context, pred_len, - self.clip, T, top_k, top_p, sample_count, verbose, news_emb=news_emb) - preds = preds[:, -pred_len:, :] - return preds - - def predict(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True, news_emb=None): - - if not isinstance(df, pd.DataFrame): - raise ValueError("Input must be a pandas DataFrame.") - - if not all(col in df.columns for col in self.price_cols): - raise ValueError(f"Price columns {self.price_cols} not found in DataFrame.") - - df = df.copy() - if self.vol_col not in df.columns: - df[self.vol_col] = 0.0 # Fill missing volume with zeros - df[self.amt_vol] = 0.0 # Fill missing amount with zeros - if self.amt_vol not in df.columns and self.vol_col in df.columns: - df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1) - - if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any(): - raise ValueError("Input DataFrame contains NaN values in price or volume columns.") - - x_time_df = calc_time_stamps(x_timestamp) - y_time_df = calc_time_stamps(y_timestamp) - - x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(np.float32) - x_stamp = x_time_df.values.astype(np.float32) - y_stamp = y_time_df.values.astype(np.float32) - - x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) - - x = (x - x_mean) / (x_std + 1e-5) - x = np.clip(x, -self.clip, self.clip) - - x = x[np.newaxis, :] - x_stamp = x_stamp[np.newaxis, :] - y_stamp = y_stamp[np.newaxis, :] - - if news_emb is not None: - news_emb_tensor = torch.from_numpy(np.array(news_emb).astype(np.float32)).to(self.device) - # Ensure batch dimension for news_emb if only one sample - if news_emb_tensor.ndim == 1: - news_emb_tensor = news_emb_tensor.unsqueeze(0) - else: - news_emb_tensor = None - - preds = self.generate(x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose, news_emb=news_emb_tensor) - - preds = preds.squeeze(0) - preds = preds * (x_std + 1e-5) + x_mean - - pred_df = pd.DataFrame(preds, columns=self.price_cols + [self.vol_col, self.amt_vol], index=y_timestamp) - return pred_df - - - def predict_batch(self, df_list, x_timestamp_list, y_timestamp_list, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True): - """ - Perform parallel (batch) prediction on multiple time series. All series must have the same historical length and prediction length (pred_len). - - Args: - df_list (List[pd.DataFrame]): List of input DataFrames, each containing price columns and optional volume/amount columns. - x_timestamp_list (List[pd.DatetimeIndex or Series]): List of timestamps corresponding to historical data, length should match the number of rows in each DataFrame. - y_timestamp_list (List[pd.DatetimeIndex or Series]): List of future prediction timestamps, length should equal pred_len. - pred_len (int): Number of prediction steps. - T (float): Sampling temperature. - top_k (int): Top-k filtering threshold. - top_p (float): Top-p (nucleus sampling) threshold. - sample_count (int): Number of parallel samples per series, automatically averaged internally. - verbose (bool): Whether to display autoregressive progress. - - Returns: - List[pd.DataFrame]: List of prediction results in the same order as input, each DataFrame contains - `open, high, low, close, volume, amount` columns, indexed by corresponding `y_timestamp`. - """ - # Basic validation - if not isinstance(df_list, (list, tuple)) or not isinstance(x_timestamp_list, (list, tuple)) or not isinstance(y_timestamp_list, (list, tuple)): - raise ValueError("df_list, x_timestamp_list, y_timestamp_list must be list or tuple types.") - if not (len(df_list) == len(x_timestamp_list) == len(y_timestamp_list)): - raise ValueError("df_list, x_timestamp_list, y_timestamp_list must have consistent lengths.") - - num_series = len(df_list) - - x_list = [] - x_stamp_list = [] - y_stamp_list = [] - means = [] - stds = [] - seq_lens = [] - y_lens = [] - - for i in range(num_series): - df = df_list[i] - if not isinstance(df, pd.DataFrame): - raise ValueError(f"Input at index {i} is not a pandas DataFrame.") - if not all(col in df.columns for col in self.price_cols): - raise ValueError(f"DataFrame at index {i} is missing price columns {self.price_cols}.") - - df = df.copy() - if self.vol_col not in df.columns: - df[self.vol_col] = 0.0 - df[self.amt_vol] = 0.0 - if self.amt_vol not in df.columns and self.vol_col in df.columns: - df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1) - - if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any(): - raise ValueError(f"DataFrame at index {i} contains NaN values in price or volume columns.") - - x_timestamp = x_timestamp_list[i] - y_timestamp = y_timestamp_list[i] - - x_time_df = calc_time_stamps(x_timestamp) - y_time_df = calc_time_stamps(y_timestamp) - - x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(np.float32) - x_stamp = x_time_df.values.astype(np.float32) - y_stamp = y_time_df.values.astype(np.float32) - - if x.shape[0] != x_stamp.shape[0]: - raise ValueError(f"Inconsistent lengths at index {i}: x has {x.shape[0]} vs x_stamp has {x_stamp.shape[0]}.") - if y_stamp.shape[0] != pred_len: - raise ValueError(f"y_timestamp length at index {i} should equal pred_len={pred_len}, got {y_stamp.shape[0]}.") - - x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) - x_norm = (x - x_mean) / (x_std + 1e-5) - x_norm = np.clip(x_norm, -self.clip, self.clip) - - x_list.append(x_norm) - x_stamp_list.append(x_stamp) - y_stamp_list.append(y_stamp) - means.append(x_mean) - stds.append(x_std) - - seq_lens.append(x_norm.shape[0]) - y_lens.append(y_stamp.shape[0]) - - # Require all series to have consistent historical and prediction lengths for batch processing - if len(set(seq_lens)) != 1: - raise ValueError(f"Parallel prediction requires all series to have consistent historical lengths, got: {seq_lens}") - if len(set(y_lens)) != 1: - raise ValueError(f"Parallel prediction requires all series to have consistent prediction lengths, got: {y_lens}") - - x_batch = np.stack(x_list, axis=0).astype(np.float32) # (B, seq_len, feat) - x_stamp_batch = np.stack(x_stamp_list, axis=0).astype(np.float32) # (B, seq_len, time_feat) - y_stamp_batch = np.stack(y_stamp_list, axis=0).astype(np.float32) # (B, pred_len, time_feat) - - preds = self.generate(x_batch, x_stamp_batch, y_stamp_batch, pred_len, T, top_k, top_p, sample_count, verbose) - # preds: (B, pred_len, feat) - - pred_dfs = [] - for i in range(num_series): - preds_i = preds[i] * (stds[i] + 1e-5) + means[i] - pred_df = pd.DataFrame(preds_i, columns=self.price_cols + [self.vol_col, self.amt_vol], index=y_timestamp_list[i]) - pred_dfs.append(pred_df) - - return pred_dfs diff --git a/skills/alphaear-reporter/scripts/utils/predictor/model/module.py b/skills/alphaear-reporter/scripts/utils/predictor/model/module.py deleted file mode 100644 index 20b29b5..0000000 --- a/skills/alphaear-reporter/scripts/utils/predictor/model/module.py +++ /dev/null @@ -1,562 +0,0 @@ -import math - -from einops import rearrange, reduce -import torch -import torch.nn as nn -from torch.autograd import Function -import torch.nn.functional as F - - -class DifferentiableEntropyFunction(Function): - @staticmethod - def forward(ctx, zq, basis, K, eps): - zb = (zq + 1) / 2 - zi = ((zb * basis).sum(-1)).to(torch.int64) - cnt = torch.scatter_reduce(torch.zeros(2 ** K, device=zq.device, dtype=zq.dtype), - 0, - zi.flatten(), - torch.ones_like(zi.flatten()).to(zq.dtype), - 'sum') - prob = (cnt + eps) / (cnt + eps).sum() - H = -(prob * torch.log(prob)).sum() - ctx.save_for_backward(zq, zi, prob) - ctx.K = K - return H - - @staticmethod - def backward(ctx, grad_output): - zq, zi, prob = ctx.saved_tensors - grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K - reord_grad = grad_array[zi.flatten()].reshape(zi.shape) - grad_input = reord_grad.unsqueeze(-1) * zq - return grad_input, None, None, None, None - - -def codebook_entropy(zq, basis, K, eps=1e-4): - return DifferentiableEntropyFunction.apply(zq, basis, K, eps) - - -class BinarySphericalQuantizer(nn.Module): - def __init__(self, embed_dim, beta, gamma0, gamma, zeta, - input_format='bchw', - soft_entropy=True, group_size=9, - persample_entropy_compute='analytical', - cb_entropy_compute='group', - l2_norm=True, - inv_temperature=1): - """ - Paper link: https://arxiv.org/pdf/2406.07548.pdf - Here we use the official implementation of the BinarySphericalQuantizer. - """ - super().__init__() - self.embed_dim = embed_dim - self.beta = beta # loss weight for commit loss - self.gamma0 = gamma0 # loss weight for entropy penalty - self.gamma = gamma # loss weight for entropy penalty - self.zeta = zeta # loss weight for entire entropy penalty - self.input_format = input_format - assert self.embed_dim % group_size == 0, "embed_dim must be divisible by group_size" - self.num_groups = self.embed_dim // group_size - self.group_size = group_size - assert persample_entropy_compute in ['group', 'analytical'], "persample_entropy_compute must be either 'group' or 'analytical'" - assert cb_entropy_compute in ['group', 'nce'], "cb_entropy_compute must be either 'group' or 'nce'" - self.persample_entropy_compute = persample_entropy_compute - self.cb_entropy_compute = cb_entropy_compute - self.l2_norm = l2_norm - self.inv_temperature = inv_temperature - - self.register_buffer('basis', 2 ** torch.arange(embed_dim - 1, -1, -1)) - self.register_buffer('group_basis', 2 ** torch.arange(group_size - 1, -1, -1)) - - self.num_dimensions = 2 ** embed_dim - self.bits_per_index = embed_dim - - # we only need to keep the codebook portion up to the group size - # because we approximate the H loss with this subcode - group_codes = torch.arange(2 ** self.group_size) - group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:] - self.register_buffer('group_codebook', group_codebook, persistent=False) - - self.soft_entropy = soft_entropy # soft_entropy: Sec 3.2 of https://arxiv.org/pdf/1911.05894.pdf - - def quantize(self, z): - assert z.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}" - - zhat = torch.where(z > 0, - torch.tensor(1, dtype=z.dtype, device=z.device), - torch.tensor(-1, dtype=z.dtype, device=z.device)) - return z + (zhat - z).detach() - - def forward(self, z, collect_metrics=True): - # if self.input_format == 'bchw': - # z = rearrange(z, 'b c h w -> b h w c') - zq = self.quantize(z) - - q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. - - zq = zq * q_scale - - if not collect_metrics: - return zq, zq.new_zeros(()), {} - - indices = self.codes_to_indexes(zq.detach()) - group_indices = self.codes_to_group_indexes(zq.detach()) - if not self.training: - used_codes = torch.unique(indices, return_counts=False) - else: - used_codes = None - - if self.soft_entropy: - persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(z) - entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy - else: - zb_by_sample = ((zq + 1) / 2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32) - persample_entropy = self.get_hard_per_sample_entropy(zb_by_sample) - cb_entropy = codebook_entropy(zq, self.basis, self.embed_dim) - entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy - - # commit loss - commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1)) - - # if self.input_format == 'bchw': - # zq = rearrange(zq, 'b h w c -> b c h w') - - return ( - zq, - commit_loss + self.zeta * entropy_penalty / self.inv_temperature, - {"H": cb_entropy, "used_codes": used_codes, "indices": indices, "group_indices": group_indices, - "avg_prob": avg_prob} - ) - - def soft_entropy_loss(self, z): - # if we divide the code in subgroups of size group_size, the codebook will be of size 2 ** group_size - # the sub-code is the last group_size bits of the full code - group_code_book = self.group_codebook / (self.embed_dim ** 0.5 if self.l2_norm else 1) - divided_z = rearrange(z, '... (g c) -> ... g c', c=self.group_size) - - # we calculate the distance between the divided_z and the codebook for each subgroup - distance = - 2 * torch.einsum('... g c, d c ->... g d', divided_z, group_code_book) - prob = (-distance * self.inv_temperature).softmax(dim=-1) - if self.persample_entropy_compute == 'analytical': - if self.l2_norm: - p = torch.sigmoid(-4 * z / (self.embed_dim ** 0.5) * self.inv_temperature) - else: - p = torch.sigmoid(-4 * z * self.inv_temperature) - prob = torch.stack([p, 1 - p], dim=-1) - per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() - else: - per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() - - # macro average of the probability of each subgroup - avg_prob = reduce(prob, '... g d ->g d', 'mean') - codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False) - - # the approximation of the entropy is the sum of the entropy of each subgroup - return per_sample_entropy, codebook_entropy.sum(), avg_prob - - def get_hard_per_sample_entropy(self, zb_by_sample): - probs_per_dim = zb_by_sample.sum(1) / zb_by_sample.shape[1] - persample_entropy = - probs_per_dim * torch.log(probs_per_dim + 1e-8) - (1 - probs_per_dim) * torch.log(1 - probs_per_dim + 1e-8) - persample_entropy = persample_entropy.sum(-1) - return persample_entropy.mean() - - def codes_to_indexes(self, zhat): - """Converts a `code` to an index in the codebook. - Args: - zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} - """ - assert zhat.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}" - return ((zhat + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64) - - def codes_to_group_indexes(self, zhat): - """Converts a `code` to a list of indexes (in groups) in the codebook. - Args: - zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} - """ - zhat_in_group = rearrange(zhat, 'b ... (g c) -> b ... g c', c=self.group_size) - return ((zhat_in_group + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64) - - def indexes_to_codes(self, indices): - """Inverse of `indexes_to_codes`.""" - indices = indices.unsqueeze(-1) - codes_non_centered = torch.remainder( - torch.floor_divide(indices, self.basis), 2 - ) - return codes_non_centered * 2 - 1 - - def group_indexes_to_codes(self, group_indices): - """Inverse of `group_indexes_to_codes`.""" - group_indices = group_indices.unsqueeze(-1) - codes_non_centered = torch.remainder( - torch.floor_divide(group_indices, self.group_basis), 2 - ) - codes_non_centered = rearrange(codes_non_centered, 'b ... g c -> b ... (g c)') - return codes_non_centered * 2 - 1 - - def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True): - if normalize: - probs = (count + eps) / (count + eps).sum(dim=dim, keepdim=True) - else: - probs = count - H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim) - return H - - def get_group_codebook_entry(self, group_indices): - z_q = self.group_indexes_to_codes(group_indices) - q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. - z_q = z_q * q_scale - if self.input_format == 'bchw': - h, w = int(z_q.shape[1] ** 0.5) - assert h * w == z_q.shape[1], 'Invalid sequence length' - z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h) - return z_q - - def get_codebook_entry(self, indices): - z_q = self.indexes_to_codes(indices) - q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. - z_q = z_q * q_scale - if self.input_format == 'bchw': - h, w = int(z_q.shape[1] ** 0.5) - assert h * w == z_q.shape[1], 'Invalid sequence length' - z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h) - return z_q - - -class BSQuantizer(nn.Module): - - def __init__(self, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size): - super().__init__() - self.codebook_dim = s1_bits + s2_bits - self.s1_bits = s1_bits - self.s2_bits = s2_bits - self.bsq = BinarySphericalQuantizer(self.codebook_dim, beta, gamma0, gamma, zeta, group_size=group_size) - - def bits_to_indices(self, bits): - bits = (bits >= 0).to(torch.long) - indices = 2 ** torch.arange( - 0, - bits.shape[-1], - 1, - dtype=torch.long, - device=bits.device, - ) - return (bits * indices).sum(-1) - - def forward(self, z, half=False, collect_metrics=True): - z = F.normalize(z, dim=-1) - quantized, bsq_loss, metrics = self.bsq(z, collect_metrics=collect_metrics) - if half: - q_pre = quantized[:, :, :self.s1_bits] - q_post = quantized[:, :, self.s1_bits:] - z_indices = [self.bits_to_indices(q_pre), self.bits_to_indices(q_post)] - else: - z_indices = self.bits_to_indices(quantized) - return bsq_loss, quantized, z_indices - - -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-5): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()).type_as(x) - return output * self.weight - - -class FeedForward(nn.Module): - def __init__(self, d_model, ff_dim, ffn_dropout_p=0.0): - super().__init__() - - self.w1 = nn.Linear(d_model, ff_dim, bias=False) - self.w3 = nn.Linear(d_model, ff_dim, bias=False) - self.w2 = nn.Linear(ff_dim, d_model, bias=False) - self.ffn_dropout = nn.Dropout(ffn_dropout_p) - - def forward(self, x): - return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) - - -class RotaryPositionalEmbedding(nn.Module): - def __init__(self, dim): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer("inv_freq", inv_freq) - self.seq_len_cached = None - self.cos_cached = None - self.sin_cached = None - - def _update_cos_sin_cache(self, x, seq_len): - if seq_len != self.seq_len_cached: - self.seq_len_cached = seq_len - t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) - freqs = torch.einsum('i,j->ij', t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.cos_cached = emb.cos()[None, None, :, :] - self.sin_cached = emb.sin()[None, None, :, :] - return self.cos_cached, self.sin_cached - - def forward(self, q, k): - cos, sin = self._update_cos_sin_cache(q, q.shape[-2]) - return ( - (q * cos) + (self._rotate_half(q) * sin), - (k * cos) + (self._rotate_half(k) * sin), - ) - - def _rotate_half(self, x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - -class MultiHeadAttentionWithRoPE(nn.Module): - def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout_p=0.0): - super().__init__() - self.d_model = d_model - self.n_heads = n_heads - self.head_dim = d_model // n_heads - - self.q_proj = nn.Linear(d_model, d_model) - self.k_proj = nn.Linear(d_model, d_model) - self.v_proj = nn.Linear(d_model, d_model) - self.out_proj = nn.Linear(d_model, d_model) - self.rotary = RotaryPositionalEmbedding(self.head_dim) - self.attn_dropout_p = attn_dropout_p - self.resid_dropout = nn.Dropout(resid_dropout_p) - - def forward(self, x, key_padding_mask=None): - batch_size, seq_len, _ = x.shape - - q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) - k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) - v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) - - q, k = self.rotary(q, k) - - if key_padding_mask is not None: - attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq_len] - attn_mask = attn_mask.expand(-1, self.n_heads, seq_len, -1) # [batch, n_heads, q_len, k_len] - else: - attn_mask = None - - attn_output = F.scaled_dot_product_attention( - q, k, v, - attn_mask=attn_mask, - dropout_p=self.attn_dropout_p if self.training else 0.0, - is_causal=True - ) - - attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) - return self.resid_dropout(self.out_proj(attn_output)) - - -class MultiHeadCrossAttentionWithRoPE(nn.Module): - def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout=0.0): - super().__init__() - self.d_model = d_model - self.n_heads = n_heads - self.head_dim = d_model // n_heads - - self.q_proj = nn.Linear(d_model, d_model) - self.k_proj = nn.Linear(d_model, d_model) - self.v_proj = nn.Linear(d_model, d_model) - self.out_proj = nn.Linear(d_model, d_model) - self.rotary = RotaryPositionalEmbedding(self.head_dim) - self.attn_dropout_p = attn_dropout_p - self.resid_dropout = nn.Dropout(resid_dropout) - - def forward(self, query, key, value, key_padding_mask=None): - batch_size, q_len, _ = query.shape - _, seq_len, _ = key.shape - - q = self.q_proj(query).view(batch_size, q_len, self.n_heads, self.head_dim).transpose(1, 2) - k = self.k_proj(key).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) - v = self.v_proj(value).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) - - q, k = self.rotary(q, k) - - if key_padding_mask is not None: - attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) - attn_mask = attn_mask.expand(-1, self.n_heads, q_len, -1) - else: - attn_mask = None - - is_causal_flag = self.training - - attn_output = F.scaled_dot_product_attention( - q, k, v, - attn_mask=attn_mask, - dropout_p=self.attn_dropout_p if self.training else 0.0, - is_causal=is_causal_flag - ) - - attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.d_model) - return self.resid_dropout(self.out_proj(attn_output)) - - -class HierarchicalEmbedding(nn.Module): - def __init__(self, s1_bits, s2_bits, d_model=256): - super().__init__() - self.s1_bits = s1_bits - self.s2_bits = s2_bits - - vocab_s1 = 2 ** s1_bits - vocab_s2 = 2 ** s2_bits - - self.emb_s1 = nn.Embedding(vocab_s1, d_model) - self.emb_s2 = nn.Embedding(vocab_s2, d_model) - self.d_model = d_model - self.fusion_proj = nn.Linear(d_model * 2, d_model) - - nn.init.normal_(self.emb_s1.weight, mean=0, std=d_model ** -0.5) - nn.init.normal_(self.emb_s2.weight, mean=0, std=d_model ** -0.5) - - def split_token(self, token_ids: torch.Tensor, s2_bits: int): - """Inputs: - token_ids (torch.Tensor): Composite token IDs of shape [batch_size, seq_len] or [N], each in range [0, 2^(s1_bits + s2_bits) - 1]. - s2_bits (int): Number of low bits used for the fine token (s2). - """ - assert isinstance(s2_bits, int) and s2_bits > 0, "s2_bits must be a positive integer" - - t = token_ids.long() - mask = (1 << s2_bits) - 1 - s2_ids = t & mask # extract low bits - s1_ids = t >> s2_bits # extract high bits - return s1_ids, s2_ids - - def forward(self, token_ids): - """Inputs: - token_ids: - - tuple or list: (s1_ids, s2_ids), each of shape [batch_size, seq_len], or - - torch.Tensor: composite token IDs of shape [batch_size, seq_len], which will be split into (s1_ids, s2_ids) internally. - Output: [batch_size, seq_len, d_model] - """ - if isinstance(token_ids, tuple) or isinstance(token_ids, list): - s1_ids, s2_ids = token_ids - else: - s1_ids, s2_ids = self.split_token(token_ids, self.s2_bits) - s1_emb = self.emb_s1(s1_ids) * math.sqrt(self.d_model) - s2_emb = self.emb_s2(s2_ids) * math.sqrt(self.d_model) - return self.fusion_proj(torch.cat([s1_emb, s2_emb], dim=-1)) - - -class DependencyAwareLayer(nn.Module): - def __init__(self, d_model, n_heads=4, attn_dropout_p=0.0, resid_dropout=0.0): - super().__init__() - self.cross_attn = MultiHeadCrossAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout) - self.norm = RMSNorm(d_model) - - def forward(self, hidden_states, sibling_embed, key_padding_mask=None): - """hidden_states: [batch, seq_len, d_model] - sibling_embed: Embedding from another subtoken - """ - attn_out = self.cross_attn( - query=sibling_embed, - key=hidden_states, - value=hidden_states, - key_padding_mask=key_padding_mask - ) - return self.norm(hidden_states + attn_out) - - -class TransformerBlock(nn.Module): - def __init__(self, d_model, n_heads, ff_dim=1024, ffn_dropout_p=0.0, attn_dropout_p=0.0, resid_dropout_p=0.0): - super().__init__() - self.norm1 = RMSNorm(d_model) - self.self_attn = MultiHeadAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout_p) - self.norm2 = RMSNorm(d_model) - self.ffn = FeedForward(d_model, ff_dim, ffn_dropout_p) - - def forward(self, x, key_padding_mask=None): - residual = x - x = self.norm1(x) - attn_out = self.self_attn(x, key_padding_mask=key_padding_mask) - x = residual + attn_out - - residual = x - x = self.norm2(x) - ffn_out = self.ffn(x) - x = residual + ffn_out - return x - - -class DualHead(nn.Module): - def __init__(self, s1_bits, s2_bits, d_model): - super().__init__() - self.vocab_s1 = 2 ** s1_bits - self.vocab_s2 = 2 ** s2_bits - self.proj_s1 = nn.Linear(d_model, self.vocab_s1) - self.proj_s2 = nn.Linear(d_model, self.vocab_s2) - - def compute_loss(self, s1_logits, s2_logits, s1_targets, s2_targets, padding_mask=None): - if padding_mask is not None: - valid_mask = (padding_mask == 0) - s1_logits = s1_logits[valid_mask] - s2_logits = s2_logits[valid_mask] - s1_targets = s1_targets[valid_mask] - s2_targets = s2_targets[valid_mask] - ce_s1 = F.cross_entropy(s1_logits, s1_targets) - ce_s2 = F.cross_entropy(s2_logits, s2_targets) - else: - ce_s1 = F.cross_entropy(s1_logits.reshape(-1, self.vocab_s1), s1_targets.reshape(-1)) - ce_s2 = F.cross_entropy(s2_logits.reshape(-1, self.vocab_s2), s2_targets.reshape(-1)) - ce_loss = (ce_s1 + ce_s2) / 2 - return ce_loss, ce_s1, ce_s2 - - def forward(self, x): - return self.proj_s1(x) - - def cond_forward(self, x2): - return self.proj_s2(x2) - - -class FixedEmbedding(nn.Module): - def __init__(self, c_in, d_model): - super(FixedEmbedding, self).__init__() - - w = torch.zeros(c_in, d_model).float() - w.require_grad = False - - position = torch.arange(0, c_in).float().unsqueeze(1) - div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() - - w[:, 0::2] = torch.sin(position * div_term) - w[:, 1::2] = torch.cos(position * div_term) - - self.emb = nn.Embedding(c_in, d_model) - self.emb.weight = nn.Parameter(w, requires_grad=False) - - def forward(self, x): - return self.emb(x).detach() - - -class TemporalEmbedding(nn.Module): - def __init__(self, d_model, learn_pe): - super(TemporalEmbedding, self).__init__() - - minute_size = 60 - hour_size = 24 - weekday_size = 7 - day_size = 32 - month_size = 13 - - Embed = FixedEmbedding if not learn_pe else nn.Embedding - self.minute_embed = Embed(minute_size, d_model) - self.hour_embed = Embed(hour_size, d_model) - self.weekday_embed = Embed(weekday_size, d_model) - self.day_embed = Embed(day_size, d_model) - self.month_embed = Embed(month_size, d_model) - - def forward(self, x): - x = x.long() - - minute_x = self.minute_embed(x[:, :, 0]) - hour_x = self.hour_embed(x[:, :, 1]) - weekday_x = self.weekday_embed(x[:, :, 2]) - day_x = self.day_embed(x[:, :, 3]) - month_x = self.month_embed(x[:, :, 4]) - - return hour_x + weekday_x + day_x + month_x + minute_x \ No newline at end of file diff --git a/skills/alphaear-reporter/scripts/utils/predictor/training.py b/skills/alphaear-reporter/scripts/utils/predictor/training.py deleted file mode 100644 index 3b41724..0000000 --- a/skills/alphaear-reporter/scripts/utils/predictor/training.py +++ /dev/null @@ -1,539 +0,0 @@ -import os -import sys -import time -import torch -import torch.nn as nn -import pandas as pd -import numpy as np -import json -import random -from loguru import logger -from datetime import datetime, timedelta -from sentence_transformers import SentenceTransformer -from skills._env_loader import load_unified_env - -load_unified_env() - -# Setup paths -KRONOS_DIR = os.path.dirname(os.path.abspath(__file__)) -SRC_DIR = os.path.dirname(os.path.dirname(KRONOS_DIR)) -if SRC_DIR not in sys.path: - sys.path.insert(0, SRC_DIR) - -from ..kronos.model import Kronos, KronosTokenizer, KronosPredictor -from ..database_manager import DatabaseManager -from ..stock_tools import StockTools -from ..search_tools import SearchTools -from ..llm.factory import get_model -from ..visualizer import VisualizerTools -from ..schema.models import ForecastResult, KLinePoint -from agno.agent import Agent - - -class AutoSynthesisTrainer: - def __init__(self, news_dim=384): - self.device = ( - "cuda" - if torch.cuda.is_available() - else "mps" - if torch.backends.mps.is_available() - else "cpu" - ) - self.db = DatabaseManager() - self.tools = StockTools(self.db) - self.searcher = SearchTools(self.db) - # Try loading from local cache first to avoid network timeouts - model_name = os.getenv( - "EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2" - ) - try: - logger.info(f"🔄 Attempting to load {model_name} from local cache...") - self.embedder = SentenceTransformer( - model_name, device=self.device, local_files_only=True - ) - logger.success("✅ Model loaded from local cache.") - except Exception: - logger.warning( - "⚠️ Local cache not found or incomplete. Attempting to download..." - ) - self.embedder = SentenceTransformer(model_name, device=self.device) - self.news_dim = news_dim - - # Try loading from local cache first to avoid network timeouts - try: - logger.info( - "🔄 Attempting to load Kronos and Tokenizer from local cache..." - ) - self.tokenizer = KronosTokenizer.from_pretrained( - "NeoQuasar/Kronos-Tokenizer-base", local_files_only=True - ).to(self.device) - base_model = Kronos.from_pretrained( - "NeoQuasar/Kronos-base", local_files_only=True - ) - logger.success("✅ Kronos and Tokenizer loaded from local cache.") - except Exception: - logger.warning( - "⚠️ Local Kronos/Tokenizer not found or incomplete. Attempting to download..." - ) - self.tokenizer = KronosTokenizer.from_pretrained( - "NeoQuasar/Kronos-Tokenizer-base" - ).to(self.device) - base_model = Kronos.from_pretrained("NeoQuasar/Kronos-base") - - self.model = Kronos( - base_model.s1_bits, - base_model.s2_bits, - base_model.n_layers, - base_model.d_model, - base_model.n_heads, - base_model.ff_dim, - base_model.ffn_dropout_p, - base_model.attn_dropout_p, - base_model.resid_dropout_p, - base_model.token_dropout_p, - base_model.learn_te, - news_dim=self.news_dim, - ).to(self.device) - self.model.load_state_dict(base_model.state_dict(), strict=False) - - # LLM for causality verification - provider = os.getenv("LLM_PROVIDER", "minimax") - model_id = os.getenv("LLM_MODEL", "Qwen") - self.llm_agent = Agent(model=get_model(provider, model_id)) - - def discover_shocks( - self, ticker_list, threshold=2.0, limit_per_stock=5, days=365, pred_len=5 - ): - """1. Find days with significant price movements (Look back 1 year)""" - shocks = [] - end_date = datetime.now().strftime("%Y-%m-%d") - start_date = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d") - - for ticker in ticker_list: - df = self.tools.get_stock_price( - ticker, start_date=start_date, end_date=end_date - ) - if df.empty or len(df) < 60: - continue - - # Look for big moves - moves = df[df["change_pct"].abs() > threshold].copy() - if moves.empty: - continue - - count = 0 - for idx, row in moves.iterrows(): - # Ensure we have history before this day AND enough future days for eval - date_idx = df.index.get_loc(idx) - if date_idx < 50 or date_idx + pred_len > len(df): - continue - - shocks.append( - { - "ticker": ticker, - "date": row["date"], - "change": row["change_pct"], - "history": df.iloc[date_idx - 50 : date_idx], - "target": df.iloc[ - date_idx : date_idx + pred_len - ], # Now capturing pred_len days - } - ) - count += 1 - if count >= limit_per_stock: - break - - logger.info( - f"✨ Discovered {len(shocks)} potential price shocks over the last {days} days." - ) - return shocks - - def find_reason_and_verify(self, shock): - """2. Search for reasons and verify causality using LLM""" - ticker_info = self.db.get_stock_by_code(shock["ticker"]) - name = ticker_info["name"] if ticker_info else shock["ticker"] - date_str = shock["date"] - - # Try multiple query variations and engines - queries = [ - f"{name} ({shock['ticker']}) {date_str} 为什么涨跌 原因", - f"{name} {date_str} 异动 原因", - f"{shock['ticker']} {date_str} 新闻", - ] - - search_results = [] - for query in queries: - logger.info(f"🔍 Searching for reason: {query}") - # Try alternate engines - for engine in ["baidu"]: - try: - results = self.searcher.search_list( - query, engine=engine, max_results=3, enrich=False - ) - if results: - search_results = results - break - except Exception as e: - logger.warning(f"Search failed for {query} on {engine}: {e}") - - if search_results: - break - time.sleep(random.uniform(1.0, 2.0)) - - if not search_results: - logger.warning( - f"⚠️ No search results found for {name} on {date_str} after multiple attempts." - ) - return None - - context = "\n".join( - [f"- {r['title']}: {r.get('content', '')[:300]}" for r in search_results] - ) - - prompt = f""" - 任务:判断以下新闻是否解释了该股票在 {date_str} 的 {shock["change"]:.2f}% 价格变动。 - - 股票:{name} - 日期:{date_str} - 变动:{shock["change"]:.2f}% - - 搜索结果: - {context} - - 要求: - 1. 该新闻是否在该日期左右发生? - 2. 该新闻是否能逻辑上解释这种大幅波动(如财报、利好政策、重组、大环境暴跌等)? - 3. 如果是,请总结一段 100 字以内的“核心推动原因”。 - 4. 返回 JSON: {{"is_causal": true/false, "summary": "原因摘要"}} - """ - - try: - res = self.llm_agent.run(prompt) - data = json.loads( - res.content.replace("```json", "").replace("```", "").strip() - ) - if data.get("is_causal"): - logger.success( - f"✅ Verified cause for {name} on {date_str}: {data['summary']}" - ) - return data["summary"] - else: - logger.warning( - f"❌ Verified cause for {name} on {date_str}: {data['summary']}" - ) - return None - except Exception as e: - logger.warning(f"Verification failed: {e}") - return None - - def save_model(self, path=None): - """Save the news_proj weights""" - if path is None: - save_dir = os.path.join(SRC_DIR, "exports/models") - os.makedirs(save_dir, exist_ok=True) - path = os.path.join( - save_dir, f"kronos_news_v1_{datetime.now().strftime('%Y%m%d_%H%M')}.pt" - ) - - # We only really need to save the news_proj part as it's the only one we train - torch.save( - { - "news_proj_state_dict": self.model.news_proj.state_dict(), - "news_dim": self.news_dim, - "d_model": self.model.d_model, - }, - path, - ) - logger.success(f"💾 Model weights saved to {path}") - return path - - def run_synthesis_and_train(self, tickers, pred_len=5): - # 1. Discovery - shocks = self.discover_shocks(tickers, pred_len=pred_len) - print(f"find {len(shocks)} shocks") - - # 2. News Association & Verification - dataset = [] - max_news_items = 200 # Limit to 200 news items per session to avoid search bans - - logger.info( - f"🧬 Starting News Association for {len(shocks)} shocks (Max limit: {max_news_items})" - ) - - for i, shock in enumerate(shocks): - if len(dataset) >= max_news_items: - logger.info("Reached maximum news items limit for this session.") - break - - summary = self.find_reason_and_verify(shock) - if summary: - # 3. Embedding news - emb = self.embedder.encode(summary) - dataset.append( - { - "history": shock["history"], - "target": shock["target"], - "news_emb": emb, - "summary": summary, - } - ) - - # Add delay after search with randomness to avoid being blocked - if i < len(shocks) - 1: - delay = random.uniform(2.0, 4.0) - time.sleep(delay) - - if not dataset: - logger.error( - "❌ No verified news-price pairs found. Adjust threshold or check if news is available in that period." - ) - return - - # 4. Train/Val Split - random.seed(42) - random.shuffle(dataset) - - if len(dataset) < 2: - train_set = dataset - val_set = [] - logger.warning( - f"⚠️ Only {len(dataset)} sample(s) found. Training on all, skipping validation." - ) - else: - split_idx = max(1, int(len(dataset) * 0.8)) - if split_idx >= len(dataset): - split_idx = len(dataset) - 1 - - train_set = dataset[:split_idx] - val_set = dataset[split_idx:] - logger.info( - f"🏗️ Dataset Split: {len(train_set)} samples for training, {len(val_set)} for validation." - ) - - if not train_set: - logger.error("❌ No samples for training.") - return - - # 5. Training (Few-shot) - optimizer = torch.optim.Adam(self.model.news_proj.parameters(), lr=1e-3) - criterion = nn.CrossEntropyLoss() - self.model.train() - - loss_history = [] - logger.info(f"🚀 Training for 30 epochs...") - for epoch in range(30): - total_loss = 0 - for item in train_set: - optimizer.zero_grad() - - # Prep Data - hist_df = item["history"] - # For training, we still focus on the immediate next point (teacher forcing) - target_df = item["target"].iloc[:1] - - hist_raw = hist_df[ - ["open", "high", "low", "close", "volume"] - ].values.astype(np.float32) - hist_raw = np.column_stack([hist_raw, hist_raw[:, 3] * hist_raw[:, 4]]) - - mean, std = hist_raw.mean(axis=0), hist_raw.std(axis=0) + 1e-5 - hist_norm = ( - torch.from_numpy((hist_raw - mean) / std) - .unsqueeze(0) - .to(self.device) - ) - - target_raw = target_df[ - ["open", "high", "low", "close", "volume"] - ].values.astype(np.float32) - target_raw = np.column_stack( - [target_raw, target_raw[:, 3] * target_raw[:, 4]] - ) - target_norm = ( - torch.from_numpy((target_raw - mean) / std) - .unsqueeze(0) - .to(self.device) - ) - - with torch.no_grad(): - z_indices = self.tokenizer.encode(hist_norm, half=True) - t_indices = self.tokenizer.encode(target_norm, half=True) - s1_ids, s2_ids = z_indices[0], z_indices[1] - t_s1, t_s2 = t_indices[0], t_indices[1] - - news_t = torch.from_numpy(item["news_emb"]).unsqueeze(0).to(self.device) - s1_logits, s2_logits = self.model( - s1_ids, - s2_ids, - news_emb=news_t, - use_teacher_forcing=True, - s1_targets=t_s1, - ) - - loss = ( - criterion(s1_logits[:, -1, :], t_s1[:, 0]) - + criterion(s2_logits[:, -1, :], t_s2[:, 0]) - ) / 2 - loss.backward() - optimizer.step() - total_loss += loss.item() - - avg_epoch_loss = total_loss / max(1, len(train_set)) - loss_history.append(avg_epoch_loss) - - if (epoch + 1) % 10 == 0: - logger.info(f"Epoch {epoch + 1} Loss: {avg_epoch_loss:.4f}") - - # 5.1 Visualize Loss Curve - loss_chart = VisualizerTools.generate_loss_chart(loss_history) - VisualizerTools.render_chart_to_file( - loss_chart, - os.path.join(SRC_DIR, "exports/training_results/loss_curve.html"), - ) - - # 5.2 Save final model - self.save_model() - - # 6. Final Evaluation on Validation Set - if not val_set: - logger.warning("⚠️ Validation set is empty. Skipping statistical analysis.") - return - - logger.info( - f"🧪 Final Evaluation: Base vs News-Integrated ({pred_len}-day Window)" - ) - self.model.eval() - predictor = KronosPredictor(self.model, self.tokenizer, device=self.device) - - base_maes = [] - news_maes = [] - - print("\n" + "=" * 90) - print( - f"{'Date':<12} | {'Ticker':<8} | {'Base MAE':<15} | {'News MAE':<15} | {'Improvement'}" - ) - print("-" * 90) - - for item in val_set: - h = item["history"] - t = item["target"] - actuals = t["close"].values[:pred_len] - - x_ts = pd.to_datetime(h["date"]) - # Future timestamps: handle business days if possible, or just simple offset - future_dates = pd.date_range( - start=x_ts.iloc[-1] + timedelta(days=1), periods=pred_len, freq="B" - ) - y_ts = pd.Series(future_dates) - - # A. Base Prediction - p_base = predictor.predict( - h, x_ts, y_ts, pred_len=pred_len, news_emb=None, verbose=False - ) - b_preds = p_base["close"].values[: len(actuals)] - - # B. News-Aware Prediction - p_news = predictor.predict( - h, - x_ts, - y_ts, - pred_len=pred_len, - news_emb=item["news_emb"], - verbose=False, - ) - n_preds = p_news["close"].values[: len(actuals)] - - # Calculate MAE over the window - b_mae = np.mean(np.abs(b_preds - actuals)) - n_mae = np.mean(np.abs(n_preds - actuals)) - - base_maes.append(b_mae) - news_maes.append(n_mae) - - improvement = (b_mae - n_mae) / (b_mae + 1e-6) * 100 - - date_str = str(t["date"].values[0])[:10] - ticker = h.iloc[-1]["ticker"] if "ticker" in h.columns else "Stock" - print( - f"{date_str:<12} | {ticker:<8} | {b_mae:<15.4f} | {n_mae:<15.4f} | {improvement:>+7.1f}%" - ) - - # C. Generate Visualization for this case - try: - # Helper to convert DF to KLinePoints - def to_kp_list(preds_df): - points = [] - for idx, row in preds_df.iterrows(): - points.append( - KLinePoint( - date=str(idx)[:10], - open=row["open"], - high=row["high"], - low=row["low"], - close=row["close"], - volume=row["volume"] if "volume" in row else 0, - ) - ) - return points - - forecast_obj = ForecastResult( - ticker=ticker, - base_forecast=to_kp_list(p_base), - adjusted_forecast=to_kp_list(p_news), - rationale=item["summary"], - ) - - # Ground truth for visualizer expects a DataFrame with 'date' and 'close' - gt_df = t[["date", "open", "high", "low", "close", "volume"]] - - chart = VisualizerTools.generate_stock_chart( - df=h, - ticker=ticker, - title=f"Training Eval: {ticker} ({date_str}) Improvement: {improvement:.1f}%", - forecast=forecast_obj, - ground_truth=gt_df, - ) - - safe_date = date_str.replace("-", "") - filename = f"eval_{ticker}_{safe_date}.html" - VisualizerTools.render_chart_to_file( - chart, os.path.join(SRC_DIR, f"exports/training_results/{filename}") - ) - except Exception as e: - logger.error(f"Failed to generate eval chart for {ticker}: {e}") - - # Summary Statistics - avg_base_err = sum(base_maes) / max(1, len(base_maes)) - avg_news_err = sum(news_maes) / max(1, len(news_maes)) - overall_imp = (avg_base_err - avg_news_err) / (avg_base_err + 1e-6) * 100 - - print("-" * 90) - print( - f"{'AVERAGE':<12} | {'-':<8} | {avg_base_err:<15.4f} | {avg_news_err:<15.4f} | {overall_imp:>+7.1f}%" - ) - print("=" * 90 + "\n") - - logger.success( - f"🏁 Statistical Analysis Complete. Avg Error Reduction ({pred_len}-day): {overall_imp:.2f}%" - ) - logger.info( - f"📊 Visualization results saved to: {os.path.join(SRC_DIR, 'exports/training_results/')}" - ) - - -if __name__ == "__main__": - trainer = AutoSynthesisTrainer() - - logger.info("📂 Fetching all stock codes from database...") - res = trainer.db.execute_query("SELECT code FROM stock_list") - all_tickers = [row["code"] for row in res] - - if not all_tickers: - logger.warning("⚠️ No tickers found in stock_list table. Trying to sync...") - trainer.tools._check_and_update_stock_list(force=True) - res = trainer.db.execute_query("SELECT code FROM stock_list") - all_tickers = [row["code"] for row in res] - - logger.info(f"🚀 Starting training on potential stocks (1-year scan)...") - # 为了演示,我们扫描前 100 个股票,寻找最近一年的冲击点 - trainer.run_synthesis_and_train(all_tickers[:100], pred_len=1) diff --git a/skills/alphaear-reporter/scripts/utils/search_tools.py b/skills/alphaear-reporter/scripts/utils/search_tools.py deleted file mode 100644 index 50b08f3..0000000 --- a/skills/alphaear-reporter/scripts/utils/search_tools.py +++ /dev/null @@ -1,611 +0,0 @@ -import os -import hashlib -import json -import re -import requests -import time -import threading -from typing import List, Dict, Optional, Any -from agno.tools.duckduckgo import DuckDuckGoTools -from agno.tools.baidusearch import BaiduSearchTools -from agno.agent import Agent -from loguru import logger -from datetime import datetime -from .database_manager import DatabaseManager -from .content_extractor import ContentExtractor -from .llm.factory import get_model -from .hybrid_search import LocalNewsSearch - -# 默认搜索缓存 TTL(秒),可通过环境变量覆盖 -DEFAULT_SEARCH_TTL = int(os.getenv("SEARCH_CACHE_TTL", "3600")) # 默认 1 小时 - - -class JinaSearchEngine: - """Jina Search API 封装 - 使用 s.jina.ai 进行网络搜索""" - - JINA_SEARCH_URL = "https://s.jina.ai/" - - # 速率限制配置 - _rate_limit_no_key = 10 # 无 key 时每分钟最大请求数 - _rate_window = 60.0 - _min_interval = 2.0 - _request_times = [] - _last_request_time = 0.0 - _lock = threading.Lock() - - def __init__(self): - self.api_key = os.getenv("JINA_API_KEY", "").strip() - self.has_api_key = bool(self.api_key) - if self.has_api_key: - logger.info("✅ Jina Search API key configured") - - @classmethod - def _wait_for_rate_limit(cls, has_api_key: bool) -> None: - """等待以满足速率限制""" - if has_api_key: - time.sleep(0.3) - return - - with cls._lock: - current_time = time.time() - cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window] - - if len(cls._request_times) >= cls._rate_limit_no_key: - oldest = cls._request_times[0] - wait_time = cls._rate_window - (current_time - oldest) + 1.0 - if wait_time > 0: - logger.warning(f"⏳ Jina Search rate limit, waiting {wait_time:.1f}s...") - time.sleep(wait_time) - current_time = time.time() - cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window] - - time_since_last = current_time - cls._last_request_time - if time_since_last < cls._min_interval: - time.sleep(cls._min_interval - time_since_last) - - cls._request_times.append(time.time()) - cls._last_request_time = time.time() - - def search(self, query: str, max_results: int = 5) -> List[Dict]: - """ - 使用 Jina Search API 执行搜索 - - Args: - query: 搜索关键词 - max_results: 返回结果数量 - - Returns: - 搜索结果列表,每个结果包含 title, url, content - """ - if not query: - return [] - - logger.info(f"🔍 Jina Search: {query}") - - # 等待速率限制 - self._wait_for_rate_limit(self.has_api_key) - - headers = { - "Accept": "application/json", - "X-Retain-Images": "none", - } - - if self.has_api_key: - headers["Authorization"] = f"Bearer {self.api_key}" - - try: - # Jina Search API: https://s.jina.ai/{query} - import urllib.parse - encoded_query = urllib.parse.quote(query) - url = f"{self.JINA_SEARCH_URL}{encoded_query}" - - response = requests.get(url, headers=headers, timeout=30) - - if response.status_code == 429: - logger.warning("⚠️ Jina Search rate limited (429), waiting 30s...") - time.sleep(30) - return self.search(query, max_results) - - if response.status_code != 200: - logger.warning(f"Jina Search failed (Status {response.status_code})") - return [] - - # 解析响应 - try: - data = response.json() - except json.JSONDecodeError: - # 如果返回纯文本,尝试解析 - data = {"data": [{"title": "Search Result", "url": "", "content": response.text}]} - - results = [] - - # Jina 返回格式可能是 {"data": [...]} 或直接是列表 - items = data.get("data", []) if isinstance(data, dict) else data - if not isinstance(items, list): - items = [items] if items else [] - - for i, item in enumerate(items[:max_results]): - if isinstance(item, dict): - results.append({ - "title": item.get("title", f"Result {i+1}"), - "url": item.get("url", ""), - "href": item.get("url", ""), # 兼容性 - "content": item.get("content", item.get("description", "")), - "body": item.get("content", item.get("description", "")), # 兼容性 - }) - elif isinstance(item, str): - results.append({ - "title": f"Result {i+1}", - "url": "", - "content": item - }) - - logger.info(f"✅ Jina Search returned {len(results)} results") - return results - - except requests.exceptions.Timeout: - logger.error("Jina Search timeout") - return [] - except requests.exceptions.RequestException as e: - logger.error(f"Jina Search request error: {e}") - return [] - except Exception as e: - logger.error(f"Jina Search unexpected error: {e}") - return [] - -class SearchTools: - """扩展性搜索工具库 - 支持多引擎聚合与内容缓存""" - - def __init__(self, db: DatabaseManager): - self.db = db - - # 检查 Jina API Key 是否配置 - jina_api_key = os.getenv("JINA_API_KEY", "").strip() - self._jina_enabled = bool(jina_api_key) - - self._engines = { - "ddg": DuckDuckGoTools(), - "baidu": BaiduSearchTools(), - "local": LocalNewsSearch(db) - } - - # 如果配置了 Jina API Key,添加 Jina 引擎 - if self._jina_enabled: - self._engines["jina"] = JinaSearchEngine() - logger.info("🚀 Jina Search engine enabled (JINA_API_KEY configured)") - - # 确定默认搜索引擎 - self._default_engine = "jina" if self._jina_enabled else "ddg" - - def _generate_hash(self, query: str, engine: str, max_results: int) -> str: - return hashlib.md5(f"{engine}:{query}:{max_results}".encode()).hexdigest() - - def search(self, query: str, engine: str = None, max_results: int = 5, ttl: Optional[int] = None) -> str: - """ - 使用指定搜索引擎执行网络搜索,结果会被缓存以提高效率。 - - Args: - query: 搜索关键词,如 "英伟达财报" 或 "光伏行业政策"。 - engine: 搜索引擎选择。可选值: - "jina" (Jina Search,需配置 JINA_API_KEY,LLM友好输出), - "ddg" (DuckDuckGo,推荐英文/国际搜索), - "baidu" (百度,推荐中文/国内搜索), - "local" (本地历史新闻搜索,基于向量+BM25)。 - 默认: 若配置了 JINA_API_KEY 则使用 "jina",否则 "ddg"。 - max_results: 期望返回的结果数量,默认 5 条。 - ttl: 缓存有效期(秒)。如果缓存超过此时间会重新搜索。 - 默认使用环境变量 SEARCH_CACHE_TTL 或 3600 秒。 - 设为 0 可强制刷新。 - - Returns: - 搜索结果的文本描述,包含标题、摘要和链接。 - """ - # 使用默认引擎(如果配置了 Jina 则优先使用 Jina) - if engine is None: - engine = self._default_engine - - if engine not in self._engines: - return f"Error: Unsupported engine '{engine}'. Available: {list(self._engines.keys())}" - - query_hash = self._generate_hash(query, engine, max_results) - effective_ttl = ttl if ttl is not None else DEFAULT_SEARCH_TTL - - # 1. 尝试从缓存读取 (local 引擎不缓存,因为它本身就是查库) - if engine != "local": - cache = self.db.get_search_cache(query_hash, ttl_seconds=effective_ttl if effective_ttl > 0 else None) - if cache and effective_ttl != 0: - logger.info(f"ℹ️ Found search results in cache for: {query} ({engine})") - return cache['results'] - - # 2. 执行真实搜索 - logger.info(f"📡 Searching {engine} for: {query}") - try: - tool = self._engines[engine] - if engine == "jina": - # Jina Search 返回 List[Dict] - jina_results = tool.search(query, max_results=max_results) - results = [] - for r in jina_results: - results.append({ - "title": r.get("title", ""), - "href": r.get("url", ""), - "body": r.get("content", "") - }) - elif engine == "ddg": - results = tool.duckduckgo_search(query, max_results=max_results) - elif engine == "baidu": - results = tool.baidu_search(query, max_results=max_results) - elif engine == "local": - # LocalNewsSearch 返回的是 List[Dict] - local_results = tool.search(query, top_n=max_results) - results = [] - for r in local_results: - results.append({ - "title": r.get("title"), - "href": r.get("url", "local"), - "body": r.get("content", "") - }) - else: - results = "Search not implemented for this engine." - - results_str = str(results) - if engine != "local": - self.db.save_search_cache(query_hash, query, engine, results_str) - return results_str - - except Exception as e: - # 搜索失败时的降级策略 - if engine == "jina": - logger.warning(f"⚠️ Jina search failed, falling back to ddg: {query} ({e})") - try: - return self.search(query, engine="ddg", max_results=max_results, ttl=ttl) - except Exception as e2: - logger.error(f"❌ DDG fallback also failed for {query}: {e2}") - elif engine == "ddg": - logger.warning(f"⚠️ DDG search failed, falling back to baidu: {query} ({e})") - try: - return self.search(query, engine="baidu", max_results=max_results, ttl=ttl) - except Exception as e2: - logger.error(f"❌ Baidu fallback also failed for {query}: {e2}") - - logger.error(f"❌ Search failed for {query}: {e}") - return f"Error occurred during search: {str(e)}" - - def search_list(self, query: str, engine: str = None, max_results: int = 5, ttl: Optional[int] = None, enrich: bool = True) -> List[Dict]: - """ - 执行搜索并返回结构化列表 (List[Dict])。 - Dict 包含: title, href (or url), body (or snippet) - - Args: - engine: 搜索引擎,默认使用配置的默认引擎(Jina 优先) - enrich: 是否抓取正文内容 (默认 True) - """ - # 使用默认引擎 - if engine is None: - engine = self._default_engine - - if engine not in self._engines: - logger.error(f"Unsupported engine {engine}") - return [] - - # 不同的 hash 以区分是否 enrichment - enrich_suffix = ":enriched" if enrich else "" - query_hash = self._generate_hash(query, engine + enrich_suffix, max_results) - effective_ttl = ttl if ttl is not None else DEFAULT_SEARCH_TTL - - # 1. 尝试从缓存读取 - cache = self.db.get_search_cache(query_hash, ttl_seconds=effective_ttl if effective_ttl > 0 else None) - if cache and effective_ttl != 0: - try: - cached_data = json.loads(cache['results']) - if isinstance(cached_data, list): - logger.info(f"ℹ️ Found structured search cache for: {query}") - return cached_data - except: - pass - - # 1.5 Smart Cache (Fuzzy + LLM) - if effective_ttl != 0: - try: - # 1. Similar cached queries - similar_queries = self.db.find_similar_queries(query, limit=3) - # Filter by TTL - valid_candidates = [] - for q in similar_queries: - if q['query'] == query: continue - q_time = datetime.fromisoformat(q['timestamp']) - if effective_ttl and (datetime.now() - q_time).total_seconds() > effective_ttl: - continue - q['type'] = 'cached_search' - valid_candidates.append(q) - - # 2. Relevant local news (as search results) - local_news = self.db.search_local_news(query, limit=3) - if local_news: - # Group local news as a single "candidate" source? Or individual? - # Better to treat "Local News Database" as one candidate source that contains X items. - # Or just add them to candidates list? - # Let's package strictly relevant news as a "local_news_bundle" - valid_candidates.append({ - 'type': 'local_news', - 'query': 'Local Database News', - 'items': local_news, - 'timestamp': datetime.now().isoformat() - }) - - if valid_candidates: - logger.info(f"🤔 Found {len(valid_candidates)} smart cache candidates (Queries/News). Asking LLM...") - evaluation = self._evaluate_cache_relevance(query, valid_candidates) - - if evaluation and evaluation.get('reuse', False): - idx = evaluation.get('index', -1) - if 0 <= idx < len(valid_candidates): - chosen = valid_candidates[idx] - logger.info(f"🤖 LLM suggested reusing: '{chosen.get('query')}' ({chosen['type']})") - - if chosen['type'] == 'cached_search': - # Load the chosen cache - cache = self.db.get_search_cache(chosen['query_hash']) - if cache: - try: - cached_data = json.loads(cache['results']) - if isinstance(cached_data, list): - return cached_data - except: - pass - elif chosen['type'] == 'local_news': - # Convert local news items to search result format - news_results = [] - for i, news in enumerate(chosen['items'], 1): - news_results.append({ - "id": news.get('id'), - "rank": i, - "title": news.get('title'), - "url": news.get('url'), - "content": news.get('content'), - "original_snippet": news.get('content')[:200] if news.get('content') else '', - "source": f"Local News ({news.get('source')})", - "publish_time": news.get('publish_time'), - "crawl_time": news.get('crawl_time'), - "sentiment_score": news.get('sentiment_score', 0), - "meta_data": {"origin": "local_db"} - }) - return news_results - - except Exception as e: - logger.warning(f"Smart cache check failed: {e}") - - # 2. 执行搜索 - logger.info(f"📡 Searching {engine} (structured) for: {query}") - try: - tool = self._engines[engine] - results = [] - if engine == "jina": - # Jina Search 直接返回结构化数据 - jina_results = tool.search(query, max_results=max_results) - for r in jina_results: - results.append({ - "title": r.get("title", ""), - "url": r.get("url", ""), - "href": r.get("url", ""), - "body": r.get("content", ""), - "content": r.get("content", ""), - "source": "Jina Search" - }) - elif engine == "ddg": - results = tool.duckduckgo_search(query, max_results=max_results) - elif engine == "baidu": - results = tool.baidu_search(query, max_results=max_results) - elif engine == "local": - # LocalNewsSearch 返回的是 List[Dict] - local_results = tool.search(query, top_n=max_results) - results = [] - for r in local_results: - results.append({ - "title": r.get("title"), - "url": r.get("url", "local"), - "body": r.get("content", "")[:500], - "source": f"Local ({r.get('source', 'db')})", - "publish_time": r.get("publish_time") - }) - - # 处理字符串类型的 JSON 返回 (Baidu 常返 JSON 字符串) - if isinstance(results, str) and engine not in ["local", "jina"]: - try: - results = json.loads(results) - except: - pass - - # 转为统一格式 - normalized_results = [] - if isinstance(results, list): - - for i, r in enumerate(results, 1): - title = r.get('title', '') - url = r.get('href') or r.get('url') or r.get('link', '') - content = r.get('body') or r.get('snippet') or r.get('abstract', '') - - if title and url: - normalized_results.append({ - "id": self._generate_hash(url + query, "search_item", i), - "rank": i, - "title": title, - "url": url, - "content": content, - "original_snippet": content, # 保留摘要 - "source": f"Search ({engine})", - "publish_time": datetime.now().isoformat(), # 暂用当前时间 - "crawl_time": datetime.now().isoformat(), - "meta_data": {"query": query, "engine": engine} - }) - - # Fallback if still string and failed to parse - elif isinstance(results, str) and results: - normalized_results.append({"title": query, "url": "", "content": results, "source": engine}) - - # 3. 抓取正文 & 计算情绪 (Enrichment) - # 注意:如果使用 Jina Search,内容已经是 LLM 友好格式,可选择跳过 enrichment - skip_content_enrichment = (engine == "jina") - - if enrich and normalized_results: - logger.info(f"🕸️ Enriching {len(normalized_results)} search results with Jina & Sentiment...") - extractor = ContentExtractor() - - # Lazy load sentiment tool - if not hasattr(self, 'sentiment_tool') or self.sentiment_tool is None: - from ..sentiment_tools import SentimentTools - self.sentiment_tool = SentimentTools(self.db) - - for item in normalized_results: - if item.get("url"): - try: - # 如果是 Jina Search,内容已经足够好,跳过额外抓取 - if skip_content_enrichment and item.get("content") and len(item.get("content", "")) > 100: - full_content = item["content"] - else: - # Use Jina Reader to get full content - full_content = extractor.extract_with_jina(item["url"], timeout=60) - - if full_content and len(full_content) > 100: - item["content"] = full_content - - # Calculate sentiment - # Use title + snippet of content for efficiency - text_to_analyze = f"{item['title']} {full_content[:500]}" - sent_result = self.sentiment_tool.analyze_sentiment(text_to_analyze) # Using self.sentiment_tool - score = sent_result.get('score', 0.0) - item["sentiment_score"] = float(score) - - logger.info(f" ✅ Enriched: {item['title'][:20]}... (Sentiment: {score:.2f})") - else: - # Fallback: Use snippet for sentiment - logger.info(f" ⚠️ Content short/failed for {item['url']}, using snippet for sentiment.") - text_to_analyze = f"{item['title']} {item['content']}" # content is snippet here - sent_result = self.sentiment_tool.analyze_sentiment(text_to_analyze) - score = sent_result.get('score', 0.0) - item["sentiment_score"] = float(score) - - except Exception as e: - # Fallback: Use snippet for sentiment on error - logger.warning(f"Failed to enrich {item['url']}: {e}. Using snippet.") - text_to_analyze = f"{item['title']} {item['content']}" - sent_result = self.sentiment_tool.analyze_sentiment(text_to_analyze) - score = sent_result.get('score', 0.0) - item["sentiment_score"] = float(score) - - # 缓存结果 list - if normalized_results: - # Pass list directly, DB manager will handle JSON dump for main cache and populate search_details - # Only cache if NOT from local news reuse (though this logic path is for fresh search) - self.db.save_search_cache(query_hash, query, engine, normalized_results) - - return normalized_results - - except Exception as e: - # 搜索失败时的降级策略 - if engine == "jina": - logger.warning(f"⚠️ Jina search_list failed, falling back to ddg: {query} ({e})") - try: - return self.search_list(query, engine="ddg", max_results=max_results, ttl=ttl, enrich=enrich) - except Exception as e2: - logger.error(f"❌ DDG fallback (search_list) also failed for {query}: {e2}") - elif engine == "ddg": - logger.warning(f"⚠️ DDG search_list failed, falling back to baidu: {query} ({e})") - try: - return self.search_list(query, engine="baidu", max_results=max_results, ttl=ttl, enrich=enrich) - except Exception as e2: - logger.error(f"❌ Baidu fallback (search_list) also failed for {query}: {e2}") - - logger.error(f"❌ Structured search failed for {query}: {e}") - return [] - - def _evaluate_cache_relevance(self, current_query: str, candidates: List[Dict]) -> Dict: - """ - 使用 LLM 评估缓存候选是否足以回答当前问题。 - """ - try: - # Prepare candidates text - candidates_desc = [] - for i, c in enumerate(candidates): - if c['type'] == 'cached_search': - # Preview cached results if available? - # Maybe just use the query string as a proxy for what's in there. - # Or peek at 'results' snippet. - preview = "" - try: - # Attempt to peek first result title from JSON string - # Note: c.get('results') might be a stringified JSON list - res_list = json.loads(c.get('results', '[]')) - if res_list and isinstance(res_list, list) and len(res_list) > 0: - first_item = res_list[0] - if isinstance(first_item, dict) and 'title' in first_item: - preview = f" (Contains: {first_item.get('title', '')[:50]}...)" - except: - pass - candidates_desc.append(f"[{i}] Old Search Query: '{c['query']}' {preview} (Time: {c['timestamp']})") - elif c['type'] == 'local_news': - # List titles of local news - titles = [item['title'] for item in c['items'][:3]] - candidates_desc.append(f"[{i}] Local Database News: {', '.join(titles)}... (Time: {c['timestamp']})") - - prompt = f""" - Task: Decide if existing information is sufficient for the new search query. - - New Query: "{current_query}" - - Available Information Candidates: - {chr(10).join(candidates_desc)} - - Instructions: - 1. Analyze if any candidate provides ENOUGH up-to-date info for the "New Query". - 2. If yes, choose the best one. - 3. If the query implies needing LATEST real-time info and candidates are old, choose none. - 4. Return strictly JSON: {{"reuse": true/false, "index": , "reason": "short explanation"}} - """ - # 初始化模型 - provider = os.getenv("LLM_PROVIDER", "minimax") - model_id = os.getenv("LLM_MODEL", "Qwen") - host = os.getenv("LLM_HOST") - if host: - model = get_model(provider, model_id, host=host) - else: - model = get_model(provider, model_id) - - agent = Agent(model=model, markdown=True) - - response = agent.run(prompt) - content = response.content - - # Parse JSON - json_match = re.search(r'```json\s*(.*?)\s*```', content, re.DOTALL) - if json_match: - return json.loads(json_match.group(1)) - elif '{' in content: - # Fallback for cases where LLM doesn't wrap in ```json - return json.loads(content[content.find('{'):content.rfind('}')+1]) - return {"reuse": False} - - except Exception as e: - logger.warning(f"LLM evaluation failed: {e}") - return {"reuse": False} - - def aggregate_search(self, query: str, engines: Optional[List[str]] = None, max_results: int = 5) -> str: - """ - 使用多个搜索引擎同时搜索并聚合结果,获得更全面的信息覆盖。 - - Args: - query: 搜索关键词。 - engines: 要使用的搜索引擎列表。可选值: ["ddg", "baidu"]。 - 默认同时使用 ddg 和 baidu。 - max_results: 每个引擎期望返回的结果数量。 - - Returns: - 聚合后的搜索结果,按引擎分组显示。 - """ - engines = engines or ["ddg", "baidu"] - aggregated_results = [] - for engine in engines: - res = self.search(query, engine=engine, max_results=max_results) - aggregated_results.append(f"--- Results from {engine.upper()} ---\n{res}") - - return "\n\n".join(aggregated_results) diff --git a/skills/alphaear-reporter/scripts/utils/sentiment_tools.py b/skills/alphaear-reporter/scripts/utils/sentiment_tools.py deleted file mode 100644 index f4278b5..0000000 --- a/skills/alphaear-reporter/scripts/utils/sentiment_tools.py +++ /dev/null @@ -1,287 +0,0 @@ -import os -from typing import Dict, List, Union, Optional -import json -from loguru import logger -from agno.agent import Agent -from .llm.factory import get_model -from .database_manager import DatabaseManager - -# 从环境变量读取默认情绪分析模式 -DEFAULT_SENTIMENT_MODE = os.getenv("SENTIMENT_MODE", "auto") # auto, bert, llm - - -class SentimentTools: - """ - 情绪分析工具 - 支持 LLM 和 BERT 两种模式 - - 模式说明: - - "auto": 自动选择,优先使用 BERT(速度快),不可用时回退到 LLM - - "bert": 强制使用 BERT 模型(需要 transformers 库) - - "llm": 强制使用 LLM(更准确但较慢) - - 可通过环境变量 SENTIMENT_MODE 设置默认模式。 - """ - - def __init__( - self, - db: DatabaseManager, - mode: Optional[str] = None, - model_provider: str = "openai", - model_id: str = "gpt-4o", - ): - """ - 初始化情绪分析工具。 - - Args: - db: 数据库管理器实例 - mode: 分析模式,可选 "auto", "bert", "llm"。None 则使用环境变量默认值。 - model_provider: LLM 提供商,如 "openai", "ust", "deepseek" - model_id: 模型标识符 - """ - self.db = db - self.mode = mode or DEFAULT_SENTIMENT_MODE - self.llm_model = None - self.bert_pipeline = None - - # Initialize LLM - try: - provider = "minimax" if os.getenv("MINIMAX_API_KEY") else model_provider - m_id = ( - os.getenv("LLM_MODEL", "MiniMax-Text-01") - if provider == "minimax" - else model_id - ) - self.llm_model = get_model(provider, m_id) - except Exception as e: - logger.warning(f"LLM initialization skipped: {e}") - - # Initialize BERT if needed - if self.mode in ["bert", "auto"]: - try: - from transformers import ( - pipeline, - AutoTokenizer, - AutoModelForSequenceClassification, - ) - from transformers.utils import logging as transformers_logging - - transformers_logging.set_verbosity_error() # 减少冗余日志 - - bert_model = os.getenv( - "BERT_SENTIMENT_MODEL", - "uer/roberta-base-finetuned-chinanews-chinese", - ) - - # 优先使用本地缓存 - try: - tokenizer = AutoTokenizer.from_pretrained( - bert_model, local_files_only=True - ) - model = AutoModelForSequenceClassification.from_pretrained( - bert_model, local_files_only=True - ) - - self.bert_pipeline = pipeline( - "sentiment-analysis", - model=model, - tokenizer=tokenizer, - device=-1, - ) - logger.info( - f"✅ BERT pipeline loaded from local cache: {bert_model}" - ) - except (OSError, ValueError, ImportError): - # 本地没有,则从网络下载 - logger.info(f"📡 Downloading BERT model: {bert_model}...") - tokenizer = AutoTokenizer.from_pretrained(bert_model) - model = AutoModelForSequenceClassification.from_pretrained( - bert_model - ) - - self.bert_pipeline = pipeline( - "sentiment-analysis", - model=model, - tokenizer=tokenizer, - device=-1, - ) - logger.info( - f"✅ BERT Sentiment pipeline ({bert_model}) initialized." - ) - except ImportError: - logger.warning( - "Transformers library not installed. BERT sentiment analysis disabled." - ) - except Exception as e: - if self.mode == "bert": - logger.error(f"BERT mode requested but failed: {e}") - else: - logger.warning(f"BERT unavailable, using LLM only. Error: {e}") - self.bert_pipeline = None - - def analyze_sentiment(self, text: str) -> Dict[str, Union[float, str]]: - """ - 分析文本的情绪极性。根据初始化时的 mode 自动选择分析方法。 - - Args: - text: 需要分析的文本内容,如新闻标题或摘要。 - - Returns: - 包含以下字段的字典: - - score: 情绪分值,范围 -1.0(极度负面)到 1.0(极度正面),0.0 为中性 - - label: 情绪标签,"positive"/"negative"/"neutral" - - reason: 分析理由(仅 LLM 模式提供详细理由) - """ - if self.mode == "bert" and self.bert_pipeline: - results = self.analyze_sentiment_bert([text]) - return results[0] if results else {"score": 0.0, "label": "error"} - elif self.mode == "llm" or (self.mode == "auto" and not self.bert_pipeline): - return self.analyze_sentiment_llm(text) - else: - # auto mode with BERT available - results = self.analyze_sentiment_bert([text]) - return results[0] if results else {"score": 0.0, "label": "error"} - - def analyze_sentiment_llm(self, text: str) -> Dict[str, Union[float, str]]: - """ - 使用 LLM 进行深度情绪分析,可获得详细的分析理由。 - - Args: - text: 需要分析的文本,最多处理前 1000 字符。 - - Returns: - 包含 score, label, reason 的字典。 - """ - if not self.llm_model: - return {"score": 0.0, "label": "neutral", "error": "LLM not initialized"} - - analyzer = Agent(model=self.llm_model, markdown=True) - prompt = f"""请分析以下金融/新闻文本的情绪极性。 - 返回严格的 JSON 格式: - {{"score": , "label": "", "reason": "<简短理由>"}} - - 文本: {text[:1000]}""" - - try: - response = analyzer.run(prompt) - content = response.content - if "```json" in content: - content = content.split("```json")[1].split("```")[0].strip() - elif "```" in content: - content = content.split("```")[1].split("```")[0].strip() - return json.loads(content) - except Exception as e: - logger.error(f"LLM sentiment failed: {e}") - return {"score": 0.0, "label": "error", "reason": str(e)} - - def analyze_sentiment_bert(self, texts: List[str]) -> List[Dict]: - """ - 使用 BERT 进行批量高速情绪分析。 - - Args: - texts: 需要分析的文本列表。 - - Returns: - 与输入列表等长的分析结果列表。 - """ - if not self.bert_pipeline: - return [ - {"score": 0.0, "label": "error", "reason": "BERT not available"} - ] * len(texts) - - try: - results = self.bert_pipeline(texts, truncation=True, max_length=512) - processed = [] - for r in results: - label = r["label"].lower() - score = r["score"] - - # 标准化不同模型的标签格式 - if "negative" in label or "neg" in label: - score = -score - elif "neutral" in label or "neu" in label: - score = 0.0 - - processed.append( - { - "score": float(round(score, 3)), - "label": "positive" - if score > 0.1 - else ("negative" if score < -0.1 else "neutral"), - "reason": "BERT automated analysis", - } - ) - return processed - except Exception as e: - logger.error(f"BERT analysis failed: {e}") - return [{"score": 0.0, "label": "error", "reason": str(e)}] * len(texts) - - def batch_update_news_sentiment( - self, - source: Optional[str] = None, - limit: int = 50, - use_bert: Optional[bool] = None, - ): - """ - 批量更新数据库中新闻的情绪分数。 - - Args: - source: 筛选特定新闻源,如 "wallstreetcn"。None 则处理所有来源。 - limit: 最多处理的新闻数量。 - use_bert: 是否使用 BERT。None 则根据初始化模式自动决定。 - - Returns: - 成功更新的新闻数量。 - """ - news_items = self.db.get_daily_news(source=source, limit=limit) - to_analyze = [item for item in news_items if not item.get("sentiment_score")] - - if not to_analyze: - return 0 - - # 决定使用哪种方法 - should_use_bert = ( - use_bert - if use_bert is not None - else (self.bert_pipeline is not None and self.mode != "llm") - ) - - updated_count = 0 - cursor = self.db.conn.cursor() - - if should_use_bert and self.bert_pipeline: - logger.info( - f"🚀 Using BERT for batch analysis of {len(to_analyze)} items..." - ) - titles = [item["title"] for item in to_analyze] - results = self.analyze_sentiment_bert(titles) - - for item, analysis in zip(to_analyze, results): - cursor.execute( - """ - UPDATE daily_news - SET sentiment_score = ?, meta_data = json_set(COALESCE(meta_data, '{}'), '$.sentiment_reason', ?) - WHERE id = ? - """, - (analysis["score"], analysis["reason"], item["id"]), - ) - updated_count += 1 - else: - logger.info(f"🚶 Using LLM for analysis of {len(to_analyze)} items...") - for item in to_analyze: - analysis = self.analyze_sentiment_llm(item["title"]) - cursor.execute( - """ - UPDATE daily_news - SET sentiment_score = ?, meta_data = json_set(COALESCE(meta_data, '{}'), '$.sentiment_reason', ?) - WHERE id = ? - """, - ( - analysis.get("score", 0.0), - analysis.get("reason", ""), - item["id"], - ), - ) - updated_count += 1 - - self.db.conn.commit() - return updated_count diff --git a/skills/alphaear-reporter/scripts/utils/stock_tools.py b/skills/alphaear-reporter/scripts/utils/stock_tools.py deleted file mode 100644 index 5929f74..0000000 --- a/skills/alphaear-reporter/scripts/utils/stock_tools.py +++ /dev/null @@ -1,257 +0,0 @@ -from datetime import datetime, timedelta -from typing import List, Dict, Optional -import akshare as ak -import pandas as pd -import re -import sqlite3 -from requests.exceptions import RequestException -from loguru import logger -from .database_manager import DatabaseManager -import os -from contextlib import contextmanager - -@contextmanager -def temporary_no_proxy(): - """Context manager to temporarily unset proxy environment variables.""" - proxies = {k: os.environ.get(k) for k in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY']} - for k in proxies: - if k in os.environ: - del os.environ[k] - try: - yield - finally: - for k, v in proxies.items(): - if v is not None: - os.environ[k] = v - -class StockTools: - """金融分析股票工具 - 结合高性能数据库缓存与增量更新""" - - def __init__(self, db: DatabaseManager, auto_update: bool = True): - """ - 初始化股票工具 - - Args: - db: 数据库管理器 - auto_update: 是否在列表为空时自动更新,默认 True - """ - self.db = db - if auto_update: - self._check_and_update_stock_list() - - def _check_and_update_stock_list(self, force: bool = False): - """检查并更新股票列表。仅在列表为空或 force=True 时从网络拉取。""" - # 直接查询表中记录数 - cursor = self.db.conn.cursor() - cursor.execute("SELECT COUNT(*) FROM stock_list") - count = cursor.fetchone()[0] - - if count > 0 and not force: - logger.info(f"ℹ️ Stock list already cached ({count} stocks)") - return - - logger.info("📡 Updating A-share and HK-share stock list from akshare...") - - def fetch_data(): - # A-share - df_a = ak.stock_zh_a_spot_em() - df_a = df_a[['代码', '名称']].copy() - df_a.columns = ['code', 'name'] - - # HK-share - df_hk = ak.stock_hk_spot_em() - df_hk = df_hk[['代码', '名称']].copy() - df_hk.columns = ['code', 'name'] - - # Combine - return pd.concat([df_a, df_hk], ignore_index=True) - - try: - try: - df_combined = fetch_data() - except (RequestException, Exception) as e: - if "Proxy" in str(e) or "proxy" in str(e): - logger.warning(f"⚠️ Proxy error detected: {e}. Retrying with proxy disabled...") - with temporary_no_proxy(): - df_combined = fetch_data() - else: - raise e - - self.db.save_stock_list(df_combined) - logger.info(f"✅ Cached {len(df_combined)} stocks (A-share + HK) to database.") - - except Exception as e: - logger.error(f"❌ Failed to sync stock list: {e}") - - - def search_ticker(self, query: str, limit: int = 5) -> List[Dict]: - """ - 模糊搜索 A 股股票代码或名称,支持常见缩写。 - """ - # 清洗后缀 (如 CATL.SZ -> CATL, 000001.SZ -> 000001) - clean_query = re.sub(r'\.(SZ|SH|HK|US)$', '', query, flags=re.IGNORECASE) - - # 常见缩写映射 - aliases = { - "CATL": "宁德时代", - "BYD": "比亚迪", - "TSLA": "特斯拉", - "Moutai": "贵州茅台", - "Tencent": "腾讯", - "Alibaba": "阿里巴巴", - "Meituan": "美团", - } - - search_query = aliases.get(clean_query.upper(), clean_query) - - # Robustness: if regex-like ticker code is embedded in query (e.g. "300364 中文在线"), try to extract it - if not search_query.isdigit(): - # Extract explicit 5-6 digit codes - match = re.search(r'\b(\d{5,6})\b', clean_query) - if match: - search_query = match.group(1) - - return self.db.search_stock(search_query, limit) - - def get_stock_price( - self, - ticker: str, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - force_sync: bool = False, - ) -> pd.DataFrame: - """ - 获取指定股票的历史价格数据。优先从本地缓存读取,缺失时自动从网络补齐。 - - Args: - ticker: 股票代码,如 "600519"(贵州茅台)或 "000001"(平安银行)。 - start_date: 开始日期,格式 "YYYY-MM-DD"。默认为 90 天前。 - end_date: 结束日期,格式 "YYYY-MM-DD"。默认为今天。 - - Returns: - 包含 date, open, close, high, low, volume, change_pct 列的 DataFrame。 - """ - now = datetime.now() - if not end_date: - end_date = now.strftime('%Y-%m-%d') - if not start_date: - start_date = (now - timedelta(days=90)).strftime('%Y-%m-%d') - - df_db = self.db.get_stock_prices(ticker, start_date, end_date) - - need_update = False - if df_db.empty: - need_update = True - else: - db_latest = pd.to_datetime(df_db['date'].max()) - req_latest = pd.to_datetime(end_date) - if (req_latest - db_latest).days > 2: - need_update = True - - if force_sync: - need_update = True - - if need_update: - logger.info(f"📡 Data stale or missing for {ticker}, syncing from network...") - - # 清洗 ticker,确保只包含数字(Akshare A 股接口通常只需要数字代码) - clean_ticker = "".join(filter(str.isdigit, ticker)) - if not clean_ticker: - # Non A/H numeric tickers are not supported by the current data source. - logger.warning(f"⚠️ Unsupported ticker format (A/H only): {ticker}") - return df_db - - try: - s_fmt = start_date.replace("-", "") - e_fmt = end_date.replace("-", "") - - df_remote = None - - def fetch_data(): - if len(clean_ticker) == 5: - # HK Stock - return ak.stock_hk_hist( - symbol=clean_ticker, period="daily", - start_date=s_fmt, end_date=e_fmt, - adjust="qfq" - ) - else: - # A-share Stock - return ak.stock_zh_a_hist( - symbol=clean_ticker, period="daily", - start_date=s_fmt, end_date=e_fmt, - adjust="qfq" - ) - - try: - df_remote = fetch_data() - except (RequestException, Exception) as e: - if "Proxy" in str(e) or "proxy" in str(e): - logger.warning(f"⚠️ Proxy error detected: {e}. Retrying with proxy disabled...") - with temporary_no_proxy(): - df_remote = fetch_data() - else: - raise e - - if df_remote is not None and not df_remote.empty: - df_remote = df_remote.rename(columns={ - '日期': 'date', '开盘': 'open', '收盘': 'close', - '最高': 'high', '最低': 'low', '成交量': 'volume', - '涨跌幅': 'change_pct' - }) - # 确保日期格式正确 - df_remote['date'] = pd.to_datetime(df_remote['date']).dt.strftime('%Y-%m-%d') - - # 只有在获取到有意义的数据时才保存 - self.db.save_stock_prices(clean_ticker, df_remote) # 保存时使用清洗后的 clean_ticker - - # 重新查询数据库返回结果,保证一致性 - return self.db.get_stock_prices(clean_ticker, start_date, end_date) - else: - logger.warning(f"⚠️ Akshare returned empty data for {clean_ticker}") - - except KeyError as e: - # Akshare 有时在某些股票无数据时会抛出 KeyError - logger.warning(f"⚠️ Akshare data missing for {clean_ticker}: {e}") - except (RequestException, ConnectionError) as e: - logger.error(f"❌ Network error during Akshare sync for {clean_ticker}: {e}") - except sqlite3.Error as e: - logger.error(f"❌ Database error during Akshare sync for {clean_ticker}: {e}") - except Exception as e: - logger.error(f"❌ Unexpected error during Akshare sync for {clean_ticker}: {e}") - - return df_db - - -def get_stock_analysis(ticker: str, db: DatabaseManager) -> str: - """ - 生成指定股票的分析摘要报告。 - - Args: - ticker: 股票代码 - db: 数据库管理器实例 - - Returns: - Markdown 格式的分析报告,包含价格走势和关键指标。 - """ - tools = StockTools(db) - df = tools.get_stock_price(ticker) - - if df.empty: - return f"❌ 未能获取 {ticker} 的股价数据。" - - latest = df.iloc[-1] - change = ((latest['close'] - df.iloc[0]['close']) / df.iloc[0]['close']) * 100 - - report = [ - f"## 📊 {ticker} 分析报告", - f"- **查询时段**: {df.iloc[0]['date']} -> {latest['date']}", - f"- **当前价**: ¥{latest['close']:.2f}", - f"- **时段涨跌**: {change:+.2f}%", - f"- **最高/最低**: ¥{df['high'].max():.2f} / ¥{df['low'].min():.2f}", - "\n### 最近交易概览", - "```", - df.tail(5)[['date', 'close', 'change_pct', 'volume']].to_string(index=False), - "```" - ] - return "\n".join(report) diff --git a/skills/alphaear-reporter/scripts/visualizer.py b/skills/alphaear-reporter/scripts/visualizer.py deleted file mode 100644 index 85a38cd..0000000 --- a/skills/alphaear-reporter/scripts/visualizer.py +++ /dev/null @@ -1,472 +0,0 @@ -import os -from typing import Dict, List, Any, Optional -import pandas as pd -from loguru import logger -from pyecharts.charts import Kline, Line, Bar, Grid, Radar, Graph -from pyecharts import options as opts -from pyecharts.globals import ThemeType -from datetime import datetime, timedelta - -class VisualizerTools: - """可视化工具库 - 使用 Pyecharts 生成 HTML 图表""" - - @staticmethod - def generate_stock_chart( - df: pd.DataFrame, - ticker: str, - title: str = None, - prediction: Optional[List[float]] = None, - forecast: Optional[Any] = None, # ForecastResult instance - ground_truth: Optional[pd.DataFrame] = None # For training visualization - ) -> Grid: - """ - 生成股票 K 线图 + 成交量 + 预测趋势 (支持多状态 K 线) - """ - if df.empty: - return None - - # 数据预处理 - df = df.sort_values('date') - dates = [str(d)[:10] for d in df['date'].tolist()] - k_data = df[['open', 'close', 'low', 'high']].values.tolist() - volumes = df['volume'].tolist() - - if not title: - title = f"{ticker} 股价走势与预测" - - legend_items = ["日K"] - - # 1. 处理传统的简单预测线 (Line) - pred_line = None - if prediction and not forecast: - try: - last_date_str = dates[-1] - last_date = datetime.strptime(last_date_str, "%Y-%m-%d") - - pred_dates = [] - for i in range(1, len(prediction) + 1): - pred_dates.append((last_date + timedelta(days=i)).strftime("%Y-%m-%d")) - - ext_dates = dates + pred_dates - last_close = df.iloc[-1]['close'] - pred_values = [None] * (len(df) - 1) + [float(last_close)] + prediction - - pred_line = ( - Line() - .add_xaxis(ext_dates) - .add_yaxis( - "AI预测趋势", - pred_values, - is_connect_nones=True, - is_symbol_show=True, - linestyle_opts=opts.LineStyleOpts(width=2, type_="dashed", color="#FF8C00"), - label_opts=opts.LabelOpts(is_show=False) - ) - ) - dates = ext_dates - legend_items.append("AI预测趋势") - except Exception as e: - logger.error(f"Failed to process simple prediction: {e}") - - # 2. 处理复杂的 Kronos 预测 (Kline) - base_kline = None - adj_kline = None - - if forecast: - try: - # 获取预测数据点 - base_points = forecast.base_forecast # List[KLinePoint] - adj_points = forecast.adjusted_forecast # List[KLinePoint] - - # 提取日期 - pred_dates = [str(p.date)[:10] for p in (adj_points or base_points)] - - # 检查日期是否已经包含在主 dates 中,如果没有则扩展 - if pred_dates and pred_dates[0] not in dates: - dates = dates + pred_dates - - # 构建 Baseline 预测 K 线数据 - if base_points: - # 前面填充 None - base_k_data = [[None]*4] * len(df) + [[p.open, p.close, p.low, p.high] for p in base_points] - base_kline = ( - Kline() - .add_xaxis(dates) - .add_yaxis( - "模型原始预测", - base_k_data, - itemstyle_opts=opts.ItemStyleOpts( - color="transparent", - color0="transparent", - border_color="#FF8C00", # 橙色 - border_color0="#FF8C00", - opacity=0.6, - border_type="dashed" - ), - ) - ) - legend_items.append("模型原始预测") - - # 构建 Adjusted 调优 K 线数据 - if adj_points: - adj_k_data = [[None]*4] * len(df) + [[p.open, p.close, p.low, p.high] for p in adj_points] - adj_kline = ( - Kline() - .add_xaxis(dates) - .add_yaxis( - "LLM调优预测", - adj_k_data, - itemstyle_opts=opts.ItemStyleOpts( - color="#9333ea", # 紫色 - color0="#9333ea", - border_color="#9333ea", - border_color0="#9333ea", - opacity=0.8 - ), - ) - ) - legend_items.append("LLM调优预测") - - except Exception as e: - logger.error(f"Failed to process complex forecast: {e}") - - # 2.5 处理 Ground Truth (用于训练评估可视化) - gt_line = None - if ground_truth is not None and not ground_truth.empty: - try: - gt_dates = [str(d)[:10] for d in ground_truth['date'].tolist()] - # 确保日期包含在 dates 中 - for d in gt_dates: - if d not in dates: - dates.append(d) - dates = sorted(list(set(dates))) # Re-sort to maintain order - - gt_values = [None] * len(dates) - for _, row in ground_truth.iterrows(): - d_str = str(row['date'])[:10] - if d_str in dates: - idx = dates.index(d_str) - gt_values[idx] = float(row['close']) - - gt_line = ( - Line() - .add_xaxis(dates) - .add_yaxis( - "真实走势 (GT)", - gt_values, - is_connect_nones=True, - linestyle_opts=opts.LineStyleOpts(width=3, color="#2ecc71"), # 绿色粗线 - label_opts=opts.LabelOpts(is_show=False) - ) - ) - legend_items.append("真实走势 (GT)") - except Exception as e: - logger.error(f"Failed to process ground truth: {e}") - - # 3. 主 K 线图 - # 为了展示预测,也需要对主 K 线数据进行填充 - main_k_data = k_data + [[None]*4] * (len(dates) - len(df)) - - kline = ( - Kline() - .add_xaxis(dates) - .add_yaxis( - "日K", - main_k_data, - itemstyle_opts=opts.ItemStyleOpts( - color="#ef4444", # 跌 - color0="#22c55e", # 涨 - border_color="#ef4444", - border_color0="#22c55e", - ), - ) - .set_global_opts( - title_opts=opts.TitleOpts(title=title, pos_left="center"), - xaxis_opts=opts.AxisOpts(is_scale=True), - yaxis_opts=opts.AxisOpts( - is_scale=True, - splitarea_opts=opts.SplitAreaOpts( - is_show=True, areastyle_opts=opts.AreaStyleOpts(opacity=1) - ), - ), - legend_opts=opts.LegendOpts(is_show=True, pos_top="5%"), - datazoom_opts=[opts.DataZoomOpts(type_="inside", range_start=50)], - tooltip_opts=opts.TooltipOpts(trigger="axis", axis_pointer_type="cross"), - ) - ) - - # Overlap all series - if pred_line: kline.overlap(pred_line) - if base_kline: kline.overlap(base_kline) - if adj_kline: kline.overlap(adj_kline) - if gt_line: kline.overlap(gt_line) - - # 4. 成交量柱状图 - # 同理扩展成交量数据 - ext_volumes = volumes + [0] * (len(dates) - len(df)) - - bar = ( - Bar() - .add_xaxis(dates) - .add_yaxis( - "成交量", - ext_volumes, - xaxis_index=1, - yaxis_index=1, - label_opts=opts.LabelOpts(is_show=False), - itemstyle_opts=opts.ItemStyleOpts(color="#7fbe9e"), - ) - .set_global_opts( - xaxis_opts=opts.AxisOpts( - type_="category", - grid_index=1, - axislabel_opts=opts.LabelOpts(is_show=False), - ), - legend_opts=opts.LegendOpts(is_show=False), - ) - ) - - # 5. 组合 Grid - grid_chart = Grid(init_opts=opts.InitOpts(width="100%", height="450px", theme=ThemeType.LIGHT)) - grid_chart.add( - kline, - grid_opts=opts.GridOpts(pos_left="10%", pos_right="8%", height="50%"), - ) - grid_chart.add( - bar, - grid_opts=opts.GridOpts( - pos_left="10%", pos_right="8%", pos_top="65%", height="20%" - ), - ) - - return grid_chart - - @staticmethod - def generate_loss_chart(losses: List[float], title: str = "训练损失收敛曲线") -> Line: - """生成 Loss 下降曲线图""" - line = ( - Line(init_opts=opts.InitOpts(width="100%", height="400px", theme=ThemeType.LIGHT)) - .add_xaxis(list(range(1, len(losses) + 1))) - .add_yaxis( - "Training Loss", - losses, - is_smooth=True, - linestyle_opts=opts.LineStyleOpts(width=2, color="#3b82f6"), - label_opts=opts.LabelOpts(is_show=False), - markpoint_opts=opts.MarkPointOpts(data=[opts.MarkPointItem(type_="min", name="最小值")]) - ) - .set_global_opts( - title_opts=opts.TitleOpts(title=title, pos_left="center"), - xaxis_opts=opts.AxisOpts(name="Epoch", is_scale=True), - yaxis_opts=opts.AxisOpts(name="Loss", is_scale=True), - tooltip_opts=opts.TooltipOpts(trigger="axis"), - ) - ) - return line - - @staticmethod - def generate_sentiment_trend_chart(sentiment_history: List[Dict[str, Any]]) -> Line: - """ - 生成舆情情绪趋势图 - :param sentiment_history: [{"date": "2024-01-01", "score": 0.8}, ...] - """ - dates = [item['date'] for item in sentiment_history] - scores = [item['score'] for item in sentiment_history] - - line = ( - Line(init_opts=opts.InitOpts(width="100%", height="300px", theme=ThemeType.LIGHT)) - .add_xaxis(dates) - .add_yaxis( - "情绪指数", - scores, - is_smooth=True, - markline_opts=opts.MarkLineOpts(data=[opts.MarkLineItem(y=0, name="中性线")]), - itemstyle_opts=opts.ItemStyleOpts(color="#5470c6"), - areastyle_opts=opts.AreaStyleOpts(opacity=0.3, color="#5470c6") - ) - .set_global_opts( - title_opts=opts.TitleOpts(title="舆情情绪趋势", pos_left="center"), - legend_opts=opts.LegendOpts(pos_top="8%"), - yaxis_opts=opts.AxisOpts(min_=-1, max_=1, name="Sentiment"), - tooltip_opts=opts.TooltipOpts(trigger="axis"), - ) - ) - return line - - @staticmethod - def generate_isq_radar_chart(sentiment: float, confidence: float, intensity: int, - expectation_gap: float = 0.5, timeliness: float = 0.8, - title: str = "信号质量 ISQ 评估") -> Radar: - """生成信号质量雷达图""" - # 标准化数据 (0-100) - # sentiment 强度: 绝对值越大强度越高 - sent_val = min(100, abs(sentiment) * 100) - # confidence: 0 to 1 -> 0 to 100 - conf_val = confidence * 100 - # intensity: 1 to 5 -> 20 to 100 - int_val = intensity * 20 - # gap & time: 0 to 1 -> 0 to 100 - gap_val = expectation_gap * 100 - time_val = timeliness * 100 - - schema = [ - opts.RadarIndicatorItem(name="情绪强度", max_=100), - opts.RadarIndicatorItem(name="确定性", max_=100), - opts.RadarIndicatorItem(name="影响力", max_=100), - opts.RadarIndicatorItem(name="预期差", max_=100), - opts.RadarIndicatorItem(name="时效性", max_=100), - ] - - radar = ( - Radar(init_opts=opts.InitOpts(width="100%", height="400px", theme=ThemeType.LIGHT)) - .add_schema(schema=schema) - .add( - "信号特征", - [[sent_val, conf_val, int_val, gap_val, time_val]], - color="#f97316", - areastyle_opts=opts.AreaStyleOpts(opacity=0.3, color="#fb923c"), - ) - .set_global_opts( - title_opts=opts.TitleOpts(title=title, pos_left="center"), - legend_opts=opts.LegendOpts(is_show=False), - ) - ) - return radar - - @staticmethod - def generate_transmission_graph(nodes_data: List[Dict[str, str]], title: str = "投资逻辑传导链条") -> Graph: - """生成逻辑传导拓扑图 (支持分支结构)""" - nodes = [] - links = [] - - # Helper for text wrapping - def wrap_text(text, width=6): - return '\n'.join([text[i:i+width] for i in range(0, len(text), width)]) - - # Map original names to wrapped names to handle links - name_map = {} - - for i, item in enumerate(nodes_data): - # 节点样式 - color = "#ef4444" if "利空" in item.get("impact_type", "") else "#22c55e" - if "中性" in item.get("impact_type", ""): color = "#6b7280" - - original_name = item.get("node_name", f"节点{i}") - wrapped_name = wrap_text(original_name) - name_map[original_name] = wrapped_name - name_map[str(item.get("id", ""))] = wrapped_name # Map ID if present - - nodes.append({ - "name": wrapped_name, - "symbolSize": 60 if i == 0 else 50, - "value": item.get("logic", ""), - "itemStyle": {"color": color}, - # Improve label readability - "label": {"show": True, "formatter": "{b}"} - }) - - # Logic for Links - source_key = item.get("source") or item.get("parent") or item.get("parent_id") - if source_key: - # Branching logic: Link from specified source - # Source needs to be resolved to its (wrapped) name - target_source_name = name_map.get(source_key) - if not target_source_name and source_key in name_map.values(): - target_source_name = source_key # It was already a mapped name? - - # If we found the source in our map (meaning it appeared before this node) - if target_source_name: - links.append({"source": target_source_name, "target": wrapped_name}) - elif i > 0: - # Fallback: Linear chain - links.append({"source": nodes[i-1]["name"], "target": wrapped_name}) - - graph = ( - Graph(init_opts=opts.InitOpts(width="100%", height="400px", theme=ThemeType.LIGHT)) - .add( - "", - nodes, - links, - repulsion=5000, - layout="force", - is_roam=True, - is_draggable=True, - symbol="circle", - edge_symbol=['circle', 'arrow'], # Add arrows - edge_symbol_size=[4, 10], - linestyle_opts=opts.LineStyleOpts(width=2, curve=0.2, opacity=0.9), - label_opts=opts.LabelOpts(is_show=True, position="inside", color="white", font_size=10), - edge_label=opts.LabelOpts(is_show=False), - ) - .set_global_opts( - title_opts=opts.TitleOpts(title=title, pos_left="center"), - tooltip_opts=opts.TooltipOpts(formatter="{b}: {c}") - ) - ) - return graph - - @staticmethod - def render_drawio_to_html(xml_content: str, filename: str, title: str = "Logic Diagram") -> str: - """ - 将 Draw.io XML 渲染为包含 Viewer 的 HTML 文件 - """ - import json - - # 构造配置字典 - config = { - "highlight": "#0000ff", - "nav": True, - "resize": True, - "toolbar": "zoom", - "xml": xml_content - } - - # 1. 转为 JSON 字符串 (自动处理内部的引号转义、换行符转义等) - json_str = json.dumps(config) - - # 2. 转为 HTML 属性安全的字符串 (主要是转义单引号,因为我们在 HTML 中用单引号包裹) - import html - safe_json_str = html.escape(json_str, quote=True) - - html_template = f""" - - - - - {title} - - - -

{title}

-
- - - - """ - - try: - os.makedirs(os.path.dirname(filename), exist_ok=True) - # Use 'w' mode with utf-8 encoding - with open(filename, 'w', encoding='utf-8') as f: - f.write(html_template) - logger.info(f"✅ Draw.io chart rendered to {filename}") - return filename - except Exception as e: - logger.error(f"Failed to render drawio chart: {e}") - return "" - - @staticmethod - def render_chart_to_file(chart: Any, filename: str) -> str: - """渲染并保存 HTML""" - try: - # 确保目录存在 - os.makedirs(os.path.dirname(filename), exist_ok=True) - chart.render(filename) - logger.info(f"✅ Chart rendered to {filename}") - return filename - except Exception as e: - logger.error(f"Failed to render chart: {e}") - return "" diff --git a/skills/alphaear-reporter/tests/test_reporter.py b/skills/alphaear-reporter/tests/test_reporter.py deleted file mode 100644 index 191c4fc..0000000 --- a/skills/alphaear-reporter/tests/test_reporter.py +++ /dev/null @@ -1,29 +0,0 @@ -import sys -import os -import unittest - -# Add skill root to path -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) - -# try: -from scripts.visualizer import VisualizerTools -from scripts.report_agent import ReportAgent -from scripts.utils.database_manager import DatabaseManager -# except ImportError as e: -# print(f"Import Error: {e}") -# sys.exit(1) - -class TestReporter(unittest.TestCase): - def test_visualizer(self): - print("Testing Visualizer...") - viz = VisualizerTools() - self.assertIsNotNone(viz) - - def test_agent_init(self): - print("Testing ReportAgent...") - # Mocking or simplified init might be needed if agent requires extensive config - # Just checking import for now is a big win - pass - -if __name__ == '__main__': - unittest.main() diff --git a/skills/alphaear-search/SKILL.md b/skills/alphaear-search/SKILL.md deleted file mode 100644 index e1318ca..0000000 --- a/skills/alphaear-search/SKILL.md +++ /dev/null @@ -1,35 +0,0 @@ ---- -name: alphaear-search -description: Perform finance web searches and local context searches. Use when the user needs general finance info from the web (Jina/DDG/Baidu) or needs to retrieve finance information from a local document store (RAG). ---- - -# AlphaEar Search Skill - -## Overview - -Unified search capabilities: web search (Jina/DDG/Baidu) and local RAG search. - -## Capabilities - -### 1. Web Search - -Use `scripts/search_tools.py` via `SearchTools`. - -- **Search**: `search(query, engine, max_results)` - - Engines: `jina`, `ddg`, `baidu`, `local`. - - Returns: JSON string (summary) or List[Dict] (via `search_list`). -- **Smart Cache (Agentic)**: If you want to avoid redundant searches, use the **Search Cache Relevance Prompt** in `references/PROMPTS.md`. Read the cache first and decide if it's usable. -- **Aggregate**: `aggregate_search(query)` - - Combines results from multiple engines. - - -### 2. Local RAG - -Use `scripts/hybrid_search.py` or `SearchTools` with `engine='local'`. - -- **Search**: Searches local `daily_news` database. - -## Dependencies - -- `duckduckgo-search`, `requests` -- `scripts/database_manager.py` (search cache & local news) diff --git a/skills/alphaear-search/references/PROMPTS.md b/skills/alphaear-search/references/PROMPTS.md deleted file mode 100644 index f859eec..0000000 --- a/skills/alphaear-search/references/PROMPTS.md +++ /dev/null @@ -1,20 +0,0 @@ -# AlphaEar Search Prompts - -## Search Cache Relevance (Smart Cache) - -**Prompt:** - -```markdown -Task: Decide if existing information from previous searches or local news is sufficient for the new search query. - -New Query: "{current_query}" - -Available Information Candidates: -{candidates_desc} - -Instructions: -1. Analyze if any candidate provides ENOUGH up-to-date info for the "New Query". -2. If yes, choose the best one. -3. If the query implies needing LATEST real-time info and candidates are older than a few hours/days (depending on topic volatility), choose none. -4. Return strictly JSON: {"reuse": true/false, "index": , "reason": "short explanation"} -``` diff --git a/skills/alphaear-search/scripts/__init__.py b/skills/alphaear-search/scripts/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/alphaear-search/scripts/content_extractor.py b/skills/alphaear-search/scripts/content_extractor.py deleted file mode 100644 index 133207a..0000000 --- a/skills/alphaear-search/scripts/content_extractor.py +++ /dev/null @@ -1,122 +0,0 @@ -import requests -from requests.exceptions import RequestException, Timeout, ConnectionError -import os -import time -import json -import threading -from typing import Optional -from loguru import logger - - -class ContentExtractor: - """内容提取工具 - 主要接入 Jina Reader API""" - - JINA_BASE_URL = "https://r.jina.ai/" - - # 速率限制配置 (无 API Key 时:20 次/分钟) - _rate_limit_no_key = 20 # 每分钟最大请求数 - _rate_window = 60.0 # 时间窗口(秒) - _min_interval = 3.0 # 请求最小间隔(秒) - - # 类级别的速率限制状态 - _request_times = [] - _last_request_time = 0.0 - _lock = threading.Lock() - - @classmethod - def _wait_for_rate_limit(cls, has_api_key: bool) -> None: - """等待以满足速率限制要求""" - if has_api_key: - # 有 API Key 时,只需保持最小间隔 - time.sleep(0.5) - return - - with cls._lock: - current_time = time.time() - - # 1. 清理过期的请求记录 - cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window] - - # 2. 检查是否达到速率限制 - if len(cls._request_times) >= cls._rate_limit_no_key: - # 需要等待最旧的请求过期 - oldest = cls._request_times[0] - wait_time = cls._rate_window - (current_time - oldest) + 1.0 - if wait_time > 0: - logger.warning(f"⏳ Jina rate limit reached, waiting {wait_time:.1f}s...") - time.sleep(wait_time) - current_time = time.time() - cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window] - - # 3. 确保请求间隔不太快 - time_since_last = current_time - cls._last_request_time - if time_since_last < cls._min_interval: - sleep_time = cls._min_interval - time_since_last - time.sleep(sleep_time) - - # 4. 记录本次请求 - cls._request_times.append(time.time()) - cls._last_request_time = time.time() - - @classmethod - def extract_with_jina(cls, url: str, timeout: int = 30) -> Optional[str]: - """ - 使用 Jina Reader 提取网页正文内容 (Markdown 格式) - - 无 API Key 时自动限速:每分钟最多 20 次请求,每次间隔至少 3 秒 - """ - if not url or not url.startswith("http"): - return None - - logger.info(f"🕸️ Extracting content from: {url} via Jina...") - - headers = { - "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", - "Accept": "application/json" - } - - # 使用统一的 JINA_API_KEY - api_key = os.getenv("JINA_API_KEY") - has_api_key = bool(api_key and api_key.strip()) - - if has_api_key: - headers["Authorization"] = f"Bearer {api_key}" - - # 等待速率限制 - cls._wait_for_rate_limit(has_api_key) - - try: - # Jina Reader API - full_url = f"{cls.JINA_BASE_URL}{url}" - response = requests.get(full_url, headers=headers, timeout=timeout) - - if response.status_code == 200: - try: - data = response.json() - # Jina JSON 响应格式通常在 data.content - if isinstance(data, dict) and "data" in data: - return data["data"].get("content", "") - return data.get("content", response.text) - except (json.JSONDecodeError, TypeError): - return response.text - elif response.status_code == 429: - # 触发速率限制,等待后重试一次 - logger.warning(f"⚠️ Jina rate limit (429), waiting 60s before retry...") - time.sleep(60) - return cls.extract_with_jina(url, timeout) - else: - logger.warning(f"Jina extraction failed (Status {response.status_code}) for {url}") - return None - - except Timeout: - logger.error(f"Timeout during Jina extraction for {url}") - return None - except ConnectionError: - logger.error(f"Connection error during Jina extraction for {url}") - return None - except RequestException as e: - logger.error(f"Request error during Jina extraction: {e}") - return None - except Exception as e: - logger.error(f"Unexpected error during Jina extraction: {e}") - return None diff --git a/skills/alphaear-search/scripts/database_manager.py b/skills/alphaear-search/scripts/database_manager.py deleted file mode 100644 index 26b1ca9..0000000 --- a/skills/alphaear-search/scripts/database_manager.py +++ /dev/null @@ -1,159 +0,0 @@ -import sqlite3 -import json -from datetime import datetime -from pathlib import Path -from typing import List, Dict, Optional, Union -from loguru import logger - -class DatabaseManager: - """ - AlphaEar Search Database Manager - Reduced version for alphaear-search skill - """ - - def __init__(self, db_path: str = "data/signal_flux.db"): - self.db_path = Path(db_path) - self.db_path.parent.mkdir(parents=True, exist_ok=True) - self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False) - self.conn.row_factory = sqlite3.Row - self._init_db() - logger.debug(f"💾 Search Database initialized at {self.db_path}") - - def _init_db(self): - cursor = self.conn.cursor() - - # 1. Daily News (Required for Local Search RAG) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS daily_news ( - id TEXT PRIMARY KEY, - source TEXT, - rank INTEGER, - title TEXT, - url TEXT, - content TEXT, - publish_time TEXT, - crawl_time TEXT, - sentiment_score REAL, - analysis TEXT, - meta_data TEXT - ) - """) - - # 2. Search Cache - cursor.execute(""" - CREATE TABLE IF NOT EXISTS search_cache ( - query_hash TEXT PRIMARY KEY, - query TEXT, - engine TEXT, - results TEXT, - timestamp TEXT - ) - """) - - # 3. Search Details - cursor.execute(""" - CREATE TABLE IF NOT EXISTS search_detail ( - id TEXT, - query_hash TEXT, - rank INTEGER, - title TEXT, - url TEXT, - content TEXT, - publish_time TEXT, - crawl_time TEXT, - sentiment_score REAL, - source TEXT, - meta_data TEXT, - PRIMARY KEY (query_hash, id) - ) - """) - - cursor.execute("CREATE INDEX IF NOT EXISTS idx_search_cache_timestamp ON search_cache(timestamp)") - self.conn.commit() - - # --- Search Cache Operations --- - - def get_search_cache(self, query_hash: str, ttl_seconds: Optional[int] = None) -> Optional[Dict]: - cursor = self.conn.cursor() - - # Try detailed cache first - cursor.execute(""" - SELECT * FROM search_detail - WHERE query_hash = ? - ORDER BY rank - """, (query_hash,)) - details = [dict(row) for row in cursor.fetchall()] - - if details: - first_time = datetime.fromisoformat(details[0]['crawl_time']) - if ttl_seconds and (datetime.now() - first_time).total_seconds() > ttl_seconds: - return None - return {"results": json.dumps(details), "timestamp": details[0]['crawl_time']} - - # Fallback to simple cache - cursor.execute("SELECT * FROM search_cache WHERE query_hash = ?", (query_hash,)) - row = cursor.fetchone() - - if not row: return None - row_dict = dict(row) - if ttl_seconds: - cache_time = datetime.fromisoformat(row_dict['timestamp']) - if (datetime.now() - cache_time).total_seconds() > ttl_seconds: - return None - return row_dict - - def save_search_cache(self, query_hash: str, query: str, engine: str, results: Union[str, List[Dict]]): - cursor = self.conn.cursor() - current_time = datetime.now().isoformat() - results_str = results if isinstance(results, str) else json.dumps(results) - - cursor.execute(""" - INSERT OR REPLACE INTO search_cache (query_hash, query, engine, results, timestamp) - VALUES (?, ?, ?, ?, ?) - """, (query_hash, query, engine, results_str, current_time)) - - if isinstance(results, list): - for item in results: - try: - item_id = item.get('id') or f"{hash(item.get('url', ''))}" - cursor.execute(""" - INSERT OR REPLACE INTO search_detail - (id, query_hash, rank, title, url, content, publish_time, crawl_time, sentiment_score, source, meta_data) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - str(item_id), query_hash, item.get('rank', 0), item.get('title'), - item.get('url'), item.get('content', ''), item.get('publish_time'), - item.get('crawl_time') or current_time, item.get('sentiment_score'), - item.get('source'), json.dumps(item.get('meta_data', {})) - )) - except Exception as e: - logger.error(f"Error saving search detail: {e}") - - self.conn.commit() - - def find_similar_queries(self, query: str, limit: int = 5) -> List[Dict]: - cursor = self.conn.cursor() - q_wild = f"%{query}%" - cursor.execute(""" - SELECT query, query_hash, timestamp, results - FROM search_cache - WHERE query LIKE ? OR ? LIKE ('%' || query || '%') - ORDER BY timestamp DESC - LIMIT ? - """, (q_wild, query, limit)) - return [dict(row) for row in cursor.fetchall()] - - def search_local_news(self, query: str, limit: int = 5) -> List[Dict]: - cursor = self.conn.cursor() - q_wild = f"%{query}%" - cursor.execute(""" - SELECT * FROM daily_news - WHERE title LIKE ? OR content LIKE ? - ORDER BY crawl_time DESC - LIMIT ? - """, (q_wild, q_wild, limit)) - return [dict(row) for row in cursor.fetchall()] - - def close(self): - if self.conn: - self.conn.close() diff --git a/skills/alphaear-search/scripts/hybrid_search.py b/skills/alphaear-search/scripts/hybrid_search.py deleted file mode 100644 index c597fee..0000000 --- a/skills/alphaear-search/scripts/hybrid_search.py +++ /dev/null @@ -1,216 +0,0 @@ -import numpy as np -import os -from typing import List, Dict, Any, Optional, Union -from rank_bm25 import BM25Okapi -from loguru import logger -from sentence_transformers import SentenceTransformer -from sklearn.metrics.pairwise import cosine_similarity - -class HybridSearcher: - """ - 统一混合检索引擎 (Hybrid RAG) - 实现 BM25 (文本) + 向量 (语义) 的融合搜索 (RRF) - """ - - def __init__(self, data: List[Dict[str, Any]], text_fields: List[str] = ["title", "content"], model_name: str = None): - """ - 初始化搜索器 - - Args: - data: 数据列表,每个元素为 Dict - text_fields: 用于建立索引的文本字段 - model_name: 向量模型名称,默认使用 paraphrase-multilingual-MiniLM-L12-v2 - """ - self.data = data - self.text_fields = text_fields - self._corpus = [] - self._bm25 = None - self._vector_model = None - self._embeddings = None - self._fitted = False - self._vector_fitted = False - - # 默认模型 - self.model_name = model_name or os.getenv("EMBEDDING_MODEL", "paraphrase-multilingual-MiniLM-L12-v2") - - if data: - self._prepare_corpus() - self._fit_bm25() - # 延迟加载向量模型,仅在需要时或初始化时显式调用 - # self._fit_vector() - - def _prepare_corpus(self): - """准备语料库用于分词""" - import jieba # 使用 jieba 进行中文分词 - - self._corpus = [] - self._full_texts = [] - for item in self.data: - text = " ".join([str(item.get(field, "")) for field in self.text_fields]) - self._full_texts.append(text) - # 中文分词优化 - tokens = list(jieba.cut(text)) - self._corpus.append(tokens) - - def _fit_bm25(self): - """训练 BM25 模型""" - if self._corpus: - self._bm25 = BM25Okapi(self._corpus) - self._fitted = True - logger.info(f"✅ BM25 index fitted with {len(self.data)} documents") - - def _fit_vector(self): - """训练向量模型并生成 Embeddings""" - if not self.data: - return - - try: - logger.info(f"📡 Loading embedding model: {self.model_name}...") - self._vector_model = SentenceTransformer(self.model_name) - logger.info(f"🧠 Encoding {len(self._full_texts)} documents...") - self._embeddings = self._vector_model.encode(self._full_texts, show_progress_bar=False) - self._vector_fitted = True - logger.info("✅ Vector index fitted successfully") - except Exception as e: - logger.error(f"❌ Failed to fit vector index: {e}") - self._vector_fitted = False - - def _compute_rrf(self, rank_lists: List[List[int]], k: int = 60) -> List[tuple]: - """ - 计算 Reciprocal Rank Fusion (RRF) - - Args: - rank_lists: 多个排序后的索引列表 - k: RRF 常数,默认 60 - """ - scores = {} - for rank_list in rank_lists: - for rank, idx in enumerate(rank_list): - if idx not in scores: - scores[idx] = 0 - scores[idx] += 1.0 / (k + rank + 1) - - # 按分数排序 - sorted_indices = sorted(scores.items(), key=lambda x: x[1], reverse=True) - return sorted_indices - - def search(self, query: str, top_n: int = 5, use_vector: bool = False) -> List[Dict[str, Any]]: - """ - 执行混合搜索 - - Args: - query: 搜索关键词 - top_n: 返回结果数量 - use_vector: 是否启用向量搜索 - """ - if not self._fitted or not query: - return [] - - import jieba - query_tokens = list(jieba.cut(query)) - - # 1. BM25 搜索结果 - bm25_scores = self._bm25.get_scores(query_tokens) - bm25_rank = np.argsort(bm25_scores)[::-1].tolist() - - rank_lists = [bm25_rank] - - # 2. 向量搜索逻辑 - if use_vector: - if not self._vector_fitted: - self._fit_vector() - - if self._vector_fitted: - query_embedding = self._vector_model.encode([query], show_progress_bar=False) - similarities = cosine_similarity(query_embedding, self._embeddings)[0] - vector_rank = np.argsort(similarities)[::-1].tolist() - rank_lists.append(vector_rank) - else: - logger.warning("Vector search requested but model not fitted, falling back to BM25") - - # 3. 融合排序 (RRF) - if len(rank_lists) > 1: - rrf_results = self._compute_rrf(rank_lists) - # RRF 返回 (idx, score) 列表 - final_rank = [idx for idx, score in rrf_results] - else: - final_rank = bm25_rank - - # 返回前 top_n 条结果 - results = [self.data[idx].copy() for idx in final_rank[:top_n]] - - # 为每个结果注入相关性评分 - for i, res in enumerate(results): - try: - original_idx = final_rank[i] - res["_search_score"] = bm25_scores[original_idx] - if use_vector and self._vector_fitted: - res["_vector_score"] = float(similarities[original_idx]) - except: - res["_search_score"] = 0 - - return results - -class InMemoryRAG(HybridSearcher): - """专门用于 ReportAgent 跨章节检索的内存态 RAG""" - - def search(self, query: str, top_n: int = 3, use_vector: bool = True) -> List[Dict[str, Any]]: - """默认开启向量搜索的内存检索""" - return super().search(query, top_n=top_n, use_vector=use_vector) - - def update_data(self, new_data: List[Dict[str, Any]]): - """动态更新数据并重新训练索引""" - self.data = new_data - self._prepare_corpus() - self._fit_bm25() - # 如果之前已经加载过向量模型,则更新向量索引 - if self._vector_model: - self._fit_vector() - logger.info(f"🔄 InMemoryRAG updated with {len(new_data)} items") - -class LocalNewsSearch(HybridSearcher): - """持久态 RAG:检索数据库中的历史新闻""" - - def __init__(self, db_manager): - """ - Args: - db_manager: DatabaseManager 实例 - """ - self.db = db_manager - # 初始时不加载数据,需调用 load_history - super().__init__([], ["title", "content"]) - - def load_history(self, days: int = 30, limit: int = 1000): - """从数据库加载最近 N 天的新闻构建索引""" - try: - # 假设 db_manager 有 execute_query - query = f"SELECT title, content, publish_time, source FROM daily_news ORDER BY publish_time DESC LIMIT ?" - results = self.db.execute_query(query, (limit,)) - - data = [] - for row in results: - # 转换 Row 为 Dict - if hasattr(row, 'keys'): - item = dict(row) - else: - item = { - "title": row[0], - "content": row[1], - "publish_time": row[2], - "source": row[3] - } - data.append(item) - - self.data = data - self._prepare_corpus() - self._fit_bm25() - # 默认不立即训练向量,等到第一次搜索时按需训练 - logger.info(f"📚 LocalNewsSearch loaded {len(data)} items from history") - except Exception as e: - logger.error(f"Failed to load history for search: {e}") - - def search(self, query: str, top_n: int = 5, use_vector: bool = True) -> List[Dict[str, Any]]: - """执行本地历史搜索,默认开启向量搜索""" - if not self.data: - self.load_history() - return super().search(query, top_n=top_n, use_vector=use_vector) diff --git a/skills/alphaear-search/scripts/llm/__init__.py b/skills/alphaear-search/scripts/llm/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/alphaear-search/scripts/llm/capability.py b/skills/alphaear-search/scripts/llm/capability.py deleted file mode 100644 index d3fb2d7..0000000 --- a/skills/alphaear-search/scripts/llm/capability.py +++ /dev/null @@ -1,85 +0,0 @@ -import os -from typing import Optional, List, Dict, Any -from agno.agent import Agent -from agno.models.base import Model -from loguru import logger -from .factory import get_model - - -def test_tool_call_support(model: Model) -> bool: - """ - 测试模型是否支持原生的 Tool Call (Function Calling)。 - 通过尝试执行一个简单的加法工具来验证。 - """ - - def get_current_weather(location: str): - """获取指定地点的天气""" - return f"{location} 的天气是晴天,25度。" - - test_agent = Agent( - model=model, - tools=[get_current_weather], - instructions="请调用工具查询北京的天气,并直接返回工具的输出结果。", - ) - - try: - # 运行一个简单的任务,观察是否触发了 tool_call - response = test_agent.run("北京天气怎么样?") - - # 检查 response 中是否包含 tool_calls - # Agno 的 RunResponse 对象通常包含 messages,我们可以检查最后几条消息 - has_tool_call = False - for msg in response.messages: - if hasattr(msg, "tool_calls") and msg.tool_calls: - has_tool_call = True - break - - if has_tool_call: - logger.info(f"✅ Model {model.id} supports native tool calling.") - return True - else: - # 如果没有 tool_calls 但返回了正确答案,可能是模型通过纯文本模拟了工具调用(ReAct) - # 或者根本没用工具。对于原生支持的判断,我们坚持要求有 tool_calls 结构。 - logger.warning( - f"⚠️ Model {model.id} did NOT use native tool calling structure." - ) - return False - - except Exception as e: - logger.error(f"❌ Error testing tool call for {model.id}: {e}") - return False - - -class ModelCapabilityRegistry: - """ - 模型能力注册表,用于缓存和管理不同模型的能力测试结果。 - """ - - _cache = {} - - @classmethod - def get_capabilities( - cls, provider: str, model_id: str, **kwargs - ) -> Dict[str, bool]: - key = f"{provider}:{model_id}" - if key not in cls._cache: - logger.info(f"🔍 Testing capabilities for {key}...") - model = get_model(provider, model_id, **kwargs) - supports_tool_call = test_tool_call_support(model) - cls._cache[key] = {"supports_tool_call": supports_tool_call} - return cls._cache[key] - - -if __name__ == "__main__": - import os - from skills._env_loader import load_unified_env - - load_unified_env() - - # 测试当前配置的模型 - p = os.getenv("LLM_PROVIDER", "minimax") - m = os.getenv("LLM_MODEL", "Qwen") - - print(f"Testing {p}/{m}...") - res = ModelCapabilityRegistry.get_capabilities(p, m) - print(f"Result: {res}") diff --git a/skills/alphaear-search/scripts/llm/factory.py b/skills/alphaear-search/scripts/llm/factory.py deleted file mode 100644 index 09b6ea5..0000000 --- a/skills/alphaear-search/scripts/llm/factory.py +++ /dev/null @@ -1,114 +0,0 @@ -import os -from agno.models.openai import OpenAIChat -from agno.models.ollama import Ollama -from agno.models.dashscope import DashScope -from agno.models.deepseek import DeepSeek -from agno.models.openrouter import OpenRouter - -def get_model(model_provider: str, model_id: str, **kwargs): - """ - Factory to get the appropriate LLM model. - - Args: - model_provider: "openai", "ollama", "deepseek" - model_id: The specific model ID (e.g., "gpt-4o", "llama3", "deepseek-chat") - **kwargs: Additional arguments for the model constructor - """ - if model_provider == "openai": - return OpenAIChat(id=model_id, **kwargs) - - elif model_provider == "ollama": - return Ollama(id=model_id, **kwargs) - - elif model_provider == "deepseek": - # DeepSeek is OpenAI compatible - api_key = os.getenv("DEEPSEEK_API_KEY") - if not api_key: - print("Warning: DEEPSEEK_API_KEY not set.") - - return DeepSeek( - id=model_id, - api_key=api_key, - **kwargs - ) - elif model_provider == "dashscope": - api_key = os.getenv("DASHSCOPE_API_KEY") - if not api_key: - print("Warning: DASHSCOPE_API_KEY not set.") - - return DashScope( - id=model_id, - base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", - api_key=api_key, - **kwargs - ) - elif model_provider == 'openrouter': - api_key = os.getenv("OPENROUTER_API_KEY") - if not api_key: - print('Warning: OPENROUTER_API_KEY not set.') - - return OpenRouter( - id=model_id, - api_key=api_key, - **kwargs - ) - - elif model_provider == 'zai': - api_key = os.getenv("ZAI_KEY_API") - if not api_key: - print('Warning: ZAI_KEY_API not set.') - - # role_map to ensure compatibility. - default_role_map = { - "system": "system", - "user": "user", - "assistant": "assistant", - "tool": "tool", - "model": "assistant", - } - - # Allow callers to override role_map via kwargs, otherwise use default - role_map = kwargs.pop("role_map", default_role_map) - - return OpenAIChat( - id=model_id, - base_url="https://api.z.ai/api/paas/v4", - api_key=api_key, - timeout=60, - role_map=role_map, - extra_body={"enable_thinking": False}, # TODO: one more setting for thinking - **kwargs - ) - - elif model_provider == 'ust': - api_key = os.getenv("UST_KEY_API") - if not api_key: - print('Warning: UST_KEY_API not set.') - - # Some UST-compatible endpoints expect the standard OpenAI role names - # (e.g. "system", "user", "assistant") rather than Agno's default - # mapping which maps "system" -> "developer". Provide an explicit - # role_map to ensure compatibility. - default_role_map = { - "system": "system", - "user": "user", - "assistant": "assistant", - "tool": "tool", - "model": "assistant", - } - - # Allow callers to override role_map via kwargs, otherwise use default - role_map = kwargs.pop("role_map", default_role_map) - - return OpenAIChat( - id=model_id, - api_key=api_key, - base_url=os.getenv("UST_URL"), - role_map=role_map, - extra_body={"enable_thinking": False}, # TODO: one more setting for thinking - **kwargs - ) - - else: - raise ValueError(f"Unknown model provider: {model_provider}") - diff --git a/skills/alphaear-search/scripts/llm/router.py b/skills/alphaear-search/scripts/llm/router.py deleted file mode 100644 index 20e7d83..0000000 --- a/skills/alphaear-search/scripts/llm/router.py +++ /dev/null @@ -1,80 +0,0 @@ -import os -from typing import Optional, List, Dict, Any, Union -from agno.models.base import Model -from loguru import logger -from .factory import get_model -from .capability import ModelCapabilityRegistry -from skills._env_loader import load_unified_env - -load_unified_env() - - -class ModelRouter: - """ - 模型路由管理器 - - 功能: - 1. 管理“推理/写作模型” (Reasoning Model) 和“工具调用模型” (Tool Model)。 - 2. 根据任务需求自动选择合适的模型。 - """ - - def __init__(self): - # 默认从环境变量读取 - self.reasoning_provider = os.getenv( - "REASONING_MODEL_PROVIDER", os.getenv("LLM_PROVIDER", "openai") - ) - self.reasoning_id = os.getenv( - "REASONING_MODEL_ID", os.getenv("LLM_MODEL", "gpt-4o") - ) - self.reasoning_host = os.getenv("REASONING_MODEL_HOST", os.getenv("LLM_HOST")) - - self.tool_provider = os.getenv("TOOL_MODEL_PROVIDER", self.reasoning_provider) - self.tool_id = os.getenv("TOOL_MODEL_ID", self.reasoning_id) - self.tool_host = os.getenv("TOOL_MODEL_HOST", self.reasoning_host) - - self._reasoning_model = None - self._tool_model = None - - logger.info( - f"🤖 ModelRouter initialized: Reasoning={self.reasoning_id} ({self.reasoning_host or 'default'}), Tool={self.tool_id} ({self.tool_host or 'default'})" - ) - - def get_reasoning_model(self, **kwargs) -> Model: - if not self._reasoning_model: - # 优先使用路由配置的 host - if self.reasoning_host and "host" not in kwargs: - kwargs["host"] = self.reasoning_host - self._reasoning_model = get_model( - self.reasoning_provider, self.reasoning_id, **kwargs - ) - return self._reasoning_model - - def get_tool_model(self, **kwargs) -> Model: - if not self._tool_model: - # 优先使用路由配置的 host - if self.tool_host and "host" not in kwargs: - kwargs["host"] = self.tool_host - - # 检查 tool_model 是否真的支持 tool call - caps = ModelCapabilityRegistry.get_capabilities( - self.tool_provider, self.tool_id, **kwargs - ) - if not caps["supports_tool_call"]: - logger.warning( - f"⚠️ Configured tool model {self.tool_id} might not support native tool calls! Consider using ReAct mode or a different model." - ) - - self._tool_model = get_model(self.tool_provider, self.tool_id, **kwargs) - return self._tool_model - - def get_model_for_agent(self, has_tools: bool = False, **kwargs) -> Model: - """ - 根据 Agent 是否包含工具来返回合适的模型。 - """ - if has_tools: - return self.get_tool_model(**kwargs) - return self.get_reasoning_model(**kwargs) - - -# 全局单例 -router = ModelRouter() diff --git a/skills/alphaear-search/scripts/search_tools.py b/skills/alphaear-search/scripts/search_tools.py deleted file mode 100644 index ea83bfd..0000000 --- a/skills/alphaear-search/scripts/search_tools.py +++ /dev/null @@ -1,479 +0,0 @@ -import os -import hashlib -import json -import re -import requests -import time -import threading -from typing import List, Dict, Optional, Any -from agno.tools.duckduckgo import DuckDuckGoTools -from agno.tools.baidusearch import BaiduSearchTools -from datetime import datetime -from .database_manager import DatabaseManager -from .content_extractor import ContentExtractor -from .hybrid_search import LocalNewsSearch - -# 默认搜索缓存 TTL(秒),可通过环境变量覆盖 -DEFAULT_SEARCH_TTL = int(os.getenv("SEARCH_CACHE_TTL", "3600")) # 默认 1 小时 - - -class JinaSearchEngine: - """Jina Search API 封装 - 使用 s.jina.ai 进行网络搜索""" - - JINA_SEARCH_URL = "https://s.jina.ai/" - - # 速率限制配置 - _rate_limit_no_key = 10 # 无 key 时每分钟最大请求数 - _rate_window = 60.0 - _min_interval = 2.0 - _request_times = [] - _last_request_time = 0.0 - _lock = threading.Lock() - - def __init__(self): - self.api_key = os.getenv("JINA_API_KEY", "").strip() - self.has_api_key = bool(self.api_key) - if self.has_api_key: - logger.info("✅ Jina Search API key configured") - - @classmethod - def _wait_for_rate_limit(cls, has_api_key: bool) -> None: - """等待以满足速率限制""" - if has_api_key: - time.sleep(0.3) - return - - with cls._lock: - current_time = time.time() - cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window] - - if len(cls._request_times) >= cls._rate_limit_no_key: - oldest = cls._request_times[0] - wait_time = cls._rate_window - (current_time - oldest) + 1.0 - if wait_time > 0: - logger.warning(f"⏳ Jina Search rate limit, waiting {wait_time:.1f}s...") - time.sleep(wait_time) - current_time = time.time() - cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window] - - time_since_last = current_time - cls._last_request_time - if time_since_last < cls._min_interval: - time.sleep(cls._min_interval - time_since_last) - - cls._request_times.append(time.time()) - cls._last_request_time = time.time() - - def search(self, query: str, max_results: int = 5) -> List[Dict]: - """ - 使用 Jina Search API 执行搜索 - - Args: - query: 搜索关键词 - max_results: 返回结果数量 - - Returns: - 搜索结果列表,每个结果包含 title, url, content - """ - if not query: - return [] - - logger.info(f"🔍 Jina Search: {query}") - - # 等待速率限制 - self._wait_for_rate_limit(self.has_api_key) - - headers = { - "Accept": "application/json", - "X-Retain-Images": "none", - } - - if self.has_api_key: - headers["Authorization"] = f"Bearer {self.api_key}" - - try: - # Jina Search API: https://s.jina.ai/{query} - import urllib.parse - encoded_query = urllib.parse.quote(query) - url = f"{self.JINA_SEARCH_URL}{encoded_query}" - - response = requests.get(url, headers=headers, timeout=30) - - if response.status_code == 429: - logger.warning("⚠️ Jina Search rate limited (429), waiting 30s...") - time.sleep(30) - return self.search(query, max_results) - - if response.status_code != 200: - logger.warning(f"Jina Search failed (Status {response.status_code})") - return [] - - # 解析响应 - try: - data = response.json() - except json.JSONDecodeError: - # 如果返回纯文本,尝试解析 - data = {"data": [{"title": "Search Result", "url": "", "content": response.text}]} - - results = [] - - # Jina 返回格式可能是 {"data": [...]} 或直接是列表 - items = data.get("data", []) if isinstance(data, dict) else data - if not isinstance(items, list): - items = [items] if items else [] - - for i, item in enumerate(items[:max_results]): - if isinstance(item, dict): - results.append({ - "title": item.get("title", f"Result {i+1}"), - "url": item.get("url", ""), - "href": item.get("url", ""), # 兼容性 - "content": item.get("content", item.get("description", "")), - "body": item.get("content", item.get("description", "")), # 兼容性 - }) - elif isinstance(item, str): - results.append({ - "title": f"Result {i+1}", - "url": "", - "content": item - }) - - logger.info(f"✅ Jina Search returned {len(results)} results") - return results - - except requests.exceptions.Timeout: - logger.error("Jina Search timeout") - return [] - except requests.exceptions.RequestException as e: - logger.error(f"Jina Search request error: {e}") - return [] - except Exception as e: - logger.error(f"Jina Search unexpected error: {e}") - return [] - -class SearchTools: - """扩展性搜索工具库 - 支持多引擎聚合与内容缓存""" - - def __init__(self, db: DatabaseManager): - self.db = db - - # 检查 Jina API Key 是否配置 - jina_api_key = os.getenv("JINA_API_KEY", "").strip() - self._jina_enabled = bool(jina_api_key) - - self._engines = { - "ddg": DuckDuckGoTools(), - "baidu": BaiduSearchTools(), - "local": LocalNewsSearch(db) - } - - # 如果配置了 Jina API Key,添加 Jina 引擎 - if self._jina_enabled: - self._engines["jina"] = JinaSearchEngine() - logger.info("🚀 Jina Search engine enabled (JINA_API_KEY configured)") - - # 确定默认搜索引擎 - self._default_engine = "jina" if self._jina_enabled else "ddg" - - def _generate_hash(self, query: str, engine: str, max_results: int) -> str: - return hashlib.md5(f"{engine}:{query}:{max_results}".encode()).hexdigest() - - def search(self, query: str, engine: str = None, max_results: int = 5, ttl: Optional[int] = None) -> str: - """ - 使用指定搜索引擎执行网络搜索,结果会被缓存以提高效率。 - - Args: - query: 搜索关键词,如 "英伟达财报" 或 "光伏行业政策"。 - engine: 搜索引擎选择。可选值: - "jina" (Jina Search,需配置 JINA_API_KEY,LLM友好输出), - "ddg" (DuckDuckGo,推荐英文/国际搜索), - "baidu" (百度,推荐中文/国内搜索), - "local" (本地历史新闻搜索,基于向量+BM25)。 - 默认: 若配置了 JINA_API_KEY 则使用 "jina",否则 "ddg"。 - max_results: 期望返回的结果数量,默认 5 条。 - ttl: 缓存有效期(秒)。如果缓存超过此时间会重新搜索。 - 默认使用环境变量 SEARCH_CACHE_TTL 或 3600 秒。 - 设为 0 可强制刷新。 - - Returns: - 搜索结果的文本描述,包含标题、摘要和链接。 - """ - # 使用默认引擎(如果配置了 Jina 则优先使用 Jina) - if engine is None: - engine = self._default_engine - - if engine not in self._engines: - return f"Error: Unsupported engine '{engine}'. Available: {list(self._engines.keys())}" - - query_hash = self._generate_hash(query, engine, max_results) - effective_ttl = ttl if ttl is not None else DEFAULT_SEARCH_TTL - - # 1. 尝试从缓存读取 (local 引擎不缓存,因为它本身就是查库) - if engine != "local": - cache = self.db.get_search_cache(query_hash, ttl_seconds=effective_ttl if effective_ttl > 0 else None) - if cache and effective_ttl != 0: - logger.info(f"ℹ️ Found search results in cache for: {query} ({engine})") - return cache['results'] - - # 2. 执行真实搜索 - logger.info(f"📡 Searching {engine} for: {query}") - try: - tool = self._engines[engine] - if engine == "jina": - # Jina Search 返回 List[Dict] - jina_results = tool.search(query, max_results=max_results) - results = [] - for r in jina_results: - results.append({ - "title": r.get("title", ""), - "href": r.get("url", ""), - "body": r.get("content", "") - }) - elif engine == "ddg": - results = tool.duckduckgo_search(query, max_results=max_results) - elif engine == "baidu": - results = tool.baidu_search(query, max_results=max_results) - elif engine == "local": - # LocalNewsSearch 返回的是 List[Dict] - local_results = tool.search(query, top_n=max_results) - results = [] - for r in local_results: - results.append({ - "title": r.get("title"), - "href": r.get("url", "local"), - "body": r.get("content", "") - }) - else: - results = "Search not implemented for this engine." - - results_str = str(results) - if engine != "local": - self.db.save_search_cache(query_hash, query, engine, results_str) - return results_str - - except Exception as e: - # 搜索失败时的降级策略 - if engine == "jina": - logger.warning(f"⚠️ Jina search failed, falling back to ddg: {query} ({e})") - try: - return self.search(query, engine="ddg", max_results=max_results, ttl=ttl) - except Exception as e2: - logger.error(f"❌ DDG fallback also failed for {query}: {e2}") - elif engine == "ddg": - logger.warning(f"⚠️ DDG search failed, falling back to baidu: {query} ({e})") - try: - return self.search(query, engine="baidu", max_results=max_results, ttl=ttl) - except Exception as e2: - logger.error(f"❌ Baidu fallback also failed for {query}: {e2}") - - logger.error(f"❌ Search failed for {query}: {e}") - return f"Error occurred during search: {str(e)}" - - def search_list(self, query: str, engine: str = None, max_results: int = 5, ttl: Optional[int] = None, enrich: bool = True) -> List[Dict]: - """ - 执行搜索并返回结构化列表 (List[Dict])。 - Dict 包含: title, href (or url), body (or snippet) - - Args: - engine: 搜索引擎,默认使用配置的默认引擎(Jina 优先) - enrich: 是否抓取正文内容 (默认 True) - """ - # 使用默认引擎 - if engine is None: - engine = self._default_engine - - if engine not in self._engines: - logger.error(f"Unsupported engine {engine}") - return [] - - # 不同的 hash 以区分是否 enrichment - enrich_suffix = ":enriched" if enrich else "" - query_hash = self._generate_hash(query, engine + enrich_suffix, max_results) - effective_ttl = ttl if ttl is not None else DEFAULT_SEARCH_TTL - - # 1. 尝试从缓存读取 - cache = self.db.get_search_cache(query_hash, ttl_seconds=effective_ttl if effective_ttl > 0 else None) - if cache and effective_ttl != 0: - try: - cached_data = json.loads(cache['results']) - if isinstance(cached_data, list): - logger.info(f"ℹ️ Found structured search cache for: {query}") - return cached_data - except: - pass - - # 1.5 Smart Cache (Delegated to Agent) - # The Agent should call list_similar_searches and judge relevance using PROMPTS.md - - - # 2. 执行搜索 - logger.info(f"📡 Searching {engine} (structured) for: {query}") - try: - tool = self._engines[engine] - results = [] - if engine == "jina": - # Jina Search 直接返回结构化数据 - jina_results = tool.search(query, max_results=max_results) - for r in jina_results: - results.append({ - "title": r.get("title", ""), - "url": r.get("url", ""), - "href": r.get("url", ""), - "body": r.get("content", ""), - "content": r.get("content", ""), - "source": "Jina Search" - }) - elif engine == "ddg": - results = tool.duckduckgo_search(query, max_results=max_results) - elif engine == "baidu": - results = tool.baidu_search(query, max_results=max_results) - elif engine == "local": - # LocalNewsSearch 返回的是 List[Dict] - local_results = tool.search(query, top_n=max_results) - results = [] - for r in local_results: - results.append({ - "title": r.get("title"), - "url": r.get("url", "local"), - "body": r.get("content", "")[:500], - "source": f"Local ({r.get('source', 'db')})", - "publish_time": r.get("publish_time") - }) - - # 处理字符串类型的 JSON 返回 (Baidu 常返 JSON 字符串) - if isinstance(results, str) and engine not in ["local", "jina"]: - try: - results = json.loads(results) - except: - pass - - # 转为统一格式 - normalized_results = [] - if isinstance(results, list): - - for i, r in enumerate(results, 1): - title = r.get('title', '') - url = r.get('href') or r.get('url') or r.get('link', '') - content = r.get('body') or r.get('snippet') or r.get('abstract', '') - - if title and url: - normalized_results.append({ - "id": self._generate_hash(url + query, "search_item", i), - "rank": i, - "title": title, - "url": url, - "content": content, - "original_snippet": content, # 保留摘要 - "source": f"Search ({engine})", - "publish_time": datetime.now().isoformat(), # 暂用当前时间 - "crawl_time": datetime.now().isoformat(), - "meta_data": {"query": query, "engine": engine} - }) - - # Fallback if still string and failed to parse - elif isinstance(results, str) and results: - normalized_results.append({"title": query, "url": "", "content": results, "source": engine}) - - # 3. 抓取正文 & 计算情绪 (Enrichment) - # 注意:如果使用 Jina Search,内容已经是 LLM 友好格式,可选择跳过 enrichment - skip_content_enrichment = (engine == "jina") - - if enrich and normalized_results: - logger.info(f"🕸️ Enriching {len(normalized_results)} search results with Jina & Sentiment...") - extractor = ContentExtractor() - - # Lazy load sentiment tool - if not hasattr(self, 'sentiment_tool') or self.sentiment_tool is None: - from .sentiment_tools import SentimentTools - self.sentiment_tool = SentimentTools(self.db) - - for item in normalized_results: - if item.get("url"): - try: - # 如果是 Jina Search,内容已经足够好,跳过额外抓取 - if skip_content_enrichment and item.get("content") and len(item.get("content", "")) > 100: - full_content = item["content"] - else: - # Use Jina Reader to get full content - full_content = extractor.extract_with_jina(item["url"], timeout=60) - - if full_content and len(full_content) > 100: - item["content"] = full_content - - # Calculate sentiment - # Use title + snippet of content for efficiency - text_to_analyze = f"{item['title']} {full_content[:500]}" - sent_result = self.sentiment_tool.analyze_sentiment(text_to_analyze) # Using self.sentiment_tool - score = sent_result.get('score', 0.0) - item["sentiment_score"] = float(score) - - logger.info(f" ✅ Enriched: {item['title'][:20]}... (Sentiment: {score:.2f})") - else: - # Fallback: Use snippet for sentiment - logger.info(f" ⚠️ Content short/failed for {item['url']}, using snippet for sentiment.") - text_to_analyze = f"{item['title']} {item['content']}" # content is snippet here - sent_result = self.sentiment_tool.analyze_sentiment(text_to_analyze) - score = sent_result.get('score', 0.0) - item["sentiment_score"] = float(score) - - except Exception as e: - # Fallback: Use snippet for sentiment on error - logger.warning(f"Failed to enrich {item['url']}: {e}. Using snippet.") - text_to_analyze = f"{item['title']} {item['content']}" - sent_result = self.sentiment_tool.analyze_sentiment(text_to_analyze) - score = sent_result.get('score', 0.0) - item["sentiment_score"] = float(score) - - # 缓存结果 list - if normalized_results: - # Pass list directly, DB manager will handle JSON dump for main cache and populate search_details - # Only cache if NOT from local news reuse (though this logic path is for fresh search) - self.db.save_search_cache(query_hash, query, engine, normalized_results) - - return normalized_results - - except Exception as e: - # 搜索失败时的降级策略 - if engine == "jina": - logger.warning(f"⚠️ Jina search_list failed, falling back to ddg: {query} ({e})") - try: - return self.search_list(query, engine="ddg", max_results=max_results, ttl=ttl, enrich=enrich) - except Exception as e2: - logger.error(f"❌ DDG fallback (search_list) also failed for {query}: {e2}") - elif engine == "ddg": - logger.warning(f"⚠️ DDG search_list failed, falling back to baidu: {query} ({e})") - try: - return self.search_list(query, engine="baidu", max_results=max_results, ttl=ttl, enrich=enrich) - except Exception as e2: - logger.error(f"❌ Baidu fallback (search_list) also failed for {query}: {e2}") - - logger.error(f"❌ Structured search failed for {query}: {e}") - return [] - - def list_similar_queries(self, query: str, limit: int = 5) -> List[Dict]: - """ - 查找与当前查询类似的已缓存查询。 - Agent 可用此方法获取候选缓存,并使用 PROMPTS.md 进行评估以决定是否重用。 - """ - return self.db.find_similar_queries(query, limit=limit) - - - def aggregate_search(self, query: str, engines: Optional[List[str]] = None, max_results: int = 5) -> str: - """ - 使用多个搜索引擎同时搜索并聚合结果,获得更全面的信息覆盖。 - - Args: - query: 搜索关键词。 - engines: 要使用的搜索引擎列表。可选值: ["ddg", "baidu"]。 - 默认同时使用 ddg 和 baidu。 - max_results: 每个引擎期望返回的结果数量。 - - Returns: - 聚合后的搜索结果,按引擎分组显示。 - """ - engines = engines or ["ddg", "baidu"] - aggregated_results = [] - for engine in engines: - res = self.search(query, engine=engine, max_results=max_results) - aggregated_results.append(f"--- Results from {engine.upper()} ---\n{res}") - - return "\n\n".join(aggregated_results) diff --git a/skills/alphaear-search/scripts/sentiment_tools.py b/skills/alphaear-search/scripts/sentiment_tools.py deleted file mode 100644 index f4278b5..0000000 --- a/skills/alphaear-search/scripts/sentiment_tools.py +++ /dev/null @@ -1,287 +0,0 @@ -import os -from typing import Dict, List, Union, Optional -import json -from loguru import logger -from agno.agent import Agent -from .llm.factory import get_model -from .database_manager import DatabaseManager - -# 从环境变量读取默认情绪分析模式 -DEFAULT_SENTIMENT_MODE = os.getenv("SENTIMENT_MODE", "auto") # auto, bert, llm - - -class SentimentTools: - """ - 情绪分析工具 - 支持 LLM 和 BERT 两种模式 - - 模式说明: - - "auto": 自动选择,优先使用 BERT(速度快),不可用时回退到 LLM - - "bert": 强制使用 BERT 模型(需要 transformers 库) - - "llm": 强制使用 LLM(更准确但较慢) - - 可通过环境变量 SENTIMENT_MODE 设置默认模式。 - """ - - def __init__( - self, - db: DatabaseManager, - mode: Optional[str] = None, - model_provider: str = "openai", - model_id: str = "gpt-4o", - ): - """ - 初始化情绪分析工具。 - - Args: - db: 数据库管理器实例 - mode: 分析模式,可选 "auto", "bert", "llm"。None 则使用环境变量默认值。 - model_provider: LLM 提供商,如 "openai", "ust", "deepseek" - model_id: 模型标识符 - """ - self.db = db - self.mode = mode or DEFAULT_SENTIMENT_MODE - self.llm_model = None - self.bert_pipeline = None - - # Initialize LLM - try: - provider = "minimax" if os.getenv("MINIMAX_API_KEY") else model_provider - m_id = ( - os.getenv("LLM_MODEL", "MiniMax-Text-01") - if provider == "minimax" - else model_id - ) - self.llm_model = get_model(provider, m_id) - except Exception as e: - logger.warning(f"LLM initialization skipped: {e}") - - # Initialize BERT if needed - if self.mode in ["bert", "auto"]: - try: - from transformers import ( - pipeline, - AutoTokenizer, - AutoModelForSequenceClassification, - ) - from transformers.utils import logging as transformers_logging - - transformers_logging.set_verbosity_error() # 减少冗余日志 - - bert_model = os.getenv( - "BERT_SENTIMENT_MODEL", - "uer/roberta-base-finetuned-chinanews-chinese", - ) - - # 优先使用本地缓存 - try: - tokenizer = AutoTokenizer.from_pretrained( - bert_model, local_files_only=True - ) - model = AutoModelForSequenceClassification.from_pretrained( - bert_model, local_files_only=True - ) - - self.bert_pipeline = pipeline( - "sentiment-analysis", - model=model, - tokenizer=tokenizer, - device=-1, - ) - logger.info( - f"✅ BERT pipeline loaded from local cache: {bert_model}" - ) - except (OSError, ValueError, ImportError): - # 本地没有,则从网络下载 - logger.info(f"📡 Downloading BERT model: {bert_model}...") - tokenizer = AutoTokenizer.from_pretrained(bert_model) - model = AutoModelForSequenceClassification.from_pretrained( - bert_model - ) - - self.bert_pipeline = pipeline( - "sentiment-analysis", - model=model, - tokenizer=tokenizer, - device=-1, - ) - logger.info( - f"✅ BERT Sentiment pipeline ({bert_model}) initialized." - ) - except ImportError: - logger.warning( - "Transformers library not installed. BERT sentiment analysis disabled." - ) - except Exception as e: - if self.mode == "bert": - logger.error(f"BERT mode requested but failed: {e}") - else: - logger.warning(f"BERT unavailable, using LLM only. Error: {e}") - self.bert_pipeline = None - - def analyze_sentiment(self, text: str) -> Dict[str, Union[float, str]]: - """ - 分析文本的情绪极性。根据初始化时的 mode 自动选择分析方法。 - - Args: - text: 需要分析的文本内容,如新闻标题或摘要。 - - Returns: - 包含以下字段的字典: - - score: 情绪分值,范围 -1.0(极度负面)到 1.0(极度正面),0.0 为中性 - - label: 情绪标签,"positive"/"negative"/"neutral" - - reason: 分析理由(仅 LLM 模式提供详细理由) - """ - if self.mode == "bert" and self.bert_pipeline: - results = self.analyze_sentiment_bert([text]) - return results[0] if results else {"score": 0.0, "label": "error"} - elif self.mode == "llm" or (self.mode == "auto" and not self.bert_pipeline): - return self.analyze_sentiment_llm(text) - else: - # auto mode with BERT available - results = self.analyze_sentiment_bert([text]) - return results[0] if results else {"score": 0.0, "label": "error"} - - def analyze_sentiment_llm(self, text: str) -> Dict[str, Union[float, str]]: - """ - 使用 LLM 进行深度情绪分析,可获得详细的分析理由。 - - Args: - text: 需要分析的文本,最多处理前 1000 字符。 - - Returns: - 包含 score, label, reason 的字典。 - """ - if not self.llm_model: - return {"score": 0.0, "label": "neutral", "error": "LLM not initialized"} - - analyzer = Agent(model=self.llm_model, markdown=True) - prompt = f"""请分析以下金融/新闻文本的情绪极性。 - 返回严格的 JSON 格式: - {{"score": , "label": "", "reason": "<简短理由>"}} - - 文本: {text[:1000]}""" - - try: - response = analyzer.run(prompt) - content = response.content - if "```json" in content: - content = content.split("```json")[1].split("```")[0].strip() - elif "```" in content: - content = content.split("```")[1].split("```")[0].strip() - return json.loads(content) - except Exception as e: - logger.error(f"LLM sentiment failed: {e}") - return {"score": 0.0, "label": "error", "reason": str(e)} - - def analyze_sentiment_bert(self, texts: List[str]) -> List[Dict]: - """ - 使用 BERT 进行批量高速情绪分析。 - - Args: - texts: 需要分析的文本列表。 - - Returns: - 与输入列表等长的分析结果列表。 - """ - if not self.bert_pipeline: - return [ - {"score": 0.0, "label": "error", "reason": "BERT not available"} - ] * len(texts) - - try: - results = self.bert_pipeline(texts, truncation=True, max_length=512) - processed = [] - for r in results: - label = r["label"].lower() - score = r["score"] - - # 标准化不同模型的标签格式 - if "negative" in label or "neg" in label: - score = -score - elif "neutral" in label or "neu" in label: - score = 0.0 - - processed.append( - { - "score": float(round(score, 3)), - "label": "positive" - if score > 0.1 - else ("negative" if score < -0.1 else "neutral"), - "reason": "BERT automated analysis", - } - ) - return processed - except Exception as e: - logger.error(f"BERT analysis failed: {e}") - return [{"score": 0.0, "label": "error", "reason": str(e)}] * len(texts) - - def batch_update_news_sentiment( - self, - source: Optional[str] = None, - limit: int = 50, - use_bert: Optional[bool] = None, - ): - """ - 批量更新数据库中新闻的情绪分数。 - - Args: - source: 筛选特定新闻源,如 "wallstreetcn"。None 则处理所有来源。 - limit: 最多处理的新闻数量。 - use_bert: 是否使用 BERT。None 则根据初始化模式自动决定。 - - Returns: - 成功更新的新闻数量。 - """ - news_items = self.db.get_daily_news(source=source, limit=limit) - to_analyze = [item for item in news_items if not item.get("sentiment_score")] - - if not to_analyze: - return 0 - - # 决定使用哪种方法 - should_use_bert = ( - use_bert - if use_bert is not None - else (self.bert_pipeline is not None and self.mode != "llm") - ) - - updated_count = 0 - cursor = self.db.conn.cursor() - - if should_use_bert and self.bert_pipeline: - logger.info( - f"🚀 Using BERT for batch analysis of {len(to_analyze)} items..." - ) - titles = [item["title"] for item in to_analyze] - results = self.analyze_sentiment_bert(titles) - - for item, analysis in zip(to_analyze, results): - cursor.execute( - """ - UPDATE daily_news - SET sentiment_score = ?, meta_data = json_set(COALESCE(meta_data, '{}'), '$.sentiment_reason', ?) - WHERE id = ? - """, - (analysis["score"], analysis["reason"], item["id"]), - ) - updated_count += 1 - else: - logger.info(f"🚶 Using LLM for analysis of {len(to_analyze)} items...") - for item in to_analyze: - analysis = self.analyze_sentiment_llm(item["title"]) - cursor.execute( - """ - UPDATE daily_news - SET sentiment_score = ?, meta_data = json_set(COALESCE(meta_data, '{}'), '$.sentiment_reason', ?) - WHERE id = ? - """, - ( - analysis.get("score", 0.0), - analysis.get("reason", ""), - item["id"], - ), - ) - updated_count += 1 - - self.db.conn.commit() - return updated_count diff --git a/skills/alphaear-search/tests/test_search.py b/skills/alphaear-search/tests/test_search.py deleted file mode 100644 index 14838b3..0000000 --- a/skills/alphaear-search/tests/test_search.py +++ /dev/null @@ -1,31 +0,0 @@ -import sys -import os -import unittest - -# Add skill root to path -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) - -try: - from scripts.search_tools import SearchTools - from scripts.database_manager import DatabaseManager - from scripts.hybrid_search import InMemoryRAG -except ImportError as e: - print(f"Import Error: {e}") - sys.exit(1) - -class TestSearch(unittest.TestCase): - def test_init(self): - print("Testing SearchTools Iteration...") - db = DatabaseManager(":memory:") - tools = SearchTools(db) - self.assertIsNotNone(tools) - print("SearchTools Initialized.") - - def test_rag(self): - print("Testing InMemoryRAG...") - rag = InMemoryRAG([]) - self.assertIsNotNone(rag) - print("InMemoryRAG Initialized.") - -if __name__ == '__main__': - unittest.main() diff --git a/skills/alphaear-sentiment/SKILL.md b/skills/alphaear-sentiment/SKILL.md deleted file mode 100644 index 2d5fc7f..0000000 --- a/skills/alphaear-sentiment/SKILL.md +++ /dev/null @@ -1,57 +0,0 @@ ---- -name: alphaear-sentiment -description: Analyze finance text sentiment using FinBERT or LLM. Use when the user needs to determine the sentiment (positive/negative/neutral) and score of financial text markets. ---- - -# AlphaEar Sentiment Skill - -## Overview - -This skill provides sentiment analysis capabilities tailored for financial texts, supporting both FinBERT (local model) and LLM-based analysis modes. - -## Capabilities - -## Capabilities - -### 1. Analyze Sentiment (FinBERT / Local) - -Use `scripts/sentiment_tools.py` for high-speed, local sentiment analysis using FinBERT. - -**Key Methods:** - -- `analyze_sentiment(text)`: Get sentiment score and label using localized FinBERT model. - - **Returns**: `{'score': float, 'label': str, 'reason': str}`. - - **Score Range**: -1.0 (Negative) to 1.0 (Positive). -- `batch_update_news_sentiment(source, limit)`: Batch process unanalyzed news in the database (FinBERT only). - -### 2. Analyze Sentiment (LLM / Agentic) - -For higher accuracy or reasoning capabilities, **YOU (the Agent)** should perform the analysis using the Prompt below, calling the LLM directly, and then update the database if necessary. - -#### Sentiment Analysis Prompt - -Use this prompt to analyze financial texts if the local tool is insufficient or if reasoning is required. - -```markdown -请分析以下金融/新闻文本的情绪极性。 -返回严格的 JSON 格式: -{"score": , "label": "", "reason": "<简短理由>"} - -文本: {text} -``` - -**Scoring Guide:** -- **Positive (0.1 to 1.0)**: Optimistic news, profit growth, policy support, etc. -- **Negative (-1.0 to -0.1)**: Losses, sanctions, price drops, pessimism. -- **Neutral (-0.1 to 0.1)**: Factual reporting, sideways movement, ambiguous impact. - -#### Helper Methods -- `update_single_news_sentiment(id, score, reason)`: Use this to save your manual analysis to the database. - -## Dependencies - -- `torch` (for FinBERT) -- `transformers` (for FinBERT) -- `sqlite3` (built-in) - -Ensure `DatabaseManager` is initialized correctly. diff --git a/skills/alphaear-sentiment/scripts/__init__.py b/skills/alphaear-sentiment/scripts/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/alphaear-sentiment/scripts/database_manager.py b/skills/alphaear-sentiment/scripts/database_manager.py deleted file mode 100644 index cfc362b..0000000 --- a/skills/alphaear-sentiment/scripts/database_manager.py +++ /dev/null @@ -1,581 +0,0 @@ -import sqlite3 -import json -from datetime import datetime, date -from pathlib import Path -from typing import List, Dict, Optional, Any, Union -import pandas as pd -from loguru import logger - -class DatabaseManager: - """ - AlphaEar 数据库管理器 - 负责存储热点数据、搜索缓存和股价数据 - 使用 SQLite 进行持久化存储 - """ - - def __init__(self, db_path: str = "data/signal_flux.db"): - self.db_path = Path(db_path) - self.db_path.parent.mkdir(parents=True, exist_ok=True) - self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False) - self.conn.row_factory = sqlite3.Row - self._init_db() - logger.info(f"💾 Database initialized at {self.db_path}") - - def _init_db(self): - """初始化表结构""" - cursor = self.conn.cursor() - - # 1. 每日热点新闻表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS daily_news ( - id TEXT PRIMARY KEY, - source TEXT, - rank INTEGER, - title TEXT, - url TEXT, - content TEXT, - publish_time TEXT, - crawl_time TEXT, - sentiment_score REAL, - analysis TEXT, - meta_data TEXT - ) - """) - - # 尝试添加 analysis 列(如果表已存在但没有该列) - try: - cursor.execute("ALTER TABLE daily_news ADD COLUMN analysis TEXT") - except: - pass # 列已存在 - - - # 2. 搜索缓存表 (原有 JSON 缓存) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS search_cache ( - query_hash TEXT PRIMARY KEY, - query TEXT, - engine TEXT, - results TEXT, - timestamp TEXT - ) - """) - - # 2.5 搜索详情表 (展开的搜索结果) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS search_detail ( - id TEXT, - query_hash TEXT, - rank INTEGER, - title TEXT, - url TEXT, - content TEXT, - publish_time TEXT, - crawl_time TEXT, - sentiment_score REAL, - source TEXT, - meta_data TEXT, - PRIMARY KEY (query_hash, id) - ) - """) - - # 3. 股价数据表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS stock_prices ( - ticker TEXT, - date TEXT, - open REAL, - close REAL, - high REAL, - low REAL, - volume REAL, - change_pct REAL, - PRIMARY KEY (ticker, date) - ) - """) - - # 4. 股票列表表 (用于检索) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS stock_list ( - code TEXT PRIMARY KEY, - name TEXT - ) - """) - - # 5. 投资信号表 (ISQ Framework) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS signals ( - signal_id TEXT PRIMARY KEY, - title TEXT, - summary TEXT, - transmission_chain TEXT, - sentiment_score REAL, - confidence REAL, - intensity INTEGER, - expected_horizon TEXT, - price_in_status TEXT, - impact_tickers TEXT, - industry_tags TEXT, - sources TEXT, - user_id TEXT, - created_at TEXT - ) - """) - - - - # 6. 创建索引以优化查询性能 - cursor.execute("CREATE INDEX IF NOT EXISTS idx_news_crawl_time ON daily_news(crawl_time)") - cursor.execute("CREATE INDEX IF NOT EXISTS idx_news_source ON daily_news(source)") - cursor.execute("CREATE INDEX IF NOT EXISTS idx_search_cache_timestamp ON search_cache(timestamp)") - cursor.execute("CREATE INDEX IF NOT EXISTS idx_stock_prices_ticker_date ON stock_prices(ticker, date)") - # 尝试添加 user_id 列到 signals 表 - try: - cursor.execute("ALTER TABLE signals ADD COLUMN user_id TEXT") - except: - pass - - cursor.execute("CREATE INDEX IF NOT EXISTS idx_signals_user_id ON signals(user_id)") - - self.conn.commit() - - # - # self.conn.commit() - - - # --- 新闻数据操作 --- - - def save_daily_news(self, news_list: List[Dict]) -> int: - """保存热点新闻,包含发布时间与抓取时间""" - cursor = self.conn.cursor() - count = 0 - crawl_time = datetime.now().isoformat() - - for news in news_list: - try: - # 兼容不同来源的 ID 生成逻辑 - news_id = news.get('id') or f"{news.get('source')}_{news.get('rank')}_{crawl_time[:10]}" - cursor.execute(""" - INSERT OR REPLACE INTO daily_news - (id, source, rank, title, url, content, publish_time, crawl_time, sentiment_score, meta_data) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - news_id, - news.get('source'), - news.get('rank'), - news.get('title'), - news.get('url'), - news.get('content', ''), - news.get('publish_time'), # 新增支持发布时间 - crawl_time, - news.get('sentiment_score'), - json.dumps(news.get('meta_data', {})) - )) - count += 1 - except sqlite3.Error as e: - logger.error(f"Database error saving news item {news.get('title')}: {e}") - except Exception as e: - logger.error(f"Unexpected error saving news item {news.get('title')}: {e}") - - self.conn.commit() - return count - - def get_daily_news(self, source: Optional[str] = None, limit: int = 100, days: int = 1) -> List[Dict]: - """获取最近 N 天的热点新闻""" - cursor = self.conn.cursor() - # 使用 crawl_time 过滤,保证结果的新鲜度 - time_threshold = (datetime.now().timestamp() - days * 86400) - time_threshold_str = datetime.fromtimestamp(time_threshold).isoformat() - - query = "SELECT * FROM daily_news WHERE crawl_time >= ?" - params = [time_threshold_str] - - if source: - query += " AND source = ?" - params.append(source) - - query += " ORDER BY crawl_time DESC, rank LIMIT ?" - params.append(limit) - - cursor.execute(query, params) - return [dict(row) for row in cursor.fetchall()] - - def lookup_reference_by_url(self, url: str) -> Optional[Dict[str, Any]]: - """Best-effort lookup of a source item by URL. - - This is used to render a stable bibliography from DB-backed metadata. - It searches both `daily_news` and `search_detail`. - """ - url = (url or "").strip() - if not url: - return None - - cursor = self.conn.cursor() - - try: - cursor.execute( - """ - SELECT title, source, publish_time, crawl_time, url - FROM daily_news - WHERE url = ? - ORDER BY crawl_time DESC - LIMIT 1 - """, - (url,), - ) - row = cursor.fetchone() - if row: - return dict(row) - except Exception: - pass - - try: - cursor.execute( - """ - SELECT title, source, publish_time, crawl_time, url - FROM search_detail - WHERE url = ? - ORDER BY crawl_time DESC - LIMIT 1 - """, - (url,), - ) - row = cursor.fetchone() - if row: - return dict(row) - except Exception: - pass - - return None - - def delete_news(self, news_id: str) -> bool: - """删除特定新闻""" - cursor = self.conn.cursor() - cursor.execute("DELETE FROM daily_news WHERE id = ?", (news_id,)) - self.conn.commit() - return cursor.rowcount > 0 - - def update_news_content(self, news_id: str, content: str = None, analysis: str = None) -> bool: - """更新新闻的内容或分析结果""" - cursor = self.conn.cursor() - updates = [] - params = [] - - if content is not None: - updates.append("content = ?") - params.append(content) - if analysis is not None: - updates.append("analysis = ?") - params.append(analysis) - - if not updates: - return False - - params.append(news_id) - query = f"UPDATE daily_news SET {', '.join(updates)} WHERE id = ?" - cursor.execute(query, params) - self.conn.commit() - return cursor.rowcount > 0 - - # --- 搜索缓存辅助 --- - - def get_search_cache(self, query_hash: str, ttl_seconds: Optional[int] = None) -> Optional[Dict]: - """获取搜索缓存 (优先查 search_detail)""" - cursor = self.conn.cursor() - - # 1. 尝试从 search_detail 获取展开的结构化数据 - cursor.execute(""" - SELECT * FROM search_detail - WHERE query_hash = ? - ORDER BY rank - """, (query_hash,)) - details = [dict(row) for row in cursor.fetchall()] - - if details: - # 检查 TTL (取第一条的时间) - first_time = datetime.fromisoformat(details[0]['crawl_time']) - if ttl_seconds and (datetime.now() - first_time).total_seconds() > ttl_seconds: - logger.info(f"⌛ Detailed cache expired for hash {query_hash}") - pass # Expired, fall through or return None? If Detail expired, Cache likely expired too. - # But let's check basic cache just in case metadata differs? - # Actually if details exist, we prefer them. If expired, we return None. - return None - - logger.info(f"✅ Hit detailed search cache for {query_hash} ({len(details)} items)") - # Reconstruct the expected 'results' list format for SearchTools - # SearchTools expects a list of dicts. - # We return a dict wrapper to match get_search_cache signature returning Dict usually containing 'results' string. - # But SearchTools logic: - # cache = db.get_search_cache(...) - # cached_data = json.loads(cache['results']) - - # To minimize SearchTools changes, we can return a dict mimicking the old structure - # OR Change SearchTools to handle list return. - # Let's return a special dict that SearchTools can recognize or just format it as before. - return {"results": json.dumps(details), "timestamp": details[0]['crawl_time']} - - # 2. Fallback to old table - cursor.execute("SELECT * FROM search_cache WHERE query_hash = ?", (query_hash,)) - row = cursor.fetchone() - - if not row: - return None - - row_dict = dict(row) - if ttl_seconds: - cache_time = datetime.fromisoformat(row_dict['timestamp']) - if (datetime.now() - cache_time).total_seconds() > ttl_seconds: - logger.info(f"⌛ Cache expired for hash {query_hash}") - return None - - return row_dict - - def save_search_cache(self, query_hash: str, query: str, engine: str, results: Union[str, List[Dict]]): - """保存搜索结果 (同时保存到 search_cache 和 search_detail)""" - cursor = self.conn.cursor() - current_time = datetime.now().isoformat() - - results_str = results if isinstance(results, str) else json.dumps(results) - - # 1. Save summary to search_cache - cursor.execute(""" - INSERT OR REPLACE INTO search_cache (query_hash, query, engine, results, timestamp) - VALUES (?, ?, ?, ?, ?) - """, (query_hash, query, engine, results_str, current_time)) - - # 2. Save details to search_detail if results is a list - if isinstance(results, list): - for item in results: - try: - item_id = item.get('id') or f"{hash(item.get('url', ''))}" - cursor.execute(""" - INSERT OR REPLACE INTO search_detail - (id, query_hash, rank, title, url, content, publish_time, crawl_time, sentiment_score, source, meta_data) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - str(item_id), - query_hash, - item.get('rank', 0), - item.get('title'), - item.get('url'), - item.get('content', ''), - item.get('publish_time'), - item.get('crawl_time') or current_time, - item.get('sentiment_score'), - item.get('source'), - json.dumps(item.get('meta_data', {})) - )) - except sqlite3.Error as e: - logger.error(f"Database error saving search detail {item.get('title')}: {e}") - except Exception as e: - logger.error(f"Unexpected error saving search detail {item.get('title')}: {e}") - - self.conn.commit() - - def find_similar_queries(self, query: str, limit: int = 5) -> List[Dict]: - """模糊搜索相似的已缓存查询""" - cursor = self.conn.cursor() - - # Simple fuzzy match: query in cached OR cached in query - q_wild = f"%{query}%" - cursor.execute(""" - SELECT query, query_hash, timestamp, results - FROM search_cache - WHERE query LIKE ? OR ? LIKE ('%' || query || '%') - ORDER BY timestamp DESC - LIMIT ? - """, (q_wild, query, limit)) - - return [dict(row) for row in cursor.fetchall()] - - def search_local_news(self, query: str, limit: int = 5) -> List[Dict]: - """从本地 daily_news 搜索相关新闻""" - cursor = self.conn.cursor() - q_wild = f"%{query}%" - # Search title and content - cursor.execute(""" - SELECT * FROM daily_news - WHERE title LIKE ? OR content LIKE ? - ORDER BY crawl_time DESC - LIMIT ? - """, (q_wild, q_wild, limit)) - return [dict(row) for row in cursor.fetchall()] - - # --- 股票数据操作 --- - - def save_stock_list(self, df: pd.DataFrame): - """保存股票列表到 stock_list 表""" - cursor = self.conn.cursor() - try: - # 清空旧表 - cursor.execute("DELETE FROM stock_list") - - # 批量插入 - data = df[['code', 'name']].to_dict('records') - cursor.executemany( - "INSERT INTO stock_list (code, name) VALUES (:code, :name)", - data - ) - self.conn.commit() - except sqlite3.Error as e: - logger.error(f"Database error saving stock list: {e}") - except Exception as e: - logger.error(f"Unexpected error saving stock list: {e}") - - def search_stock(self, query: str, limit: int = 5) -> List[Dict]: - """模糊搜索股票代码或名称""" - cursor = self.conn.cursor() - wild = f"%{query}%" - cursor.execute(""" - SELECT code, name FROM stock_list - WHERE code LIKE ? OR name LIKE ? - LIMIT ? - """, (wild, wild, limit)) - return [dict(row) for row in cursor.fetchall()] - - def get_stock_by_code(self, code: str) -> Optional[Dict[str, str]]: - """精确按代码获取股票信息。 - - Args: - code: 股票代码(A股6位 / 港股5位),必须为纯数字字符串。 - - Returns: - dict: {"code": str, "name": str} 或 None。 - """ - if not code: - return None - clean = "".join([c for c in str(code).strip() if c.isdigit()]) - if not clean: - return None - - cursor = self.conn.cursor() - cursor.execute("SELECT code, name FROM stock_list WHERE code = ? LIMIT 1", (clean,)) - row = cursor.fetchone() - return dict(row) if row else None - - def save_stock_prices(self, ticker: str, df: pd.DataFrame): - """保存股价历史数据""" - if df.empty: - return - - cursor = self.conn.cursor() - - # 确保 DataFrame 有必要的列 - required_cols = ['date', 'open', 'close', 'high', 'low', 'volume', 'change_pct'] - for col in required_cols: - if col not in df.columns: - logger.warning(f"Missing column {col} in stock data for {ticker}") - return - - try: - for _, row in df.iterrows(): - cursor.execute(""" - INSERT OR REPLACE INTO stock_prices - (ticker, date, open, close, high, low, volume, change_pct) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, ( - ticker, - row['date'], - row['open'], - row['close'], - row['high'], - row['low'], - row['volume'], - row['change_pct'] - )) - self.conn.commit() - except sqlite3.Error as e: - logger.error(f"Database error saving stock prices for {ticker}: {e}") - except Exception as e: - logger.error(f"Unexpected error saving stock prices for {ticker}: {e}") - - def get_stock_prices(self, ticker: str, start_date: str, end_date: str) -> pd.DataFrame: - """获取指定日期范围的股价数据""" - cursor = self.conn.cursor() - - cursor.execute(""" - SELECT * FROM stock_prices - WHERE ticker = ? AND date >= ? AND date <= ? - ORDER BY date - """, (ticker, start_date, end_date)) - - rows = cursor.fetchall() - if not rows: - return pd.DataFrame() - - columns = ['ticker', 'date', 'open', 'close', 'high', 'low', 'volume', 'change_pct'] - return pd.DataFrame([dict(row) for row in rows], columns=columns) - - def execute_query(self, query: str, params: tuple = ()) -> List[Any]: - """执行自定义 SQL 查询""" - try: - cursor = self.conn.cursor() - cursor.execute(query, params) - if query.strip().upper().startswith("SELECT"): - return cursor.fetchall() - else: - self.conn.commit() - return [] - except sqlite3.Error as e: - logger.error(f"SQL execution failed (Database error): {e}") - return [] - except Exception as e: - logger.error(f"SQL execution failed (Unexpected error): {e}") - return [] - - # --- 投资信号操作 (ISQ Framework) --- - - def save_signal(self, signal: Dict[str, Any]): - """保存投资信号""" - cursor = self.conn.cursor() - created_at = datetime.now().isoformat() - - cursor.execute(""" - INSERT OR REPLACE INTO signals - (signal_id, title, summary, transmission_chain, sentiment_score, - confidence, intensity, expected_horizon, price_in_status, - impact_tickers, industry_tags, sources, user_id, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - signal.get('signal_id'), - signal.get('title'), - signal.get('summary'), - json.dumps(signal.get('transmission_chain', [])), - signal.get('sentiment_score', 0.0), - signal.get('confidence', 0.0), - signal.get('intensity', 1), - signal.get('expected_horizon', 'T+0'), - signal.get('price_in_status', '未知'), - json.dumps(signal.get('impact_tickers', [])), - json.dumps(signal.get('industry_tags', [])), - json.dumps(signal.get('sources', [])), - signal.get('user_id'), - created_at - )) - self.conn.commit() - - def get_recent_signals(self, limit: int = 20, user_id: Optional[str] = None) -> List[Dict]: - """获取最近的投资信号""" - cursor = self.conn.cursor() - if user_id: - cursor.execute("SELECT * FROM signals WHERE user_id = ? ORDER BY created_at DESC LIMIT ?", (user_id, limit)) - else: - cursor.execute("SELECT * FROM signals ORDER BY created_at DESC LIMIT ?", (limit,)) - rows = cursor.fetchall() - - signals = [] - for row in rows: - d = dict(row) - # 解析 JSON 字段 - for field in ['transmission_chain', 'impact_tickers', 'industry_tags', 'sources']: - if d.get(field): - try: - d[field] = json.loads(d[field]) - except: - pass - signals.append(d) - return signals - - def close(self): - if self.conn: - self.conn.close() - logger.info("Database connection closed.") - diff --git a/skills/alphaear-sentiment/scripts/llm/capability.py b/skills/alphaear-sentiment/scripts/llm/capability.py deleted file mode 100644 index de9de32..0000000 --- a/skills/alphaear-sentiment/scripts/llm/capability.py +++ /dev/null @@ -1,85 +0,0 @@ -import os -from typing import Optional, List, Dict, Any -from agno.agent import Agent -from agno.models.base import Model -from loguru import logger -from .llm.factory import get_model - - -def test_tool_call_support(model: Model) -> bool: - """ - 测试模型是否支持原生的 Tool Call (Function Calling)。 - 通过尝试执行一个简单的加法工具来验证。 - """ - - def get_current_weather(location: str): - """获取指定地点的天气""" - return f"{location} 的天气是晴天,25度。" - - test_agent = Agent( - model=model, - tools=[get_current_weather], - instructions="请调用工具查询北京的天气,并直接返回工具的输出结果。", - ) - - try: - # 运行一个简单的任务,观察是否触发了 tool_call - response = test_agent.run("北京天气怎么样?") - - # 检查 response 中是否包含 tool_calls - # Agno 的 RunResponse 对象通常包含 messages,我们可以检查最后几条消息 - has_tool_call = False - for msg in response.messages: - if hasattr(msg, "tool_calls") and msg.tool_calls: - has_tool_call = True - break - - if has_tool_call: - logger.info(f"✅ Model {model.id} supports native tool calling.") - return True - else: - # 如果没有 tool_calls 但返回了正确答案,可能是模型通过纯文本模拟了工具调用(ReAct) - # 或者根本没用工具。对于原生支持的判断,我们坚持要求有 tool_calls 结构。 - logger.warning( - f"⚠️ Model {model.id} did NOT use native tool calling structure." - ) - return False - - except Exception as e: - logger.error(f"❌ Error testing tool call for {model.id}: {e}") - return False - - -class ModelCapabilityRegistry: - """ - 模型能力注册表,用于缓存和管理不同模型的能力测试结果。 - """ - - _cache = {} - - @classmethod - def get_capabilities( - cls, provider: str, model_id: str, **kwargs - ) -> Dict[str, bool]: - key = f"{provider}:{model_id}" - if key not in cls._cache: - logger.info(f"🔍 Testing capabilities for {key}...") - model = get_model(provider, model_id, **kwargs) - supports_tool_call = test_tool_call_support(model) - cls._cache[key] = {"supports_tool_call": supports_tool_call} - return cls._cache[key] - - -if __name__ == "__main__": - import os - from skills._env_loader import load_unified_env - - load_unified_env() - - # 测试当前配置的模型 - p = os.getenv("LLM_PROVIDER", "minimax") - m = os.getenv("LLM_MODEL", "Qwen") - - print(f"Testing {p}/{m}...") - res = ModelCapabilityRegistry.get_capabilities(p, m) - print(f"Result: {res}") diff --git a/skills/alphaear-sentiment/scripts/llm/factory.py b/skills/alphaear-sentiment/scripts/llm/factory.py deleted file mode 100644 index 09b6ea5..0000000 --- a/skills/alphaear-sentiment/scripts/llm/factory.py +++ /dev/null @@ -1,114 +0,0 @@ -import os -from agno.models.openai import OpenAIChat -from agno.models.ollama import Ollama -from agno.models.dashscope import DashScope -from agno.models.deepseek import DeepSeek -from agno.models.openrouter import OpenRouter - -def get_model(model_provider: str, model_id: str, **kwargs): - """ - Factory to get the appropriate LLM model. - - Args: - model_provider: "openai", "ollama", "deepseek" - model_id: The specific model ID (e.g., "gpt-4o", "llama3", "deepseek-chat") - **kwargs: Additional arguments for the model constructor - """ - if model_provider == "openai": - return OpenAIChat(id=model_id, **kwargs) - - elif model_provider == "ollama": - return Ollama(id=model_id, **kwargs) - - elif model_provider == "deepseek": - # DeepSeek is OpenAI compatible - api_key = os.getenv("DEEPSEEK_API_KEY") - if not api_key: - print("Warning: DEEPSEEK_API_KEY not set.") - - return DeepSeek( - id=model_id, - api_key=api_key, - **kwargs - ) - elif model_provider == "dashscope": - api_key = os.getenv("DASHSCOPE_API_KEY") - if not api_key: - print("Warning: DASHSCOPE_API_KEY not set.") - - return DashScope( - id=model_id, - base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", - api_key=api_key, - **kwargs - ) - elif model_provider == 'openrouter': - api_key = os.getenv("OPENROUTER_API_KEY") - if not api_key: - print('Warning: OPENROUTER_API_KEY not set.') - - return OpenRouter( - id=model_id, - api_key=api_key, - **kwargs - ) - - elif model_provider == 'zai': - api_key = os.getenv("ZAI_KEY_API") - if not api_key: - print('Warning: ZAI_KEY_API not set.') - - # role_map to ensure compatibility. - default_role_map = { - "system": "system", - "user": "user", - "assistant": "assistant", - "tool": "tool", - "model": "assistant", - } - - # Allow callers to override role_map via kwargs, otherwise use default - role_map = kwargs.pop("role_map", default_role_map) - - return OpenAIChat( - id=model_id, - base_url="https://api.z.ai/api/paas/v4", - api_key=api_key, - timeout=60, - role_map=role_map, - extra_body={"enable_thinking": False}, # TODO: one more setting for thinking - **kwargs - ) - - elif model_provider == 'ust': - api_key = os.getenv("UST_KEY_API") - if not api_key: - print('Warning: UST_KEY_API not set.') - - # Some UST-compatible endpoints expect the standard OpenAI role names - # (e.g. "system", "user", "assistant") rather than Agno's default - # mapping which maps "system" -> "developer". Provide an explicit - # role_map to ensure compatibility. - default_role_map = { - "system": "system", - "user": "user", - "assistant": "assistant", - "tool": "tool", - "model": "assistant", - } - - # Allow callers to override role_map via kwargs, otherwise use default - role_map = kwargs.pop("role_map", default_role_map) - - return OpenAIChat( - id=model_id, - api_key=api_key, - base_url=os.getenv("UST_URL"), - role_map=role_map, - extra_body={"enable_thinking": False}, # TODO: one more setting for thinking - **kwargs - ) - - else: - raise ValueError(f"Unknown model provider: {model_provider}") - diff --git a/skills/alphaear-sentiment/scripts/llm/router.py b/skills/alphaear-sentiment/scripts/llm/router.py deleted file mode 100644 index 3a3cede..0000000 --- a/skills/alphaear-sentiment/scripts/llm/router.py +++ /dev/null @@ -1,80 +0,0 @@ -import os -from typing import Optional, List, Dict, Any, Union -from agno.models.base import Model -from loguru import logger -from .llm.factory import get_model -from utils.llm.capability import ModelCapabilityRegistry -from skills._env_loader import load_unified_env - -load_unified_env() - - -class ModelRouter: - """ - 模型路由管理器 - - 功能: - 1. 管理“推理/写作模型” (Reasoning Model) 和“工具调用模型” (Tool Model)。 - 2. 根据任务需求自动选择合适的模型。 - """ - - def __init__(self): - # 默认从环境变量读取 - self.reasoning_provider = os.getenv( - "REASONING_MODEL_PROVIDER", os.getenv("LLM_PROVIDER", "openai") - ) - self.reasoning_id = os.getenv( - "REASONING_MODEL_ID", os.getenv("LLM_MODEL", "gpt-4o") - ) - self.reasoning_host = os.getenv("REASONING_MODEL_HOST", os.getenv("LLM_HOST")) - - self.tool_provider = os.getenv("TOOL_MODEL_PROVIDER", self.reasoning_provider) - self.tool_id = os.getenv("TOOL_MODEL_ID", self.reasoning_id) - self.tool_host = os.getenv("TOOL_MODEL_HOST", self.reasoning_host) - - self._reasoning_model = None - self._tool_model = None - - logger.info( - f"🤖 ModelRouter initialized: Reasoning={self.reasoning_id} ({self.reasoning_host or 'default'}), Tool={self.tool_id} ({self.tool_host or 'default'})" - ) - - def get_reasoning_model(self, **kwargs) -> Model: - if not self._reasoning_model: - # 优先使用路由配置的 host - if self.reasoning_host and "host" not in kwargs: - kwargs["host"] = self.reasoning_host - self._reasoning_model = get_model( - self.reasoning_provider, self.reasoning_id, **kwargs - ) - return self._reasoning_model - - def get_tool_model(self, **kwargs) -> Model: - if not self._tool_model: - # 优先使用路由配置的 host - if self.tool_host and "host" not in kwargs: - kwargs["host"] = self.tool_host - - # 检查 tool_model 是否真的支持 tool call - caps = ModelCapabilityRegistry.get_capabilities( - self.tool_provider, self.tool_id, **kwargs - ) - if not caps["supports_tool_call"]: - logger.warning( - f"⚠️ Configured tool model {self.tool_id} might not support native tool calls! Consider using ReAct mode or a different model." - ) - - self._tool_model = get_model(self.tool_provider, self.tool_id, **kwargs) - return self._tool_model - - def get_model_for_agent(self, has_tools: bool = False, **kwargs) -> Model: - """ - 根据 Agent 是否包含工具来返回合适的模型。 - """ - if has_tools: - return self.get_tool_model(**kwargs) - return self.get_reasoning_model(**kwargs) - - -# 全局单例 -router = ModelRouter() diff --git a/skills/alphaear-sentiment/scripts/sentiment_tools.py b/skills/alphaear-sentiment/scripts/sentiment_tools.py deleted file mode 100644 index 330a47e..0000000 --- a/skills/alphaear-sentiment/scripts/sentiment_tools.py +++ /dev/null @@ -1,205 +0,0 @@ -import os -from typing import Dict, List, Union, Optional -import json -from loguru import logger -# IMPORTS REMOVED: agno.agent, get_model -# Internal LLM logic has been removed to delegate analysis to the calling Agent. -from .database_manager import DatabaseManager - -# 从环境变量读取默认情绪分析模式 -DEFAULT_SENTIMENT_MODE = os.getenv("SENTIMENT_MODE", "auto") # auto, bert, llm - -class SentimentTools: - """ - 情绪分析工具 - 支持 LLM 和 BERT 两种模式 - - 模式说明: - - "auto": 自动选择,优先使用 BERT(速度快),不可用时回退到 LLM - - "bert": 强制使用 BERT 模型(需要 transformers 库) - - "llm": 强制使用 LLM(更准确但较慢) - - 可通过环境变量 SENTIMENT_MODE 设置默认模式。 - """ - - def __init__(self, db: DatabaseManager, mode: Optional[str] = None): - """ - 初始化情绪分析工具。 - - Args: - db: 数据库管理器实例 - mode: 分析模式,可选 "auto", "bert", "llm"。None 则使用环境变量默认值。 - model_provider: LLM 提供商,如 "openai", "ust", "deepseek" - model_id: 模型标识符 - """ - self.db = db - self.mode = mode or DEFAULT_SENTIMENT_MODE - self.bert_pipeline = None - - # LLM initialization removed. Agent should perform analysis if needed. - - # Initialize BERT if needed - if self.mode in ["bert", "auto"]: - try: - from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification - from transformers.utils import logging as transformers_logging - transformers_logging.set_verbosity_error() # 减少冗余日志 - - bert_model = os.getenv("BERT_SENTIMENT_MODEL", "uer/roberta-base-finetuned-chinanews-chinese") - - # 优先使用本地缓存 - try: - tokenizer = AutoTokenizer.from_pretrained(bert_model, local_files_only=True) - model = AutoModelForSequenceClassification.from_pretrained(bert_model, local_files_only=True) - - self.bert_pipeline = pipeline( - "sentiment-analysis", - model=model, - tokenizer=tokenizer, - device=-1 - ) - logger.info(f"✅ BERT pipeline loaded from local cache: {bert_model}") - except (OSError, ValueError, ImportError): - # 本地没有,则从网络下载 - logger.info(f"📡 Downloading BERT model: {bert_model}...") - tokenizer = AutoTokenizer.from_pretrained(bert_model) - model = AutoModelForSequenceClassification.from_pretrained(bert_model) - - self.bert_pipeline = pipeline( - "sentiment-analysis", - model=model, - tokenizer=tokenizer, - device=-1 - ) - logger.info(f"✅ BERT Sentiment pipeline ({bert_model}) initialized.") - except ImportError: - logger.warning("Transformers library not installed. BERT sentiment analysis disabled.") - except Exception as e: - if self.mode == "bert": - logger.error(f"BERT mode requested but failed: {e}") - else: - logger.warning(f"BERT unavailable, using LLM only. Error: {e}") - self.bert_pipeline = None - - - def analyze_sentiment(self, text: str) -> Dict[str, Union[float, str]]: - """ - 分析文本的情绪极性。仅支持 BERT 模式。 - 如需 LLM 分析,请 Agent 按照 SKILL.md 中的 Prompt 自行执行。 - - Args: - text: 需要分析的文本内容。 - - Returns: - BERT 分析结果,或错误信息。 - """ - if self.bert_pipeline: - results = self.analyze_sentiment_bert([text]) - return results[0] if results else {"score": 0.0, "label": "error"} - else: - return { - "score": 0.0, - "label": "error", - "reason": "BERT pipeline not initialized. For LLM analysis, please manually execute the prompt in SKILL.md." - } - - def update_single_news_sentiment(self, news_id: Union[str, int], score: float, reason: str = "") -> bool: - """ - 允许 Agent 将手动分析的结果保存到数据库。 - - Args: - news_id: 新闻 ID - score: -1.0 到 1.0 - reason: 分析理由 - - Returns: - Success bool - """ - try: - cursor = self.db.conn.cursor() - cursor.execute(""" - UPDATE daily_news - SET sentiment_score = ?, meta_data = json_set(COALESCE(meta_data, '{}'), '$.sentiment_reason', ?) - WHERE id = ? - """, (score, reason, news_id)) - self.db.conn.commit() - return True - except Exception as e: - logger.error(f"Failed to update sentiment for {news_id}: {e}") - return False - - def analyze_sentiment_bert(self, texts: List[str]) -> List[Dict]: - """ - 使用 BERT 进行批量高速情绪分析。 - - Args: - texts: 需要分析的文本列表。 - - Returns: - 与输入列表等长的分析结果列表。 - """ - if not self.bert_pipeline: - return [{"score": 0.0, "label": "error", "reason": "BERT not available"}] * len(texts) - - try: - results = self.bert_pipeline(texts, truncation=True, max_length=512) - processed = [] - for r in results: - label = r['label'].lower() - score = r['score'] - - # 标准化不同模型的标签格式 - if 'negative' in label or 'neg' in label: - score = -score - elif 'neutral' in label or 'neu' in label: - score = 0.0 - - processed.append({ - "score": float(round(score, 3)), - "label": "positive" if score > 0.1 else ("negative" if score < -0.1 else "neutral"), - "reason": "BERT automated analysis" - }) - return processed - except Exception as e: - logger.error(f"BERT analysis failed: {e}") - return [{"score": 0.0, "label": "error", "reason": str(e)}] * len(texts) - - def batch_update_news_sentiment(self, source: Optional[str] = None, limit: int = 50, use_bert: Optional[bool] = None): - """ - 批量更新数据库中新闻的情绪分数。 - - Args: - source: 筛选特定新闻源,如 "wallstreetcn"。None 则处理所有来源。 - limit: 最多处理的新闻数量。 - use_bert: 是否使用 BERT。None 则根据初始化模式自动决定。 - - Returns: - 成功更新的新闻数量。 - """ - news_items = self.db.get_daily_news(source=source, limit=limit) - to_analyze = [item for item in news_items if not item.get('sentiment_score')] - - if not to_analyze: - return 0 - - updated_count = 0 - cursor = self.db.conn.cursor() - - # 决定使用哪种方法 - if self.bert_pipeline: - logger.info(f"🚀 Using BERT for batch analysis of {len(to_analyze)} items...") - titles = [item['title'] for item in to_analyze] - results = self.analyze_sentiment_bert(titles) - - for item, analysis in zip(to_analyze, results): - cursor.execute(""" - UPDATE daily_news - SET sentiment_score = ?, meta_data = json_set(COALESCE(meta_data, '{}'), '$.sentiment_reason', ?) - WHERE id = ? - """, (analysis['score'], analysis['reason'], item['id'])) - updated_count += 1 - else: - logger.warning("BERT pipeline not available. Batch update skipped. Please use Agentic analysis for high-quality results.") - - self.db.conn.commit() - return updated_count - diff --git a/skills/alphaear-sentiment/tests/test_sentiment.py b/skills/alphaear-sentiment/tests/test_sentiment.py deleted file mode 100644 index 3e0549c..0000000 --- a/skills/alphaear-sentiment/tests/test_sentiment.py +++ /dev/null @@ -1,25 +0,0 @@ -import sys -import os -import unittest - -# Add skill root to path -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) - -try: - from scripts.sentiment_tools import SentimentTools - from scripts.database_manager import DatabaseManager -except ImportError as e: - print(f"Import Error: {e}") - sys.exit(1) - -class TestSentiment(unittest.TestCase): - def test_init(self): - print("Testing SentimentTools Iteration...") - db = DatabaseManager(":memory:") - # Mock mode="llm" to avoid loading large models or needing keys - tools = SentimentTools(db, mode="llm") - self.assertIsNotNone(tools) - print("SentimentTools Initialized.") - -if __name__ == '__main__': - unittest.main() diff --git a/skills/alphaear-signal-tracker/SKILL.md b/skills/alphaear-signal-tracker/SKILL.md deleted file mode 100644 index f4f4a28..0000000 --- a/skills/alphaear-signal-tracker/SKILL.md +++ /dev/null @@ -1,51 +0,0 @@ ---- -name: alphaear-signal-tracker -description: Track finance investment signal evolution and update logic based on new finance market information. Use when monitoring finance signals and determining if they are strengthened, weakened, or falsified. ---- - -# AlphaEar Signal Tracker Skill - -## Overview - -This skill provides logic to track and update investment signals. It assesses how new market information impacts existing signals (Strengthened, Weakened, Falsified, or Unchanged). - -## Capabilities - -### 1. Track Signal Evolution - -### 1. Track Signal Evolution (Agentic Workflow) - -**YOU (the Agent)** are the Tracker. Use the prompts in `references/PROMPTS.md`. - -**Workflow:** -1. **Research**: Use **FinResearcher Prompt** to gather facts/price for a signal. -2. **Analyze**: Use **FinAnalyst Prompt** to generate the initial `InvestmentSignal`. -3. **Track**: For existing signals, use **Signal Tracking Prompt** to assess evolution (Strengthened/Weakened/Falsified) based on new info. - -**Tools:** -- Use `alphaear-search` and `alphaear-stock` skills to gather the necessary data. -- Use `scripts/fin_agent.py` helper `_sanitize_signal_output` if needing to clean JSON. - -**Key Logic:** - -- **Input**: Existing Signal State + New Information (News/Price). -- **Process**: - 1. Compare new info with signal thesis. - 2. Determine impact direction (Positive/Negative/Neutral). - 3. Update confidence and intensity. -- **Output**: Updated Signal. - -**Example Usage (Conceptual):** - -```python -# This skill is currently a pattern extracted from FinAgent. -# In a future refactor, it should be a standalone utility class. -# For now, refer to `scripts/fin_agent.py`'s `track_signal` method implementation. -``` - -## Dependencies - -- `agno` (Agent framework) -- `sqlite3` (built-in) - -Ensure `DatabaseManager` is initialized correctly. diff --git a/skills/alphaear-signal-tracker/references/PROMPTS.md b/skills/alphaear-signal-tracker/references/PROMPTS.md deleted file mode 100644 index 5bff3b4..0000000 --- a/skills/alphaear-signal-tracker/references/PROMPTS.md +++ /dev/null @@ -1,72 +0,0 @@ -# AlphaEar Signal Tracker Prompts - -## 1. FinResearcher - -**Prompt:** - -```markdown -You are a senior financial researcher. Current time: {current_time}. -Your task is to investigate the "Raw Signal" to provide materials for deep analysis. - -### Core Duties -1. **Identify Ticker**: Confirm specific listed company codes. Use tools to check price/history. -2. **Fact Check**: Verify signal authenticity via search/news. -3. **Industry Chain**: Map upstream/downstream. - -### Tool Usage -- Check price for EVERY mentioned company. -- Cross-verify information. - -### Output -Output a structured research report covering fundamentals, price trend, and industry background. -``` - -## 2. FinAnalyst (Signal Parsing) - -**Prompt:** - -```markdown -You are a senior financial analyst (FinAgent). Current time: {current_time}. -Task: transform research materials into actionable Investment Intelligence (ISQ). - -### Raw Signal -{signal_text} - -### Research Context -{research_context_str} - -### Analysis Requirements -1. **Title**: Concise (<15 words). -2. **Pricing**: Analyze if priced-in based on provided price data. -3. **Impact**: Fill `impact_tickers` with codes and weights. -4. **Logic**: `transmission_chain` with `node_name`, `impact_type`, `logic`. -5. **Prediction**: `summary` must contain specific targets (price/change). - -### Output (Strict JSON - InvestmentSignal) -Output valid JSON matching the InvestmentSignal schema. -``` - -## 3. Signal Tracking (Evolution) - -**Prompt:** - -```markdown -You are tracking signal evolution. -Task: Re-evaluate previous investment signal based on new market info. - -=== Baseline Signal === -{old_sig_str} - -=== Latest Tracking (NEWS & PRICE) === -{new_research_str} - -### Requirements -1. **Evolution Detection**: - - Has logic changed? (Falsified? Realized? strengthened?) - - Mark `reasoning` with "Logic Evolution: ...". -2. **Parameter Correction**: - - Update `sentiment_score`, `confidence`, `expectation_gap`. -3. **Output**: - - Keep `signal_id`. - - Output full InvestmentSignal JSON. -``` diff --git a/skills/alphaear-signal-tracker/scripts/__init__.py b/skills/alphaear-signal-tracker/scripts/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/alphaear-signal-tracker/scripts/fin_agent.py b/skills/alphaear-signal-tracker/scripts/fin_agent.py deleted file mode 100644 index 07608ed..0000000 --- a/skills/alphaear-signal-tracker/scripts/fin_agent.py +++ /dev/null @@ -1,106 +0,0 @@ -import time -from typing import Optional, List -from loguru import logger - -from .utils.database_manager import DatabaseManager - -class FinUtils: - """ - 金融分析辅助工具 (FinUtils) - 提供数据清洗、Output Sanitization 等功能。 - 核心分析逻辑已移交 Agent 执行 (参考 scripts/prompts/PROMPTS.md)。 - """ - - def __init__(self, db: DatabaseManager): - self.db = db - - @staticmethod - def _clean_digits(value: str) -> str: - s = (value or "").strip() - if not s: - return "" - return "".join([c for c in s if c.isdigit()]) - - def sanitize_signal_output(self, json_data: dict, research_data: Optional[dict] = None, raw_signal: str = "") -> dict: - """Post-process LLM output to prevent spurious ticker/name binding.""" - if not isinstance(json_data, dict): - return json_data - - tool_suggested: set[str] = set() - if isinstance(research_data, dict): - tf = research_data.get('tickers_found') - if isinstance(tf, list): - for item in tf: - if not isinstance(item, dict): - continue - code_raw = item.get('code') or item.get('ticker') or item.get('symbol') - code = self._clean_digits(str(code_raw or "")) - if code: - tool_suggested.add(code) - - sources = json_data.get('sources') - source_titles: list[str] = [] - source_urls: list[str] = [] - if isinstance(sources, list): - for s in sources: - if not isinstance(s, dict): - continue - t = str(s.get('title') or "").strip() - u = str(s.get('url') or "").strip() - if t: - source_titles.append(t) - if u: - source_urls.append(u) - - evidence_text = " ".join([ - str(raw_signal or ""), - str(json_data.get('title') or ""), - str(json_data.get('summary') or ""), - " ".join(source_titles), - " ".join(source_urls), - ]) - - impact = json_data.get('impact_tickers') - if not isinstance(impact, list): - return json_data - - if not impact: - return json_data - - sanitized: list[dict] = [] - for item in impact: - if not isinstance(item, dict): - continue - code_raw = item.get('ticker') or item.get('code') or item.get('symbol') - code = self._clean_digits(str(code_raw or "")) - - # Simple validation if DB lookup is too expensive or complex here. - # But the original code used self.db, so we try to use it. - if not (code.isdigit() and len(code) in (5, 6)): - continue - - # Original logic used DB to verify stock existence - try: - stock = self.db.get_stock_by_code(code) - if not stock: - continue - official_name = stock.get('name') or "" - - mentioned = (code in evidence_text) or (official_name and official_name in evidence_text) - if tool_suggested: - if code not in tool_suggested and not mentioned: - continue - else: - if not mentioned: - continue - - new_item = dict(item) - new_item['ticker'] = code - new_item['name'] = official_name - sanitized.append(new_item) - except Exception: - # If DB access fails, be permissive or conservative? Conservative to avoid hallucinations. - pass - - json_data['impact_tickers'] = sanitized - return json_data diff --git a/skills/alphaear-signal-tracker/scripts/prompts/fin_agent.py b/skills/alphaear-signal-tracker/scripts/prompts/fin_agent.py deleted file mode 100644 index 83386af..0000000 --- a/skills/alphaear-signal-tracker/scripts/prompts/fin_agent.py +++ /dev/null @@ -1,127 +0,0 @@ -from datetime import datetime -from .isq_prompt_generator import generate_isq_prompt_section - -def get_fin_researcher_instructions() -> str: - """生成金融研究员 (Researcher) 的系统指令""" - current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') - return f"""你是一名资深金融研究员,当前时间是 {current_time}。 -你的任务是针对给定的“原始信号”进行详尽的背景调查,为后续的深度分析提供素材。 - -### 1. 核心职责 -1. **标的识别**: 识别信号中涉及的具体上市公司。必须调用 `search_ticker` 确认代码,并调用 `get_stock_price` 获取最新价格和近 30 天走势。 -2. **事实核查**: 使用 `web_search` 或 `fetch_news_content` 验证信号的真实性,并寻找更多细节(如公告原文、行业研报摘要)。 -3. **产业链梳理**: 补充该信号涉及的上下游环节及竞争格局。 - -### 2. 工具使用规范 (CRITICAL) -- **每个提到的公司都需要调用工具**: 不能依赖记忆,必须实时查询。 -- **完整呈现工具结果**: 包括具体的股价数字、代码、技术面数据等,不要缩略。 -- **股价数据必需**: 当前价格、近期最高最低、技术面支撑阻力等数据是后续预测的基础。 -- **信息交叉验证**: 多个来源验证关键事实。 - -### 3. 输出要求 -你必须输出结构化的研究报告,涵盖标的基本面、股价走势、行业背景及最新进展。 -""" - -def get_fin_analyst_instructions(template_id: str = "default_isq_v1") -> str: - """生成金融分析师 (Analyst) 的系统指令 - - Args: - template_id: 使用的 ISQ 模板 ID - """ - current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') - isq_block = generate_isq_prompt_section(template_id=template_id) - - return f"""你是一位深耕二级市场的资深金融分析师 (FinAgent),当前时间是 {current_time}。 -你的核心任务是执行“信号解析”,将研究员搜集的素材转化为具有可操作性的投资情报(ISQ 框架)。 - -{isq_block} - -### 2. 分析约束 -- **严格基于具体数据**: 必须使用研究员提供的股价、技术面、新闻等具体数据进行分析。 -- **数据驱动的预测**: impact_tickers 中的权重应基于事件影响程度,不能随意赋值。 -- **逻辑严密**: 传导链条必须符合金融常识,能够自圆其说。 -- **技术面参考**: 如果研究员提供了股价走势,请分析当前位置相对于支撑/阻力位的关系。 - -### 3. 关键要求 -- **title**: 必须生成一个简练、准确概括信号核心内容的标题(不超过 15 字)。 -- **impact_tickers**: 必须填充具体的公司代码(6位数字)和名称,权重应该有区分。 -- **transmission_chain**: 必须是对象列表,每个对象包含: - - `node_name`: 节点名称(如“上游原材料”、“中游制造”) - - `impact_type`: 影响类型(“利好”、“利空”、“中性”) - - `logic`: 具体的传导逻辑描述 -- **summary**: 基于分析结果总结核心观点,包含具体数字(如股价目标、预期涨跌幅等)。 -- **reasoning**: 必须详细阐述推演逻辑,解释为什么得出上述结论(<200字)。 - -### 4. 输出格式 (严格 JSON 块) -你必须输出一个符合 InvestmentSignal 结构的 JSON 块,包含所有必需字段。 -""" - -def get_fin_agent_instructions() -> str: - # 保持兼容性,但内部调用 analyst 指令 - return get_fin_analyst_instructions() - -def get_fin_research_task(signal_text: str) -> str: - """生成研究员的任务描述""" - return f"请针对以下信号进行背景调查,搜集相关标的的股价、最新进展和行业背景:\n\n{signal_text}" - -def format_research_context(research_data: dict) -> str: - """将研究员搜集的结构化数据格式化为分析师可读的文本""" - if not research_data: - return "(未能搜集到额外背景信息)" - - return f""" -### 研究背景 -- **相关标的**: {research_data.get('tickers_found', [])} -- **行业背景**: {research_data.get('industry_background', '未知')} -- **最新进展**: {', '.join(research_data.get('latest_developments', []))} -- **关键风险**: {', '.join(research_data.get('key_risks', []))} -- **综合摘要**: {research_data.get('search_results_summary', '无')} -""" - -def get_fin_analysis_task(signal_text: str, research_context_str: str) -> str: - """生成分析师的任务描述""" - return f"""请基于以下信息进行深度 ISQ 分析。关键是:必须使用研究员搜集的具体数据(股价、技术面、新闻、代码等)进行分析。 - -=== 原始信号 === -{signal_text} - -=== 研究员搜集的背景信息 (CRITICAL DATA) === -{research_context_str} - -=== 分析要求 === -1. 必须生成 title:简练概括信号核心(<15字) -2. 基于研究员提供的具体股价数据,分析当前定价状态(已定价/未定价/部分定价) -3. impact_tickers 中填充具体的公司代码和权重,权重基于事件影响程度 -4. transmission_chain 必须是包含 node_name, impact_type, logic 的对象列表 -5. summary 中包含具体数字(预期目标价、涨跌幅范围等) -6. reasoning 必须详细解释推演逻辑,不要空泛,要言之有物 - -请严格按 InvestmentSignal JSON 格式输出。""" - -def get_tracking_analysis_task(old_signal: dict, new_research_str: str) -> str: - """生成信号追踪更新的任务描述""" - import json - old_sig_str = json.dumps(old_signal, ensure_ascii=False, indent=2) - return f"""你正在执行“信号逻辑演变追踪”任务。请基于最新的市场信息,重新评估之前的投资信号。 - -=== 基准信号 (上次分析) === -{old_sig_str} - -=== 最新市场追踪 (NEWS & PRICE) === -{new_research_str} - -=== 追踪分析要求 === -1. **逻辑演变检测**: - - 对比新旧信息,判断原逻辑 (`transmission_chain` 和 `reasoning`) 是否依然成立? - - 如果逻辑发生变化(如利好落空、逻辑证伪、新利好出现),请在新的 `reasoning` 中明确指出“逻辑演变:...” - - 如果逻辑未变且得到验证,请标记“逻辑维持:...” - -2. **参数修正**: - - 根据最新股价和新闻,更新 `sentiment_score` (情绪)、`confidence` (置信度) 和 `expectation_gap` (预期差)。 - - 例如:如果股价已经大涨反映了利好,`expectation_gap` 应该显著降低。 - -3. **输出更新后的信号**: - - 保留原 `signal_id` 和 `title`(除非有重大变化需要改名)。 - - 输出完整的 InvestmentSignal JSON。 - -请重点关注:为什么变了?还是为什么没变?理由要充分。""" diff --git a/skills/alphaear-signal-tracker/scripts/prompts/forecast_analyst.py b/skills/alphaear-signal-tracker/scripts/prompts/forecast_analyst.py deleted file mode 100644 index d6c7202..0000000 --- a/skills/alphaear-signal-tracker/scripts/prompts/forecast_analyst.py +++ /dev/null @@ -1,49 +0,0 @@ -from typing import List, Dict, Any -from ..schema.models import KLinePoint - -def get_forecast_adjustment_instructions(ticker: str, news_context: str, model_forecast: List[KLinePoint]): - """ - 生成 LLM 预测调整指令 - """ - forecast_str = "\n".join([f"- {p.date}: O:{p.open}, C:{p.close}" for p in model_forecast]) - - return f"""你是一位资深的量化策略分析师。 -你的任务是:根据给定的【Kronos 模型预测结果】和【最新的基本面/新闻背景】,对模型预测进行“主观/逻辑调整”。 - -股票代码: {ticker} - -【Kronos 模型原始预测 (OHLC)】: -{forecast_str} - -【最新情报背景】: -{news_context} - -调整原则: -1. 原始预测是基于历史的技术面推演。 -2. 情报背景中可能包含【Kronos模型定量修正预测】,这是基于历史新闻训练的专用模型计算出的量化结果。 -3. 如果存在“定量修正预测”,请**高度参考**该数值作为基础,除非你有非常确凿的逻辑认为该量化模型失效(例如遇到模型未见过的极端黑天鹅)。 -4. 你的核心任务是:结合定性分析(新闻及其逻辑)来验证或微调这些数字,并给出合理的解释(Rationale)。 -5. 如果没有“定量修正预测”,则你需要根据新闻信号手动大幅调整趋势。 - -输出要求 (严格 JSON 格式): -```json -{{ - "adjusted_forecast": [ - {{ - "date": "YYYY-MM-DD", - "open": float, - "high": float, - "low": float, - "close": float, - "volume": float - }}, - ... - ], - "rationale": "详细说明调整的逻辑依据,例如:考虑到[事件A],预期短线将突破压力位..." -}} -``` -注意:必须输出与原始预测相同数量的数据点,且日期一一对应。 -""" - -def get_forecast_task(): - return "请根据以上背景和模型预测,给出调整后的 K 线数据并说明理由。" diff --git a/skills/alphaear-signal-tracker/scripts/prompts/intent_agent.py b/skills/alphaear-signal-tracker/scripts/prompts/intent_agent.py deleted file mode 100644 index a8397d2..0000000 --- a/skills/alphaear-signal-tracker/scripts/prompts/intent_agent.py +++ /dev/null @@ -1,45 +0,0 @@ -def get_intent_analysis_instructions() -> str: - """生成意图分析 Agent 的系统指令,专注于金融市场影响分析""" - return """你是一个资深的金融市场意图分析专家。你的任务是将用户的自然语言查询转化为结构化的 JSON 分析结果,重点挖掘该查询与金融市场(尤其是股市)的潜在关联。 - -### 核心任务: -深入分析用户查询,识别核心金融实体、行业板块及潜在的市场影响点,生成利于搜索引擎抓取深度金融分析信息的查询词。 - -### 输出格式(严格 JSON): -```json -{ - "keywords": ["实体/行业/事件"], - "search_queries": ["针对市场影响的搜索词1", "针对行业变动的搜索词2"], - "affected_sectors": ["相关板块1", "相关板块2"], - "is_market_moving": true/false, - "time_range": "recent/all/specific_date", - "intent_summary": "一句话描述其金融市场分析意图" -} -``` - -### 字段说明: -1. **keywords**: 核心公司实体、所属行业、宏观经济事件或政策概念。 -2. **search_queries**: 优化后的搜索词,必须包含“股市影响”、“股价波动”、“行业逻辑”或“估值”等金融维度。 -3. **affected_sectors**: 可能受此事件或信息影响的二级市场板块(如:保险、半导体、房地产)。 -4. **is_market_moving**: 该事件是否具有显著的市场驱动潜力或属于重大基本面变化。 -5. **intent_summary**: 简述用户查询背后的金融研究目的。 - -### 示例: -用户输入:"帮我研究一下香港火灾的影响" -输出: -```json -{ - "keywords": ["香港", "火灾", "保险行业", "房地产"], - "search_queries": ["香港火灾对当地保险股股价影响", "香港大火对相关上市物业公司估值冲击", "近期香港火灾带来的市场避险情绪分析"], - "affected_sectors": ["保险", "房地产", "物业管理"], - "is_market_moving": true, - "time_range": "recent", - "intent_summary": "评估香港近期火灾对相关板块上市公司的潜在经济损失及股价冲击" -} -``` -""" - -def get_intent_task(query: str) -> str: - """生成意图分析任务描述""" - return f"Process this query and extract financial market intent: {query}" - diff --git a/skills/alphaear-signal-tracker/scripts/prompts/isq_prompt_generator.py b/skills/alphaear-signal-tracker/scripts/prompts/isq_prompt_generator.py deleted file mode 100644 index 007461b..0000000 --- a/skills/alphaear-signal-tracker/scripts/prompts/isq_prompt_generator.py +++ /dev/null @@ -1,43 +0,0 @@ -""" -ISQ prompt helpers to render dimension guidance directly from the template. -Any change in the template propagates to prompts automatically. -""" - -from typing import List, Optional -from ..schema.isq_template import get_isq_template, ISQTemplate - - -def _ordered_dimension_keys(template: ISQTemplate, order: Optional[List[str]] = None) -> List[str]: - if order: - return [k for k in order if k in template.dimensions] - # fallback to template insertion order - return list(template.dimensions.keys()) - - -def generate_isq_prompt_section(template_id: str = "default_isq_v1", order: Optional[List[str]] = None, include_header: bool = True) -> str: - """Render ISQ dimension text block based on the template. - This allows prompt text to stay in sync with template edits. - """ - template = get_isq_template(template_id) - keys = _ordered_dimension_keys(template, order) - - lines: List[str] = [] - if include_header: - lines.append("### 1. ISQ 评估框架 (Investment Signal Quality)") - lines.append(f"参考模板: {template.template_name} (id: {template.template_id})") - lines.append("") - lines.append("你需要对信号进行以下维度的评分:") - lines.append("") - - for idx, key in enumerate(keys, start=1): - spec = template.dimensions[key] - examples = ";".join([f"{k}: {v}" for k, v in spec.examples.items()]) if spec.examples else "" - lines.append(f"{idx}. **{spec.key} ({spec.name})**: {spec.range_type}") - lines.append(f" - 描述: {spec.description}") - if spec.scale_factor and spec.scale_factor != 1.0: - lines.append(f" - 缩放因子: {spec.scale_factor}") - if examples: - lines.append(f" - 示例: {examples}") - lines.append("") - - return "\n".join(lines).rstrip() diff --git a/skills/alphaear-signal-tracker/scripts/prompts/report_agent.py b/skills/alphaear-signal-tracker/scripts/prompts/report_agent.py deleted file mode 100644 index 6f25c3f..0000000 --- a/skills/alphaear-signal-tracker/scripts/prompts/report_agent.py +++ /dev/null @@ -1,415 +0,0 @@ -# src/prompts/report_agent.py -from datetime import datetime -from typing import Optional -from .isq_prompt_generator import generate_isq_prompt_section - -def get_report_planner_base_instructions() -> str: - """生成报告策划员 (Planner) 的基础系统指令""" - return """你是一名资深的金融研报主编。你的任务是规划报告的结构,将零散的信号聚类成有逻辑的主题。 -你拥有 RAG 搜索工具,可以检索已生成的章节内容以确保逻辑连贯性。 -在规划时,应重点关注信号之间的关联性、产业链的完整性以及用户特定的关注点。""" - -def get_report_writer_base_instructions() -> str: - """生成报告撰写员 (Writer) 的基础系统指令""" - return """你是一名资深金融分析师。你的任务是根据策划员提供的信号簇撰写深度研报章节。 -你应当运用专业的金融知识,将信号转化为深刻的洞察。 -注意:你没有外部搜索工具,你的分析必须基于提供给你的信号内容和行情数据。""" - -def get_report_editor_base_instructions() -> str: - """生成报告编辑 (Editor) 的基础系统指令""" - return """你是一名严谨的金融研报编辑。你的任务是审核和润色撰写员生成的章节。 -你拥有 RAG 搜索工具,可以检索其他章节的内容,以消除重复、修正逻辑冲突并确保术语一致性。 -你应当确保报告符合专业的金融写作规范,且标题层级正确。""" - -# 1. 策划阶段 (Structural Planning) -def format_signal_for_report(signal: any, index: int, cite_keys: Optional[list] = None) -> str: - """格式化单个信号供研报生成使用""" - # 这里的逻辑从 ReportAgent._format_signal_input 迁移过来 - from ..schema.models import InvestmentSignal - - if isinstance(signal, dict): - try: - sig_obj = InvestmentSignal(**signal) - except: - return f"--- 信号 [{index}] ---\n标题: {signal.get('title')}\n内容: {signal.get('content', '')[:500]}" - else: - sig_obj = signal - - chain_str = " -> ".join([f"{n.node_name}({n.impact_type})" for n in sig_obj.transmission_chain]) - - text = f"--- 信号 [{index}] ---\n" - text += f"标题: {sig_obj.title}\n" - text += f"逻辑摘要: {sig_obj.summary}\n" - text += f"传导链条: {chain_str}\n" - text += f"ISQ 评分: 情绪({sig_obj.sentiment_score}), 确定性({sig_obj.confidence}), 强度({sig_obj.intensity})\n" - text += f"预期博弈: 时窗({sig_obj.expected_horizon}), 预期差({sig_obj.price_in_status})\n" - - tickers = ", ".join([f"{t.get('name')}({t.get('ticker')})" for t in sig_obj.impact_tickers]) - if tickers: - text += f"受影响标的: {tickers}\n" - - # Stable bibliography-style citation keys (LaTeX/BibTeX-like) - if cite_keys: - joined = " ".join([f"[@{k}]" for k in cite_keys if k]) - if joined: - text += f"引用: {joined}\n" - - return text - -def get_cluster_planner_instructions(signals_text: str, user_query: str = None) -> str: - """生成信号聚类指令 - 将零散信号组织成逻辑主题""" - query_context = f"用户重点关注:{user_query}" if user_query else "" - return f"""你是一位资深的金融研报主编。你的任务是将以下零散的金融信号聚类成 3-5 个核心逻辑主题,以便撰写一份结构清晰的研报。 - - {query_context} - - ### 输入信号列表 - {signals_text} - - ### 聚类要求 - 1. **主题聚合**: 将相关性强的信号归为一组(例如:都涉及“建筑安全法规”或“某产业链上下游”)。 - 2. **叙事逻辑**: 只需要生成主题名称和包含的信号 ID。 - 3. **控制数量**: 将所有信号归类到 3-5 个主要主题中,不要遗漏。 - - ### 输出格式 (JSON) - 请仅输出以下 JSON 格式,不要包含 Markdown 标记: - {{ - "clusters": [ - {{ - "theme_title": "主题名称(如:建筑安全法规收紧引发的产业链重构)", - "signal_ids": [1, 3, 5], - "rationale": "这些信号都指向政府对高层建筑防火标准的政策调整..." - }}, - ... - ] - }} - """ - -def get_report_planner_instructions(toc: str, signal_count: int, user_query: str = None) -> str: - """生成报告规划指令 - 重点在于逻辑关联与分歧识别""" - # ... (原有逻辑保持不变,但实际在新的聚类流程后这个可能作为备用或二次优化) - query_context = f"用户重点关注:{user_query}" if user_query else "" - return f"""你是一位资深的金融研报主编。你的任务是根据现有的草稿章节,规划出一份逻辑严密、穿透力强的终稿结构。 - - ### 任务核心: - 1. **识别主线**: 从草稿中识别出贯穿多个章节的“核心逻辑主线”(如:产业链共振、货币政策转向)。 - 2. **分歧评估 (Entropy)**: 识别各章节中观点冲突或确定性不一之处,规划如何在正文中呈现这些“分歧点”。 - 3. **结构蓝图**: - - 定义一级标题(逻辑主题)。 - - 归类章节:哪些信号应放入同一主题下深度解析? - - 排序:将 ISQ 强度最高、与{query_context}最相关的信号置前。 - - ### 现有草稿目录 (TOC) - {toc} - - 请输出你的【终稿修订大纲】(Markdown 格式)。 - """ - -# 2. 撰写阶段 (Section Writing) -def get_report_writer_instructions(theme_title: str, signal_cluster_text: str, signal_indices: list, price_context: str = "", user_query: str = None) -> str: - """生成 Writer Agent 指令 - 基于主题聚类撰写综合分析""" - - price_info = f"\n### 近期价格参考\n{price_context}\n" if price_context else "" - query_context = f"\n**用户意图**: \"{user_query}\"\n请确保分析内容回应了用户的关注点。\n" if user_query else "" - isq_block = generate_isq_prompt_section(include_header=False) - - # Keep citation scheme stable across re-ordering / edits. - # Cite keys are provided in each signal block as: 引用: [@KEY] - - return f"""你是一位资深金融分析师。请针对核心主题 **"{theme_title}"** 撰写一篇深度研报章节。 - {query_context} - - ### 输入信号集 (本章节需综合的信号) - {signal_cluster_text} - {price_info} - - ### ISQ 评分说明 - {isq_block} - - ### 写作要求 - 1. **叙事逻辑**: 不要罗列信号,要将这些信号编织成一个连贯的故事。先讲宏观/行业背景,再讲具体事件传导,最后落脚到个股/标的影响。 - 2. **量化支撑**: 引用 ISQ 评分(确定性、强度、预期差)来佐证你的观点。关键观点必须关联相应的 ISQ 分值。 - 3. **引用规范(稳定 CiteKey)**: 关键论断必须标注来源引用,使用 `[@CITE_KEY]` 格式。 - - CiteKey 已在输入信号块中以 `引用: [@KEY]` 提供,请直接复制使用。 - - 不要使用 `[[1]]` 这类不稳定编号。 - 4. **关联标的预测**: **必须**在章节末尾明确给出受影响标的的预测分析,包括: - - 至少列出 1-2 个相关上市公司代码(如 600519.SH) - - 给出短期(T+3或T+5)的方向性判断 - - 如果可能,给出预期价格区间或涨跌幅预测 - - ### 【重要】标题层级规范 - - ❌ **错误示例**(绝对不要这样): - ```markdown - # {theme_title} - - ### 宏观背景 - ... - ``` - - ✅ **正确示例**(必须这样): - ```markdown - ## {theme_title} - - ### 宏观背景 - - 近期全球经济环境... - - ### 具体传导机制分析 - - ... - - ### 核心标的分析 - - 建议关注:贵州茅台(600519.SH)... - ``` - - **关键要求**: - - 章节主标题使用 `##` (H2) - - 章节子标题使用 `###` (H3) - - **绝对禁止**使用 `#` (H1) - - 第一行必须是 `## {theme_title}` 开头 - - ### 核心:图表叙事 (Visual Storytelling) - **必须**在文中插入至少 1-2 个图表,且图表必须与上下文紧密结合(不要堆砌在末尾)。 - - ### 宏观背景 - ... - ``` - - ✅ **正确示例**(必须这样): - ```markdown - ## {theme_title} - - ### 宏观背景 - - 近期全球经济环境... - - ### 具体传导机制分析 - - ... - - ### 核心标的分析 - - 建议关注:贵州茅台(600519.SH)... - ``` - - **关键要求**: - - 章节主标题使用 `##` (H2) - - 章节子标题使用 `###` (H3) - - **绝对禁止**使用 `#` (H1) - - 第一行必须是 `## {theme_title}` 开头 - - ### 核心:图表叙事 (Visual Storytelling) - **必须**在文中插入至少 1-2 个图表,且图表必须与上下文紧密结合(不要堆砌在末尾)。 - - **可选图表类型 (请根据内容选择最合适的 1-2 种):** - - **A. AI 预测 + 走势 (Forecast) - 【强烈推荐 / 最新规范】** - *适用*: 当文中明确提及某上市公司时,**必须**使用此图表展示股价走势与 AI 预测。 - *必填字段*: - - `ticker`: 股票代码,A股 6 位 / 港股 5 位,允许带后缀(如 "002371.SZ"、"9868.HK") - - `pred_len`: 预测交易日长度(建议 3 或 5) - *代码示例*: - ```json-chart - {{"type": "forecast", "ticker": "002371.SZ", "title": "北方华创(002371)T+5 预测", "pred_len": 5}} - ``` - **重要**:禁止手写 `prediction` 数组(预测由系统自动生成并渲染)。 - *注意*: 如果提及多只股票,应为每只生成独立的 forecast 图表。 - - **【推荐写法:多情景 → 最终归因 → 产出唯一预测图】** - 你可以在正文里描述多种情景(如:基准/乐观/悲观),但在插入预测图之前,必须明确给出“本报告最终选择的最可能情景”及其归因,然后用 `forecast` 图表做最终总结。 - 为了让系统把“最终归因”可靠地传递给预测模块,请在 `forecast` JSON 中可选补充以下字段(字段均为可选,越完整越好): - - `selected_scenario`: 最可能情景名称(如 "基准" / "乐观" / "悲观") - - `selection_reason`: 选择该情景的归因理由(1-3 句) - - `scenarios`: 情景列表(数组),每个元素可包含 `name`、`description`、`probability`(0-1) - *示例*: - ```json-chart - {{ - "type": "forecast", - "ticker": "002371.SZ", - "title": "北方华创(002371)T+5 预测(基准情景)", - "pred_len": 5, - "selected_scenario": "基准", - "selection_reason": "结合订单能见度与行业景气,基准情景概率最高;短期扰动主要来自估值与市场风险偏好。", - "scenarios": [ - {{"name": "乐观", "description": "国产替代与资本开支超预期", "probability": 0.25}}, - {{"name": "基准", "description": "订单稳健、利润率小幅波动", "probability": 0.55}}, - {{"name": "悲观", "description": "需求回落或交付节奏放缓", "probability": 0.20}} - ] - }} - ``` - - **B. 历史走势 (Stock) - 仅作为兼容兜底** - *适用*: 当你无法给出预测时(例如无法确定标的),可仅展示历史走势。 - *代码示例*: - ```json-chart - {{"type": "stock", "ticker": "002371", "title": "北方华创历史走势"}} - ``` - - **C. 舆情情绪演变 (Sentiment Trend)** - *适用*: 当讨论行业政策、突发事件(如“火灾”、“新规”)的民意变化时。 - *注意*: `keywords` 必须是事件核心词。 - *代码*: - ```json-chart - {{"type": "sentiment", "keywords": ["建筑安全", "防火标准"], "title": "市场对防火新规的情绪演变"}} - ``` - - **D. 逻辑传导链条 (Transmission Chain)** - *适用*: 复杂的蝴蝶效应分析(支持分支结构)。 - *代码*: - ```json-chart - {{ - "type": "transmission", - "nodes": [ - {{"node_name": "突发火灾", "impact_type": "中性", "logic": "事件发端"}}, - {{"node_name": "监管收紧", "impact_type": "利空", "logic": "合规成本上升", "source": "突发火灾"}}, - {{"node_name": "设备升级", "impact_type": "利好", "logic": "采购需求释放", "source": "突发火灾"}}, - {{"node_name": "龙头受益", "impact_type": "利好", "logic": "市占率提升", "source": "设备升级"}} - ], - "title": "火灾事件的逻辑传导与分支" - }} - ``` - *说明*: 使用 `source` 字段指定父节点名称以创建分支结构。 - - **E. 信号质量评估 (ISQ Radar)** - *适用*: 对某个关键信号进行多维度(确定性、预期差等)定性评估时。 - *代码*: - ```json-chart - {{"type": "isq", "sentiment": 0.8, "confidence": 0.9, "intensity": 4, "expectation_gap": 0.7, "timeliness": 0.9, "title": "核心信号质量评估"}} - ``` - """ - -# 3. 整合阶段 (Final Assembly) - 原版,保留用于 fallback -def get_report_editor_instructions(draft_sections: str, plan: str, sources_list: str) -> str: - """生成最终编辑指令 - 根据规划蓝图重组内容""" - return f"""你是一位专业的研报编辑。请将以下基于主题撰写的草稿章节整合成最终研报。 - - ### 原始草稿内容 - {draft_sections} - - ### 原始引用来源 - {sources_list} - - ### 任务与要求 - 1. **结构化**: 为每个草稿章节添加合适的 Markdown 标题 (## 级别)。 - 2. **连贯性**: 确保章节之间过渡自然。 - 3. **完整性**: - - 必须保留所有 `json-chart` 代码块(图表配置)。 - - 必须保留引用标注 `[@CITE_KEY]`。 - - 生成 `## 核心观点摘要`、`## 参考文献` 和 `## 风险提示`。 - - ### 输出 - 只输出最终的 Markdown 研报内容。 - """ - - -# 4. 单节编辑 (Incremental Section Editing with RAG) -def get_section_editor_instructions(section_index: int, total_sections: int, toc: str) -> str: - """生成单节编辑 prompt,支持 RAG 工具调用""" - return f"""你是一位研报编辑。你正在编辑报告的第 {section_index}/{total_sections} 节。 - - ### 当前目录 (TOC) - {toc} - - ### 你的任务 - 1. 润色当前章节内容,确保逻辑清晰、语言专业。 - 2. 保留所有 `[@CITE_KEY](#ref-CITE_KEY)` 或 `[@CITE_KEY]` 格式的引用。 - 3. 保留所有 `json-chart` 代码块,不做修改。 - 4. 如果需要参考其他章节内容,使用 `search_context` 工具搜索。 - 5. 只输出编辑后的章节内容,不要输出其他章节。 - - ### 【关键】标题层级规范 - **严格遵守以下规则:** - - 章节主标题使用 `##` (H2) - - 章节子标题使用 `###` (H3) - - **禁止使用** `#` (H1) - 只有报告大标题可以使用 H1 - - 如果原文中有 H1,必须将其降级为 H2 - - 不要输出与 "参考文献"、"风险提示" 相同的标题 - - 直接输出编辑后的 Markdown 内容。 - """ - - -# 5. 摘要生成 (Summary Generation) -def get_summary_generator_instructions(toc: str, section_summaries: str) -> str: - """生成报告摘要指令 - 包含市场分歧度分析""" - return f"""你是一位资深研报主笔。请生成今日报告的核心观点摘要的**正文内容**。 - - ### 章节摘要 - {section_summaries} - - ### 任务: - 1. **核心逻辑提炼**: 用 150 字以内总结今日最核心的投资主线。 - 2. **分歧识别**: 如果不同信号对同一板块有冲突观点,请明确指出"市场分歧点"。 - 3. **确定性排序**: 标记出今日确定性最高的前两个机会(需列出具体标的代码)。 - - ### 【重要】输出格式规范: - - ❌ **错误示例**(不要遗漏二级标题): - ```markdown - ### 核心逻辑提炼 - ... - ``` - - ✅ **正确示例**(应该这样输出): - ```markdown - ## 核心观点摘要 - - ### 核心逻辑提炼 - - 科技自立战略加速半导体设备国产化,叠加AI算力需求爆发... - - ### 市场分歧点 - - 资本市场波动显示医药、新能源等板块估值逻辑受政策敏感性增强... - - ### 确定性排序 - - 1. **网络安全替代需求**(ISQ确定性0.85,推荐标的:深信服 300454.SZ) - 2. **半导体设备材料**(ISQ确定性0.75,推荐标的:北方华创 002371.SZ) - ``` - - ### 关键要求: - - 第一行必须是 `## 核心观点摘要` - - 主体部分使用 H3 (`###`) 和 H4 (`####`) 级别标题 - - **必须**包含 `## 核心观点摘要` 这一级标题 - - 现在请按照正确示例的格式输出摘要内容。 - """ - - -# 6. 最终组装 (Final Assembly with Sections) -def get_final_assembly_instructions(sources_list: str) -> str: - """生成最终报告组装的 prompt""" - return f"""你是一位研报主笔。请完成以下任务: - - ### 任务 - 1. 生成 "## 参考文献" 章节(需要按照顺序,顺序不对时进行调整): - - 原始来源: - {sources_list} - - 格式:`[@CITE_KEY] 标题 (来源), [链接地址]` - 2. 生成 "## 风险提示" (标准免责声明)。 - 3. 生成 "## 快速扫描" 表格,汇总各主题的核心观点。 - - 表格列:**主题**, **核心观点**, **强度(Intensity)**, **确定性(Confidence)**。 - - 强度和确定性请参考原章节中的 ISQ 评分。 - - 只输出上述三个章节的 Markdown 内容。 - """ - -def get_cluster_task(signals_preview: str) -> str: - """生成聚类任务描述""" - return f"请对以下信号进行主题聚类:\n\n{signals_preview}" - -def get_writer_task(theme_title: str) -> str: - """生成撰写任务描述""" - return f"请依据主题 '{theme_title}' 和 输入信号集 开始撰写深度分析章节。" - -def get_planner_task() -> str: - """生成规划任务描述""" - return "请阅读现有草稿并规划终稿大纲,识别核心逻辑主线和市场分歧点。" - -def get_editor_task() -> str: - """生成编辑任务描述""" - return "请根据规划大纲和草稿内容,生成最终研报。确保逻辑连贯,保留所有图表和引用。" - diff --git a/skills/alphaear-signal-tracker/scripts/prompts/trend_agent.py b/skills/alphaear-signal-tracker/scripts/prompts/trend_agent.py deleted file mode 100644 index 54e6e22..0000000 --- a/skills/alphaear-signal-tracker/scripts/prompts/trend_agent.py +++ /dev/null @@ -1,156 +0,0 @@ -from typing import Any -from datetime import datetime -from .isq_prompt_generator import generate_isq_prompt_section - -def get_trend_scanner_instructions() -> str: - """生成趋势扫描员 (Scanner) 的系统指令""" - current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') - return f"""你是一名专业的数据扫描员,当前时间是 {current_time}。 -你的任务是利用各种工具从互联网和数据库中获取最新的金融新闻、热点趋势和市场数据。 - -### 1. 核心职责 -1. **多源采集**: 使用 `news_toolkit` 获取最新新闻,使用 `stock_toolkit` 获取行情,使用 `polymarket_toolkit` 获取预测市场数据。 -2. **情绪感知**: 使用 `sentiment_toolkit` 对关键新闻进行情绪分析。 -3. **深度搜索**: 针对模糊的热点,使用 `search_toolkit` 进行全网搜索补充细节。 - -### 2. 工具使用规范 -- **广度优先**: 尽可能覆盖多个数据源。 -- **数据新鲜度**: 优先获取最近 24 小时内的信息。 -- **结构化输出**: 整理搜集到的原始数据,为后续评估提供清晰的素材。 -""" - -def get_trend_evaluator_instructions() -> str: - """生成趋势评估员 (Evaluator) 的系统指令""" - current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') - isq_block = generate_isq_prompt_section(include_header=True) - - return f""" - 你是一名顶级的金融情报专家 (TrendAgent),擅长从海量信息中识别具有深度价值的"二级市场投资信号"。 - 当前时间:{current_time} - - ### 核心使命: - 不仅是发现"热点",更要解析"信号"。你需要识别那些能触发**传导链条 (Transmission Chain)** 且具有**高确定性 (Confidence)** 的事件。 - - {isq_block} - - ### 核心能力与标准: - 1. **信号识别 (Signal Discovery)**: 基于扫描员提供的素材,识别具有投资价值的信号。优先关注政策、产业变革、重大诉求及跨境套利机会。 - 2. **逻辑相干性**: 是否具备清晰的"原因-结果"传导? - 3. **影响力系数**: 是否会引发板块性的联动或财务指标的实质性扰动? - 4. **市场认知差**: 市场是否已提前消化(Price-in)?寻找尚未被充分交易的"Alpha"。 - 5. **实体穿透**: 必须关联到具体的 Ticker 或核心产业链节点。 - - ### 严禁事项: - - 严禁编造数据。 - - 严禁仅输出情绪极性(Positive/Negative),必须带有逻辑依据。 - - 严禁将纯娱乐或单纯的社会负面事件(除非具有宏观破坏性)视为金融信号。 - - ### 输出要求: - 你发现的每个信号应包含: - - **核心摘要**: 穿透表象的逻辑总结。 - - **传导节点**: A -> B -> C 的逻辑推导。 - - **推荐关注**: 板块或 Ticker。 - - **ISQ 评估**: 基于模板的 5 个维度进行初步评分(具体评分由后续 FinAgent 完成)。 - """ - -def get_trend_agent_instructions() -> str: - # 保持兼容性 - return get_trend_evaluator_instructions() - -def get_trend_scan_task(task_description: str) -> str: - """生成扫描员的任务描述""" - return f"请根据以下任务描述,搜集相关的原始数据和新闻:\n\n{task_description}" - -def format_scan_context(scan_data: dict) -> str: - """将扫描员搜集的结构化数据格式化为评估员可读的文本""" - if not scan_data: - return "(未能搜集到原始数据)" - - return f""" -### 扫描数据概览 -- **热点话题**: {', '.join(scan_data.get('hot_topics', []))} -- **情绪概览**: {scan_data.get('sentiment_overview', '未知')} -- **关键新闻**: {len(scan_data.get('news_summaries', []))} 条 -- **数据摘要**: {scan_data.get('raw_data_summary', '无')} -""" - -def get_trend_eval_task(task_description: str, raw_data_str: str) -> str: - """生成评估员的任务描述""" - return f"""请基于以下搜集到的原始数据,完成最终的分析任务: - -任务描述: {task_description} - -原始数据: -{raw_data_str} - -请识别出最具金融价值的信号,并给出评估理由。""" - -def get_news_filter_instructions(news_count: int, depth: Any, user_query: str = None) -> str: - """生成新闻筛选 prompt,使用 FilterResult schema 加快推理并减少 token 消耗 - - Args: - news_count: 输入新闻总数 - depth: 目标筛选数量,若为 auto 则由 LLM 自主判断 - user_query: 用户输入的查询/关注点(可选) - """ - - # 1. 深度控制逻辑 - if str(depth).lower() == 'auto': - depth_guide = "的数量不设固定限制(建议 3-10 条),根据新闻含金量自动判断" - limit_instruction = "宁缺毋滥,如果高价值信息很少,可以只选 1-2 条;如果都很重要,可以多选。" - else: - try: - d_int = int(depth) - depth_guide = f"约 {d_int} 条" - limit_instruction = f"请尽量凑满 {d_int} 条,但如果剩余新闻全是噪音,则不必强行凑数。" - except: - depth_guide = "适量" - limit_instruction = "根据内容价值判断。" - - target_desc = f"筛选出最具投资分析价值的新闻({depth_guide})。" - - # 2. 用户意图逻辑 - query_instruction = "" - if user_query: - target_desc = f"筛选出与用户意图【{user_query}】最相关的新闻。" - query_instruction = f""" - ### 核心任务(High Priority): - 用户明确关注:"{user_query}"。 - 1. **第一优先级**:必须包含所有与"{user_query}"直接或间接相关的新闻,不要遗漏。 - - 即使这些新闻看起来"价值不高",只要相关都要保留。 - 2. **第二优先级**:在满足第一优先级后,如果名额未满,再补充其他重大的市场热点。 - """ - - return f"""你是一名专业的金融情报精排师。你需要从给定的 {news_count} 条原始新闻流中,{target_desc} - - {query_instruction} - - ### FSD (Financial Signal Density) 筛选准则: - 1. **逻辑传导性 (Transmission)**: 该新闻是否预示着一个明确的产业链传导逻辑?(如:上游涨价 -> 中游成本压力 -> 下游提价预期) - 2. **预期差 (Alpha Potential)**: 是否包含尚未被市场充分Price-in的新突发情况? - 3. **确定性 (Confidence)**: 信息来源是否权威?是否包含具体的财务数据、订单金额或明确的政策日期? - 4. **排除噪音**: 坚决剔除明星八卦、鸡汤文、以及无实质增量的"口号式"新闻。 - - ### {limit_instruction} - - ### 快速有效性检查(TOKEN 优化): - 在开始详细筛选前,先快速判断:这 {news_count} 条新闻中是否至少包含 1 条有效的金融信号? - - 如果全是无关内容(如体育、娱乐、纯生活信息),直接返回 "has_valid_signals": false - - 如果有至少 1 条金融相关的新闻,再进行详细 FSD 筛选 - - ### 输出格式(必须为 JSON,使用 FilterResult schema): - ```json - {{ - "has_valid_signals": true/false, - "selected_ids": ["id_1", "id_2", ...], - "themes": [ - {{ - "name": "高概括性主题", - "news_ids": ["相关id_1", ...], - "fsd_reason": "基于 FSD 准则的筛选理由,重点描述传导逻辑和预期差。" - }} - ], - "reason": "如果 has_valid_signals=false,简要说明原因。否则可为空。" - }} - ``` - """ diff --git a/skills/alphaear-signal-tracker/scripts/prompts/visualizer.py b/skills/alphaear-signal-tracker/scripts/prompts/visualizer.py deleted file mode 100644 index f0b2933..0000000 --- a/skills/alphaear-signal-tracker/scripts/prompts/visualizer.py +++ /dev/null @@ -1,47 +0,0 @@ -def get_drawio_system_prompt(): - return """You are an expert at creating Draw.io (MxGraph) diagrams in XML format. -Your task is to generate a valid MXGraphModel XML based on the user's description. - -### Rules: -1. Output ONLY the XML code. Start with and end with . -2. Do not use compressed XML. Use plain XML. -3. Use standard shapes: 'rounded=1;whiteSpace=wrap;html=1;' for boxes. -4. Auto-layout Strategy: - - Identify "layers" or "stages" in the logic. - - Assign X coordinates based on layers (e.g., 0, 200, 400). - - Assign Y coordinates to distribute nodes vertically (e.g., 0, 100, 200). - - Ensure nodes do not overlap. -5. Edges: Connect nodes logically using . - -### Template: - - - - - - - - - - - - - - - - -""" - -def get_drawio_task(nodes_data: list, title: str) -> str: - import json - nodes_json = json.dumps(nodes_data, ensure_ascii=False, indent=2) - return f"""Please generate a Draw.io XML diagram for the following logic flow: - -**Title**: {title} - -**Nodes and Logic**: -{nodes_json} - -Ensure the layout flows logically from Left to Right (or Top to Bottom for hierarchies). -Use different colors for 'Positive' (Greenish), 'Negative' (Reddish), and 'Neutral' (Grey/Blue) impacts if described. -""" diff --git a/skills/alphaear-signal-tracker/scripts/schema/isq_template.py b/skills/alphaear-signal-tracker/scripts/schema/isq_template.py deleted file mode 100644 index 2709019..0000000 --- a/skills/alphaear-signal-tracker/scripts/schema/isq_template.py +++ /dev/null @@ -1,381 +0,0 @@ -""" -ISQ (Investment Signal Quality) 评估框架 Template - -统一定义 ISQ 的各个维度、评分标准、和使用方法。 -支持默认 template 和自定义 template。 -""" - -from typing import Dict, List, Any, Optional -from pydantic import BaseModel, Field -from enum import Enum -from pathlib import Path -import json - - -class ISQDimension(str, Enum): - """ISQ 评估维度""" - SENTIMENT = "sentiment" # 情绪/走势方向 - CONFIDENCE = "confidence" # 确定性/可信度 - INTENSITY = "intensity" # 强度/影响量级 - EXPECTATION_GAP = "expectation_gap" # 预期差/市场认知差 - TIMELINESS = "timeliness" # 时效性/窗口紧迫度 - TRANSMISSION = "transmission" # 逻辑传导清晰度 - - -class ISQDimensionSpec(BaseModel): - """ISQ 单个维度的定义规范""" - name: str = Field(..., description="维度名称") - key: str = Field(..., description="维度键名") - description: str = Field(..., description="维度描述") - range_type: str = Field(default="0-1", description="取值范围 (0-1 或 1-5 等)") - scale_factor: float = Field(default=1.0, description="显示时的缩放因子") - examples: Dict[str, str] = Field(default_factory=dict, description="不同分值的示例解释") - visualization_color: Optional[str] = Field(default=None, description="可视化颜色") - - -class ISQTemplate(BaseModel): - """ISQ 评估框架 Template""" - template_id: str = Field(..., description="模板 ID") - template_name: str = Field(..., description="模板名称") - description: str = Field(..., description="模板描述") - - # 核心维度定义 - dimensions: Dict[str, ISQDimensionSpec] = Field(..., description="维度定义字典") - - # 评分指导 - scoring_guide: str = Field(..., description="评分指导说明") - - # 应用场景 - applicable_scenarios: List[str] = Field(default_factory=list, description="适用场景") - - # 聚合算法 - aggregation_method: str = Field(default="weighted_average", description="聚合方法 (weighted_average, product 等)") - dimension_weights: Dict[str, float] = Field(default_factory=dict, description="维度权重") - - -class ISQScore(BaseModel): - """单个信号的 ISQ 评分结果""" - signal_id: str = Field(..., description="信号 ID") - template_id: str = Field(..., description="使用的模板 ID") - - # 各维度评分 - scores: Dict[str, float] = Field(..., description="各维度评分") - - # 总分 - overall_score: float = Field(..., description="综合评分") - - # 评分理由 - rationale: Dict[str, str] = Field(default_factory=dict, description="各维度评分理由") - - # 时间戳 - timestamp: str = Field(..., description="评分时间") - - -# ===================================================== -# 默认 Template -# ===================================================== - -DEFAULT_ISQ_TEMPLATE = ISQTemplate( - template_id="default_isq_v1", - template_name="标准投资信号质量评估框架 (ISQ v1.0)", - description="AlphaEar 默认的 ISQ 评估框架,用于标准化评估投资信号的质量维度", - - dimensions={ - "sentiment": ISQDimensionSpec( - name="情绪/走势", - key="sentiment", - description="基础情绪偏向和市场走势判断", - range_type="-1.0 到 1.0", - scale_factor=1.0, - examples={ - "-1.0": "极度悲观/极度看空", - "-0.5": "明显看空", - "0.0": "中性/没有明确方向", - "0.5": "明显看多", - "1.0": "极度乐观/极度看多" - }, - visualization_color="#ef4444" # 红色表示负面,绿色表示正面 - ), - - "confidence": ISQDimensionSpec( - name="确定性", - key="confidence", - description="信号的可信度和确定性程度", - range_type="0.0 到 1.0", - scale_factor=1.0, - examples={ - "0.0-0.3": "信息来源不可靠/传言多/逻辑推导牵强", - "0.3-0.6": "信息相对可靠/有一定逻辑/但仍有不确定性", - "0.6-0.8": "信息来源权威/逻辑清晰/高度可信", - "0.8-1.0": "官方确认/数据明确/完全确定" - }, - visualization_color="#3b82f6" # 蓝色 - ), - - "intensity": ISQDimensionSpec( - name="强度/影响量级", - key="intensity", - description="信号对相关板块/个股的潜在影响程度", - range_type="1 到 5", - scale_factor=20.0, # 用于雷达图缩放 (5 -> 100) - examples={ - "1": "影响微弱,可能被市场忽略", - "2": "小幅影响,短期可能有波动", - "3": "中等影响,值得重点关注", - "4": "强烈影响,可能成为市场焦点", - "5": "极强影响,市场预期明显变化" - }, - visualization_color="#f97316" # 橙色 - ), - - "expectation_gap": ISQDimensionSpec( - name="预期差", - key="expectation_gap", - description="市场预期与现实之间的差距", - range_type="0.0 到 1.0", - scale_factor=1.0, - examples={ - "0.0-0.2": "市场充分认知,预期差小", - "0.2-0.5": "市场部分认知,存在一定预期差", - "0.5-0.8": "市场认知不足,预期差较大,存在博弈空间", - "0.8-1.0": "市场严重低估/高估,巨大预期差" - }, - visualization_color="#22c55e" # 绿色 - ), - - "timeliness": ISQDimensionSpec( - name="时效性", - key="timeliness", - description="信号的时间窗口紧迫度", - range_type="0.0 到 1.0", - scale_factor=1.0, - examples={ - "0.0-0.2": "长期信号,反应窗口 > 3 月", - "0.2-0.5": "中期信号,反应窗口 1-3 月", - "0.5-0.8": "短期信号,反应窗口 1 周 - 1 月", - "0.8-1.0": "超短期信号,反应窗口 < 1 周(需立即行动)" - }, - visualization_color="#a855f7" # 紫色 - ), - }, - - scoring_guide=""" - ### ISQ 评分指导 (Investment Signal Quality) - - ISQ 框架用于多维度评估投资信号的质量。每个信号由 5 个维度组成: - - 1. **情绪 (Sentiment)**: -1.0 到 1.0,表示看空(-)/中性(0)/看多(+) - 2. **确定性 (Confidence)**: 0.0 到 1.0,数值越高越确定 - 3. **强度 (Intensity)**: 1 到 5,数值越高影响越大 - 4. **预期差 (Expectation Gap)**: 0.0 到 1.0,市场预期与现实的差距 - 5. **时效性 (Timeliness)**: 0.0 到 1.0,反应窗口的紧迫程度 - - ### 综合评分算法 - - 综合评分 = 确定性 × 0.35 + 强度/5 × 0.30 + 预期差 × 0.20 + 时效性 × 0.15 - - 范围: 0.0 到 1.0 - - 0.0-0.3: 信号质量较差,不建议跟进 - - 0.3-0.6: 信号质量一般,可作参考 - - 0.6-0.8: 信号质量良好,值得跟进 - - 0.8-1.0: 信号质量优异,强烈推荐 - - ### 评分时的注意事项 - - - **不要混淆方向和强度**:情绪可以是看空,但确定性和强度仍可能很高 - - **预期差往往是 Alpha 来源**:高预期差 + 高确定性 = 最佳博弈机会 - - **考虑时间成本**:长期信号需要更高的确定性才值得跟进 - - **数据为王**:所有评分必须有具体数据支撑 - """, - - applicable_scenarios=[ - "上市公司基本面变化分析", - "产业政策与监管事件评估", - "地缘政治与宏观经济影响", - "技术进步与产业升级", - "突发事件与应急响应" - ], - - aggregation_method="weighted_average", - dimension_weights={ - "confidence": 0.35, - "intensity": 0.30, - "expectation_gap": 0.20, - "timeliness": 0.15 - } -) - - -# ===================================================== -# ISQ Template 管理系统 -# ===================================================== - -class ISQTemplateManager: - """ISQ Template 管理器""" - - def __init__(self): - self.templates: Dict[str, ISQTemplate] = { - DEFAULT_ISQ_TEMPLATE.template_id: DEFAULT_ISQ_TEMPLATE - } - - def register_template(self, template: ISQTemplate) -> None: - """注册新的 template""" - self.templates[template.template_id] = template - - def register_template_dict(self, template_dict: Dict[str, Any]) -> ISQTemplate: - """从 dict 注册模板,返回实例。""" - tpl = ISQTemplate(**template_dict) - self.register_template(tpl) - return tpl - - def get_template(self, template_id: str) -> ISQTemplate: - """获取指定 template""" - if template_id not in self.templates: - return DEFAULT_ISQ_TEMPLATE - return self.templates[template_id] - - def list_templates(self) -> List[Dict[str, str]]: - """列出所有可用 template""" - return [ - { - "id": t.template_id, - "name": t.template_name, - "description": t.description, - "dimensions": list(t.dimensions.keys()) - } - for t in self.templates.values() - ] - - def get_dimension(self, template_id: str, dimension_key: str) -> ISQDimensionSpec: - """获取指定 template 的某个维度定义""" - template = self.get_template(template_id) - return template.dimensions.get(dimension_key) - - def get_scoring_prompt(self, template_id: str) -> str: - """获取用于 LLM 的评分 prompt""" - template = self.get_template(template_id) - - dimensions_desc = "\n".join([ - f"- **{d.name} ({d.key})**\n" - f" 范围: {d.range_type}\n" - f" 说明: {d.description}\n" - f" 示例: {', '.join(f'{k}={v}' for k, v in list(d.examples.items())[:3])}" - for d in template.dimensions.values() - ]) - - return f""" -### ISQ 评估指导 ({template.template_name}) - -使用以下 {len(template.dimensions)} 个维度评估信号质量: - -{dimensions_desc} - -### 评分标准 -{template.scoring_guide} - -### 输出格式 (JSON) -请输出以下 JSON 格式的评分结果: -{{ - "sentiment": , - "confidence": , - "intensity": , - "expectation_gap": , - "timeliness": , - "rationale": {{ - "sentiment": "评分理由", - "confidence": "评分理由", - "intensity": "评分理由", - "expectation_gap": "评分理由", - "timeliness": "评分理由" - }} -}} -""" - - -# 全局 template 管理器实例 -isq_template_manager = ISQTemplateManager() - - -# ===================================================== -# 配置加载 -# ===================================================== - -def load_templates_from_config(config_path: Optional[str] = None) -> None: - """从配置目录加载所有 JSON 模板文件,未找到则跳过,不影响默认模板。 - 支持单个 JSON 文件或目录(目录下的所有 .json 文件)。 - """ - if config_path: - path = Path(config_path) - else: - # 默认目录:config/isq_templates/ - # __file__ = src/schema/isq_template.py - # parent = src/schema, parent.parent = src, parent.parent.parent = 项目根目录 - path = Path(__file__).resolve().parent.parent.parent / "config" - - if not path.exists(): - return - - # 如果是目录,扫描所有 .json 文件 - if path.is_dir(): - json_files = list(path.glob("*.json")) - else: - json_files = [path] - - for json_file in json_files: - try: - data = json.loads(json_file.read_text(encoding="utf-8")) - - # 如果是单个模板对象,转为列表 - if isinstance(data, dict): - templates = [data] - elif isinstance(data, list): - templates = data - else: - continue - - # 注册所有模板 - for tpl_dict in templates: - if not isinstance(tpl_dict, dict): - continue - try: - isq_template_manager.register_template_dict(tpl_dict) - except Exception: - # 忽略单个模板的加载错误,继续其他模板 - continue - except Exception: - # JSON 解析失败,跳过该文件 - continue - - -# 在模块加载时自动尝试加载配置模板 -load_templates_from_config() - - -# ===================================================== -# 便利函数 -# ===================================================== - -def get_isq_template(template_id: str = "default_isq_v1") -> ISQTemplate: - """获取 ISQ template""" - return isq_template_manager.get_template(template_id) - - -def get_isq_scoring_prompt(template_id: str = "default_isq_v1") -> str: - """获取用于 LLM 的 ISQ 评分 prompt""" - return isq_template_manager.get_scoring_prompt(template_id) - - -def calculate_isq_overall_score(scores: Dict[str, float], template_id: str = "default_isq_v1") -> float: - """计算 ISQ 综合评分""" - template = get_isq_template(template_id) - - overall = 0.0 - for dim_key, weight in template.dimension_weights.items(): - if dim_key in scores: - score = scores[dim_key] - # 处理强度维度的特殊缩放 (1-5 -> 0-1) - if dim_key == "intensity": - score = score / 5.0 - overall += score * weight - - return min(1.0, max(0.0, overall)) # 限制在 0-1 之间 diff --git a/skills/alphaear-signal-tracker/scripts/schema/models.py b/skills/alphaear-signal-tracker/scripts/schema/models.py deleted file mode 100644 index 422ca9c..0000000 --- a/skills/alphaear-signal-tracker/scripts/schema/models.py +++ /dev/null @@ -1,100 +0,0 @@ -from pydantic import BaseModel, Field -from typing import List, Optional, Dict, Any -from datetime import datetime - -class TransmissionNode(BaseModel): - node_name: str = Field(..., description="产业链节点名称") - impact_type: str = Field(..., description="利好/利空/中性") - logic: str = Field(..., description="该节点的传导逻辑") - -class IntentAnalysis(BaseModel): - keywords: List[str] = Field(..., description="核心实体、事件或概念关键词") - search_queries: List[str] = Field(..., description="优化后的搜索引擎查询词") - is_specific_event: bool = Field(..., description="是否查询特定突发事件") - time_range: str = Field(..., description="时间范围 (recent/all/specific_date)") - intent_summary: str = Field(..., description="一句话意图描述") - -class FilterResult(BaseModel): - """LLM 筛选结果 - 快速判断是否有有效信号""" - has_valid_signals: bool = Field(..., description="列表中是否包含有效的金融信号") - selected_ids: List[int] = Field(default_factory=list, description="筛选出的有效信号 ID 列表") - themes: List[str] = Field(default_factory=list, description="信号涉及的主题") - reason: Optional[str] = Field(default=None, description="如果无有效信号,说明原因") - -class InvestmentSignal(BaseModel): - # 核心元数据 - signal_id: str = Field(default="unknown_sig", description="唯一信号 ID") - title: str = Field(..., description="信号标题") - summary: str = Field(default="暂无摘要分析", description="100 字核心观点快报") - reasoning: str = Field(default="", description="详细的推演逻辑和理由") - - # 逻辑传导 (ISQ Key 1) - transmission_chain: List[TransmissionNode] = Field(default_factory=list, description="产业链传导逻辑链条") - - # 信号质量 (ISQ Key 2) - 来自 isq_template.DEFAULT_ISQ_TEMPLATE - # 参考: src/schema/isq_template.py 的 DEFAULT_ISQ_TEMPLATE 定义 - sentiment_score: float = Field(default=0.0, description="[ISQ] 情绪/走势 (-1.0=极度看空 ~ 0.0=中性 ~ 1.0=极度看多)") - confidence: float = Field(default=0.5, description="[ISQ] 确定性 (0.0=不可信 ~ 1.0=完全确定)") - intensity: int = Field(default=3, description="[ISQ] 强度/影响量级 (1=微弱 ~ 5=极强)") - expectation_gap: float = Field(default=0.5, description="[ISQ] 预期差/博弈空间 (0.0=充分定价 ~ 1.0=巨大预期差)") - timeliness: float = Field(default=0.8, description="[ISQ] 时效性 (0.0=长期 ~ 1.0=超短期)") - - # 预测与博弈 (ISQ Key 3) - expected_horizon: str = Field(default="T+N", description="预期的反应时窗 (如: T+0, T+3, Long-term)") - price_in_status: str = Field(default="未知", description="市场预期消化程度 (未定价/部分定价/充分定价)") - - # 关联实体 - impact_tickers: List[Dict[str, Any]] = Field(default_factory=list, description="受影响的代码列表及其权重") - industry_tags: List[str] = Field(default_factory=list, description="关联行业标签") - - # 溯源 - sources: List[Dict[str, str]] = Field(default_factory=list, description="来源详情 (包含 title, url, source_name)") - -class ResearchContext(BaseModel): - """研究员搜集的背景信息结构""" - raw_signal: str = Field(..., description="原始信号内容") - tickers_found: List[Dict[str, Any]] = Field(default_factory=list, description="找到的相关标的及其基本面/股价信息") - industry_background: str = Field(..., description="行业背景及产业链现状") - latest_developments: List[str] = Field(default_factory=list, description="相关事件的最新进展") - key_risks: List[str] = Field(default_factory=list, description="潜在风险点") - search_results_summary: str = Field(..., description="搜索结果的综合摘要") - -class ScanContext(BaseModel): - """扫描员搜集的原始数据结构""" - hot_topics: List[str] = Field(..., description="当前市场热点话题") - news_summaries: List[Dict[str, Any]] = Field(..., description="关键新闻摘要列表") - market_data: Dict[str, Any] = Field(default_factory=dict, description="相关的市场行情数据") - sentiment_overview: str = Field(..., description="整体市场情绪概览") - raw_data_summary: str = Field(..., description="原始数据的综合摘要") - -class SignalCluster(BaseModel): - theme_title: str = Field(..., description="主题名称") - signal_ids: List[int] = Field(..., description="包含的信号 ID 列表") - rationale: str = Field(..., description="聚类理由") - -class ClusterContext(BaseModel): - """信号聚类结果结构""" - clusters: List[SignalCluster] = Field(..., description="聚类列表") - -class KLinePoint(BaseModel): - date: str = Field(..., description="日期") - open: float = Field(..., description="开盘价") - high: float = Field(..., description="最高价") - low: float = Field(..., description="最低价") - close: float = Field(..., description="收盘价") - volume: float = Field(..., description="成交量") - -class ForecastResult(BaseModel): - ticker: str = Field(..., description="股票代码") - base_forecast: List[KLinePoint] = Field(default_factory=list, description="Kronos 模型原始预测") - adjusted_forecast: List[KLinePoint] = Field(default_factory=list, description="LLM 调整后的预测") - rationale: str = Field(default="", description="预测调整理由及逻辑说明") - timestamp: str = Field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d %H:%M:%S"), description="生成时间") - -class InvestmentReport(BaseModel): - overall_sentiment: str = Field(..., description="整体市场情绪评价") - market_entropy: float = Field(..., description="市场分歧度 (0-1, 1代表极高分歧)") - signals: List[InvestmentSignal] = Field(..., description="深度解析的投资信号列表") - forecasts: List[ForecastResult] = Field(default_factory=list, description="相关标的的预测结果") - timestamp: str = Field(..., description="报告生成时间") - meta_info: Optional[Dict[str, Any]] = Field(default_factory=dict, description="其他元数据") diff --git a/skills/alphaear-signal-tracker/scripts/tools/__init__.py b/skills/alphaear-signal-tracker/scripts/tools/__init__.py deleted file mode 100644 index 97fbb5d..0000000 --- a/skills/alphaear-signal-tracker/scripts/tools/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# src/tools/__init__.py -""" -AlphaEar 工具包层 - Agno Toolkit 适配器 - -提供的 Toolkit 类: -- NewsToolkit: 热点新闻获取 -- StockToolkit: 股票搜索与价格查询 -- SentimentToolkit: 情绪分析 -- SearchToolkit: 网络搜索 -""" - -from .toolkits import ( - NewsToolkit, - StockToolkit, - SentimentToolkit, - SearchToolkit, -) - -__all__ = [ - "NewsToolkit", - "StockToolkit", - "SentimentToolkit", - "SearchToolkit", -] diff --git a/skills/alphaear-signal-tracker/scripts/tools/toolkits.py b/skills/alphaear-signal-tracker/scripts/tools/toolkits.py deleted file mode 100644 index ebd0b69..0000000 --- a/skills/alphaear-signal-tracker/scripts/tools/toolkits.py +++ /dev/null @@ -1,526 +0,0 @@ -""" -AlphaEar 工具包层 - Agno Toolkit 适配器 -复用 utils 中的底层工具实现,提供 Agno Agent 兼容的 Toolkit 接口 -""" -from datetime import datetime -from typing import Optional -from agno.tools import Toolkit -from loguru import logger - -from ..utils.database_manager import DatabaseManager -from ..utils.news_tools import NewsNowTools, PolymarketTools -from ..utils.stock_tools import StockTools -from ..utils.search_tools import SearchTools -from ..utils.sentiment_tools import SentimentTools - - -class NewsToolkit(Toolkit): - """ - 新闻工具包 - 包装 NewsNowTools 为 Agno Toolkit - - 提供热点新闻获取、内容提取等功能 - """ - - def __init__(self, db: DatabaseManager, **kwargs): - self._news_tools = NewsNowTools(db) - self._sources = self._news_tools.SOURCES - - tools = [ - self.fetch_hot_news, - self.fetch_news_content, - self.get_unified_trends, - self.enrich_news_content, - ] - super().__init__(name="news_toolkit", tools=tools, **kwargs) - - - def fetch_hot_news(self, source_id: str, count: int = 10) -> str: - """ - 从指定新闻源获取热点新闻列表。 - - Args: - source_id: 新闻源标识符。可选值按类别: - **金融类**: "cls" (财联社), "wallstreetcn" (华尔街见闻), "xueqiu" (雪球) - **综合类**: "weibo" (微博热搜), "zhihu" (知乎热榜), "baidu" (百度热搜), - "toutiao" (今日头条), "douyin" (抖音), "thepaper" (澎湃新闻) - **科技类**: "36kr" (36氪), "ithome" (IT之家), "v2ex", "juejin" (掘金), - "hackernews" (Hacker News) - 推荐金融分析使用 "cls", "wallstreetcn", "xueqiu"。 - count: 获取的新闻数量,默认 10 条。 - - Returns: - 热点新闻列表的文本描述,包含排名、标题和链接。如果源不可用则返回错误信息。 - """ - logger.info(f"🔧 [TOOL CALLED] fetch_hot_news(source_id={source_id}, count={count})") - - items = self._news_tools.fetch_hot_news(source_id, count=count, fetch_content=False) - - if not items: - return f"获取 {source_id} 热点失败" - - source_name = self._sources.get(source_id, source_id) - result = f"## {source_name} 热点 (获取时间: {datetime.now().strftime('%Y-%m-%d %H:%M')})\n\n" - - for item in items: - result += f"{item['rank']}. {item['title']}\n 链接: {item['url']}\n\n" - - logger.info(f"✅ [TOOL SUCCESS] Got {len(items)} news items from {source_id}") - return result - - def fetch_news_content(self, url: str) -> str: - """ - 使用 Jina Reader 抓取指定 URL 的网页正文内容。 - - Args: - url: 需要抓取内容的完整网页 URL,必须以 http:// 或 https:// 开头。 - - Returns: - 提取的网页正文内容,如果失败则返回错误信息。 - """ - content = self._news_tools.fetch_news_content(url) - if content: - return content[:5000] # 限制长度 - return "内容抓取失败" - - def get_unified_trends(self, sources: str = "wallstreetcn,cls") -> str: - """ - 获取多平台综合热点报告。 - - Args: - sources: 要扫描的新闻源,用逗号分隔。 - 可选值: weibo, zhihu, baidu, toutiao, wallstreetcn, cls - 默认: "wallstreetcn,cls" (金融资讯) - - Returns: - 格式化的热点汇总报告。 - """ - source_list = [s.strip() for s in sources.split(",")] - report = self._news_tools.get_unified_trends(source_list) - return report - - def enrich_news_content(self, source: str = None, limit: int = 5) -> str: - """ - 为数据库中缺少正文内容的新闻补充内容。 - - Args: - source: 筛选特定新闻源(如 "cls"),为空则处理所有。 - limit: 最多处理的新闻数量,默认 5 条。 - - Returns: - 处理结果的描述。 - """ - logger.info(f"🔧 [TOOL CALLED] enrich_news_content(source={source}, limit={limit})") - - # 获取需要补充内容的新闻 - news_items = self._news_tools.db.get_daily_news(source=source, limit=limit) - items_without_content = [n for n in news_items if not n.get('content')] - - if not items_without_content: - return "没有需要补充内容的新闻" - - updated_count = 0 - cursor = self._news_tools.db.conn.cursor() - - for item in items_without_content[:limit]: - url = item.get('url') - if url: - content = self._news_tools.fetch_news_content(url) - if content: - cursor.execute( - "UPDATE daily_news SET content = ? WHERE id = ?", - (content[:10000], item['id']) - ) - updated_count += 1 - - self._news_tools.db.conn.commit() - logger.info(f"✅ [TOOL SUCCESS] Enriched {updated_count} news items with content") - - return f"✅ 已为 {updated_count} 条新闻补充正文内容" - - -class PolymarketToolkit(Toolkit): - """ - Polymarket 预测市场工具包 - 获取热门预测市场数据 - - 预测市场数据可反映公众情绪、预期和关注度 - """ - - def __init__(self, db: DatabaseManager, **kwargs): - self._poly_tools = PolymarketTools(db) - - tools = [ - self.get_prediction_markets, - self.get_market_summary, - ] - super().__init__(name="polymarket_toolkit", tools=tools, **kwargs) - - def get_prediction_markets(self, limit: int = 20) -> str: - """ - 获取 Polymarket 活跃预测市场的关键数据。 - - 预测市场反映公众对重大事件的概率预期,可用于: - - 分析市场情绪和风险偏好 - - 了解热门话题的关注度 - - 获取重大事件的概率预期 - - Args: - limit: 获取的市场数量,默认 20 个。 - - Returns: - 预测市场数据列表,包含问题、结果概率和交易量。 - 如果获取失败返回错误信息。 - """ - logger.info(f"🔧 [TOOL CALLED] get_prediction_markets(limit={limit})") - - markets = self._poly_tools.get_active_markets(limit) - if not markets: - return "❌ 无法获取 Polymarket 数据(可能是网络问题)" - - result = f"## 🔮 Polymarket 热门预测 (共 {len(markets)} 个)\n\n" - for i, m in enumerate(markets[:limit], 1): - question = m.get("question", "Unknown") - prices = m.get("outcomePrices", []) - volume = m.get("volume", 0) - - result += f"{i}. **{question}**\n" - if prices: - result += f" 概率: {prices}\n" - if volume: - try: - result += f" 交易量: ${float(volume):,.0f}\n" - except: - result += f" 交易量: {volume}\n" - result += "\n" - - logger.info(f"✅ [TOOL SUCCESS] Got {len(markets)} prediction markets") - return result - - def get_market_summary(self, limit: int = 10) -> str: - """ - 获取预测市场摘要报告,了解当前热门话题和公众预期。 - - Args: - limit: 获取的市场数量,默认 10 个。 - - Returns: - 格式化的预测市场报告。 - """ - return self._poly_tools.get_market_summary(limit) - - -class StockToolkit(Toolkit): - - """ - 股票工具包 - 包装 StockTools 为 Agno Toolkit - - 提供股票搜索、价格查询等功能 - """ - - def __init__(self, db: DatabaseManager, **kwargs): - self._stock_tools = StockTools(db) - - tools = [ - self.search_ticker, - self.get_stock_price, - ] - super().__init__(name="stock_toolkit", tools=tools, **kwargs) - - def search_ticker(self, query: str) -> str: - """ - 模糊搜索 A 股股票代码或名称。 - - Args: - query: 搜索关键词,可以是股票代码(如 "600519")或名称关键词(如 "茅台"、"宁德"、"比亚迪")。 - - Returns: - 匹配的股票列表,包含代码和名称。 - """ - q = (query or "").strip() - # Guardrails: prevent overly generic queries that tend to return arbitrary "...股份" matches. - generic_terms = { - "股份", - "有限公司", - "概念股", - "受益股", - "龙头", - "标的", - "相关股票", - "合作概念股", - } - if not q: - return "查询为空,无法搜索股票" - if q in generic_terms: - return f"查询过于泛化({q}),为避免误匹配已拒绝。请提供更具体的公司名或6位代码。" - # If it's not a numeric code, require at least 2 non-space chars. - if not any(ch.isdigit() for ch in q) and len(q.replace(" ", "")) < 2: - return "查询过短,无法搜索股票。请提供更具体的公司名或6位代码。" - - results = self._stock_tools.search_ticker(query) - - if not results: - return f"未找到匹配 '{query}' 的股票" - - output = f"## 股票搜索结果 (关键词: {query})\n\n" - for r in results: - output += f"- {r['code']} - {r['name']}\n" - return output - - def get_stock_price(self, ticker: str, days: int = 30) -> str: - """ - 获取指定股票的近期价格走势。 - - Args: - ticker: 股票代码,如 "600519"(贵州茅台)或 "000001"(平安银行)。 - days: 查询天数,默认 30 天。 - - Returns: - 价格走势的文本摘要。 - """ - from datetime import timedelta - end_date = datetime.now().strftime('%Y-%m-%d') - start_date = (datetime.now() - timedelta(days=days)).strftime('%Y-%m-%d') - - df = self._stock_tools.get_stock_price(ticker, start_date, end_date) - - if df.empty: - return f"未能获取 {ticker} 的股价数据" - - - latest = df.iloc[-1] - change = ((latest['close'] - df.iloc[0]['close']) / df.iloc[0]['close']) * 100 - - # 格式化历史数据供 LLM 分析 (取最近 15 天) - history_df = df.tail(15).copy() - history_df['date'] = history_df['date'].astype(str) - # 简化列名以节省 token - history_cols = ['date', 'open', 'close', 'high', 'low', 'volume'] - - # 尝试使用 markdown 格式,如果失败退回到 string - try: - history_str = history_df[history_cols].to_markdown(index=False, numalign="left", stralign="left") - except ImportError: - history_str = history_df[history_cols].to_string(index=False) - except Exception: - history_str = history_df[history_cols].to_string(index=False) - - return f"""## {ticker} 价格走势 ({days}天) -- 当前价: ¥{latest['close']:.2f} -- 期间涨跌: {change:+.2f}% -- 最高/最低: ¥{df['high'].max():.2f} / ¥{df['low'].min():.2f} -- 数据范围: {df.iloc[0]['date']} -> {latest['date']} - -### 最近 15 个交易日详细数据 (OHLCV): -{history_str} -""" - - - -class SentimentToolkit(Toolkit): - """ - 情绪分析工具包 - 包装 SentimentTools 为 Agno Toolkit - - 提供文本情绪分析功能(支持 BERT 和 LLM 模式) - """ - - def __init__(self, db: DatabaseManager, mode: str = "auto", **kwargs): - self._sentiment_tools = SentimentTools(db, mode=mode) - self._db = db - - tools = [ - self.analyze_sentiment, - self.batch_update_sentiment, - ] - super().__init__(name="sentiment_toolkit", tools=tools, **kwargs) - - def analyze_sentiment(self, text: str) -> str: - """ - 分析文本的情绪极性。 - - Args: - text: 需要分析的文本内容,如新闻标题或摘要。 - - Returns: - 情绪分析结果,包含分值(-1.0到1.0)和标签(positive/negative/neutral)。 - """ - result = self._sentiment_tools.analyze_sentiment(text) - - score = result.get('score', 0.0) - label = result.get('label', 'neutral') - reason = result.get('reason', '') - - return f"""情绪分析结果: -- 文本: {text[:100]}{'...' if len(text) > 100 else ''} -- 分值: {score:.2f} -- 标签: {label} -- 分析: {reason}""" - - def batch_update_sentiment(self, source: str = None, limit: int = 20) -> str: - """ - 批量更新数据库中新闻的情绪分数。 - - Args: - source: 筛选特定新闻源(如 "cls", "wallstreetcn"),为空则处理所有。 - limit: 最多处理的新闻数量,默认 20 条。 - - Returns: - 更新结果的描述。 - """ - logger.info(f"🔧 [TOOL CALLED] batch_update_sentiment(source={source}, limit={limit})") - - count = self._sentiment_tools.batch_update_news_sentiment(source=source, limit=limit) - - return f"✅ 已更新 {count} 条新闻的情绪分数" - - - -class SearchToolkit(Toolkit): - """ - 搜索工具包 - 包装 SearchTools 为 Agno Toolkit - - 提供网络搜索功能(支持 Jina、DuckDuckGo 和百度) - - 当环境变量 JINA_API_KEY 设置时,默认使用 Jina Search, - 提供 LLM 友好的搜索结果。 - """ - - def __init__(self, db: DatabaseManager, **kwargs): - self._search_tools = SearchTools(db) - - tools = [ - self.web_search, - self.aggregate_search, - ] - super().__init__(name="search_toolkit", tools=tools, **kwargs) - - def web_search(self, query: str, engine: str = None, max_results: int = 5) -> str: - """ - 使用指定搜索引擎执行网络搜索。 - - Args: - query: 搜索关键词,如 "英伟达财报" 或 "光伏行业政策"。 - engine: 搜索引擎选择。可选值: - "jina" (Jina Search,需配置 JINA_API_KEY,LLM友好输出), - "ddg" (DuckDuckGo,推荐英文/国际搜索), - "baidu" (百度,推荐中文/国内搜索)。 - 默认: 若配置了 JINA_API_KEY 则使用 "jina",否则 "ddg"。 - max_results: 返回结果数量。默认 5。 - - Returns: - 搜索结果的文本描述。 - """ - return self._search_tools.search(query, engine=engine, max_results=max_results) - - def aggregate_search(self, query: str, max_results: int = 5) -> str: - """ - 同时使用多个搜索引擎搜索并聚合结果。 - - Args: - query: 搜索关键词。 - max_results: 每个引擎返回的最大结果数。默认 5。 - - Returns: - 聚合后的搜索结果。 - """ - return self._search_tools.aggregate_search(query, max_results=max_results) - - -class ContextSearchToolkit(Toolkit): - """ - 上下文搜索工具包 - 用于 RAG 场景的文档片段检索 - - 支持在内存中存储文档片段,并通过关键词搜索相关内容。 - 适用于 ReportAgent 的分段编辑场景。 - """ - - def __init__(self, **kwargs): - self._store = {} # {doc_id: {"title": str, "content": str, "summary": str}} - - tools = [ - self.search_context, - self.get_toc, - ] - super().__init__(name="context_search_toolkit", tools=tools, **kwargs) - - def add_document(self, doc_id: str, title: str, content: str, summary: str = ""): - """添加文档到存储(供外部调用,非 LLM 工具)""" - self._store[doc_id] = { - "title": title, - "content": content, - "summary": summary or content[:200] + "..." - } - logger.info(f"📄 Added document to context store: {doc_id} - {title[:30]}...") - - def clear(self): - """清空文档存储""" - self._store.clear() - logger.info("🗑️ Context store cleared") - - def search_context(self, query: str, max_results: int = 3) -> str: - """ - 在已存储的文档中搜索与查询相关的内容片段。 - - Args: - query: 搜索关键词,如 "消费板块" 或 "茅台 预测"。 - max_results: 返回的最大结果数,默认 3。 - - Returns: - 匹配的文档片段,按相关性排序。 - """ - logger.info(f"🔍 [TOOL CALLED] search_context(query={query}, max_results={max_results})") - - if not self._store: - return "⚠️ 上下文存储为空,无可搜索内容。" - - # 简单的关键词匹配 + 计分 - query_terms = query.lower().split() - results = [] - - for doc_id, doc in self._store.items(): - score = 0 - content_lower = doc["content"].lower() - title_lower = doc["title"].lower() - - for term in query_terms: - # 标题匹配权重更高 - if term in title_lower: - score += 3 - if term in content_lower: - score += content_lower.count(term) - - if score > 0: - results.append((score, doc_id, doc)) - - # 按分数排序 - results.sort(key=lambda x: x[0], reverse=True) - results = results[:max_results] - - if not results: - return f"未找到与 '{query}' 相关的内容。" - - output = f"## 搜索结果 (查询: {query})\n\n" - for score, doc_id, doc in results: - output += f"### [{doc_id}] {doc['title']}\n" - # 返回摘要而非全文,节省 token - output += f"{doc['summary']}\n\n" - - logger.info(f"✅ [TOOL SUCCESS] Found {len(results)} matching documents") - return output - - def get_toc(self) -> str: - """ - 获取当前存储的所有文档的目录(TOC)。 - - Returns: - 文档目录列表,包含 ID 和标题。 - """ - logger.info("🔍 [TOOL CALLED] get_toc()") - - if not self._store: - return "⚠️ 上下文存储为空。" - - output = "## 文档目录 (TOC)\n\n" - for doc_id, doc in self._store.items(): - output += f"- **[{doc_id}]** {doc['title']}\n" - - return output - diff --git a/skills/alphaear-signal-tracker/scripts/utils/__init__.py b/skills/alphaear-signal-tracker/scripts/utils/__init__.py deleted file mode 100644 index 27e1961..0000000 --- a/skills/alphaear-signal-tracker/scripts/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# AlphaEar utils package diff --git a/skills/alphaear-signal-tracker/scripts/utils/content_extractor.py b/skills/alphaear-signal-tracker/scripts/utils/content_extractor.py deleted file mode 100644 index 133207a..0000000 --- a/skills/alphaear-signal-tracker/scripts/utils/content_extractor.py +++ /dev/null @@ -1,122 +0,0 @@ -import requests -from requests.exceptions import RequestException, Timeout, ConnectionError -import os -import time -import json -import threading -from typing import Optional -from loguru import logger - - -class ContentExtractor: - """内容提取工具 - 主要接入 Jina Reader API""" - - JINA_BASE_URL = "https://r.jina.ai/" - - # 速率限制配置 (无 API Key 时:20 次/分钟) - _rate_limit_no_key = 20 # 每分钟最大请求数 - _rate_window = 60.0 # 时间窗口(秒) - _min_interval = 3.0 # 请求最小间隔(秒) - - # 类级别的速率限制状态 - _request_times = [] - _last_request_time = 0.0 - _lock = threading.Lock() - - @classmethod - def _wait_for_rate_limit(cls, has_api_key: bool) -> None: - """等待以满足速率限制要求""" - if has_api_key: - # 有 API Key 时,只需保持最小间隔 - time.sleep(0.5) - return - - with cls._lock: - current_time = time.time() - - # 1. 清理过期的请求记录 - cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window] - - # 2. 检查是否达到速率限制 - if len(cls._request_times) >= cls._rate_limit_no_key: - # 需要等待最旧的请求过期 - oldest = cls._request_times[0] - wait_time = cls._rate_window - (current_time - oldest) + 1.0 - if wait_time > 0: - logger.warning(f"⏳ Jina rate limit reached, waiting {wait_time:.1f}s...") - time.sleep(wait_time) - current_time = time.time() - cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window] - - # 3. 确保请求间隔不太快 - time_since_last = current_time - cls._last_request_time - if time_since_last < cls._min_interval: - sleep_time = cls._min_interval - time_since_last - time.sleep(sleep_time) - - # 4. 记录本次请求 - cls._request_times.append(time.time()) - cls._last_request_time = time.time() - - @classmethod - def extract_with_jina(cls, url: str, timeout: int = 30) -> Optional[str]: - """ - 使用 Jina Reader 提取网页正文内容 (Markdown 格式) - - 无 API Key 时自动限速:每分钟最多 20 次请求,每次间隔至少 3 秒 - """ - if not url or not url.startswith("http"): - return None - - logger.info(f"🕸️ Extracting content from: {url} via Jina...") - - headers = { - "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", - "Accept": "application/json" - } - - # 使用统一的 JINA_API_KEY - api_key = os.getenv("JINA_API_KEY") - has_api_key = bool(api_key and api_key.strip()) - - if has_api_key: - headers["Authorization"] = f"Bearer {api_key}" - - # 等待速率限制 - cls._wait_for_rate_limit(has_api_key) - - try: - # Jina Reader API - full_url = f"{cls.JINA_BASE_URL}{url}" - response = requests.get(full_url, headers=headers, timeout=timeout) - - if response.status_code == 200: - try: - data = response.json() - # Jina JSON 响应格式通常在 data.content - if isinstance(data, dict) and "data" in data: - return data["data"].get("content", "") - return data.get("content", response.text) - except (json.JSONDecodeError, TypeError): - return response.text - elif response.status_code == 429: - # 触发速率限制,等待后重试一次 - logger.warning(f"⚠️ Jina rate limit (429), waiting 60s before retry...") - time.sleep(60) - return cls.extract_with_jina(url, timeout) - else: - logger.warning(f"Jina extraction failed (Status {response.status_code}) for {url}") - return None - - except Timeout: - logger.error(f"Timeout during Jina extraction for {url}") - return None - except ConnectionError: - logger.error(f"Connection error during Jina extraction for {url}") - return None - except RequestException as e: - logger.error(f"Request error during Jina extraction: {e}") - return None - except Exception as e: - logger.error(f"Unexpected error during Jina extraction: {e}") - return None diff --git a/skills/alphaear-signal-tracker/scripts/utils/database_manager.py b/skills/alphaear-signal-tracker/scripts/utils/database_manager.py deleted file mode 100644 index cfc362b..0000000 --- a/skills/alphaear-signal-tracker/scripts/utils/database_manager.py +++ /dev/null @@ -1,581 +0,0 @@ -import sqlite3 -import json -from datetime import datetime, date -from pathlib import Path -from typing import List, Dict, Optional, Any, Union -import pandas as pd -from loguru import logger - -class DatabaseManager: - """ - AlphaEar 数据库管理器 - 负责存储热点数据、搜索缓存和股价数据 - 使用 SQLite 进行持久化存储 - """ - - def __init__(self, db_path: str = "data/signal_flux.db"): - self.db_path = Path(db_path) - self.db_path.parent.mkdir(parents=True, exist_ok=True) - self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False) - self.conn.row_factory = sqlite3.Row - self._init_db() - logger.info(f"💾 Database initialized at {self.db_path}") - - def _init_db(self): - """初始化表结构""" - cursor = self.conn.cursor() - - # 1. 每日热点新闻表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS daily_news ( - id TEXT PRIMARY KEY, - source TEXT, - rank INTEGER, - title TEXT, - url TEXT, - content TEXT, - publish_time TEXT, - crawl_time TEXT, - sentiment_score REAL, - analysis TEXT, - meta_data TEXT - ) - """) - - # 尝试添加 analysis 列(如果表已存在但没有该列) - try: - cursor.execute("ALTER TABLE daily_news ADD COLUMN analysis TEXT") - except: - pass # 列已存在 - - - # 2. 搜索缓存表 (原有 JSON 缓存) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS search_cache ( - query_hash TEXT PRIMARY KEY, - query TEXT, - engine TEXT, - results TEXT, - timestamp TEXT - ) - """) - - # 2.5 搜索详情表 (展开的搜索结果) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS search_detail ( - id TEXT, - query_hash TEXT, - rank INTEGER, - title TEXT, - url TEXT, - content TEXT, - publish_time TEXT, - crawl_time TEXT, - sentiment_score REAL, - source TEXT, - meta_data TEXT, - PRIMARY KEY (query_hash, id) - ) - """) - - # 3. 股价数据表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS stock_prices ( - ticker TEXT, - date TEXT, - open REAL, - close REAL, - high REAL, - low REAL, - volume REAL, - change_pct REAL, - PRIMARY KEY (ticker, date) - ) - """) - - # 4. 股票列表表 (用于检索) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS stock_list ( - code TEXT PRIMARY KEY, - name TEXT - ) - """) - - # 5. 投资信号表 (ISQ Framework) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS signals ( - signal_id TEXT PRIMARY KEY, - title TEXT, - summary TEXT, - transmission_chain TEXT, - sentiment_score REAL, - confidence REAL, - intensity INTEGER, - expected_horizon TEXT, - price_in_status TEXT, - impact_tickers TEXT, - industry_tags TEXT, - sources TEXT, - user_id TEXT, - created_at TEXT - ) - """) - - - - # 6. 创建索引以优化查询性能 - cursor.execute("CREATE INDEX IF NOT EXISTS idx_news_crawl_time ON daily_news(crawl_time)") - cursor.execute("CREATE INDEX IF NOT EXISTS idx_news_source ON daily_news(source)") - cursor.execute("CREATE INDEX IF NOT EXISTS idx_search_cache_timestamp ON search_cache(timestamp)") - cursor.execute("CREATE INDEX IF NOT EXISTS idx_stock_prices_ticker_date ON stock_prices(ticker, date)") - # 尝试添加 user_id 列到 signals 表 - try: - cursor.execute("ALTER TABLE signals ADD COLUMN user_id TEXT") - except: - pass - - cursor.execute("CREATE INDEX IF NOT EXISTS idx_signals_user_id ON signals(user_id)") - - self.conn.commit() - - # - # self.conn.commit() - - - # --- 新闻数据操作 --- - - def save_daily_news(self, news_list: List[Dict]) -> int: - """保存热点新闻,包含发布时间与抓取时间""" - cursor = self.conn.cursor() - count = 0 - crawl_time = datetime.now().isoformat() - - for news in news_list: - try: - # 兼容不同来源的 ID 生成逻辑 - news_id = news.get('id') or f"{news.get('source')}_{news.get('rank')}_{crawl_time[:10]}" - cursor.execute(""" - INSERT OR REPLACE INTO daily_news - (id, source, rank, title, url, content, publish_time, crawl_time, sentiment_score, meta_data) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - news_id, - news.get('source'), - news.get('rank'), - news.get('title'), - news.get('url'), - news.get('content', ''), - news.get('publish_time'), # 新增支持发布时间 - crawl_time, - news.get('sentiment_score'), - json.dumps(news.get('meta_data', {})) - )) - count += 1 - except sqlite3.Error as e: - logger.error(f"Database error saving news item {news.get('title')}: {e}") - except Exception as e: - logger.error(f"Unexpected error saving news item {news.get('title')}: {e}") - - self.conn.commit() - return count - - def get_daily_news(self, source: Optional[str] = None, limit: int = 100, days: int = 1) -> List[Dict]: - """获取最近 N 天的热点新闻""" - cursor = self.conn.cursor() - # 使用 crawl_time 过滤,保证结果的新鲜度 - time_threshold = (datetime.now().timestamp() - days * 86400) - time_threshold_str = datetime.fromtimestamp(time_threshold).isoformat() - - query = "SELECT * FROM daily_news WHERE crawl_time >= ?" - params = [time_threshold_str] - - if source: - query += " AND source = ?" - params.append(source) - - query += " ORDER BY crawl_time DESC, rank LIMIT ?" - params.append(limit) - - cursor.execute(query, params) - return [dict(row) for row in cursor.fetchall()] - - def lookup_reference_by_url(self, url: str) -> Optional[Dict[str, Any]]: - """Best-effort lookup of a source item by URL. - - This is used to render a stable bibliography from DB-backed metadata. - It searches both `daily_news` and `search_detail`. - """ - url = (url or "").strip() - if not url: - return None - - cursor = self.conn.cursor() - - try: - cursor.execute( - """ - SELECT title, source, publish_time, crawl_time, url - FROM daily_news - WHERE url = ? - ORDER BY crawl_time DESC - LIMIT 1 - """, - (url,), - ) - row = cursor.fetchone() - if row: - return dict(row) - except Exception: - pass - - try: - cursor.execute( - """ - SELECT title, source, publish_time, crawl_time, url - FROM search_detail - WHERE url = ? - ORDER BY crawl_time DESC - LIMIT 1 - """, - (url,), - ) - row = cursor.fetchone() - if row: - return dict(row) - except Exception: - pass - - return None - - def delete_news(self, news_id: str) -> bool: - """删除特定新闻""" - cursor = self.conn.cursor() - cursor.execute("DELETE FROM daily_news WHERE id = ?", (news_id,)) - self.conn.commit() - return cursor.rowcount > 0 - - def update_news_content(self, news_id: str, content: str = None, analysis: str = None) -> bool: - """更新新闻的内容或分析结果""" - cursor = self.conn.cursor() - updates = [] - params = [] - - if content is not None: - updates.append("content = ?") - params.append(content) - if analysis is not None: - updates.append("analysis = ?") - params.append(analysis) - - if not updates: - return False - - params.append(news_id) - query = f"UPDATE daily_news SET {', '.join(updates)} WHERE id = ?" - cursor.execute(query, params) - self.conn.commit() - return cursor.rowcount > 0 - - # --- 搜索缓存辅助 --- - - def get_search_cache(self, query_hash: str, ttl_seconds: Optional[int] = None) -> Optional[Dict]: - """获取搜索缓存 (优先查 search_detail)""" - cursor = self.conn.cursor() - - # 1. 尝试从 search_detail 获取展开的结构化数据 - cursor.execute(""" - SELECT * FROM search_detail - WHERE query_hash = ? - ORDER BY rank - """, (query_hash,)) - details = [dict(row) for row in cursor.fetchall()] - - if details: - # 检查 TTL (取第一条的时间) - first_time = datetime.fromisoformat(details[0]['crawl_time']) - if ttl_seconds and (datetime.now() - first_time).total_seconds() > ttl_seconds: - logger.info(f"⌛ Detailed cache expired for hash {query_hash}") - pass # Expired, fall through or return None? If Detail expired, Cache likely expired too. - # But let's check basic cache just in case metadata differs? - # Actually if details exist, we prefer them. If expired, we return None. - return None - - logger.info(f"✅ Hit detailed search cache for {query_hash} ({len(details)} items)") - # Reconstruct the expected 'results' list format for SearchTools - # SearchTools expects a list of dicts. - # We return a dict wrapper to match get_search_cache signature returning Dict usually containing 'results' string. - # But SearchTools logic: - # cache = db.get_search_cache(...) - # cached_data = json.loads(cache['results']) - - # To minimize SearchTools changes, we can return a dict mimicking the old structure - # OR Change SearchTools to handle list return. - # Let's return a special dict that SearchTools can recognize or just format it as before. - return {"results": json.dumps(details), "timestamp": details[0]['crawl_time']} - - # 2. Fallback to old table - cursor.execute("SELECT * FROM search_cache WHERE query_hash = ?", (query_hash,)) - row = cursor.fetchone() - - if not row: - return None - - row_dict = dict(row) - if ttl_seconds: - cache_time = datetime.fromisoformat(row_dict['timestamp']) - if (datetime.now() - cache_time).total_seconds() > ttl_seconds: - logger.info(f"⌛ Cache expired for hash {query_hash}") - return None - - return row_dict - - def save_search_cache(self, query_hash: str, query: str, engine: str, results: Union[str, List[Dict]]): - """保存搜索结果 (同时保存到 search_cache 和 search_detail)""" - cursor = self.conn.cursor() - current_time = datetime.now().isoformat() - - results_str = results if isinstance(results, str) else json.dumps(results) - - # 1. Save summary to search_cache - cursor.execute(""" - INSERT OR REPLACE INTO search_cache (query_hash, query, engine, results, timestamp) - VALUES (?, ?, ?, ?, ?) - """, (query_hash, query, engine, results_str, current_time)) - - # 2. Save details to search_detail if results is a list - if isinstance(results, list): - for item in results: - try: - item_id = item.get('id') or f"{hash(item.get('url', ''))}" - cursor.execute(""" - INSERT OR REPLACE INTO search_detail - (id, query_hash, rank, title, url, content, publish_time, crawl_time, sentiment_score, source, meta_data) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - str(item_id), - query_hash, - item.get('rank', 0), - item.get('title'), - item.get('url'), - item.get('content', ''), - item.get('publish_time'), - item.get('crawl_time') or current_time, - item.get('sentiment_score'), - item.get('source'), - json.dumps(item.get('meta_data', {})) - )) - except sqlite3.Error as e: - logger.error(f"Database error saving search detail {item.get('title')}: {e}") - except Exception as e: - logger.error(f"Unexpected error saving search detail {item.get('title')}: {e}") - - self.conn.commit() - - def find_similar_queries(self, query: str, limit: int = 5) -> List[Dict]: - """模糊搜索相似的已缓存查询""" - cursor = self.conn.cursor() - - # Simple fuzzy match: query in cached OR cached in query - q_wild = f"%{query}%" - cursor.execute(""" - SELECT query, query_hash, timestamp, results - FROM search_cache - WHERE query LIKE ? OR ? LIKE ('%' || query || '%') - ORDER BY timestamp DESC - LIMIT ? - """, (q_wild, query, limit)) - - return [dict(row) for row in cursor.fetchall()] - - def search_local_news(self, query: str, limit: int = 5) -> List[Dict]: - """从本地 daily_news 搜索相关新闻""" - cursor = self.conn.cursor() - q_wild = f"%{query}%" - # Search title and content - cursor.execute(""" - SELECT * FROM daily_news - WHERE title LIKE ? OR content LIKE ? - ORDER BY crawl_time DESC - LIMIT ? - """, (q_wild, q_wild, limit)) - return [dict(row) for row in cursor.fetchall()] - - # --- 股票数据操作 --- - - def save_stock_list(self, df: pd.DataFrame): - """保存股票列表到 stock_list 表""" - cursor = self.conn.cursor() - try: - # 清空旧表 - cursor.execute("DELETE FROM stock_list") - - # 批量插入 - data = df[['code', 'name']].to_dict('records') - cursor.executemany( - "INSERT INTO stock_list (code, name) VALUES (:code, :name)", - data - ) - self.conn.commit() - except sqlite3.Error as e: - logger.error(f"Database error saving stock list: {e}") - except Exception as e: - logger.error(f"Unexpected error saving stock list: {e}") - - def search_stock(self, query: str, limit: int = 5) -> List[Dict]: - """模糊搜索股票代码或名称""" - cursor = self.conn.cursor() - wild = f"%{query}%" - cursor.execute(""" - SELECT code, name FROM stock_list - WHERE code LIKE ? OR name LIKE ? - LIMIT ? - """, (wild, wild, limit)) - return [dict(row) for row in cursor.fetchall()] - - def get_stock_by_code(self, code: str) -> Optional[Dict[str, str]]: - """精确按代码获取股票信息。 - - Args: - code: 股票代码(A股6位 / 港股5位),必须为纯数字字符串。 - - Returns: - dict: {"code": str, "name": str} 或 None。 - """ - if not code: - return None - clean = "".join([c for c in str(code).strip() if c.isdigit()]) - if not clean: - return None - - cursor = self.conn.cursor() - cursor.execute("SELECT code, name FROM stock_list WHERE code = ? LIMIT 1", (clean,)) - row = cursor.fetchone() - return dict(row) if row else None - - def save_stock_prices(self, ticker: str, df: pd.DataFrame): - """保存股价历史数据""" - if df.empty: - return - - cursor = self.conn.cursor() - - # 确保 DataFrame 有必要的列 - required_cols = ['date', 'open', 'close', 'high', 'low', 'volume', 'change_pct'] - for col in required_cols: - if col not in df.columns: - logger.warning(f"Missing column {col} in stock data for {ticker}") - return - - try: - for _, row in df.iterrows(): - cursor.execute(""" - INSERT OR REPLACE INTO stock_prices - (ticker, date, open, close, high, low, volume, change_pct) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, ( - ticker, - row['date'], - row['open'], - row['close'], - row['high'], - row['low'], - row['volume'], - row['change_pct'] - )) - self.conn.commit() - except sqlite3.Error as e: - logger.error(f"Database error saving stock prices for {ticker}: {e}") - except Exception as e: - logger.error(f"Unexpected error saving stock prices for {ticker}: {e}") - - def get_stock_prices(self, ticker: str, start_date: str, end_date: str) -> pd.DataFrame: - """获取指定日期范围的股价数据""" - cursor = self.conn.cursor() - - cursor.execute(""" - SELECT * FROM stock_prices - WHERE ticker = ? AND date >= ? AND date <= ? - ORDER BY date - """, (ticker, start_date, end_date)) - - rows = cursor.fetchall() - if not rows: - return pd.DataFrame() - - columns = ['ticker', 'date', 'open', 'close', 'high', 'low', 'volume', 'change_pct'] - return pd.DataFrame([dict(row) for row in rows], columns=columns) - - def execute_query(self, query: str, params: tuple = ()) -> List[Any]: - """执行自定义 SQL 查询""" - try: - cursor = self.conn.cursor() - cursor.execute(query, params) - if query.strip().upper().startswith("SELECT"): - return cursor.fetchall() - else: - self.conn.commit() - return [] - except sqlite3.Error as e: - logger.error(f"SQL execution failed (Database error): {e}") - return [] - except Exception as e: - logger.error(f"SQL execution failed (Unexpected error): {e}") - return [] - - # --- 投资信号操作 (ISQ Framework) --- - - def save_signal(self, signal: Dict[str, Any]): - """保存投资信号""" - cursor = self.conn.cursor() - created_at = datetime.now().isoformat() - - cursor.execute(""" - INSERT OR REPLACE INTO signals - (signal_id, title, summary, transmission_chain, sentiment_score, - confidence, intensity, expected_horizon, price_in_status, - impact_tickers, industry_tags, sources, user_id, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - signal.get('signal_id'), - signal.get('title'), - signal.get('summary'), - json.dumps(signal.get('transmission_chain', [])), - signal.get('sentiment_score', 0.0), - signal.get('confidence', 0.0), - signal.get('intensity', 1), - signal.get('expected_horizon', 'T+0'), - signal.get('price_in_status', '未知'), - json.dumps(signal.get('impact_tickers', [])), - json.dumps(signal.get('industry_tags', [])), - json.dumps(signal.get('sources', [])), - signal.get('user_id'), - created_at - )) - self.conn.commit() - - def get_recent_signals(self, limit: int = 20, user_id: Optional[str] = None) -> List[Dict]: - """获取最近的投资信号""" - cursor = self.conn.cursor() - if user_id: - cursor.execute("SELECT * FROM signals WHERE user_id = ? ORDER BY created_at DESC LIMIT ?", (user_id, limit)) - else: - cursor.execute("SELECT * FROM signals ORDER BY created_at DESC LIMIT ?", (limit,)) - rows = cursor.fetchall() - - signals = [] - for row in rows: - d = dict(row) - # 解析 JSON 字段 - for field in ['transmission_chain', 'impact_tickers', 'industry_tags', 'sources']: - if d.get(field): - try: - d[field] = json.loads(d[field]) - except: - pass - signals.append(d) - return signals - - def close(self): - if self.conn: - self.conn.close() - logger.info("Database connection closed.") - diff --git a/skills/alphaear-signal-tracker/scripts/utils/hybrid_search.py b/skills/alphaear-signal-tracker/scripts/utils/hybrid_search.py deleted file mode 100644 index c597fee..0000000 --- a/skills/alphaear-signal-tracker/scripts/utils/hybrid_search.py +++ /dev/null @@ -1,216 +0,0 @@ -import numpy as np -import os -from typing import List, Dict, Any, Optional, Union -from rank_bm25 import BM25Okapi -from loguru import logger -from sentence_transformers import SentenceTransformer -from sklearn.metrics.pairwise import cosine_similarity - -class HybridSearcher: - """ - 统一混合检索引擎 (Hybrid RAG) - 实现 BM25 (文本) + 向量 (语义) 的融合搜索 (RRF) - """ - - def __init__(self, data: List[Dict[str, Any]], text_fields: List[str] = ["title", "content"], model_name: str = None): - """ - 初始化搜索器 - - Args: - data: 数据列表,每个元素为 Dict - text_fields: 用于建立索引的文本字段 - model_name: 向量模型名称,默认使用 paraphrase-multilingual-MiniLM-L12-v2 - """ - self.data = data - self.text_fields = text_fields - self._corpus = [] - self._bm25 = None - self._vector_model = None - self._embeddings = None - self._fitted = False - self._vector_fitted = False - - # 默认模型 - self.model_name = model_name or os.getenv("EMBEDDING_MODEL", "paraphrase-multilingual-MiniLM-L12-v2") - - if data: - self._prepare_corpus() - self._fit_bm25() - # 延迟加载向量模型,仅在需要时或初始化时显式调用 - # self._fit_vector() - - def _prepare_corpus(self): - """准备语料库用于分词""" - import jieba # 使用 jieba 进行中文分词 - - self._corpus = [] - self._full_texts = [] - for item in self.data: - text = " ".join([str(item.get(field, "")) for field in self.text_fields]) - self._full_texts.append(text) - # 中文分词优化 - tokens = list(jieba.cut(text)) - self._corpus.append(tokens) - - def _fit_bm25(self): - """训练 BM25 模型""" - if self._corpus: - self._bm25 = BM25Okapi(self._corpus) - self._fitted = True - logger.info(f"✅ BM25 index fitted with {len(self.data)} documents") - - def _fit_vector(self): - """训练向量模型并生成 Embeddings""" - if not self.data: - return - - try: - logger.info(f"📡 Loading embedding model: {self.model_name}...") - self._vector_model = SentenceTransformer(self.model_name) - logger.info(f"🧠 Encoding {len(self._full_texts)} documents...") - self._embeddings = self._vector_model.encode(self._full_texts, show_progress_bar=False) - self._vector_fitted = True - logger.info("✅ Vector index fitted successfully") - except Exception as e: - logger.error(f"❌ Failed to fit vector index: {e}") - self._vector_fitted = False - - def _compute_rrf(self, rank_lists: List[List[int]], k: int = 60) -> List[tuple]: - """ - 计算 Reciprocal Rank Fusion (RRF) - - Args: - rank_lists: 多个排序后的索引列表 - k: RRF 常数,默认 60 - """ - scores = {} - for rank_list in rank_lists: - for rank, idx in enumerate(rank_list): - if idx not in scores: - scores[idx] = 0 - scores[idx] += 1.0 / (k + rank + 1) - - # 按分数排序 - sorted_indices = sorted(scores.items(), key=lambda x: x[1], reverse=True) - return sorted_indices - - def search(self, query: str, top_n: int = 5, use_vector: bool = False) -> List[Dict[str, Any]]: - """ - 执行混合搜索 - - Args: - query: 搜索关键词 - top_n: 返回结果数量 - use_vector: 是否启用向量搜索 - """ - if not self._fitted or not query: - return [] - - import jieba - query_tokens = list(jieba.cut(query)) - - # 1. BM25 搜索结果 - bm25_scores = self._bm25.get_scores(query_tokens) - bm25_rank = np.argsort(bm25_scores)[::-1].tolist() - - rank_lists = [bm25_rank] - - # 2. 向量搜索逻辑 - if use_vector: - if not self._vector_fitted: - self._fit_vector() - - if self._vector_fitted: - query_embedding = self._vector_model.encode([query], show_progress_bar=False) - similarities = cosine_similarity(query_embedding, self._embeddings)[0] - vector_rank = np.argsort(similarities)[::-1].tolist() - rank_lists.append(vector_rank) - else: - logger.warning("Vector search requested but model not fitted, falling back to BM25") - - # 3. 融合排序 (RRF) - if len(rank_lists) > 1: - rrf_results = self._compute_rrf(rank_lists) - # RRF 返回 (idx, score) 列表 - final_rank = [idx for idx, score in rrf_results] - else: - final_rank = bm25_rank - - # 返回前 top_n 条结果 - results = [self.data[idx].copy() for idx in final_rank[:top_n]] - - # 为每个结果注入相关性评分 - for i, res in enumerate(results): - try: - original_idx = final_rank[i] - res["_search_score"] = bm25_scores[original_idx] - if use_vector and self._vector_fitted: - res["_vector_score"] = float(similarities[original_idx]) - except: - res["_search_score"] = 0 - - return results - -class InMemoryRAG(HybridSearcher): - """专门用于 ReportAgent 跨章节检索的内存态 RAG""" - - def search(self, query: str, top_n: int = 3, use_vector: bool = True) -> List[Dict[str, Any]]: - """默认开启向量搜索的内存检索""" - return super().search(query, top_n=top_n, use_vector=use_vector) - - def update_data(self, new_data: List[Dict[str, Any]]): - """动态更新数据并重新训练索引""" - self.data = new_data - self._prepare_corpus() - self._fit_bm25() - # 如果之前已经加载过向量模型,则更新向量索引 - if self._vector_model: - self._fit_vector() - logger.info(f"🔄 InMemoryRAG updated with {len(new_data)} items") - -class LocalNewsSearch(HybridSearcher): - """持久态 RAG:检索数据库中的历史新闻""" - - def __init__(self, db_manager): - """ - Args: - db_manager: DatabaseManager 实例 - """ - self.db = db_manager - # 初始时不加载数据,需调用 load_history - super().__init__([], ["title", "content"]) - - def load_history(self, days: int = 30, limit: int = 1000): - """从数据库加载最近 N 天的新闻构建索引""" - try: - # 假设 db_manager 有 execute_query - query = f"SELECT title, content, publish_time, source FROM daily_news ORDER BY publish_time DESC LIMIT ?" - results = self.db.execute_query(query, (limit,)) - - data = [] - for row in results: - # 转换 Row 为 Dict - if hasattr(row, 'keys'): - item = dict(row) - else: - item = { - "title": row[0], - "content": row[1], - "publish_time": row[2], - "source": row[3] - } - data.append(item) - - self.data = data - self._prepare_corpus() - self._fit_bm25() - # 默认不立即训练向量,等到第一次搜索时按需训练 - logger.info(f"📚 LocalNewsSearch loaded {len(data)} items from history") - except Exception as e: - logger.error(f"Failed to load history for search: {e}") - - def search(self, query: str, top_n: int = 5, use_vector: bool = True) -> List[Dict[str, Any]]: - """执行本地历史搜索,默认开启向量搜索""" - if not self.data: - self.load_history() - return super().search(query, top_n=top_n, use_vector=use_vector) diff --git a/skills/alphaear-signal-tracker/scripts/utils/json_utils.py b/skills/alphaear-signal-tracker/scripts/utils/json_utils.py deleted file mode 100644 index c29aab2..0000000 --- a/skills/alphaear-signal-tracker/scripts/utils/json_utils.py +++ /dev/null @@ -1,180 +0,0 @@ -import ast -import json -import re -from typing import Optional, Any -from loguru import logger - -def _strip_comments(text: str) -> str: - """ - Safely remove C-style comments (// and /* */) from JSON-like text, - preserving strings (including URLs like http://). - """ - result = [] - i = 0 - n = len(text) - in_string = False - escape = False - - while i < n: - char = text[i] - - if in_string: - if char == '\\': - escape = not escape - elif char == '"' and not escape: - in_string = False - else: - escape = False - result.append(char) - i += 1 - continue - - # Not in string - if char == '"': - in_string = True - result.append(char) - i += 1 - continue - - # Check for // comment - if i + 1 < n and text[i:i+2] == '//': - i += 2 - while i < n and text[i] != '\n': - i += 1 - continue - - # Check for /* comment - if i + 1 < n and text[i:i+2] == '/*': - i += 2 - while i + 1 < n and text[i:i+2] != '*/': - i += 1 - i += 2 - continue - - result.append(char) - i += 1 - - return ''.join(result) - -def extract_json(text: str) -> Optional[Any]: - """ - 更加鲁棒的 JSON 提取工具。 - 处理: - 1. Markdown 代码块 (```json ... ```) - 2. 首尾多余字符 - 3. 同一个文本中多个 JSON 对象 (仅提取第一个) - 4. 简单的 JSON 修复 (末尾逗号等) - 5. C 风格注释 (// 和 /* */) - """ - if not text: - return None - - # 1. 清理明显的 Markdown 包装 - text = text.strip() - - # 先尝试精确匹配 ```json ... ``` 或 ```...``` - md_match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', text, re.DOTALL) - if md_match: - text = md_match.group(1).strip() - elif text.startswith("```"): - # 回退:如果开头有 ``` 但没完整匹配 - text = re.sub(r'^```[a-z]*\n?', '', text) - text = re.sub(r'\n?```\s*$', '', text) - - # 2. 寻找第一个 JSON 起始符 { 或 [ - start_brace = text.find('{') - start_bracket = text.find('[') - - if start_brace == -1 and start_bracket == -1: - return None - - start_idx = start_brace if (start_bracket == -1 or (start_brace != -1 and start_brace < start_bracket)) else start_bracket - - # 2.5 预处理:修复一些极其常见的 LLM 错误 - potential_json = text[start_idx:].strip() - - # remove comments safely - potential_json = _strip_comments(potential_json) - - # b. 修复缺失开头引号的键: nodes": [ -> "nodes": [ - # 匹配模式: (空白或换行) 单词 紧跟引号和冒号 - potential_json = re.sub(r'([\{\,]\s*)([a-zA-Z_]\w*)\"\s*:', r'\1"\2":', potential_json) - - # c. 修复缺失末尾引号的键: "nodes: [ -> "nodes": [ - potential_json = re.sub(r'([\{\,]\s*)\"([a-zA-Z_]\w*)\s*:', r'\1"\2":', potential_json) - - # d. 修复完全缺失引号的键: nodes: [ -> "nodes": [ - # 注意避免匹配到像 http:// 这种内容,所以限定在 { 或 , 之后 - potential_json = re.sub(r'([\{\,]\s*)([a-zA-Z_]\w*)\s*:', r'\1"\2":', potential_json) - - # 3. 使用 raw_decode 尝试解析 - decoder = json.JSONDecoder() - - # 首先尝试直接解析(不做任何预处理) - try: - obj = json.loads(potential_json) - return obj - except json.JSONDecodeError: - pass - - # 简单预处理:移除对象/列表末位多余逗号 - processed_json = re.sub(r',\s*([\]}])', r'\1', potential_json) - - try: - obj, end_pos = decoder.raw_decode(processed_json) - return obj - except json.JSONDecodeError: - pass - - # e. 修复未终止的字符串字面量问题:移除值中的实际换行符 - # LLM 可能在字符串值中生成包含真实 newline 的内容,导致 JSON 非法 - def fix_multiline_strings(s): - # 简单策略:将字符串值内的换行替换为空格 - lines = s.split('\n') - result = [] - in_string = False - for line in lines: - # 计算未转义的引号数 - quote_count = line.count('"') - line.count('\\"') - if in_string: - result[-1] += ' ' + line.strip() - else: - result.append(line) - - if quote_count % 2 == 1: - in_string = not in_string - return '\n'.join(result) - - fixed_json = fix_multiline_strings(processed_json) - - try: - obj, end_pos = decoder.raw_decode(fixed_json) - return obj - except json.JSONDecodeError: - try: - # 4. 尝试处理单引号问题 (JSON 规范要求双引号,但 LLM 常输出单引号) - # 这是一个简单的替换技巧,仅针对像 {'key': 'value'} 这样的结构 - # 注意:这可能会破坏包含单引号的字符串值,所以作为较后的回退 - fix_quotes = re.sub(r"'(.*?)':", r'"\1":', processed_json) # 修复键 - fix_quotes = re.sub(r":\s*'(.*?)'", r': "\1"', fix_quotes) # 修复简单值 - obj, end_pos = decoder.raw_decode(fix_quotes) - return obj - except (json.JSONDecodeError, TypeError): - try: - # 5. 使用 ast.literal_eval 作为终极回退 (处理 Python 字典格式) - # 提取第一个匹配的括号对内容 - # 寻找匹配的 { } - stack = [] - for i, char in enumerate(potential_json): - if char == '{': stack.append('{') - elif char == '}': - if stack: stack.pop() - if not stack: - content = potential_json[:i+1] - return ast.literal_eval(content) - except (ValueError, SyntaxError, MemoryError) as e: - logger.warning(f"All JSON extraction attempts failed: {e}") - except Exception as e: - logger.error(f"Unexpected error during JSON extraction: {e}") - - return None diff --git a/skills/alphaear-signal-tracker/scripts/utils/llm/capability.py b/skills/alphaear-signal-tracker/scripts/utils/llm/capability.py deleted file mode 100644 index d07ca4f..0000000 --- a/skills/alphaear-signal-tracker/scripts/utils/llm/capability.py +++ /dev/null @@ -1,85 +0,0 @@ -import os -from typing import Optional, List, Dict, Any -from agno.agent import Agent -from agno.models.base import Model -from loguru import logger -from ..llm.factory import get_model - - -def test_tool_call_support(model: Model) -> bool: - """ - 测试模型是否支持原生的 Tool Call (Function Calling)。 - 通过尝试执行一个简单的加法工具来验证。 - """ - - def get_current_weather(location: str): - """获取指定地点的天气""" - return f"{location} 的天气是晴天,25度。" - - test_agent = Agent( - model=model, - tools=[get_current_weather], - instructions="请调用工具查询北京的天气,并直接返回工具的输出结果。", - ) - - try: - # 运行一个简单的任务,观察是否触发了 tool_call - response = test_agent.run("北京天气怎么样?") - - # 检查 response 中是否包含 tool_calls - # Agno 的 RunResponse 对象通常包含 messages,我们可以检查最后几条消息 - has_tool_call = False - for msg in response.messages: - if hasattr(msg, "tool_calls") and msg.tool_calls: - has_tool_call = True - break - - if has_tool_call: - logger.info(f"✅ Model {model.id} supports native tool calling.") - return True - else: - # 如果没有 tool_calls 但返回了正确答案,可能是模型通过纯文本模拟了工具调用(ReAct) - # 或者根本没用工具。对于原生支持的判断,我们坚持要求有 tool_calls 结构。 - logger.warning( - f"⚠️ Model {model.id} did NOT use native tool calling structure." - ) - return False - - except Exception as e: - logger.error(f"❌ Error testing tool call for {model.id}: {e}") - return False - - -class ModelCapabilityRegistry: - """ - 模型能力注册表,用于缓存和管理不同模型的能力测试结果。 - """ - - _cache = {} - - @classmethod - def get_capabilities( - cls, provider: str, model_id: str, **kwargs - ) -> Dict[str, bool]: - key = f"{provider}:{model_id}" - if key not in cls._cache: - logger.info(f"🔍 Testing capabilities for {key}...") - model = get_model(provider, model_id, **kwargs) - supports_tool_call = test_tool_call_support(model) - cls._cache[key] = {"supports_tool_call": supports_tool_call} - return cls._cache[key] - - -if __name__ == "__main__": - import os - from skills._env_loader import load_unified_env - - load_unified_env() - - # 测试当前配置的模型 - p = os.getenv("LLM_PROVIDER", "minimax") - m = os.getenv("LLM_MODEL", "Qwen") - - print(f"Testing {p}/{m}...") - res = ModelCapabilityRegistry.get_capabilities(p, m) - print(f"Result: {res}") diff --git a/skills/alphaear-signal-tracker/scripts/utils/llm/factory.py b/skills/alphaear-signal-tracker/scripts/utils/llm/factory.py deleted file mode 100644 index 09b6ea5..0000000 --- a/skills/alphaear-signal-tracker/scripts/utils/llm/factory.py +++ /dev/null @@ -1,114 +0,0 @@ -import os -from agno.models.openai import OpenAIChat -from agno.models.ollama import Ollama -from agno.models.dashscope import DashScope -from agno.models.deepseek import DeepSeek -from agno.models.openrouter import OpenRouter - -def get_model(model_provider: str, model_id: str, **kwargs): - """ - Factory to get the appropriate LLM model. - - Args: - model_provider: "openai", "ollama", "deepseek" - model_id: The specific model ID (e.g., "gpt-4o", "llama3", "deepseek-chat") - **kwargs: Additional arguments for the model constructor - """ - if model_provider == "openai": - return OpenAIChat(id=model_id, **kwargs) - - elif model_provider == "ollama": - return Ollama(id=model_id, **kwargs) - - elif model_provider == "deepseek": - # DeepSeek is OpenAI compatible - api_key = os.getenv("DEEPSEEK_API_KEY") - if not api_key: - print("Warning: DEEPSEEK_API_KEY not set.") - - return DeepSeek( - id=model_id, - api_key=api_key, - **kwargs - ) - elif model_provider == "dashscope": - api_key = os.getenv("DASHSCOPE_API_KEY") - if not api_key: - print("Warning: DASHSCOPE_API_KEY not set.") - - return DashScope( - id=model_id, - base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", - api_key=api_key, - **kwargs - ) - elif model_provider == 'openrouter': - api_key = os.getenv("OPENROUTER_API_KEY") - if not api_key: - print('Warning: OPENROUTER_API_KEY not set.') - - return OpenRouter( - id=model_id, - api_key=api_key, - **kwargs - ) - - elif model_provider == 'zai': - api_key = os.getenv("ZAI_KEY_API") - if not api_key: - print('Warning: ZAI_KEY_API not set.') - - # role_map to ensure compatibility. - default_role_map = { - "system": "system", - "user": "user", - "assistant": "assistant", - "tool": "tool", - "model": "assistant", - } - - # Allow callers to override role_map via kwargs, otherwise use default - role_map = kwargs.pop("role_map", default_role_map) - - return OpenAIChat( - id=model_id, - base_url="https://api.z.ai/api/paas/v4", - api_key=api_key, - timeout=60, - role_map=role_map, - extra_body={"enable_thinking": False}, # TODO: one more setting for thinking - **kwargs - ) - - elif model_provider == 'ust': - api_key = os.getenv("UST_KEY_API") - if not api_key: - print('Warning: UST_KEY_API not set.') - - # Some UST-compatible endpoints expect the standard OpenAI role names - # (e.g. "system", "user", "assistant") rather than Agno's default - # mapping which maps "system" -> "developer". Provide an explicit - # role_map to ensure compatibility. - default_role_map = { - "system": "system", - "user": "user", - "assistant": "assistant", - "tool": "tool", - "model": "assistant", - } - - # Allow callers to override role_map via kwargs, otherwise use default - role_map = kwargs.pop("role_map", default_role_map) - - return OpenAIChat( - id=model_id, - api_key=api_key, - base_url=os.getenv("UST_URL"), - role_map=role_map, - extra_body={"enable_thinking": False}, # TODO: one more setting for thinking - **kwargs - ) - - else: - raise ValueError(f"Unknown model provider: {model_provider}") - diff --git a/skills/alphaear-signal-tracker/scripts/utils/llm/router.py b/skills/alphaear-signal-tracker/scripts/utils/llm/router.py deleted file mode 100644 index 8c69958..0000000 --- a/skills/alphaear-signal-tracker/scripts/utils/llm/router.py +++ /dev/null @@ -1,80 +0,0 @@ -import os -from typing import Optional, List, Dict, Any, Union -from agno.models.base import Model -from loguru import logger -from ..llm.factory import get_model -from ..llm.capability import ModelCapabilityRegistry -from skills._env_loader import load_unified_env - -load_unified_env() - - -class ModelRouter: - """ - 模型路由管理器 - - 功能: - 1. 管理“推理/写作模型” (Reasoning Model) 和“工具调用模型” (Tool Model)。 - 2. 根据任务需求自动选择合适的模型。 - """ - - def __init__(self): - # 默认从环境变量读取 - self.reasoning_provider = os.getenv( - "REASONING_MODEL_PROVIDER", os.getenv("LLM_PROVIDER", "openai") - ) - self.reasoning_id = os.getenv( - "REASONING_MODEL_ID", os.getenv("LLM_MODEL", "gpt-4o") - ) - self.reasoning_host = os.getenv("REASONING_MODEL_HOST", os.getenv("LLM_HOST")) - - self.tool_provider = os.getenv("TOOL_MODEL_PROVIDER", self.reasoning_provider) - self.tool_id = os.getenv("TOOL_MODEL_ID", self.reasoning_id) - self.tool_host = os.getenv("TOOL_MODEL_HOST", self.reasoning_host) - - self._reasoning_model = None - self._tool_model = None - - logger.info( - f"🤖 ModelRouter initialized: Reasoning={self.reasoning_id} ({self.reasoning_host or 'default'}), Tool={self.tool_id} ({self.tool_host or 'default'})" - ) - - def get_reasoning_model(self, **kwargs) -> Model: - if not self._reasoning_model: - # 优先使用路由配置的 host - if self.reasoning_host and "host" not in kwargs: - kwargs["host"] = self.reasoning_host - self._reasoning_model = get_model( - self.reasoning_provider, self.reasoning_id, **kwargs - ) - return self._reasoning_model - - def get_tool_model(self, **kwargs) -> Model: - if not self._tool_model: - # 优先使用路由配置的 host - if self.tool_host and "host" not in kwargs: - kwargs["host"] = self.tool_host - - # 检查 tool_model 是否真的支持 tool call - caps = ModelCapabilityRegistry.get_capabilities( - self.tool_provider, self.tool_id, **kwargs - ) - if not caps["supports_tool_call"]: - logger.warning( - f"⚠️ Configured tool model {self.tool_id} might not support native tool calls! Consider using ReAct mode or a different model." - ) - - self._tool_model = get_model(self.tool_provider, self.tool_id, **kwargs) - return self._tool_model - - def get_model_for_agent(self, has_tools: bool = False, **kwargs) -> Model: - """ - 根据 Agent 是否包含工具来返回合适的模型。 - """ - if has_tools: - return self.get_tool_model(**kwargs) - return self.get_reasoning_model(**kwargs) - - -# 全局单例 -router = ModelRouter() diff --git a/skills/alphaear-signal-tracker/scripts/utils/logging_setup.py b/skills/alphaear-signal-tracker/scripts/utils/logging_setup.py deleted file mode 100644 index 9a2ca62..0000000 --- a/skills/alphaear-signal-tracker/scripts/utils/logging_setup.py +++ /dev/null @@ -1,45 +0,0 @@ -import os -import sys -from datetime import datetime -from typing import Optional - -from loguru import logger - - -def setup_file_logging( - run_id: str, - log_dir: str = "logs", - level: str = "INFO", - retention: str = "10 days", - rotation: str = "20 MB", -) -> str: - """Configure Loguru to log to stderr + a per-run file. - - Returns the log file path. - """ - os.makedirs(log_dir, exist_ok=True) - - # Remove default handler to avoid duplicate logs. - logger.remove() - - # Console - logger.add(sys.stderr, level=level, backtrace=False, diagnose=False) - - # File (safe for multi-thread via enqueue) - log_path = os.path.join(log_dir, f"signalflux_{run_id}.log") - logger.add( - log_path, - level=level, - rotation=rotation, - retention=retention, - enqueue=True, - backtrace=True, - diagnose=False, - encoding="utf-8", - ) - return log_path - - -def make_run_id(prefix: Optional[str] = None) -> str: - ts = datetime.now().strftime("%Y%m%d_%H%M%S") - return f"{prefix}_{ts}" if prefix else ts diff --git a/skills/alphaear-signal-tracker/scripts/utils/md_to_html.py b/skills/alphaear-signal-tracker/scripts/utils/md_to_html.py deleted file mode 100644 index 314c282..0000000 --- a/skills/alphaear-signal-tracker/scripts/utils/md_to_html.py +++ /dev/null @@ -1,185 +0,0 @@ -import markdown -import os -from loguru import logger - -def convert_md_to_html(md_content: str, title: str = "AlphaEar Report") -> str: - """ - 将 Markdown 转换为带样式的 HTML - """ - # 转换 Markdown 为 HTML - # 启用 table, toc 等扩展 - # 使用 'md_in_html' 来正确处理 markdown 中的 HTML 块 - html_body = markdown.markdown( - md_content, - extensions=['extra', 'toc', 'nl2br', 'md_in_html'] - ) - - - # 简单的 Premium CSS 模板 - html_template = f""" - - - - - - {title} - - - -
- {html_body} - -
- - - """ - return html_template - -def save_report_as_html(md_path: str, output_path: str = None): - if not output_path: - output_path = md_path.replace(".md", ".html") - - try: - with open(md_path, "r", encoding="utf-8") as f: - md_content = f.read() - - title = "AlphaEar 市场研报" - # 尝试从第一行获取标题 - lines = md_content.split('\n') - if lines and lines[0].startswith('# '): - title = lines[0].replace('# ', '').strip() - - html_content = convert_md_to_html(md_content, title) - - with open(output_path, "w", encoding="utf-8") as f: - f.write(html_content) - - logger.info(f"✅ HTML Report saved to: {output_path}") - return output_path - except Exception as e: - logger.error(f"Failed to convert report to HTML: {e}") - return None - -if __name__ == "__main__": - import sys - if len(sys.argv) > 1: - save_report_as_html(sys.argv[1]) - else: - print("Usage: python3 md_to_html.py ") diff --git a/skills/alphaear-signal-tracker/scripts/utils/news_tools.py b/skills/alphaear-signal-tracker/scripts/utils/news_tools.py deleted file mode 100644 index e833e2e..0000000 --- a/skills/alphaear-signal-tracker/scripts/utils/news_tools.py +++ /dev/null @@ -1,256 +0,0 @@ -import requests -from requests.exceptions import RequestException, Timeout -import json -import time -from datetime import datetime -from typing import List, Dict, Optional -from loguru import logger -from .database_manager import DatabaseManager -from .content_extractor import ContentExtractor - -class NewsNowTools: - """热点新闻获取工具 - 接入 NewsNow API 与 Jina 内容提取""" - - BASE_URL = "https://newsnow.busiyi.world" - SOURCES = { - # 金融类 - "cls": "财联社", - "wallstreetcn": "华尔街见闻", - "xueqiu": "雪球热榜", - # 综合/社交 - "weibo": "微博热搜", - "zhihu": "知乎热榜", - "baidu": "百度热搜", - "toutiao": "今日头条", - "douyin": "抖音热榜", - "thepaper": "澎湃新闻", - # 科技类 - "36kr": "36氪", - "ithome": "IT之家", - "v2ex": "V2EX", - "juejin": "掘金", - "hackernews": "Hacker News", - } - - - def __init__(self, db: DatabaseManager): - self.db = db - self.user_agent = ( - "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " - "AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36" - ) - self.extractor = ContentExtractor() - # Simple in-memory cache: source_id -> {"time": timestamp, "data": []} - self._cache = {} - - def fetch_hot_news(self, source_id: str, count: int = 15, fetch_content: bool = False) -> List[Dict]: - """ - 从指定新闻源获取热点新闻列表(支持5分钟缓存)。 - """ - # 1. Check cache validity (5 minutes) - cache_key = f"{source_id}_{count}" - cached = self._cache.get(cache_key) - now = time.time() - - if cached and (now - cached["time"] < 300): - logger.info(f"⚡ Using cached news for {source_id} (Age: {int(now - cached['time'])}s)") - return cached["data"] - - try: - url = f"{self.BASE_URL}/api/s?id={source_id}" - response = requests.get(url, headers={"User-Agent": self.user_agent}, timeout=30) - if response.status_code == 200: - data = response.json() - items = data.get("items", [])[:count] - processed_items = [] - for i, item in enumerate(items, 1): - item_url = item.get("url", "") - content = "" - if fetch_content and item_url: - content = self.extractor.extract_with_jina(item_url) or "" - - processed_items.append({ - "id": item.get("id") or f"{source_id}_{int(time.time())}_{i}", - "source": source_id, - "rank": i, - "title": item.get("title", ""), - "url": item_url, - "content": content, - "publish_time": item.get("publish_time"), - "meta_data": item.get("extra", {}) - }) - - # Update Cache - self._cache[cache_key] = {"time": now, "data": processed_items} - logger.info(f"✅ Fetched and cached news for {source_id}") - - self.db.save_daily_news(processed_items) - return processed_items - else: - logger.error(f"NewsNow API Error: {response.status_code}") - # Fallback to stale cache if available - if cached: - logger.warning(f"⚠️ API failed, using stale cache for {source_id}") - return cached["data"] - return [] - except Timeout: - logger.error(f"Timeout fetching hot news from {source_id}") - if cached: - logger.warning(f"⚠️ Timeout, using stale cache for {source_id}") - return cached["data"] - return [] - except RequestException as e: - logger.error(f"Network error fetching hot news from {source_id}: {e}") - if cached: - logger.warning(f"⚠️ Network check failed, using stale cache for {source_id}") - return cached["data"] - return [] - except json.JSONDecodeError: - logger.error(f"Failed to parse JSON response from NewsNow for {source_id}") - return [] - except Exception as e: - logger.error(f"Unexpected error fetching hot news from {source_id}: {e}") - return [] - - def fetch_news_content(self, url: str) -> Optional[str]: - """ - 使用 Jina Reader 抓取指定 URL 的网页正文内容。 - - Args: - url: 需要抓取内容的完整网页 URL,必须以 http:// 或 https:// 开头。 - - Returns: - 提取的网页正文内容 (Markdown 格式),如果失败则返回 None。 - """ - return self.extractor.extract_with_jina(url) - - def get_unified_trends(self, sources: Optional[List[str]] = None) -> str: - """ - 获取多平台综合热点报告,自动聚合多个新闻源的热门内容。 - - Args: - sources: 要扫描的新闻源列表。可选值按类别: - **金融类**: "cls", "wallstreetcn", "xueqiu" - **综合类**: "weibo", "zhihu", "baidu", "toutiao", "douyin", "thepaper" - **科技类**: "36kr", "ithome", "v2ex", "juejin", "hackernews" - - Returns: - 格式化的 Markdown 热点汇总报告,包含各平台 Top 10 热点标题和链接。 - """ - sources = sources or ["weibo", "zhihu", "wallstreetcn"] - all_news = [] - for src in sources: - all_news.extend(self.fetch_hot_news(src)) - time.sleep(0.2) - - if not all_news: - return "❌ 未能获取到热点数据" - - report = f"# 实时全网热点汇总 ({datetime.now().strftime('%Y-%m-%d %H:%M')})\n\n" - for src in sources: - - src_name = self.SOURCES.get(src, src) - report += f"### 🔥 {src_name}\n" - src_news = [n for n in all_news if n['source'] == src] - for n in src_news[:10]: - report += f"- {n['title']} ([链接]({n['url']}))\n" - report += "\n" - - return report - - -class PolymarketTools: - """Polymarket 预测市场数据工具 - 获取热门预测市场反映公众情绪和预期""" - - BASE_URL = "https://gamma-api.polymarket.com" - - def __init__(self, db: DatabaseManager): - self.db = db - self.user_agent = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36" - - def get_active_markets(self, limit: int = 20) -> List[Dict]: - """ - 获取活跃的预测市场,用于分析公众情绪和预期。 - - 预测市场数据可以反映: - - 公众对重大事件的预期概率 - - 市场情绪和风险偏好 - - 热门话题的关注度 - - Args: - limit: 获取的市场数量,默认 20 个。 - - Returns: - 包含预测市场信息的列表,每个市场包含: - - question: 预测问题 - - outcomes: 可能的结果 - - outcomePrices: 各结果的概率价格 - - volume: 交易量 - """ - try: - response = requests.get( - f"{self.BASE_URL}/markets", - params={"active": "true", "closed": "false", "limit": limit}, - headers={"User-Agent": self.user_agent, "Accept": "application/json"}, - timeout=30 - ) - - if response.status_code == 200: - markets = response.json() - result = [] - for m in markets: - result.append({ - "id": m.get("id"), - "question": m.get("question"), - "slug": m.get("slug"), - "outcomes": m.get("outcomes"), - "outcomePrices": m.get("outcomePrices"), - "volume": m.get("volume"), - "liquidity": m.get("liquidity"), - }) - logger.info(f"✅ 获取 {len(result)} 个预测市场") - return result - else: - logger.warning(f"⚠️ Polymarket API 返回 {response.status_code}") - return [] - except Timeout: - logger.error("Timeout fetching Polymarket markets") - return [] - except RequestException as e: - logger.error(f"Network error fetching Polymarket markets: {e}") - return [] - except json.JSONDecodeError: - logger.error("Failed to parse JSON response from Polymarket") - return [] - except Exception as e: - logger.error(f"Unexpected error fetching Polymarket markets: {e}") - return [] - - def get_market_summary(self, limit: int = 10) -> str: - """ - 获取预测市场摘要报告,用于了解当前热门话题和公众预期。 - - Args: - limit: 获取的市场数量 - - Returns: - 格式化的预测市场报告 - """ - markets = self.get_active_markets(limit) - if not markets: - return "❌ 无法获取 Polymarket 数据" - - report = f"# 🔮 Polymarket 热门预测 ({datetime.now().strftime('%Y-%m-%d %H:%M')})\n\n" - for i, m in enumerate(markets, 1): - question = m.get("question", "Unknown") - prices = m.get("outcomePrices", []) - volume = m.get("volume", 0) - - report += f"**{i}. {question}**\n" - if prices: - report += f" 概率: {prices}\n" - if volume: - report += f" 交易量: ${float(volume):,.0f}\n" - report += "\n" - - return report diff --git a/skills/alphaear-signal-tracker/scripts/utils/predictor/evaluation.py b/skills/alphaear-signal-tracker/scripts/utils/predictor/evaluation.py deleted file mode 100644 index 26c5df7..0000000 --- a/skills/alphaear-signal-tracker/scripts/utils/predictor/evaluation.py +++ /dev/null @@ -1,137 +0,0 @@ -import os -import sys -import torch -import pandas as pd -import numpy as np -import glob -from loguru import logger -from datetime import datetime, timedelta - -# Setup paths -KRONOS_DIR = os.path.dirname(os.path.abspath(__file__)) -SRC_DIR = os.path.dirname(os.path.dirname(KRONOS_DIR)) -if SRC_DIR not in sys.path: - sys.path.insert(0, SRC_DIR) - -from ..kronos.auto_synthesis_training import AutoSynthesisTrainer -from ..kronos.model import KronosPredictor -from ..visualizer import VisualizerTools -from ..schema.models import ForecastResult, KLinePoint - -class NewsModelEvaluator: - def __init__(self, model_path=None): - self.trainer = AutoSynthesisTrainer() - self.device = self.trainer.device - - if model_path is None: - # Try to find the latest model in exports/models - model_files = glob.glob(os.path.join(SRC_DIR, "exports/models/*.pt")) - if not model_files: - logger.warning("⚠️ No trained models found in exports/models/. Using base model (zero-init proj).") - else: - model_path = max(model_files, key=os.path.getctime) - - if model_path: - self.load_weights(model_path) - - def load_weights(self, path): - logger.info(f"🔄 Loading model weights from {path}...") - checkpoint = torch.load(path, map_location=self.device) - self.trainer.model.news_proj.load_state_dict(checkpoint['news_proj_state_dict']) - logger.success("✅ News projection layer loaded.") - - def evaluate_range(self, start_idx=100, end_idx=200, pred_len=5): - # 1. Fetch Tickers - res = self.trainer.db.execute_query("SELECT code FROM stock_list") - all_tickers = [row['code'] for row in res] - test_tickers = all_tickers[start_idx:end_idx] - - if not test_tickers: - logger.error(f"No tickers found in range {start_idx}-{end_idx}") - return - - logger.info(f"🚀 Evaluating News Model on stocks {start_idx} to {end_idx}...") - - # 2. Discover Shocks - shocks = self.trainer.discover_shocks(test_tickers, pred_len=pred_len) - - # 3. Associate News & Predict - self.trainer.model.eval() - predictor = KronosPredictor(self.trainer.model, self.trainer.tokenizer, device=self.device) - - save_dir = os.path.join(SRC_DIR, "exports/evaluation_results") - os.makedirs(save_dir, exist_ok=True) - - count = 0 - for shock in shocks: - summary = self.trainer.find_reason_and_verify(shock) - if not summary: - continue - - logger.info(f"📈 Testing shock: {shock['ticker']} on {shock['date']}") - - # Embedding news - news_emb = self.trainer.embedder.encode(summary) - - # Prediction - h = shock['history'] - t = shock['target'] - actuals = t['close'].values[:pred_len] - - x_ts = pd.to_datetime(h['date']) - future_dates = pd.date_range(start=x_ts.iloc[-1] + timedelta(days=1), periods=pred_len, freq='B') - y_ts = pd.Series(future_dates) - - # A. Base Prediction (No news) - p_base = predictor.predict(h, x_ts, y_ts, pred_len=pred_len, news_emb=None, verbose=False) - - # B. News-Aware Prediction - p_news = predictor.predict(h, x_ts, y_ts, pred_len=pred_len, news_emb=news_emb, verbose=False) - - # Calculate Improvement - b_preds = p_base['close'].values[:len(actuals)] - n_preds = p_news['close'].values[:len(actuals)] - b_mae = np.mean(np.abs(b_preds - actuals)) - n_mae = np.mean(np.abs(n_preds - actuals)) - improvement = (b_mae - n_mae) / (b_mae + 1e-6) * 100 - - # C. Visualize - try: - def to_kp_list(preds_df): - points = [] - for idx, row in preds_df.iterrows(): - points.append(KLinePoint( - date=str(idx)[:10], open=row['open'], high=row['high'], - low=row['low'], close=row['close'], volume=row.get('volume', 0) - )) - return points - - forecast_obj = ForecastResult( - ticker=shock['ticker'], - base_forecast=to_kp_list(p_base), - adjusted_forecast=to_kp_list(p_news), - rationale=summary - ) - - chart = VisualizerTools.generate_stock_chart( - df=h, ticker=shock['ticker'], - title=f"Test Eval: {shock['ticker']} ({shock['date']}) Imp: {improvement:.1f}%", - forecast=forecast_obj, - ground_truth=t[['date', 'open', 'high', 'low', 'close', 'volume']] - ) - - safe_date = shock['date'].replace("-", "") - filename = f"test_{shock['ticker']}_{safe_date}.html" - VisualizerTools.render_chart_to_file(chart, os.path.join(save_dir, filename)) - - logger.success(f"📊 Result for {shock['ticker']} saved. Base MAE: {b_mae:.4f}, News MAE: {n_mae:.4f}") - count += 1 - except Exception as e: - logger.error(f"Visualization failed: {e}") - - logger.info(f"🏁 Finished evaluation. {count} cases visualized in {save_dir}") - -if __name__ == "__main__": - # If you have a specific model, pass the path here. Otherwise it picks the latest. - evaluator = NewsModelEvaluator() - evaluator.evaluate_range(start_idx=100, end_idx=200, pred_len=1) diff --git a/skills/alphaear-signal-tracker/scripts/utils/predictor/kline_generate.py b/skills/alphaear-signal-tracker/scripts/utils/predictor/kline_generate.py deleted file mode 100644 index 3224c21..0000000 --- a/skills/alphaear-signal-tracker/scripts/utils/predictor/kline_generate.py +++ /dev/null @@ -1,196 +0,0 @@ -# Ref: https://github.com/shiyu-coder/Kronos - -from model import Kronos, KronosTokenizer, KronosPredictor -import pandas as pd -import sqlite3 -import torch -import matplotlib.pyplot as plt -import matplotlib.gridspec as gridspec -from pandas.tseries.offsets import BusinessDay -import numpy as np - -def get_device(): - device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" - print(f"Using device: {device}") - return device - -def load_predictor(): - tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") - model = Kronos.from_pretrained("NeoQuasar/Kronos-base") - device = get_device() - tokenizer = tokenizer.to(device) - model = model.to(device) - return KronosPredictor(model, tokenizer, device=device, max_context=512) - -def load_data(ticker="002111", db_path="AlphaEar/data/signal_flux.db"): - with sqlite3.connect(db_path) as conn: - df = pd.read_sql_query(f"SELECT * FROM stock_prices WHERE ticker = '{ticker}'", conn) - df['date'] = pd.to_datetime(df['date']) - df = df.sort_values('date').reset_index(drop=True) - return df - -def plot_kline_matplotlib(ax, ax_vol, dates, df, label_suffix="", color_up='#ef4444', color_down='#22c55e', alpha=1.0, is_prediction=False): - """ - 绘制 K 线图和成交量 - """ - # X axis mapping to integers for consistent spacing - x = np.arange(len(dates)) - - # K-line data - opens = df['open'].values - closes = df['close'].values - highs = df['high'].values - lows = df['low'].values - volumes = df['volume'].values - - # Width of the candlestick - width = 0.6 - - for i in range(len(x)): - color = color_up if closes[i] >= opens[i] else color_down - linestyle = '--' if is_prediction else '-' - - # Wick - ax.vlines(x[i], lows[i], highs[i], color=color, linewidth=1, alpha=alpha, linestyle=linestyle) - - # Body - rect_bottom = min(opens[i], closes[i]) - rect_height = abs(opens[i] - closes[i]) - if rect_height == 0: rect_height = 0.001 # Visual hair - - ax.add_patch(plt.Rectangle((x[i] - width/2, rect_bottom), width, rect_height, - edgecolor=color, facecolor=color if not is_prediction else 'none', - alpha=alpha, linewidth=1, linestyle=linestyle)) - - # Volume - ax_vol.bar(x[i], volumes[i], color=color, alpha=alpha * 0.5, width=width) - -def render_comparison_chart(history_df, actual_df, pred_df, title): - """ - 渲染组合图:历史 K 线 + 真值 K 线 + 预测 K 线 - """ - # Combine all dates for X axis - all_dates = pd.concat([history_df['date'], actual_df['date'] if actual_df is not None else pred_df.index.to_series()]).unique() - all_dates = sorted(all_dates) - date_to_idx = {date: i for i, date in enumerate(all_dates)} - - fig = plt.figure(figsize=(14, 8), facecolor='white') - gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1], hspace=0.1) - ax_main = fig.add_subplot(gs[0]) - ax_vol = fig.add_subplot(gs[1], sharex=ax_main) - - # 1. Plot History - hist_indices = [date_to_idx[d] for d in history_df['date']] - # We use a custom x for plotting to ensure continuity - plot_kline_matplotlib(ax_main, ax_vol, history_df['date'], history_df, alpha=0.8) - - offset = len(history_df) - - # 2. Plot Actual if exists - if actual_df is not None: - # Shift indices - actual_x = np.arange(len(actual_df)) + offset - # Plotting manually to handle offset - for i in range(len(actual_df)): - idx = actual_x[i] - row = actual_df.iloc[i] - color = '#ef4444' if row['close'] >= row['open'] else '#22c55e' - ax_main.vlines(idx, row['low'], row['high'], color=color, linewidth=1, alpha=0.9) - ax_main.add_patch(plt.Rectangle((idx - 0.3, min(row['open'], row['close'])), 0.6, abs(row['open']-row['close']), - edgecolor=color, facecolor=color, alpha=0.9)) - ax_vol.bar(idx, row['volume'], color=color, alpha=0.4) - - # 3. Plot Prediction - pred_x = np.arange(len(pred_df)) + offset - for i in range(len(pred_df)): - idx = pred_x[i] - row = pred_df.iloc[i] - color = '#ff8c00' # Orange for prediction to distinguish - ax_main.vlines(idx, row['low'], row['high'], color=color, linewidth=1.5, linestyle='--') - ax_main.add_patch(plt.Rectangle((idx - 0.3, min(row['open'], row['close'])), 0.6, abs(row['open']-row['close']), - edgecolor=color, facecolor='none', linewidth=1.5, linestyle='--')) - # Plot secondary prediction line for close - if i == 0: - # Connect to history - ax_main.plot([offset-1, idx], [history_df['close'].iloc[-1], row['close']], color=color, linestyle='--', alpha=0.6) - elif i > 0: - ax_main.plot([idx-1, idx], [pred_df['close'].iloc[i-1], row['close']], color=color, linestyle='--', alpha=0.6) - - # Styling - ax_main.set_title(title, fontsize=14, fontweight='bold') - ax_main.grid(True, linestyle=':', alpha=0.6) - ax_vol.grid(True, linestyle=':', alpha=0.6) - ax_vol.set_ylabel('Volume') - ax_main.set_ylabel('Price') - - # Set X ticks - step = max(1, len(all_dates) // 10) - ax_vol.set_xticks(np.arange(0, len(all_dates), step)) - ax_vol.set_xticklabels([all_dates[i].strftime('%Y-%m-%d') for i in range(0, len(all_dates), step)], rotation=45) - - plt.tight_layout() - plt.show() - plt.close() - -def run_backtest(df, predictor, lookback, pred_len, start_index=0): - total_len = len(df) - history_start = start_index - history_end = start_index + lookback - pred_start = history_end - - available_pred_len = total_len - pred_start - if available_pred_len <= 0: return - actual_pred_len = min(pred_len, available_pred_len) - pred_end = pred_start + actual_pred_len - - x_df = df.iloc[history_start : history_end].copy() - y_true_df = df.iloc[pred_start : pred_end].copy() - y_timestamp = y_true_df['date'] - - print(f"Backtesting: {x_df['date'].iloc[0].date()} to {y_timestamp.iloc[-1].date()}") - - pred_df = predictor.predict( - df=x_df[['open', 'high', 'low', 'close', 'volume']], - x_timestamp=x_df['date'], - y_timestamp=y_timestamp, - pred_len=actual_pred_len, - T=1.0, top_p=0.9, sample_count=1 - ) - - render_comparison_chart(x_df, y_true_df, pred_df, f"Backtest: {TICKER} K-Line Comparison") - -def run_forecast(df, predictor, lookback, pred_len): - if len(df) < lookback: return - x_df = df.iloc[-lookback:].copy() - last_date = x_df['date'].iloc[-1] - future_dates = pd.date_range(start=last_date + BusinessDay(1), periods=pred_len, freq='B') - future_dates = pd.Series(future_dates) - - print(f"Forecasting: Starting from {future_dates.iloc[0].date()}") - - pred_df = predictor.predict( - df=x_df[['open', 'high', 'low', 'close', 'volume']], - x_timestamp=x_df['date'], - y_timestamp=future_dates, - pred_len=pred_len, - T=1.0, top_p=0.9, sample_count=1 - ) - - render_comparison_chart(x_df, None, pred_df, f"Forecast: {TICKER} Future K-Line") - -if __name__ == "__main__": - LOOKBACK = 20 - PRED_LEN = 10 - TICKER = '002111' - - pred_model = load_predictor() - stock_data = load_data(TICKER) - - total_rows = len(stock_data) - backtest_start = max(0, total_rows - LOOKBACK - PRED_LEN - 10) # Leave some space to see trend - - print("\n--- Running Backtest ---") - run_backtest(stock_data, pred_model, LOOKBACK, PRED_LEN, start_index=backtest_start) - - print("\n--- Running Forecast ---") - run_forecast(stock_data, pred_model, LOOKBACK, PRED_LEN) \ No newline at end of file diff --git a/skills/alphaear-signal-tracker/scripts/utils/predictor/model/__init__.py b/skills/alphaear-signal-tracker/scripts/utils/predictor/model/__init__.py deleted file mode 100644 index d10e200..0000000 --- a/skills/alphaear-signal-tracker/scripts/utils/predictor/model/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from .kronos import KronosTokenizer, Kronos, KronosPredictor - -model_dict = { - 'kronos_tokenizer': KronosTokenizer, - 'kronos': Kronos, - 'kronos_predictor': KronosPredictor -} - - -def get_model_class(model_name): - if model_name in model_dict: - return model_dict[model_name] - else: - print(f"Model {model_name} not found in model_dict") - raise NotImplementedError - diff --git a/skills/alphaear-signal-tracker/scripts/utils/predictor/model/kronos.py b/skills/alphaear-signal-tracker/scripts/utils/predictor/model/kronos.py deleted file mode 100644 index cf8bece..0000000 --- a/skills/alphaear-signal-tracker/scripts/utils/predictor/model/kronos.py +++ /dev/null @@ -1,676 +0,0 @@ -import numpy as np -import pandas as pd -import torch -from huggingface_hub import PyTorchModelHubMixin -import sys - -from tqdm import trange - -sys.path.append("../") -from model.module import * - - -class KronosTokenizer(nn.Module, PyTorchModelHubMixin): - """ - KronosTokenizer module for tokenizing input data using a hybrid quantization approach. - - This tokenizer utilizes a combination of encoder and decoder Transformer blocks - along with the Binary Spherical Quantization (BSQuantizer) to compress and decompress input data. - - Args: - d_in (int): Input dimension. - d_model (int): Model dimension. - n_heads (int): Number of attention heads. - ff_dim (int): Feed-forward dimension. - n_enc_layers (int): Number of encoder layers. - n_dec_layers (int): Number of decoder layers. - ffn_dropout_p (float): Dropout probability for feed-forward networks. - attn_dropout_p (float): Dropout probability for attention mechanisms. - resid_dropout_p (float): Dropout probability for residual connections. - s1_bits (int): Number of bits for the pre token in BSQuantizer. - s2_bits (int): Number of bits for the post token in BSQuantizer. - beta (float): Beta parameter for BSQuantizer. - gamma0 (float): Gamma0 parameter for BSQuantizer. - gamma (float): Gamma parameter for BSQuantizer. - zeta (float): Zeta parameter for BSQuantizer. - group_size (int): Group size parameter for BSQuantizer. - - """ - - def __init__(self, d_in, d_model, n_heads, ff_dim, n_enc_layers, n_dec_layers, ffn_dropout_p, attn_dropout_p, resid_dropout_p, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size): - - super().__init__() - self.d_in = d_in - self.d_model = d_model - self.n_heads = n_heads - self.ff_dim = ff_dim - self.enc_layers = n_enc_layers - self.dec_layers = n_dec_layers - self.ffn_dropout_p = ffn_dropout_p - self.attn_dropout_p = attn_dropout_p - self.resid_dropout_p = resid_dropout_p - - self.s1_bits = s1_bits - self.s2_bits = s2_bits - self.codebook_dim = s1_bits + s2_bits # Total dimension of the codebook after quantization - self.embed = nn.Linear(self.d_in, self.d_model) - self.head = nn.Linear(self.d_model, self.d_in) - - # Encoder Transformer Blocks - self.encoder = nn.ModuleList([ - TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p) - for _ in range(self.enc_layers - 1) - ]) - # Decoder Transformer Blocks - self.decoder = nn.ModuleList([ - TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p) - for _ in range(self.dec_layers - 1) - ]) - self.quant_embed = nn.Linear(in_features=self.d_model, out_features=self.codebook_dim) # Linear layer before quantization - self.post_quant_embed_pre = nn.Linear(in_features=self.s1_bits, out_features=self.d_model) # Linear layer after quantization (pre part - s1 bits) - self.post_quant_embed = nn.Linear(in_features=self.codebook_dim, out_features=self.d_model) # Linear layer after quantization (full codebook) - self.tokenizer = BSQuantizer(self.s1_bits, self.s2_bits, beta, gamma0, gamma, zeta, group_size) # BSQuantizer module - - def forward(self, x): - """ - Forward pass of the KronosTokenizer. - - Args: - x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in). - - Returns: - tuple: A tuple containing: - - tuple: (z_pre, z) - Reconstructed outputs from decoder with s1_bits and full codebook respectively, - both of shape (batch_size, seq_len, d_in). - - torch.Tensor: bsq_loss - Loss from the BSQuantizer. - - torch.Tensor: quantized - Quantized representation from BSQuantizer. - - torch.Tensor: z_indices - Indices from the BSQuantizer. - """ - z = self.embed(x) - - for layer in self.encoder: - z = layer(z) - - z = self.quant_embed(z) # (B, T, codebook) - - bsq_loss, quantized, z_indices = self.tokenizer(z) - - quantized_pre = quantized[:, :, :self.s1_bits] # Extract the first part of quantized representation (s1_bits) - z_pre = self.post_quant_embed_pre(quantized_pre) - - z = self.post_quant_embed(quantized) - - # Decoder layers (for pre part - s1 bits) - for layer in self.decoder: - z_pre = layer(z_pre) - z_pre = self.head(z_pre) - - # Decoder layers (for full codebook) - for layer in self.decoder: - z = layer(z) - z = self.head(z) - - return (z_pre, z), bsq_loss, quantized, z_indices - - def indices_to_bits(self, x, half=False): - """ - Converts indices to bit representations and scales them. - - Args: - x (torch.Tensor): Indices tensor. - half (bool, optional): Whether to process only half of the codebook dimension. Defaults to False. - - Returns: - torch.Tensor: Bit representation tensor. - """ - if half: - x1 = x[0] # Assuming x is a tuple of indices if half is True - x2 = x[1] - mask = 2 ** torch.arange(self.codebook_dim//2, device=x1.device, dtype=torch.long) # Create a mask for bit extraction - x1 = (x1.unsqueeze(-1) & mask) != 0 # Extract bits for the first half - x2 = (x2.unsqueeze(-1) & mask) != 0 # Extract bits for the second half - x = torch.cat([x1, x2], dim=-1) # Concatenate the bit representations - else: - mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long) # Create a mask for bit extraction - x = (x.unsqueeze(-1) & mask) != 0 # Extract bits - - x = x.float() * 2 - 1 # Convert boolean to bipolar (-1, 1) - q_scale = 1. / (self.codebook_dim ** 0.5) # Scaling factor - x = x * q_scale - return x - - def encode(self, x, half=False): - """ - Encodes the input data into quantized indices. - - Args: - x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in). - half (bool, optional): Whether to use half quantization in BSQuantizer. Defaults to False. - - Returns: - torch.Tensor: Quantized indices from BSQuantizer. - """ - z = self.embed(x) - for layer in self.encoder: - z = layer(z) - z = self.quant_embed(z) - - bsq_loss, quantized, z_indices = self.tokenizer(z, half=half, collect_metrics=False) - return z_indices - - def decode(self, x, half=False): - """ - Decodes quantized indices back to the input data space. - - Args: - x (torch.Tensor): Quantized indices tensor. - half (bool, optional): Whether the indices were generated with half quantization. Defaults to False. - - Returns: - torch.Tensor: Reconstructed output tensor of shape (batch_size, seq_len, d_in). - """ - quantized = self.indices_to_bits(x, half) - z = self.post_quant_embed(quantized) - for layer in self.decoder: - z = layer(z) - z = self.head(z) - return z - - -class Kronos(nn.Module, PyTorchModelHubMixin): - """ - Kronos Model. - - Args: - s1_bits (int): Number of bits for pre tokens. - s2_bits (int): Number of bits for post tokens. - n_layers (int): Number of Transformer blocks. - d_model (int): Dimension of the model's embeddings and hidden states. - n_heads (int): Number of attention heads in the MultiheadAttention layers. - ff_dim (int): Dimension of the feedforward network in the Transformer blocks. - ffn_dropout_p (float): Dropout probability for the feedforward network. - attn_dropout_p (float): Dropout probability for the attention layers. - resid_dropout_p (float): Dropout probability for residual connections. - token_dropout_p (float): Dropout probability for token embeddings. - learn_te (bool): Whether to use learnable temporal embeddings. - """ - - def __init__(self, s1_bits, s2_bits, n_layers, d_model, n_heads, ff_dim, ffn_dropout_p, attn_dropout_p, resid_dropout_p, token_dropout_p, learn_te, news_dim=None): - super().__init__() - self.s1_bits = s1_bits - self.s2_bits = s2_bits - self.n_layers = n_layers - self.d_model = d_model - self.n_heads = n_heads - self.learn_te = learn_te - self.ff_dim = ff_dim - self.ffn_dropout_p = ffn_dropout_p - self.attn_dropout_p = attn_dropout_p - self.resid_dropout_p = resid_dropout_p - self.token_dropout_p = token_dropout_p - self.news_dim = news_dim - - self.s1_vocab_size = 2 ** self.s1_bits - self.token_drop = nn.Dropout(self.token_dropout_p) - self.embedding = HierarchicalEmbedding(self.s1_bits, self.s2_bits, self.d_model) - self.time_emb = TemporalEmbedding(self.d_model, self.learn_te) - self.transformer = nn.ModuleList([ - TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p) - for _ in range(self.n_layers) - ]) - self.norm = RMSNorm(self.d_model) - self.dep_layer = DependencyAwareLayer(self.d_model) - self.head = DualHead(self.s1_bits, self.s2_bits, self.d_model) - - if self.news_dim is not None: - self.news_proj = nn.Linear(self.news_dim, self.d_model) - else: - self.news_proj = None - - self.apply(self._init_weights) - - def _init_weights(self, module): - - if isinstance(module, nn.Linear): - nn.init.xavier_normal_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - nn.init.normal_(module.weight, mean=0, std=self.embedding.d_model ** -0.5) - elif isinstance(module, nn.LayerNorm): - nn.init.ones_(module.weight) - nn.init.zeros_(module.bias) - elif isinstance(module, RMSNorm): - nn.init.ones_(module.weight) - - def forward(self, s1_ids, s2_ids, stamp=None, padding_mask=None, use_teacher_forcing=False, s1_targets=None, news_emb=None): - """ - Args: - s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len] - s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len] - stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None. - padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. - use_teacher_forcing (bool, optional): Whether to use teacher forcing for s1 decoding. Defaults to False. - s1_targets (torch.Tensor, optional): Target s1 token IDs for teacher forcing. Shape: [batch_size, seq_len]. Defaults to None. - news_emb (torch.Tensor, optional): News embedding tensor. Shape: [batch_size, news_dim]. Defaults to None. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: - - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size] - - s2_logits: Logits for s2 token predictions, conditioned on s1. Shape: [batch_size, seq_len, s2_vocab_size] - """ - x = self.embedding([s1_ids, s2_ids]) - if stamp is not None: - time_embedding = self.time_emb(stamp) - x = x + time_embedding - x = self.token_drop(x) - - for layer in self.transformer: - x = layer(x, key_padding_mask=padding_mask) - - x = self.norm(x) - - if news_emb is not None and self.news_proj is not None: - news_bias = self.news_proj(news_emb).unsqueeze(1) # [B, 1, d_model] - x = x + news_bias - - s1_logits = self.head(x) - - if use_teacher_forcing: - sibling_embed = self.embedding.emb_s1(s1_targets) - else: - s1_probs = F.softmax(s1_logits.detach(), dim=-1) - sample_s1_ids = torch.multinomial(s1_probs.view(-1, self.s1_vocab_size), 1).view(s1_ids.shape) - sibling_embed = self.embedding.emb_s1(sample_s1_ids) - - x2 = self.dep_layer(x, sibling_embed, key_padding_mask=padding_mask) # Dependency Aware Layer: Condition on s1 embeddings - s2_logits = self.head.cond_forward(x2) - return s1_logits, s2_logits - - def decode_s1(self, s1_ids, s2_ids, stamp=None, padding_mask=None, news_emb=None): - """ - Decodes only the s1 tokens. - - This method performs a forward pass to predict only s1 tokens. It returns the s1 logits - and the context representation from the Transformer, which can be used for subsequent s2 decoding. - - Args: - s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len] - s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len] - stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None. - padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. - news_emb (torch.Tensor, optional): News embedding tensor. Shape: [batch_size, news_dim]. Defaults to None. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: - - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size] - - context: Context representation from the Transformer. Shape: [batch_size, seq_len, d_model] - """ - x = self.embedding([s1_ids, s2_ids]) - if stamp is not None: - time_embedding = self.time_emb(stamp) - x = x + time_embedding - x = self.token_drop(x) - - for layer in self.transformer: - x = layer(x, key_padding_mask=padding_mask) - - x = self.norm(x) - - if news_emb is not None and self.news_proj is not None: - news_bias = self.news_proj(news_emb).unsqueeze(1) # [B, 1, d_model] - x = x + news_bias - - s1_logits = self.head(x) - return s1_logits, x - - def decode_s2(self, context, s1_ids, padding_mask=None): - """ - Decodes the s2 tokens, conditioned on the context and s1 tokens. - - This method decodes s2 tokens based on a pre-computed context representation (typically from `decode_s1`) - and the s1 token IDs. It uses the dependency-aware layer and the conditional s2 head to predict s2 tokens. - - Args: - context (torch.Tensor): Context representation from the transformer (output of decode_s1). - Shape: [batch_size, seq_len, d_model] - s1_ids (torch.torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len] - padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. - - Returns: - torch.Tensor: s2 logits. Shape: [batch_size, seq_len, s2_vocab_size] - """ - sibling_embed = self.embedding.emb_s1(s1_ids) - x2 = self.dep_layer(context, sibling_embed, key_padding_mask=padding_mask) - return self.head.cond_forward(x2) - - -def top_k_top_p_filtering( - logits, - top_k: int = 0, - top_p: float = 1.0, - filter_value: float = -float("Inf"), - min_tokens_to_keep: int = 1, -): - """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering - Args: - logits: logits distribution shape (batch size, vocabulary size) - if top_k > 0: keep only top k tokens with highest probability (top-k filtering). - if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). - Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) - Make sure we keep at least min_tokens_to_keep per batch example in the output - From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 - """ - if top_k > 0: - top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check - # Remove all tokens with a probability less than the last token of the top-k - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits[indices_to_remove] = filter_value - return logits - - if top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold (token with 0 are kept) - sorted_indices_to_remove = cumulative_probs > top_p - if min_tokens_to_keep > 1: - # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) - sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 - # Shift the indices to the right to keep also the first token above the threshold - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - - # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - logits[indices_to_remove] = filter_value - return logits - - -def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_logits=True): - logits = logits / temperature - if top_k is not None or top_p is not None: - if top_k > 0 or top_p < 1.0: - logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) - - probs = F.softmax(logits, dim=-1) - - if not sample_logits: - _, x = top_k(probs, k=1, dim=-1) - else: - x = torch.multinomial(probs, num_samples=1) - - return x - - -def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False, news_emb=None): - with torch.no_grad(): - x = torch.clip(x, -clip, clip) - - device = x.device - x = x.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x.size(1), x.size(2)).to(device) - x_stamp = x_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x_stamp.size(1), x_stamp.size(2)).to(device) - y_stamp = y_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, y_stamp.size(1), y_stamp.size(2)).to(device) - - x_token = tokenizer.encode(x, half=True) - - initial_seq_len = x.size(1) - batch_size = x_token[0].size(0) - total_seq_len = initial_seq_len + pred_len - full_stamp = torch.cat([x_stamp, y_stamp], dim=1) - - generated_pre = x_token[0].new_empty(batch_size, pred_len) - generated_post = x_token[1].new_empty(batch_size, pred_len) - - pre_buffer = x_token[0].new_zeros(batch_size, max_context) - post_buffer = x_token[1].new_zeros(batch_size, max_context) - buffer_len = min(initial_seq_len, max_context) - if buffer_len > 0: - start_idx = max(0, initial_seq_len - max_context) - pre_buffer[:, :buffer_len] = x_token[0][:, start_idx:start_idx + buffer_len] - post_buffer[:, :buffer_len] = x_token[1][:, start_idx:start_idx + buffer_len] - - if verbose: - ran = trange - else: - ran = range - for i in ran(pred_len): - current_seq_len = initial_seq_len + i - window_len = min(current_seq_len, max_context) - - if current_seq_len <= max_context: - input_tokens = [ - pre_buffer[:, :window_len], - post_buffer[:, :window_len] - ] - else: - input_tokens = [pre_buffer, post_buffer] - - context_end = current_seq_len - context_start = max(0, context_end - max_context) - current_stamp = full_stamp[:, context_start:context_end, :].contiguous() - - s1_logits, context = model.decode_s1(input_tokens[0], input_tokens[1], current_stamp, news_emb=news_emb) - s1_logits = s1_logits[:, -1, :] - sample_pre = sample_from_logits(s1_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True) - - s2_logits = model.decode_s2(context, sample_pre) - s2_logits = s2_logits[:, -1, :] - sample_post = sample_from_logits(s2_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True) - - generated_pre[:, i] = sample_pre.squeeze(-1) - generated_post[:, i] = sample_post.squeeze(-1) - - if current_seq_len < max_context: - pre_buffer[:, current_seq_len] = sample_pre.squeeze(-1) - post_buffer[:, current_seq_len] = sample_post.squeeze(-1) - else: - pre_buffer.copy_(torch.roll(pre_buffer, shifts=-1, dims=1)) - post_buffer.copy_(torch.roll(post_buffer, shifts=-1, dims=1)) - pre_buffer[:, -1] = sample_pre.squeeze(-1) - post_buffer[:, -1] = sample_post.squeeze(-1) - - full_pre = torch.cat([x_token[0], generated_pre], dim=1) - full_post = torch.cat([x_token[1], generated_post], dim=1) - - context_start = max(0, total_seq_len - max_context) - input_tokens = [ - full_pre[:, context_start:total_seq_len].contiguous(), - full_post[:, context_start:total_seq_len].contiguous() - ] - z = tokenizer.decode(input_tokens, half=True) - z = z.reshape(-1, sample_count, z.size(1), z.size(2)) - preds = z.cpu().numpy() - preds = np.mean(preds, axis=1) - - return preds - - -def calc_time_stamps(x_timestamp): - time_df = pd.DataFrame() - time_df['minute'] = x_timestamp.dt.minute - time_df['hour'] = x_timestamp.dt.hour - time_df['weekday'] = x_timestamp.dt.weekday - time_df['day'] = x_timestamp.dt.day - time_df['month'] = x_timestamp.dt.month - return time_df - - -class KronosPredictor: - - def __init__(self, model, tokenizer, device="cuda:0", max_context=512, clip=5): - self.tokenizer = tokenizer - self.model = model - self.max_context = max_context - self.clip = clip - self.price_cols = ['open', 'high', 'low', 'close'] - self.vol_col = 'volume' - self.amt_vol = 'amount' - self.time_cols = ['minute', 'hour', 'weekday', 'day', 'month'] - self.device = device - - self.tokenizer = self.tokenizer.to(self.device) - self.model = self.model.to(self.device) - - def generate(self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose, news_emb=None): - - x_tensor = torch.from_numpy(np.array(x).astype(np.float32)).to(self.device) - x_stamp_tensor = torch.from_numpy(np.array(x_stamp).astype(np.float32)).to(self.device) - y_stamp_tensor = torch.from_numpy(np.array(y_stamp).astype(np.float32)).to(self.device) - - preds = auto_regressive_inference(self.tokenizer, self.model, x_tensor, x_stamp_tensor, y_stamp_tensor, self.max_context, pred_len, - self.clip, T, top_k, top_p, sample_count, verbose, news_emb=news_emb) - preds = preds[:, -pred_len:, :] - return preds - - def predict(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True, news_emb=None): - - if not isinstance(df, pd.DataFrame): - raise ValueError("Input must be a pandas DataFrame.") - - if not all(col in df.columns for col in self.price_cols): - raise ValueError(f"Price columns {self.price_cols} not found in DataFrame.") - - df = df.copy() - if self.vol_col not in df.columns: - df[self.vol_col] = 0.0 # Fill missing volume with zeros - df[self.amt_vol] = 0.0 # Fill missing amount with zeros - if self.amt_vol not in df.columns and self.vol_col in df.columns: - df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1) - - if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any(): - raise ValueError("Input DataFrame contains NaN values in price or volume columns.") - - x_time_df = calc_time_stamps(x_timestamp) - y_time_df = calc_time_stamps(y_timestamp) - - x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(np.float32) - x_stamp = x_time_df.values.astype(np.float32) - y_stamp = y_time_df.values.astype(np.float32) - - x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) - - x = (x - x_mean) / (x_std + 1e-5) - x = np.clip(x, -self.clip, self.clip) - - x = x[np.newaxis, :] - x_stamp = x_stamp[np.newaxis, :] - y_stamp = y_stamp[np.newaxis, :] - - if news_emb is not None: - news_emb_tensor = torch.from_numpy(np.array(news_emb).astype(np.float32)).to(self.device) - # Ensure batch dimension for news_emb if only one sample - if news_emb_tensor.ndim == 1: - news_emb_tensor = news_emb_tensor.unsqueeze(0) - else: - news_emb_tensor = None - - preds = self.generate(x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose, news_emb=news_emb_tensor) - - preds = preds.squeeze(0) - preds = preds * (x_std + 1e-5) + x_mean - - pred_df = pd.DataFrame(preds, columns=self.price_cols + [self.vol_col, self.amt_vol], index=y_timestamp) - return pred_df - - - def predict_batch(self, df_list, x_timestamp_list, y_timestamp_list, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True): - """ - Perform parallel (batch) prediction on multiple time series. All series must have the same historical length and prediction length (pred_len). - - Args: - df_list (List[pd.DataFrame]): List of input DataFrames, each containing price columns and optional volume/amount columns. - x_timestamp_list (List[pd.DatetimeIndex or Series]): List of timestamps corresponding to historical data, length should match the number of rows in each DataFrame. - y_timestamp_list (List[pd.DatetimeIndex or Series]): List of future prediction timestamps, length should equal pred_len. - pred_len (int): Number of prediction steps. - T (float): Sampling temperature. - top_k (int): Top-k filtering threshold. - top_p (float): Top-p (nucleus sampling) threshold. - sample_count (int): Number of parallel samples per series, automatically averaged internally. - verbose (bool): Whether to display autoregressive progress. - - Returns: - List[pd.DataFrame]: List of prediction results in the same order as input, each DataFrame contains - `open, high, low, close, volume, amount` columns, indexed by corresponding `y_timestamp`. - """ - # Basic validation - if not isinstance(df_list, (list, tuple)) or not isinstance(x_timestamp_list, (list, tuple)) or not isinstance(y_timestamp_list, (list, tuple)): - raise ValueError("df_list, x_timestamp_list, y_timestamp_list must be list or tuple types.") - if not (len(df_list) == len(x_timestamp_list) == len(y_timestamp_list)): - raise ValueError("df_list, x_timestamp_list, y_timestamp_list must have consistent lengths.") - - num_series = len(df_list) - - x_list = [] - x_stamp_list = [] - y_stamp_list = [] - means = [] - stds = [] - seq_lens = [] - y_lens = [] - - for i in range(num_series): - df = df_list[i] - if not isinstance(df, pd.DataFrame): - raise ValueError(f"Input at index {i} is not a pandas DataFrame.") - if not all(col in df.columns for col in self.price_cols): - raise ValueError(f"DataFrame at index {i} is missing price columns {self.price_cols}.") - - df = df.copy() - if self.vol_col not in df.columns: - df[self.vol_col] = 0.0 - df[self.amt_vol] = 0.0 - if self.amt_vol not in df.columns and self.vol_col in df.columns: - df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1) - - if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any(): - raise ValueError(f"DataFrame at index {i} contains NaN values in price or volume columns.") - - x_timestamp = x_timestamp_list[i] - y_timestamp = y_timestamp_list[i] - - x_time_df = calc_time_stamps(x_timestamp) - y_time_df = calc_time_stamps(y_timestamp) - - x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(np.float32) - x_stamp = x_time_df.values.astype(np.float32) - y_stamp = y_time_df.values.astype(np.float32) - - if x.shape[0] != x_stamp.shape[0]: - raise ValueError(f"Inconsistent lengths at index {i}: x has {x.shape[0]} vs x_stamp has {x_stamp.shape[0]}.") - if y_stamp.shape[0] != pred_len: - raise ValueError(f"y_timestamp length at index {i} should equal pred_len={pred_len}, got {y_stamp.shape[0]}.") - - x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) - x_norm = (x - x_mean) / (x_std + 1e-5) - x_norm = np.clip(x_norm, -self.clip, self.clip) - - x_list.append(x_norm) - x_stamp_list.append(x_stamp) - y_stamp_list.append(y_stamp) - means.append(x_mean) - stds.append(x_std) - - seq_lens.append(x_norm.shape[0]) - y_lens.append(y_stamp.shape[0]) - - # Require all series to have consistent historical and prediction lengths for batch processing - if len(set(seq_lens)) != 1: - raise ValueError(f"Parallel prediction requires all series to have consistent historical lengths, got: {seq_lens}") - if len(set(y_lens)) != 1: - raise ValueError(f"Parallel prediction requires all series to have consistent prediction lengths, got: {y_lens}") - - x_batch = np.stack(x_list, axis=0).astype(np.float32) # (B, seq_len, feat) - x_stamp_batch = np.stack(x_stamp_list, axis=0).astype(np.float32) # (B, seq_len, time_feat) - y_stamp_batch = np.stack(y_stamp_list, axis=0).astype(np.float32) # (B, pred_len, time_feat) - - preds = self.generate(x_batch, x_stamp_batch, y_stamp_batch, pred_len, T, top_k, top_p, sample_count, verbose) - # preds: (B, pred_len, feat) - - pred_dfs = [] - for i in range(num_series): - preds_i = preds[i] * (stds[i] + 1e-5) + means[i] - pred_df = pd.DataFrame(preds_i, columns=self.price_cols + [self.vol_col, self.amt_vol], index=y_timestamp_list[i]) - pred_dfs.append(pred_df) - - return pred_dfs diff --git a/skills/alphaear-signal-tracker/scripts/utils/predictor/model/module.py b/skills/alphaear-signal-tracker/scripts/utils/predictor/model/module.py deleted file mode 100644 index 20b29b5..0000000 --- a/skills/alphaear-signal-tracker/scripts/utils/predictor/model/module.py +++ /dev/null @@ -1,562 +0,0 @@ -import math - -from einops import rearrange, reduce -import torch -import torch.nn as nn -from torch.autograd import Function -import torch.nn.functional as F - - -class DifferentiableEntropyFunction(Function): - @staticmethod - def forward(ctx, zq, basis, K, eps): - zb = (zq + 1) / 2 - zi = ((zb * basis).sum(-1)).to(torch.int64) - cnt = torch.scatter_reduce(torch.zeros(2 ** K, device=zq.device, dtype=zq.dtype), - 0, - zi.flatten(), - torch.ones_like(zi.flatten()).to(zq.dtype), - 'sum') - prob = (cnt + eps) / (cnt + eps).sum() - H = -(prob * torch.log(prob)).sum() - ctx.save_for_backward(zq, zi, prob) - ctx.K = K - return H - - @staticmethod - def backward(ctx, grad_output): - zq, zi, prob = ctx.saved_tensors - grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K - reord_grad = grad_array[zi.flatten()].reshape(zi.shape) - grad_input = reord_grad.unsqueeze(-1) * zq - return grad_input, None, None, None, None - - -def codebook_entropy(zq, basis, K, eps=1e-4): - return DifferentiableEntropyFunction.apply(zq, basis, K, eps) - - -class BinarySphericalQuantizer(nn.Module): - def __init__(self, embed_dim, beta, gamma0, gamma, zeta, - input_format='bchw', - soft_entropy=True, group_size=9, - persample_entropy_compute='analytical', - cb_entropy_compute='group', - l2_norm=True, - inv_temperature=1): - """ - Paper link: https://arxiv.org/pdf/2406.07548.pdf - Here we use the official implementation of the BinarySphericalQuantizer. - """ - super().__init__() - self.embed_dim = embed_dim - self.beta = beta # loss weight for commit loss - self.gamma0 = gamma0 # loss weight for entropy penalty - self.gamma = gamma # loss weight for entropy penalty - self.zeta = zeta # loss weight for entire entropy penalty - self.input_format = input_format - assert self.embed_dim % group_size == 0, "embed_dim must be divisible by group_size" - self.num_groups = self.embed_dim // group_size - self.group_size = group_size - assert persample_entropy_compute in ['group', 'analytical'], "persample_entropy_compute must be either 'group' or 'analytical'" - assert cb_entropy_compute in ['group', 'nce'], "cb_entropy_compute must be either 'group' or 'nce'" - self.persample_entropy_compute = persample_entropy_compute - self.cb_entropy_compute = cb_entropy_compute - self.l2_norm = l2_norm - self.inv_temperature = inv_temperature - - self.register_buffer('basis', 2 ** torch.arange(embed_dim - 1, -1, -1)) - self.register_buffer('group_basis', 2 ** torch.arange(group_size - 1, -1, -1)) - - self.num_dimensions = 2 ** embed_dim - self.bits_per_index = embed_dim - - # we only need to keep the codebook portion up to the group size - # because we approximate the H loss with this subcode - group_codes = torch.arange(2 ** self.group_size) - group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:] - self.register_buffer('group_codebook', group_codebook, persistent=False) - - self.soft_entropy = soft_entropy # soft_entropy: Sec 3.2 of https://arxiv.org/pdf/1911.05894.pdf - - def quantize(self, z): - assert z.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}" - - zhat = torch.where(z > 0, - torch.tensor(1, dtype=z.dtype, device=z.device), - torch.tensor(-1, dtype=z.dtype, device=z.device)) - return z + (zhat - z).detach() - - def forward(self, z, collect_metrics=True): - # if self.input_format == 'bchw': - # z = rearrange(z, 'b c h w -> b h w c') - zq = self.quantize(z) - - q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. - - zq = zq * q_scale - - if not collect_metrics: - return zq, zq.new_zeros(()), {} - - indices = self.codes_to_indexes(zq.detach()) - group_indices = self.codes_to_group_indexes(zq.detach()) - if not self.training: - used_codes = torch.unique(indices, return_counts=False) - else: - used_codes = None - - if self.soft_entropy: - persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(z) - entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy - else: - zb_by_sample = ((zq + 1) / 2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32) - persample_entropy = self.get_hard_per_sample_entropy(zb_by_sample) - cb_entropy = codebook_entropy(zq, self.basis, self.embed_dim) - entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy - - # commit loss - commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1)) - - # if self.input_format == 'bchw': - # zq = rearrange(zq, 'b h w c -> b c h w') - - return ( - zq, - commit_loss + self.zeta * entropy_penalty / self.inv_temperature, - {"H": cb_entropy, "used_codes": used_codes, "indices": indices, "group_indices": group_indices, - "avg_prob": avg_prob} - ) - - def soft_entropy_loss(self, z): - # if we divide the code in subgroups of size group_size, the codebook will be of size 2 ** group_size - # the sub-code is the last group_size bits of the full code - group_code_book = self.group_codebook / (self.embed_dim ** 0.5 if self.l2_norm else 1) - divided_z = rearrange(z, '... (g c) -> ... g c', c=self.group_size) - - # we calculate the distance between the divided_z and the codebook for each subgroup - distance = - 2 * torch.einsum('... g c, d c ->... g d', divided_z, group_code_book) - prob = (-distance * self.inv_temperature).softmax(dim=-1) - if self.persample_entropy_compute == 'analytical': - if self.l2_norm: - p = torch.sigmoid(-4 * z / (self.embed_dim ** 0.5) * self.inv_temperature) - else: - p = torch.sigmoid(-4 * z * self.inv_temperature) - prob = torch.stack([p, 1 - p], dim=-1) - per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() - else: - per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() - - # macro average of the probability of each subgroup - avg_prob = reduce(prob, '... g d ->g d', 'mean') - codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False) - - # the approximation of the entropy is the sum of the entropy of each subgroup - return per_sample_entropy, codebook_entropy.sum(), avg_prob - - def get_hard_per_sample_entropy(self, zb_by_sample): - probs_per_dim = zb_by_sample.sum(1) / zb_by_sample.shape[1] - persample_entropy = - probs_per_dim * torch.log(probs_per_dim + 1e-8) - (1 - probs_per_dim) * torch.log(1 - probs_per_dim + 1e-8) - persample_entropy = persample_entropy.sum(-1) - return persample_entropy.mean() - - def codes_to_indexes(self, zhat): - """Converts a `code` to an index in the codebook. - Args: - zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} - """ - assert zhat.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}" - return ((zhat + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64) - - def codes_to_group_indexes(self, zhat): - """Converts a `code` to a list of indexes (in groups) in the codebook. - Args: - zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} - """ - zhat_in_group = rearrange(zhat, 'b ... (g c) -> b ... g c', c=self.group_size) - return ((zhat_in_group + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64) - - def indexes_to_codes(self, indices): - """Inverse of `indexes_to_codes`.""" - indices = indices.unsqueeze(-1) - codes_non_centered = torch.remainder( - torch.floor_divide(indices, self.basis), 2 - ) - return codes_non_centered * 2 - 1 - - def group_indexes_to_codes(self, group_indices): - """Inverse of `group_indexes_to_codes`.""" - group_indices = group_indices.unsqueeze(-1) - codes_non_centered = torch.remainder( - torch.floor_divide(group_indices, self.group_basis), 2 - ) - codes_non_centered = rearrange(codes_non_centered, 'b ... g c -> b ... (g c)') - return codes_non_centered * 2 - 1 - - def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True): - if normalize: - probs = (count + eps) / (count + eps).sum(dim=dim, keepdim=True) - else: - probs = count - H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim) - return H - - def get_group_codebook_entry(self, group_indices): - z_q = self.group_indexes_to_codes(group_indices) - q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. - z_q = z_q * q_scale - if self.input_format == 'bchw': - h, w = int(z_q.shape[1] ** 0.5) - assert h * w == z_q.shape[1], 'Invalid sequence length' - z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h) - return z_q - - def get_codebook_entry(self, indices): - z_q = self.indexes_to_codes(indices) - q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. - z_q = z_q * q_scale - if self.input_format == 'bchw': - h, w = int(z_q.shape[1] ** 0.5) - assert h * w == z_q.shape[1], 'Invalid sequence length' - z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h) - return z_q - - -class BSQuantizer(nn.Module): - - def __init__(self, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size): - super().__init__() - self.codebook_dim = s1_bits + s2_bits - self.s1_bits = s1_bits - self.s2_bits = s2_bits - self.bsq = BinarySphericalQuantizer(self.codebook_dim, beta, gamma0, gamma, zeta, group_size=group_size) - - def bits_to_indices(self, bits): - bits = (bits >= 0).to(torch.long) - indices = 2 ** torch.arange( - 0, - bits.shape[-1], - 1, - dtype=torch.long, - device=bits.device, - ) - return (bits * indices).sum(-1) - - def forward(self, z, half=False, collect_metrics=True): - z = F.normalize(z, dim=-1) - quantized, bsq_loss, metrics = self.bsq(z, collect_metrics=collect_metrics) - if half: - q_pre = quantized[:, :, :self.s1_bits] - q_post = quantized[:, :, self.s1_bits:] - z_indices = [self.bits_to_indices(q_pre), self.bits_to_indices(q_post)] - else: - z_indices = self.bits_to_indices(quantized) - return bsq_loss, quantized, z_indices - - -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-5): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()).type_as(x) - return output * self.weight - - -class FeedForward(nn.Module): - def __init__(self, d_model, ff_dim, ffn_dropout_p=0.0): - super().__init__() - - self.w1 = nn.Linear(d_model, ff_dim, bias=False) - self.w3 = nn.Linear(d_model, ff_dim, bias=False) - self.w2 = nn.Linear(ff_dim, d_model, bias=False) - self.ffn_dropout = nn.Dropout(ffn_dropout_p) - - def forward(self, x): - return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) - - -class RotaryPositionalEmbedding(nn.Module): - def __init__(self, dim): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer("inv_freq", inv_freq) - self.seq_len_cached = None - self.cos_cached = None - self.sin_cached = None - - def _update_cos_sin_cache(self, x, seq_len): - if seq_len != self.seq_len_cached: - self.seq_len_cached = seq_len - t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) - freqs = torch.einsum('i,j->ij', t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.cos_cached = emb.cos()[None, None, :, :] - self.sin_cached = emb.sin()[None, None, :, :] - return self.cos_cached, self.sin_cached - - def forward(self, q, k): - cos, sin = self._update_cos_sin_cache(q, q.shape[-2]) - return ( - (q * cos) + (self._rotate_half(q) * sin), - (k * cos) + (self._rotate_half(k) * sin), - ) - - def _rotate_half(self, x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - -class MultiHeadAttentionWithRoPE(nn.Module): - def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout_p=0.0): - super().__init__() - self.d_model = d_model - self.n_heads = n_heads - self.head_dim = d_model // n_heads - - self.q_proj = nn.Linear(d_model, d_model) - self.k_proj = nn.Linear(d_model, d_model) - self.v_proj = nn.Linear(d_model, d_model) - self.out_proj = nn.Linear(d_model, d_model) - self.rotary = RotaryPositionalEmbedding(self.head_dim) - self.attn_dropout_p = attn_dropout_p - self.resid_dropout = nn.Dropout(resid_dropout_p) - - def forward(self, x, key_padding_mask=None): - batch_size, seq_len, _ = x.shape - - q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) - k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) - v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) - - q, k = self.rotary(q, k) - - if key_padding_mask is not None: - attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq_len] - attn_mask = attn_mask.expand(-1, self.n_heads, seq_len, -1) # [batch, n_heads, q_len, k_len] - else: - attn_mask = None - - attn_output = F.scaled_dot_product_attention( - q, k, v, - attn_mask=attn_mask, - dropout_p=self.attn_dropout_p if self.training else 0.0, - is_causal=True - ) - - attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) - return self.resid_dropout(self.out_proj(attn_output)) - - -class MultiHeadCrossAttentionWithRoPE(nn.Module): - def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout=0.0): - super().__init__() - self.d_model = d_model - self.n_heads = n_heads - self.head_dim = d_model // n_heads - - self.q_proj = nn.Linear(d_model, d_model) - self.k_proj = nn.Linear(d_model, d_model) - self.v_proj = nn.Linear(d_model, d_model) - self.out_proj = nn.Linear(d_model, d_model) - self.rotary = RotaryPositionalEmbedding(self.head_dim) - self.attn_dropout_p = attn_dropout_p - self.resid_dropout = nn.Dropout(resid_dropout) - - def forward(self, query, key, value, key_padding_mask=None): - batch_size, q_len, _ = query.shape - _, seq_len, _ = key.shape - - q = self.q_proj(query).view(batch_size, q_len, self.n_heads, self.head_dim).transpose(1, 2) - k = self.k_proj(key).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) - v = self.v_proj(value).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) - - q, k = self.rotary(q, k) - - if key_padding_mask is not None: - attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) - attn_mask = attn_mask.expand(-1, self.n_heads, q_len, -1) - else: - attn_mask = None - - is_causal_flag = self.training - - attn_output = F.scaled_dot_product_attention( - q, k, v, - attn_mask=attn_mask, - dropout_p=self.attn_dropout_p if self.training else 0.0, - is_causal=is_causal_flag - ) - - attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.d_model) - return self.resid_dropout(self.out_proj(attn_output)) - - -class HierarchicalEmbedding(nn.Module): - def __init__(self, s1_bits, s2_bits, d_model=256): - super().__init__() - self.s1_bits = s1_bits - self.s2_bits = s2_bits - - vocab_s1 = 2 ** s1_bits - vocab_s2 = 2 ** s2_bits - - self.emb_s1 = nn.Embedding(vocab_s1, d_model) - self.emb_s2 = nn.Embedding(vocab_s2, d_model) - self.d_model = d_model - self.fusion_proj = nn.Linear(d_model * 2, d_model) - - nn.init.normal_(self.emb_s1.weight, mean=0, std=d_model ** -0.5) - nn.init.normal_(self.emb_s2.weight, mean=0, std=d_model ** -0.5) - - def split_token(self, token_ids: torch.Tensor, s2_bits: int): - """Inputs: - token_ids (torch.Tensor): Composite token IDs of shape [batch_size, seq_len] or [N], each in range [0, 2^(s1_bits + s2_bits) - 1]. - s2_bits (int): Number of low bits used for the fine token (s2). - """ - assert isinstance(s2_bits, int) and s2_bits > 0, "s2_bits must be a positive integer" - - t = token_ids.long() - mask = (1 << s2_bits) - 1 - s2_ids = t & mask # extract low bits - s1_ids = t >> s2_bits # extract high bits - return s1_ids, s2_ids - - def forward(self, token_ids): - """Inputs: - token_ids: - - tuple or list: (s1_ids, s2_ids), each of shape [batch_size, seq_len], or - - torch.Tensor: composite token IDs of shape [batch_size, seq_len], which will be split into (s1_ids, s2_ids) internally. - Output: [batch_size, seq_len, d_model] - """ - if isinstance(token_ids, tuple) or isinstance(token_ids, list): - s1_ids, s2_ids = token_ids - else: - s1_ids, s2_ids = self.split_token(token_ids, self.s2_bits) - s1_emb = self.emb_s1(s1_ids) * math.sqrt(self.d_model) - s2_emb = self.emb_s2(s2_ids) * math.sqrt(self.d_model) - return self.fusion_proj(torch.cat([s1_emb, s2_emb], dim=-1)) - - -class DependencyAwareLayer(nn.Module): - def __init__(self, d_model, n_heads=4, attn_dropout_p=0.0, resid_dropout=0.0): - super().__init__() - self.cross_attn = MultiHeadCrossAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout) - self.norm = RMSNorm(d_model) - - def forward(self, hidden_states, sibling_embed, key_padding_mask=None): - """hidden_states: [batch, seq_len, d_model] - sibling_embed: Embedding from another subtoken - """ - attn_out = self.cross_attn( - query=sibling_embed, - key=hidden_states, - value=hidden_states, - key_padding_mask=key_padding_mask - ) - return self.norm(hidden_states + attn_out) - - -class TransformerBlock(nn.Module): - def __init__(self, d_model, n_heads, ff_dim=1024, ffn_dropout_p=0.0, attn_dropout_p=0.0, resid_dropout_p=0.0): - super().__init__() - self.norm1 = RMSNorm(d_model) - self.self_attn = MultiHeadAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout_p) - self.norm2 = RMSNorm(d_model) - self.ffn = FeedForward(d_model, ff_dim, ffn_dropout_p) - - def forward(self, x, key_padding_mask=None): - residual = x - x = self.norm1(x) - attn_out = self.self_attn(x, key_padding_mask=key_padding_mask) - x = residual + attn_out - - residual = x - x = self.norm2(x) - ffn_out = self.ffn(x) - x = residual + ffn_out - return x - - -class DualHead(nn.Module): - def __init__(self, s1_bits, s2_bits, d_model): - super().__init__() - self.vocab_s1 = 2 ** s1_bits - self.vocab_s2 = 2 ** s2_bits - self.proj_s1 = nn.Linear(d_model, self.vocab_s1) - self.proj_s2 = nn.Linear(d_model, self.vocab_s2) - - def compute_loss(self, s1_logits, s2_logits, s1_targets, s2_targets, padding_mask=None): - if padding_mask is not None: - valid_mask = (padding_mask == 0) - s1_logits = s1_logits[valid_mask] - s2_logits = s2_logits[valid_mask] - s1_targets = s1_targets[valid_mask] - s2_targets = s2_targets[valid_mask] - ce_s1 = F.cross_entropy(s1_logits, s1_targets) - ce_s2 = F.cross_entropy(s2_logits, s2_targets) - else: - ce_s1 = F.cross_entropy(s1_logits.reshape(-1, self.vocab_s1), s1_targets.reshape(-1)) - ce_s2 = F.cross_entropy(s2_logits.reshape(-1, self.vocab_s2), s2_targets.reshape(-1)) - ce_loss = (ce_s1 + ce_s2) / 2 - return ce_loss, ce_s1, ce_s2 - - def forward(self, x): - return self.proj_s1(x) - - def cond_forward(self, x2): - return self.proj_s2(x2) - - -class FixedEmbedding(nn.Module): - def __init__(self, c_in, d_model): - super(FixedEmbedding, self).__init__() - - w = torch.zeros(c_in, d_model).float() - w.require_grad = False - - position = torch.arange(0, c_in).float().unsqueeze(1) - div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() - - w[:, 0::2] = torch.sin(position * div_term) - w[:, 1::2] = torch.cos(position * div_term) - - self.emb = nn.Embedding(c_in, d_model) - self.emb.weight = nn.Parameter(w, requires_grad=False) - - def forward(self, x): - return self.emb(x).detach() - - -class TemporalEmbedding(nn.Module): - def __init__(self, d_model, learn_pe): - super(TemporalEmbedding, self).__init__() - - minute_size = 60 - hour_size = 24 - weekday_size = 7 - day_size = 32 - month_size = 13 - - Embed = FixedEmbedding if not learn_pe else nn.Embedding - self.minute_embed = Embed(minute_size, d_model) - self.hour_embed = Embed(hour_size, d_model) - self.weekday_embed = Embed(weekday_size, d_model) - self.day_embed = Embed(day_size, d_model) - self.month_embed = Embed(month_size, d_model) - - def forward(self, x): - x = x.long() - - minute_x = self.minute_embed(x[:, :, 0]) - hour_x = self.hour_embed(x[:, :, 1]) - weekday_x = self.weekday_embed(x[:, :, 2]) - day_x = self.day_embed(x[:, :, 3]) - month_x = self.month_embed(x[:, :, 4]) - - return hour_x + weekday_x + day_x + month_x + minute_x \ No newline at end of file diff --git a/skills/alphaear-signal-tracker/scripts/utils/predictor/training.py b/skills/alphaear-signal-tracker/scripts/utils/predictor/training.py deleted file mode 100644 index 3b41724..0000000 --- a/skills/alphaear-signal-tracker/scripts/utils/predictor/training.py +++ /dev/null @@ -1,539 +0,0 @@ -import os -import sys -import time -import torch -import torch.nn as nn -import pandas as pd -import numpy as np -import json -import random -from loguru import logger -from datetime import datetime, timedelta -from sentence_transformers import SentenceTransformer -from skills._env_loader import load_unified_env - -load_unified_env() - -# Setup paths -KRONOS_DIR = os.path.dirname(os.path.abspath(__file__)) -SRC_DIR = os.path.dirname(os.path.dirname(KRONOS_DIR)) -if SRC_DIR not in sys.path: - sys.path.insert(0, SRC_DIR) - -from ..kronos.model import Kronos, KronosTokenizer, KronosPredictor -from ..database_manager import DatabaseManager -from ..stock_tools import StockTools -from ..search_tools import SearchTools -from ..llm.factory import get_model -from ..visualizer import VisualizerTools -from ..schema.models import ForecastResult, KLinePoint -from agno.agent import Agent - - -class AutoSynthesisTrainer: - def __init__(self, news_dim=384): - self.device = ( - "cuda" - if torch.cuda.is_available() - else "mps" - if torch.backends.mps.is_available() - else "cpu" - ) - self.db = DatabaseManager() - self.tools = StockTools(self.db) - self.searcher = SearchTools(self.db) - # Try loading from local cache first to avoid network timeouts - model_name = os.getenv( - "EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2" - ) - try: - logger.info(f"🔄 Attempting to load {model_name} from local cache...") - self.embedder = SentenceTransformer( - model_name, device=self.device, local_files_only=True - ) - logger.success("✅ Model loaded from local cache.") - except Exception: - logger.warning( - "⚠️ Local cache not found or incomplete. Attempting to download..." - ) - self.embedder = SentenceTransformer(model_name, device=self.device) - self.news_dim = news_dim - - # Try loading from local cache first to avoid network timeouts - try: - logger.info( - "🔄 Attempting to load Kronos and Tokenizer from local cache..." - ) - self.tokenizer = KronosTokenizer.from_pretrained( - "NeoQuasar/Kronos-Tokenizer-base", local_files_only=True - ).to(self.device) - base_model = Kronos.from_pretrained( - "NeoQuasar/Kronos-base", local_files_only=True - ) - logger.success("✅ Kronos and Tokenizer loaded from local cache.") - except Exception: - logger.warning( - "⚠️ Local Kronos/Tokenizer not found or incomplete. Attempting to download..." - ) - self.tokenizer = KronosTokenizer.from_pretrained( - "NeoQuasar/Kronos-Tokenizer-base" - ).to(self.device) - base_model = Kronos.from_pretrained("NeoQuasar/Kronos-base") - - self.model = Kronos( - base_model.s1_bits, - base_model.s2_bits, - base_model.n_layers, - base_model.d_model, - base_model.n_heads, - base_model.ff_dim, - base_model.ffn_dropout_p, - base_model.attn_dropout_p, - base_model.resid_dropout_p, - base_model.token_dropout_p, - base_model.learn_te, - news_dim=self.news_dim, - ).to(self.device) - self.model.load_state_dict(base_model.state_dict(), strict=False) - - # LLM for causality verification - provider = os.getenv("LLM_PROVIDER", "minimax") - model_id = os.getenv("LLM_MODEL", "Qwen") - self.llm_agent = Agent(model=get_model(provider, model_id)) - - def discover_shocks( - self, ticker_list, threshold=2.0, limit_per_stock=5, days=365, pred_len=5 - ): - """1. Find days with significant price movements (Look back 1 year)""" - shocks = [] - end_date = datetime.now().strftime("%Y-%m-%d") - start_date = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d") - - for ticker in ticker_list: - df = self.tools.get_stock_price( - ticker, start_date=start_date, end_date=end_date - ) - if df.empty or len(df) < 60: - continue - - # Look for big moves - moves = df[df["change_pct"].abs() > threshold].copy() - if moves.empty: - continue - - count = 0 - for idx, row in moves.iterrows(): - # Ensure we have history before this day AND enough future days for eval - date_idx = df.index.get_loc(idx) - if date_idx < 50 or date_idx + pred_len > len(df): - continue - - shocks.append( - { - "ticker": ticker, - "date": row["date"], - "change": row["change_pct"], - "history": df.iloc[date_idx - 50 : date_idx], - "target": df.iloc[ - date_idx : date_idx + pred_len - ], # Now capturing pred_len days - } - ) - count += 1 - if count >= limit_per_stock: - break - - logger.info( - f"✨ Discovered {len(shocks)} potential price shocks over the last {days} days." - ) - return shocks - - def find_reason_and_verify(self, shock): - """2. Search for reasons and verify causality using LLM""" - ticker_info = self.db.get_stock_by_code(shock["ticker"]) - name = ticker_info["name"] if ticker_info else shock["ticker"] - date_str = shock["date"] - - # Try multiple query variations and engines - queries = [ - f"{name} ({shock['ticker']}) {date_str} 为什么涨跌 原因", - f"{name} {date_str} 异动 原因", - f"{shock['ticker']} {date_str} 新闻", - ] - - search_results = [] - for query in queries: - logger.info(f"🔍 Searching for reason: {query}") - # Try alternate engines - for engine in ["baidu"]: - try: - results = self.searcher.search_list( - query, engine=engine, max_results=3, enrich=False - ) - if results: - search_results = results - break - except Exception as e: - logger.warning(f"Search failed for {query} on {engine}: {e}") - - if search_results: - break - time.sleep(random.uniform(1.0, 2.0)) - - if not search_results: - logger.warning( - f"⚠️ No search results found for {name} on {date_str} after multiple attempts." - ) - return None - - context = "\n".join( - [f"- {r['title']}: {r.get('content', '')[:300]}" for r in search_results] - ) - - prompt = f""" - 任务:判断以下新闻是否解释了该股票在 {date_str} 的 {shock["change"]:.2f}% 价格变动。 - - 股票:{name} - 日期:{date_str} - 变动:{shock["change"]:.2f}% - - 搜索结果: - {context} - - 要求: - 1. 该新闻是否在该日期左右发生? - 2. 该新闻是否能逻辑上解释这种大幅波动(如财报、利好政策、重组、大环境暴跌等)? - 3. 如果是,请总结一段 100 字以内的“核心推动原因”。 - 4. 返回 JSON: {{"is_causal": true/false, "summary": "原因摘要"}} - """ - - try: - res = self.llm_agent.run(prompt) - data = json.loads( - res.content.replace("```json", "").replace("```", "").strip() - ) - if data.get("is_causal"): - logger.success( - f"✅ Verified cause for {name} on {date_str}: {data['summary']}" - ) - return data["summary"] - else: - logger.warning( - f"❌ Verified cause for {name} on {date_str}: {data['summary']}" - ) - return None - except Exception as e: - logger.warning(f"Verification failed: {e}") - return None - - def save_model(self, path=None): - """Save the news_proj weights""" - if path is None: - save_dir = os.path.join(SRC_DIR, "exports/models") - os.makedirs(save_dir, exist_ok=True) - path = os.path.join( - save_dir, f"kronos_news_v1_{datetime.now().strftime('%Y%m%d_%H%M')}.pt" - ) - - # We only really need to save the news_proj part as it's the only one we train - torch.save( - { - "news_proj_state_dict": self.model.news_proj.state_dict(), - "news_dim": self.news_dim, - "d_model": self.model.d_model, - }, - path, - ) - logger.success(f"💾 Model weights saved to {path}") - return path - - def run_synthesis_and_train(self, tickers, pred_len=5): - # 1. Discovery - shocks = self.discover_shocks(tickers, pred_len=pred_len) - print(f"find {len(shocks)} shocks") - - # 2. News Association & Verification - dataset = [] - max_news_items = 200 # Limit to 200 news items per session to avoid search bans - - logger.info( - f"🧬 Starting News Association for {len(shocks)} shocks (Max limit: {max_news_items})" - ) - - for i, shock in enumerate(shocks): - if len(dataset) >= max_news_items: - logger.info("Reached maximum news items limit for this session.") - break - - summary = self.find_reason_and_verify(shock) - if summary: - # 3. Embedding news - emb = self.embedder.encode(summary) - dataset.append( - { - "history": shock["history"], - "target": shock["target"], - "news_emb": emb, - "summary": summary, - } - ) - - # Add delay after search with randomness to avoid being blocked - if i < len(shocks) - 1: - delay = random.uniform(2.0, 4.0) - time.sleep(delay) - - if not dataset: - logger.error( - "❌ No verified news-price pairs found. Adjust threshold or check if news is available in that period." - ) - return - - # 4. Train/Val Split - random.seed(42) - random.shuffle(dataset) - - if len(dataset) < 2: - train_set = dataset - val_set = [] - logger.warning( - f"⚠️ Only {len(dataset)} sample(s) found. Training on all, skipping validation." - ) - else: - split_idx = max(1, int(len(dataset) * 0.8)) - if split_idx >= len(dataset): - split_idx = len(dataset) - 1 - - train_set = dataset[:split_idx] - val_set = dataset[split_idx:] - logger.info( - f"🏗️ Dataset Split: {len(train_set)} samples for training, {len(val_set)} for validation." - ) - - if not train_set: - logger.error("❌ No samples for training.") - return - - # 5. Training (Few-shot) - optimizer = torch.optim.Adam(self.model.news_proj.parameters(), lr=1e-3) - criterion = nn.CrossEntropyLoss() - self.model.train() - - loss_history = [] - logger.info(f"🚀 Training for 30 epochs...") - for epoch in range(30): - total_loss = 0 - for item in train_set: - optimizer.zero_grad() - - # Prep Data - hist_df = item["history"] - # For training, we still focus on the immediate next point (teacher forcing) - target_df = item["target"].iloc[:1] - - hist_raw = hist_df[ - ["open", "high", "low", "close", "volume"] - ].values.astype(np.float32) - hist_raw = np.column_stack([hist_raw, hist_raw[:, 3] * hist_raw[:, 4]]) - - mean, std = hist_raw.mean(axis=0), hist_raw.std(axis=0) + 1e-5 - hist_norm = ( - torch.from_numpy((hist_raw - mean) / std) - .unsqueeze(0) - .to(self.device) - ) - - target_raw = target_df[ - ["open", "high", "low", "close", "volume"] - ].values.astype(np.float32) - target_raw = np.column_stack( - [target_raw, target_raw[:, 3] * target_raw[:, 4]] - ) - target_norm = ( - torch.from_numpy((target_raw - mean) / std) - .unsqueeze(0) - .to(self.device) - ) - - with torch.no_grad(): - z_indices = self.tokenizer.encode(hist_norm, half=True) - t_indices = self.tokenizer.encode(target_norm, half=True) - s1_ids, s2_ids = z_indices[0], z_indices[1] - t_s1, t_s2 = t_indices[0], t_indices[1] - - news_t = torch.from_numpy(item["news_emb"]).unsqueeze(0).to(self.device) - s1_logits, s2_logits = self.model( - s1_ids, - s2_ids, - news_emb=news_t, - use_teacher_forcing=True, - s1_targets=t_s1, - ) - - loss = ( - criterion(s1_logits[:, -1, :], t_s1[:, 0]) - + criterion(s2_logits[:, -1, :], t_s2[:, 0]) - ) / 2 - loss.backward() - optimizer.step() - total_loss += loss.item() - - avg_epoch_loss = total_loss / max(1, len(train_set)) - loss_history.append(avg_epoch_loss) - - if (epoch + 1) % 10 == 0: - logger.info(f"Epoch {epoch + 1} Loss: {avg_epoch_loss:.4f}") - - # 5.1 Visualize Loss Curve - loss_chart = VisualizerTools.generate_loss_chart(loss_history) - VisualizerTools.render_chart_to_file( - loss_chart, - os.path.join(SRC_DIR, "exports/training_results/loss_curve.html"), - ) - - # 5.2 Save final model - self.save_model() - - # 6. Final Evaluation on Validation Set - if not val_set: - logger.warning("⚠️ Validation set is empty. Skipping statistical analysis.") - return - - logger.info( - f"🧪 Final Evaluation: Base vs News-Integrated ({pred_len}-day Window)" - ) - self.model.eval() - predictor = KronosPredictor(self.model, self.tokenizer, device=self.device) - - base_maes = [] - news_maes = [] - - print("\n" + "=" * 90) - print( - f"{'Date':<12} | {'Ticker':<8} | {'Base MAE':<15} | {'News MAE':<15} | {'Improvement'}" - ) - print("-" * 90) - - for item in val_set: - h = item["history"] - t = item["target"] - actuals = t["close"].values[:pred_len] - - x_ts = pd.to_datetime(h["date"]) - # Future timestamps: handle business days if possible, or just simple offset - future_dates = pd.date_range( - start=x_ts.iloc[-1] + timedelta(days=1), periods=pred_len, freq="B" - ) - y_ts = pd.Series(future_dates) - - # A. Base Prediction - p_base = predictor.predict( - h, x_ts, y_ts, pred_len=pred_len, news_emb=None, verbose=False - ) - b_preds = p_base["close"].values[: len(actuals)] - - # B. News-Aware Prediction - p_news = predictor.predict( - h, - x_ts, - y_ts, - pred_len=pred_len, - news_emb=item["news_emb"], - verbose=False, - ) - n_preds = p_news["close"].values[: len(actuals)] - - # Calculate MAE over the window - b_mae = np.mean(np.abs(b_preds - actuals)) - n_mae = np.mean(np.abs(n_preds - actuals)) - - base_maes.append(b_mae) - news_maes.append(n_mae) - - improvement = (b_mae - n_mae) / (b_mae + 1e-6) * 100 - - date_str = str(t["date"].values[0])[:10] - ticker = h.iloc[-1]["ticker"] if "ticker" in h.columns else "Stock" - print( - f"{date_str:<12} | {ticker:<8} | {b_mae:<15.4f} | {n_mae:<15.4f} | {improvement:>+7.1f}%" - ) - - # C. Generate Visualization for this case - try: - # Helper to convert DF to KLinePoints - def to_kp_list(preds_df): - points = [] - for idx, row in preds_df.iterrows(): - points.append( - KLinePoint( - date=str(idx)[:10], - open=row["open"], - high=row["high"], - low=row["low"], - close=row["close"], - volume=row["volume"] if "volume" in row else 0, - ) - ) - return points - - forecast_obj = ForecastResult( - ticker=ticker, - base_forecast=to_kp_list(p_base), - adjusted_forecast=to_kp_list(p_news), - rationale=item["summary"], - ) - - # Ground truth for visualizer expects a DataFrame with 'date' and 'close' - gt_df = t[["date", "open", "high", "low", "close", "volume"]] - - chart = VisualizerTools.generate_stock_chart( - df=h, - ticker=ticker, - title=f"Training Eval: {ticker} ({date_str}) Improvement: {improvement:.1f}%", - forecast=forecast_obj, - ground_truth=gt_df, - ) - - safe_date = date_str.replace("-", "") - filename = f"eval_{ticker}_{safe_date}.html" - VisualizerTools.render_chart_to_file( - chart, os.path.join(SRC_DIR, f"exports/training_results/{filename}") - ) - except Exception as e: - logger.error(f"Failed to generate eval chart for {ticker}: {e}") - - # Summary Statistics - avg_base_err = sum(base_maes) / max(1, len(base_maes)) - avg_news_err = sum(news_maes) / max(1, len(news_maes)) - overall_imp = (avg_base_err - avg_news_err) / (avg_base_err + 1e-6) * 100 - - print("-" * 90) - print( - f"{'AVERAGE':<12} | {'-':<8} | {avg_base_err:<15.4f} | {avg_news_err:<15.4f} | {overall_imp:>+7.1f}%" - ) - print("=" * 90 + "\n") - - logger.success( - f"🏁 Statistical Analysis Complete. Avg Error Reduction ({pred_len}-day): {overall_imp:.2f}%" - ) - logger.info( - f"📊 Visualization results saved to: {os.path.join(SRC_DIR, 'exports/training_results/')}" - ) - - -if __name__ == "__main__": - trainer = AutoSynthesisTrainer() - - logger.info("📂 Fetching all stock codes from database...") - res = trainer.db.execute_query("SELECT code FROM stock_list") - all_tickers = [row["code"] for row in res] - - if not all_tickers: - logger.warning("⚠️ No tickers found in stock_list table. Trying to sync...") - trainer.tools._check_and_update_stock_list(force=True) - res = trainer.db.execute_query("SELECT code FROM stock_list") - all_tickers = [row["code"] for row in res] - - logger.info(f"🚀 Starting training on potential stocks (1-year scan)...") - # 为了演示,我们扫描前 100 个股票,寻找最近一年的冲击点 - trainer.run_synthesis_and_train(all_tickers[:100], pred_len=1) diff --git a/skills/alphaear-signal-tracker/scripts/utils/search_tools.py b/skills/alphaear-signal-tracker/scripts/utils/search_tools.py deleted file mode 100644 index 50b08f3..0000000 --- a/skills/alphaear-signal-tracker/scripts/utils/search_tools.py +++ /dev/null @@ -1,611 +0,0 @@ -import os -import hashlib -import json -import re -import requests -import time -import threading -from typing import List, Dict, Optional, Any -from agno.tools.duckduckgo import DuckDuckGoTools -from agno.tools.baidusearch import BaiduSearchTools -from agno.agent import Agent -from loguru import logger -from datetime import datetime -from .database_manager import DatabaseManager -from .content_extractor import ContentExtractor -from .llm.factory import get_model -from .hybrid_search import LocalNewsSearch - -# 默认搜索缓存 TTL(秒),可通过环境变量覆盖 -DEFAULT_SEARCH_TTL = int(os.getenv("SEARCH_CACHE_TTL", "3600")) # 默认 1 小时 - - -class JinaSearchEngine: - """Jina Search API 封装 - 使用 s.jina.ai 进行网络搜索""" - - JINA_SEARCH_URL = "https://s.jina.ai/" - - # 速率限制配置 - _rate_limit_no_key = 10 # 无 key 时每分钟最大请求数 - _rate_window = 60.0 - _min_interval = 2.0 - _request_times = [] - _last_request_time = 0.0 - _lock = threading.Lock() - - def __init__(self): - self.api_key = os.getenv("JINA_API_KEY", "").strip() - self.has_api_key = bool(self.api_key) - if self.has_api_key: - logger.info("✅ Jina Search API key configured") - - @classmethod - def _wait_for_rate_limit(cls, has_api_key: bool) -> None: - """等待以满足速率限制""" - if has_api_key: - time.sleep(0.3) - return - - with cls._lock: - current_time = time.time() - cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window] - - if len(cls._request_times) >= cls._rate_limit_no_key: - oldest = cls._request_times[0] - wait_time = cls._rate_window - (current_time - oldest) + 1.0 - if wait_time > 0: - logger.warning(f"⏳ Jina Search rate limit, waiting {wait_time:.1f}s...") - time.sleep(wait_time) - current_time = time.time() - cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window] - - time_since_last = current_time - cls._last_request_time - if time_since_last < cls._min_interval: - time.sleep(cls._min_interval - time_since_last) - - cls._request_times.append(time.time()) - cls._last_request_time = time.time() - - def search(self, query: str, max_results: int = 5) -> List[Dict]: - """ - 使用 Jina Search API 执行搜索 - - Args: - query: 搜索关键词 - max_results: 返回结果数量 - - Returns: - 搜索结果列表,每个结果包含 title, url, content - """ - if not query: - return [] - - logger.info(f"🔍 Jina Search: {query}") - - # 等待速率限制 - self._wait_for_rate_limit(self.has_api_key) - - headers = { - "Accept": "application/json", - "X-Retain-Images": "none", - } - - if self.has_api_key: - headers["Authorization"] = f"Bearer {self.api_key}" - - try: - # Jina Search API: https://s.jina.ai/{query} - import urllib.parse - encoded_query = urllib.parse.quote(query) - url = f"{self.JINA_SEARCH_URL}{encoded_query}" - - response = requests.get(url, headers=headers, timeout=30) - - if response.status_code == 429: - logger.warning("⚠️ Jina Search rate limited (429), waiting 30s...") - time.sleep(30) - return self.search(query, max_results) - - if response.status_code != 200: - logger.warning(f"Jina Search failed (Status {response.status_code})") - return [] - - # 解析响应 - try: - data = response.json() - except json.JSONDecodeError: - # 如果返回纯文本,尝试解析 - data = {"data": [{"title": "Search Result", "url": "", "content": response.text}]} - - results = [] - - # Jina 返回格式可能是 {"data": [...]} 或直接是列表 - items = data.get("data", []) if isinstance(data, dict) else data - if not isinstance(items, list): - items = [items] if items else [] - - for i, item in enumerate(items[:max_results]): - if isinstance(item, dict): - results.append({ - "title": item.get("title", f"Result {i+1}"), - "url": item.get("url", ""), - "href": item.get("url", ""), # 兼容性 - "content": item.get("content", item.get("description", "")), - "body": item.get("content", item.get("description", "")), # 兼容性 - }) - elif isinstance(item, str): - results.append({ - "title": f"Result {i+1}", - "url": "", - "content": item - }) - - logger.info(f"✅ Jina Search returned {len(results)} results") - return results - - except requests.exceptions.Timeout: - logger.error("Jina Search timeout") - return [] - except requests.exceptions.RequestException as e: - logger.error(f"Jina Search request error: {e}") - return [] - except Exception as e: - logger.error(f"Jina Search unexpected error: {e}") - return [] - -class SearchTools: - """扩展性搜索工具库 - 支持多引擎聚合与内容缓存""" - - def __init__(self, db: DatabaseManager): - self.db = db - - # 检查 Jina API Key 是否配置 - jina_api_key = os.getenv("JINA_API_KEY", "").strip() - self._jina_enabled = bool(jina_api_key) - - self._engines = { - "ddg": DuckDuckGoTools(), - "baidu": BaiduSearchTools(), - "local": LocalNewsSearch(db) - } - - # 如果配置了 Jina API Key,添加 Jina 引擎 - if self._jina_enabled: - self._engines["jina"] = JinaSearchEngine() - logger.info("🚀 Jina Search engine enabled (JINA_API_KEY configured)") - - # 确定默认搜索引擎 - self._default_engine = "jina" if self._jina_enabled else "ddg" - - def _generate_hash(self, query: str, engine: str, max_results: int) -> str: - return hashlib.md5(f"{engine}:{query}:{max_results}".encode()).hexdigest() - - def search(self, query: str, engine: str = None, max_results: int = 5, ttl: Optional[int] = None) -> str: - """ - 使用指定搜索引擎执行网络搜索,结果会被缓存以提高效率。 - - Args: - query: 搜索关键词,如 "英伟达财报" 或 "光伏行业政策"。 - engine: 搜索引擎选择。可选值: - "jina" (Jina Search,需配置 JINA_API_KEY,LLM友好输出), - "ddg" (DuckDuckGo,推荐英文/国际搜索), - "baidu" (百度,推荐中文/国内搜索), - "local" (本地历史新闻搜索,基于向量+BM25)。 - 默认: 若配置了 JINA_API_KEY 则使用 "jina",否则 "ddg"。 - max_results: 期望返回的结果数量,默认 5 条。 - ttl: 缓存有效期(秒)。如果缓存超过此时间会重新搜索。 - 默认使用环境变量 SEARCH_CACHE_TTL 或 3600 秒。 - 设为 0 可强制刷新。 - - Returns: - 搜索结果的文本描述,包含标题、摘要和链接。 - """ - # 使用默认引擎(如果配置了 Jina 则优先使用 Jina) - if engine is None: - engine = self._default_engine - - if engine not in self._engines: - return f"Error: Unsupported engine '{engine}'. Available: {list(self._engines.keys())}" - - query_hash = self._generate_hash(query, engine, max_results) - effective_ttl = ttl if ttl is not None else DEFAULT_SEARCH_TTL - - # 1. 尝试从缓存读取 (local 引擎不缓存,因为它本身就是查库) - if engine != "local": - cache = self.db.get_search_cache(query_hash, ttl_seconds=effective_ttl if effective_ttl > 0 else None) - if cache and effective_ttl != 0: - logger.info(f"ℹ️ Found search results in cache for: {query} ({engine})") - return cache['results'] - - # 2. 执行真实搜索 - logger.info(f"📡 Searching {engine} for: {query}") - try: - tool = self._engines[engine] - if engine == "jina": - # Jina Search 返回 List[Dict] - jina_results = tool.search(query, max_results=max_results) - results = [] - for r in jina_results: - results.append({ - "title": r.get("title", ""), - "href": r.get("url", ""), - "body": r.get("content", "") - }) - elif engine == "ddg": - results = tool.duckduckgo_search(query, max_results=max_results) - elif engine == "baidu": - results = tool.baidu_search(query, max_results=max_results) - elif engine == "local": - # LocalNewsSearch 返回的是 List[Dict] - local_results = tool.search(query, top_n=max_results) - results = [] - for r in local_results: - results.append({ - "title": r.get("title"), - "href": r.get("url", "local"), - "body": r.get("content", "") - }) - else: - results = "Search not implemented for this engine." - - results_str = str(results) - if engine != "local": - self.db.save_search_cache(query_hash, query, engine, results_str) - return results_str - - except Exception as e: - # 搜索失败时的降级策略 - if engine == "jina": - logger.warning(f"⚠️ Jina search failed, falling back to ddg: {query} ({e})") - try: - return self.search(query, engine="ddg", max_results=max_results, ttl=ttl) - except Exception as e2: - logger.error(f"❌ DDG fallback also failed for {query}: {e2}") - elif engine == "ddg": - logger.warning(f"⚠️ DDG search failed, falling back to baidu: {query} ({e})") - try: - return self.search(query, engine="baidu", max_results=max_results, ttl=ttl) - except Exception as e2: - logger.error(f"❌ Baidu fallback also failed for {query}: {e2}") - - logger.error(f"❌ Search failed for {query}: {e}") - return f"Error occurred during search: {str(e)}" - - def search_list(self, query: str, engine: str = None, max_results: int = 5, ttl: Optional[int] = None, enrich: bool = True) -> List[Dict]: - """ - 执行搜索并返回结构化列表 (List[Dict])。 - Dict 包含: title, href (or url), body (or snippet) - - Args: - engine: 搜索引擎,默认使用配置的默认引擎(Jina 优先) - enrich: 是否抓取正文内容 (默认 True) - """ - # 使用默认引擎 - if engine is None: - engine = self._default_engine - - if engine not in self._engines: - logger.error(f"Unsupported engine {engine}") - return [] - - # 不同的 hash 以区分是否 enrichment - enrich_suffix = ":enriched" if enrich else "" - query_hash = self._generate_hash(query, engine + enrich_suffix, max_results) - effective_ttl = ttl if ttl is not None else DEFAULT_SEARCH_TTL - - # 1. 尝试从缓存读取 - cache = self.db.get_search_cache(query_hash, ttl_seconds=effective_ttl if effective_ttl > 0 else None) - if cache and effective_ttl != 0: - try: - cached_data = json.loads(cache['results']) - if isinstance(cached_data, list): - logger.info(f"ℹ️ Found structured search cache for: {query}") - return cached_data - except: - pass - - # 1.5 Smart Cache (Fuzzy + LLM) - if effective_ttl != 0: - try: - # 1. Similar cached queries - similar_queries = self.db.find_similar_queries(query, limit=3) - # Filter by TTL - valid_candidates = [] - for q in similar_queries: - if q['query'] == query: continue - q_time = datetime.fromisoformat(q['timestamp']) - if effective_ttl and (datetime.now() - q_time).total_seconds() > effective_ttl: - continue - q['type'] = 'cached_search' - valid_candidates.append(q) - - # 2. Relevant local news (as search results) - local_news = self.db.search_local_news(query, limit=3) - if local_news: - # Group local news as a single "candidate" source? Or individual? - # Better to treat "Local News Database" as one candidate source that contains X items. - # Or just add them to candidates list? - # Let's package strictly relevant news as a "local_news_bundle" - valid_candidates.append({ - 'type': 'local_news', - 'query': 'Local Database News', - 'items': local_news, - 'timestamp': datetime.now().isoformat() - }) - - if valid_candidates: - logger.info(f"🤔 Found {len(valid_candidates)} smart cache candidates (Queries/News). Asking LLM...") - evaluation = self._evaluate_cache_relevance(query, valid_candidates) - - if evaluation and evaluation.get('reuse', False): - idx = evaluation.get('index', -1) - if 0 <= idx < len(valid_candidates): - chosen = valid_candidates[idx] - logger.info(f"🤖 LLM suggested reusing: '{chosen.get('query')}' ({chosen['type']})") - - if chosen['type'] == 'cached_search': - # Load the chosen cache - cache = self.db.get_search_cache(chosen['query_hash']) - if cache: - try: - cached_data = json.loads(cache['results']) - if isinstance(cached_data, list): - return cached_data - except: - pass - elif chosen['type'] == 'local_news': - # Convert local news items to search result format - news_results = [] - for i, news in enumerate(chosen['items'], 1): - news_results.append({ - "id": news.get('id'), - "rank": i, - "title": news.get('title'), - "url": news.get('url'), - "content": news.get('content'), - "original_snippet": news.get('content')[:200] if news.get('content') else '', - "source": f"Local News ({news.get('source')})", - "publish_time": news.get('publish_time'), - "crawl_time": news.get('crawl_time'), - "sentiment_score": news.get('sentiment_score', 0), - "meta_data": {"origin": "local_db"} - }) - return news_results - - except Exception as e: - logger.warning(f"Smart cache check failed: {e}") - - # 2. 执行搜索 - logger.info(f"📡 Searching {engine} (structured) for: {query}") - try: - tool = self._engines[engine] - results = [] - if engine == "jina": - # Jina Search 直接返回结构化数据 - jina_results = tool.search(query, max_results=max_results) - for r in jina_results: - results.append({ - "title": r.get("title", ""), - "url": r.get("url", ""), - "href": r.get("url", ""), - "body": r.get("content", ""), - "content": r.get("content", ""), - "source": "Jina Search" - }) - elif engine == "ddg": - results = tool.duckduckgo_search(query, max_results=max_results) - elif engine == "baidu": - results = tool.baidu_search(query, max_results=max_results) - elif engine == "local": - # LocalNewsSearch 返回的是 List[Dict] - local_results = tool.search(query, top_n=max_results) - results = [] - for r in local_results: - results.append({ - "title": r.get("title"), - "url": r.get("url", "local"), - "body": r.get("content", "")[:500], - "source": f"Local ({r.get('source', 'db')})", - "publish_time": r.get("publish_time") - }) - - # 处理字符串类型的 JSON 返回 (Baidu 常返 JSON 字符串) - if isinstance(results, str) and engine not in ["local", "jina"]: - try: - results = json.loads(results) - except: - pass - - # 转为统一格式 - normalized_results = [] - if isinstance(results, list): - - for i, r in enumerate(results, 1): - title = r.get('title', '') - url = r.get('href') or r.get('url') or r.get('link', '') - content = r.get('body') or r.get('snippet') or r.get('abstract', '') - - if title and url: - normalized_results.append({ - "id": self._generate_hash(url + query, "search_item", i), - "rank": i, - "title": title, - "url": url, - "content": content, - "original_snippet": content, # 保留摘要 - "source": f"Search ({engine})", - "publish_time": datetime.now().isoformat(), # 暂用当前时间 - "crawl_time": datetime.now().isoformat(), - "meta_data": {"query": query, "engine": engine} - }) - - # Fallback if still string and failed to parse - elif isinstance(results, str) and results: - normalized_results.append({"title": query, "url": "", "content": results, "source": engine}) - - # 3. 抓取正文 & 计算情绪 (Enrichment) - # 注意:如果使用 Jina Search,内容已经是 LLM 友好格式,可选择跳过 enrichment - skip_content_enrichment = (engine == "jina") - - if enrich and normalized_results: - logger.info(f"🕸️ Enriching {len(normalized_results)} search results with Jina & Sentiment...") - extractor = ContentExtractor() - - # Lazy load sentiment tool - if not hasattr(self, 'sentiment_tool') or self.sentiment_tool is None: - from ..sentiment_tools import SentimentTools - self.sentiment_tool = SentimentTools(self.db) - - for item in normalized_results: - if item.get("url"): - try: - # 如果是 Jina Search,内容已经足够好,跳过额外抓取 - if skip_content_enrichment and item.get("content") and len(item.get("content", "")) > 100: - full_content = item["content"] - else: - # Use Jina Reader to get full content - full_content = extractor.extract_with_jina(item["url"], timeout=60) - - if full_content and len(full_content) > 100: - item["content"] = full_content - - # Calculate sentiment - # Use title + snippet of content for efficiency - text_to_analyze = f"{item['title']} {full_content[:500]}" - sent_result = self.sentiment_tool.analyze_sentiment(text_to_analyze) # Using self.sentiment_tool - score = sent_result.get('score', 0.0) - item["sentiment_score"] = float(score) - - logger.info(f" ✅ Enriched: {item['title'][:20]}... (Sentiment: {score:.2f})") - else: - # Fallback: Use snippet for sentiment - logger.info(f" ⚠️ Content short/failed for {item['url']}, using snippet for sentiment.") - text_to_analyze = f"{item['title']} {item['content']}" # content is snippet here - sent_result = self.sentiment_tool.analyze_sentiment(text_to_analyze) - score = sent_result.get('score', 0.0) - item["sentiment_score"] = float(score) - - except Exception as e: - # Fallback: Use snippet for sentiment on error - logger.warning(f"Failed to enrich {item['url']}: {e}. Using snippet.") - text_to_analyze = f"{item['title']} {item['content']}" - sent_result = self.sentiment_tool.analyze_sentiment(text_to_analyze) - score = sent_result.get('score', 0.0) - item["sentiment_score"] = float(score) - - # 缓存结果 list - if normalized_results: - # Pass list directly, DB manager will handle JSON dump for main cache and populate search_details - # Only cache if NOT from local news reuse (though this logic path is for fresh search) - self.db.save_search_cache(query_hash, query, engine, normalized_results) - - return normalized_results - - except Exception as e: - # 搜索失败时的降级策略 - if engine == "jina": - logger.warning(f"⚠️ Jina search_list failed, falling back to ddg: {query} ({e})") - try: - return self.search_list(query, engine="ddg", max_results=max_results, ttl=ttl, enrich=enrich) - except Exception as e2: - logger.error(f"❌ DDG fallback (search_list) also failed for {query}: {e2}") - elif engine == "ddg": - logger.warning(f"⚠️ DDG search_list failed, falling back to baidu: {query} ({e})") - try: - return self.search_list(query, engine="baidu", max_results=max_results, ttl=ttl, enrich=enrich) - except Exception as e2: - logger.error(f"❌ Baidu fallback (search_list) also failed for {query}: {e2}") - - logger.error(f"❌ Structured search failed for {query}: {e}") - return [] - - def _evaluate_cache_relevance(self, current_query: str, candidates: List[Dict]) -> Dict: - """ - 使用 LLM 评估缓存候选是否足以回答当前问题。 - """ - try: - # Prepare candidates text - candidates_desc = [] - for i, c in enumerate(candidates): - if c['type'] == 'cached_search': - # Preview cached results if available? - # Maybe just use the query string as a proxy for what's in there. - # Or peek at 'results' snippet. - preview = "" - try: - # Attempt to peek first result title from JSON string - # Note: c.get('results') might be a stringified JSON list - res_list = json.loads(c.get('results', '[]')) - if res_list and isinstance(res_list, list) and len(res_list) > 0: - first_item = res_list[0] - if isinstance(first_item, dict) and 'title' in first_item: - preview = f" (Contains: {first_item.get('title', '')[:50]}...)" - except: - pass - candidates_desc.append(f"[{i}] Old Search Query: '{c['query']}' {preview} (Time: {c['timestamp']})") - elif c['type'] == 'local_news': - # List titles of local news - titles = [item['title'] for item in c['items'][:3]] - candidates_desc.append(f"[{i}] Local Database News: {', '.join(titles)}... (Time: {c['timestamp']})") - - prompt = f""" - Task: Decide if existing information is sufficient for the new search query. - - New Query: "{current_query}" - - Available Information Candidates: - {chr(10).join(candidates_desc)} - - Instructions: - 1. Analyze if any candidate provides ENOUGH up-to-date info for the "New Query". - 2. If yes, choose the best one. - 3. If the query implies needing LATEST real-time info and candidates are old, choose none. - 4. Return strictly JSON: {{"reuse": true/false, "index": , "reason": "short explanation"}} - """ - # 初始化模型 - provider = os.getenv("LLM_PROVIDER", "minimax") - model_id = os.getenv("LLM_MODEL", "Qwen") - host = os.getenv("LLM_HOST") - if host: - model = get_model(provider, model_id, host=host) - else: - model = get_model(provider, model_id) - - agent = Agent(model=model, markdown=True) - - response = agent.run(prompt) - content = response.content - - # Parse JSON - json_match = re.search(r'```json\s*(.*?)\s*```', content, re.DOTALL) - if json_match: - return json.loads(json_match.group(1)) - elif '{' in content: - # Fallback for cases where LLM doesn't wrap in ```json - return json.loads(content[content.find('{'):content.rfind('}')+1]) - return {"reuse": False} - - except Exception as e: - logger.warning(f"LLM evaluation failed: {e}") - return {"reuse": False} - - def aggregate_search(self, query: str, engines: Optional[List[str]] = None, max_results: int = 5) -> str: - """ - 使用多个搜索引擎同时搜索并聚合结果,获得更全面的信息覆盖。 - - Args: - query: 搜索关键词。 - engines: 要使用的搜索引擎列表。可选值: ["ddg", "baidu"]。 - 默认同时使用 ddg 和 baidu。 - max_results: 每个引擎期望返回的结果数量。 - - Returns: - 聚合后的搜索结果,按引擎分组显示。 - """ - engines = engines or ["ddg", "baidu"] - aggregated_results = [] - for engine in engines: - res = self.search(query, engine=engine, max_results=max_results) - aggregated_results.append(f"--- Results from {engine.upper()} ---\n{res}") - - return "\n\n".join(aggregated_results) diff --git a/skills/alphaear-signal-tracker/scripts/utils/sentiment_tools.py b/skills/alphaear-signal-tracker/scripts/utils/sentiment_tools.py deleted file mode 100644 index f4278b5..0000000 --- a/skills/alphaear-signal-tracker/scripts/utils/sentiment_tools.py +++ /dev/null @@ -1,287 +0,0 @@ -import os -from typing import Dict, List, Union, Optional -import json -from loguru import logger -from agno.agent import Agent -from .llm.factory import get_model -from .database_manager import DatabaseManager - -# 从环境变量读取默认情绪分析模式 -DEFAULT_SENTIMENT_MODE = os.getenv("SENTIMENT_MODE", "auto") # auto, bert, llm - - -class SentimentTools: - """ - 情绪分析工具 - 支持 LLM 和 BERT 两种模式 - - 模式说明: - - "auto": 自动选择,优先使用 BERT(速度快),不可用时回退到 LLM - - "bert": 强制使用 BERT 模型(需要 transformers 库) - - "llm": 强制使用 LLM(更准确但较慢) - - 可通过环境变量 SENTIMENT_MODE 设置默认模式。 - """ - - def __init__( - self, - db: DatabaseManager, - mode: Optional[str] = None, - model_provider: str = "openai", - model_id: str = "gpt-4o", - ): - """ - 初始化情绪分析工具。 - - Args: - db: 数据库管理器实例 - mode: 分析模式,可选 "auto", "bert", "llm"。None 则使用环境变量默认值。 - model_provider: LLM 提供商,如 "openai", "ust", "deepseek" - model_id: 模型标识符 - """ - self.db = db - self.mode = mode or DEFAULT_SENTIMENT_MODE - self.llm_model = None - self.bert_pipeline = None - - # Initialize LLM - try: - provider = "minimax" if os.getenv("MINIMAX_API_KEY") else model_provider - m_id = ( - os.getenv("LLM_MODEL", "MiniMax-Text-01") - if provider == "minimax" - else model_id - ) - self.llm_model = get_model(provider, m_id) - except Exception as e: - logger.warning(f"LLM initialization skipped: {e}") - - # Initialize BERT if needed - if self.mode in ["bert", "auto"]: - try: - from transformers import ( - pipeline, - AutoTokenizer, - AutoModelForSequenceClassification, - ) - from transformers.utils import logging as transformers_logging - - transformers_logging.set_verbosity_error() # 减少冗余日志 - - bert_model = os.getenv( - "BERT_SENTIMENT_MODEL", - "uer/roberta-base-finetuned-chinanews-chinese", - ) - - # 优先使用本地缓存 - try: - tokenizer = AutoTokenizer.from_pretrained( - bert_model, local_files_only=True - ) - model = AutoModelForSequenceClassification.from_pretrained( - bert_model, local_files_only=True - ) - - self.bert_pipeline = pipeline( - "sentiment-analysis", - model=model, - tokenizer=tokenizer, - device=-1, - ) - logger.info( - f"✅ BERT pipeline loaded from local cache: {bert_model}" - ) - except (OSError, ValueError, ImportError): - # 本地没有,则从网络下载 - logger.info(f"📡 Downloading BERT model: {bert_model}...") - tokenizer = AutoTokenizer.from_pretrained(bert_model) - model = AutoModelForSequenceClassification.from_pretrained( - bert_model - ) - - self.bert_pipeline = pipeline( - "sentiment-analysis", - model=model, - tokenizer=tokenizer, - device=-1, - ) - logger.info( - f"✅ BERT Sentiment pipeline ({bert_model}) initialized." - ) - except ImportError: - logger.warning( - "Transformers library not installed. BERT sentiment analysis disabled." - ) - except Exception as e: - if self.mode == "bert": - logger.error(f"BERT mode requested but failed: {e}") - else: - logger.warning(f"BERT unavailable, using LLM only. Error: {e}") - self.bert_pipeline = None - - def analyze_sentiment(self, text: str) -> Dict[str, Union[float, str]]: - """ - 分析文本的情绪极性。根据初始化时的 mode 自动选择分析方法。 - - Args: - text: 需要分析的文本内容,如新闻标题或摘要。 - - Returns: - 包含以下字段的字典: - - score: 情绪分值,范围 -1.0(极度负面)到 1.0(极度正面),0.0 为中性 - - label: 情绪标签,"positive"/"negative"/"neutral" - - reason: 分析理由(仅 LLM 模式提供详细理由) - """ - if self.mode == "bert" and self.bert_pipeline: - results = self.analyze_sentiment_bert([text]) - return results[0] if results else {"score": 0.0, "label": "error"} - elif self.mode == "llm" or (self.mode == "auto" and not self.bert_pipeline): - return self.analyze_sentiment_llm(text) - else: - # auto mode with BERT available - results = self.analyze_sentiment_bert([text]) - return results[0] if results else {"score": 0.0, "label": "error"} - - def analyze_sentiment_llm(self, text: str) -> Dict[str, Union[float, str]]: - """ - 使用 LLM 进行深度情绪分析,可获得详细的分析理由。 - - Args: - text: 需要分析的文本,最多处理前 1000 字符。 - - Returns: - 包含 score, label, reason 的字典。 - """ - if not self.llm_model: - return {"score": 0.0, "label": "neutral", "error": "LLM not initialized"} - - analyzer = Agent(model=self.llm_model, markdown=True) - prompt = f"""请分析以下金融/新闻文本的情绪极性。 - 返回严格的 JSON 格式: - {{"score": , "label": "", "reason": "<简短理由>"}} - - 文本: {text[:1000]}""" - - try: - response = analyzer.run(prompt) - content = response.content - if "```json" in content: - content = content.split("```json")[1].split("```")[0].strip() - elif "```" in content: - content = content.split("```")[1].split("```")[0].strip() - return json.loads(content) - except Exception as e: - logger.error(f"LLM sentiment failed: {e}") - return {"score": 0.0, "label": "error", "reason": str(e)} - - def analyze_sentiment_bert(self, texts: List[str]) -> List[Dict]: - """ - 使用 BERT 进行批量高速情绪分析。 - - Args: - texts: 需要分析的文本列表。 - - Returns: - 与输入列表等长的分析结果列表。 - """ - if not self.bert_pipeline: - return [ - {"score": 0.0, "label": "error", "reason": "BERT not available"} - ] * len(texts) - - try: - results = self.bert_pipeline(texts, truncation=True, max_length=512) - processed = [] - for r in results: - label = r["label"].lower() - score = r["score"] - - # 标准化不同模型的标签格式 - if "negative" in label or "neg" in label: - score = -score - elif "neutral" in label or "neu" in label: - score = 0.0 - - processed.append( - { - "score": float(round(score, 3)), - "label": "positive" - if score > 0.1 - else ("negative" if score < -0.1 else "neutral"), - "reason": "BERT automated analysis", - } - ) - return processed - except Exception as e: - logger.error(f"BERT analysis failed: {e}") - return [{"score": 0.0, "label": "error", "reason": str(e)}] * len(texts) - - def batch_update_news_sentiment( - self, - source: Optional[str] = None, - limit: int = 50, - use_bert: Optional[bool] = None, - ): - """ - 批量更新数据库中新闻的情绪分数。 - - Args: - source: 筛选特定新闻源,如 "wallstreetcn"。None 则处理所有来源。 - limit: 最多处理的新闻数量。 - use_bert: 是否使用 BERT。None 则根据初始化模式自动决定。 - - Returns: - 成功更新的新闻数量。 - """ - news_items = self.db.get_daily_news(source=source, limit=limit) - to_analyze = [item for item in news_items if not item.get("sentiment_score")] - - if not to_analyze: - return 0 - - # 决定使用哪种方法 - should_use_bert = ( - use_bert - if use_bert is not None - else (self.bert_pipeline is not None and self.mode != "llm") - ) - - updated_count = 0 - cursor = self.db.conn.cursor() - - if should_use_bert and self.bert_pipeline: - logger.info( - f"🚀 Using BERT for batch analysis of {len(to_analyze)} items..." - ) - titles = [item["title"] for item in to_analyze] - results = self.analyze_sentiment_bert(titles) - - for item, analysis in zip(to_analyze, results): - cursor.execute( - """ - UPDATE daily_news - SET sentiment_score = ?, meta_data = json_set(COALESCE(meta_data, '{}'), '$.sentiment_reason', ?) - WHERE id = ? - """, - (analysis["score"], analysis["reason"], item["id"]), - ) - updated_count += 1 - else: - logger.info(f"🚶 Using LLM for analysis of {len(to_analyze)} items...") - for item in to_analyze: - analysis = self.analyze_sentiment_llm(item["title"]) - cursor.execute( - """ - UPDATE daily_news - SET sentiment_score = ?, meta_data = json_set(COALESCE(meta_data, '{}'), '$.sentiment_reason', ?) - WHERE id = ? - """, - ( - analysis.get("score", 0.0), - analysis.get("reason", ""), - item["id"], - ), - ) - updated_count += 1 - - self.db.conn.commit() - return updated_count diff --git a/skills/alphaear-signal-tracker/scripts/utils/stock_tools.py b/skills/alphaear-signal-tracker/scripts/utils/stock_tools.py deleted file mode 100644 index 5929f74..0000000 --- a/skills/alphaear-signal-tracker/scripts/utils/stock_tools.py +++ /dev/null @@ -1,257 +0,0 @@ -from datetime import datetime, timedelta -from typing import List, Dict, Optional -import akshare as ak -import pandas as pd -import re -import sqlite3 -from requests.exceptions import RequestException -from loguru import logger -from .database_manager import DatabaseManager -import os -from contextlib import contextmanager - -@contextmanager -def temporary_no_proxy(): - """Context manager to temporarily unset proxy environment variables.""" - proxies = {k: os.environ.get(k) for k in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY']} - for k in proxies: - if k in os.environ: - del os.environ[k] - try: - yield - finally: - for k, v in proxies.items(): - if v is not None: - os.environ[k] = v - -class StockTools: - """金融分析股票工具 - 结合高性能数据库缓存与增量更新""" - - def __init__(self, db: DatabaseManager, auto_update: bool = True): - """ - 初始化股票工具 - - Args: - db: 数据库管理器 - auto_update: 是否在列表为空时自动更新,默认 True - """ - self.db = db - if auto_update: - self._check_and_update_stock_list() - - def _check_and_update_stock_list(self, force: bool = False): - """检查并更新股票列表。仅在列表为空或 force=True 时从网络拉取。""" - # 直接查询表中记录数 - cursor = self.db.conn.cursor() - cursor.execute("SELECT COUNT(*) FROM stock_list") - count = cursor.fetchone()[0] - - if count > 0 and not force: - logger.info(f"ℹ️ Stock list already cached ({count} stocks)") - return - - logger.info("📡 Updating A-share and HK-share stock list from akshare...") - - def fetch_data(): - # A-share - df_a = ak.stock_zh_a_spot_em() - df_a = df_a[['代码', '名称']].copy() - df_a.columns = ['code', 'name'] - - # HK-share - df_hk = ak.stock_hk_spot_em() - df_hk = df_hk[['代码', '名称']].copy() - df_hk.columns = ['code', 'name'] - - # Combine - return pd.concat([df_a, df_hk], ignore_index=True) - - try: - try: - df_combined = fetch_data() - except (RequestException, Exception) as e: - if "Proxy" in str(e) or "proxy" in str(e): - logger.warning(f"⚠️ Proxy error detected: {e}. Retrying with proxy disabled...") - with temporary_no_proxy(): - df_combined = fetch_data() - else: - raise e - - self.db.save_stock_list(df_combined) - logger.info(f"✅ Cached {len(df_combined)} stocks (A-share + HK) to database.") - - except Exception as e: - logger.error(f"❌ Failed to sync stock list: {e}") - - - def search_ticker(self, query: str, limit: int = 5) -> List[Dict]: - """ - 模糊搜索 A 股股票代码或名称,支持常见缩写。 - """ - # 清洗后缀 (如 CATL.SZ -> CATL, 000001.SZ -> 000001) - clean_query = re.sub(r'\.(SZ|SH|HK|US)$', '', query, flags=re.IGNORECASE) - - # 常见缩写映射 - aliases = { - "CATL": "宁德时代", - "BYD": "比亚迪", - "TSLA": "特斯拉", - "Moutai": "贵州茅台", - "Tencent": "腾讯", - "Alibaba": "阿里巴巴", - "Meituan": "美团", - } - - search_query = aliases.get(clean_query.upper(), clean_query) - - # Robustness: if regex-like ticker code is embedded in query (e.g. "300364 中文在线"), try to extract it - if not search_query.isdigit(): - # Extract explicit 5-6 digit codes - match = re.search(r'\b(\d{5,6})\b', clean_query) - if match: - search_query = match.group(1) - - return self.db.search_stock(search_query, limit) - - def get_stock_price( - self, - ticker: str, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - force_sync: bool = False, - ) -> pd.DataFrame: - """ - 获取指定股票的历史价格数据。优先从本地缓存读取,缺失时自动从网络补齐。 - - Args: - ticker: 股票代码,如 "600519"(贵州茅台)或 "000001"(平安银行)。 - start_date: 开始日期,格式 "YYYY-MM-DD"。默认为 90 天前。 - end_date: 结束日期,格式 "YYYY-MM-DD"。默认为今天。 - - Returns: - 包含 date, open, close, high, low, volume, change_pct 列的 DataFrame。 - """ - now = datetime.now() - if not end_date: - end_date = now.strftime('%Y-%m-%d') - if not start_date: - start_date = (now - timedelta(days=90)).strftime('%Y-%m-%d') - - df_db = self.db.get_stock_prices(ticker, start_date, end_date) - - need_update = False - if df_db.empty: - need_update = True - else: - db_latest = pd.to_datetime(df_db['date'].max()) - req_latest = pd.to_datetime(end_date) - if (req_latest - db_latest).days > 2: - need_update = True - - if force_sync: - need_update = True - - if need_update: - logger.info(f"📡 Data stale or missing for {ticker}, syncing from network...") - - # 清洗 ticker,确保只包含数字(Akshare A 股接口通常只需要数字代码) - clean_ticker = "".join(filter(str.isdigit, ticker)) - if not clean_ticker: - # Non A/H numeric tickers are not supported by the current data source. - logger.warning(f"⚠️ Unsupported ticker format (A/H only): {ticker}") - return df_db - - try: - s_fmt = start_date.replace("-", "") - e_fmt = end_date.replace("-", "") - - df_remote = None - - def fetch_data(): - if len(clean_ticker) == 5: - # HK Stock - return ak.stock_hk_hist( - symbol=clean_ticker, period="daily", - start_date=s_fmt, end_date=e_fmt, - adjust="qfq" - ) - else: - # A-share Stock - return ak.stock_zh_a_hist( - symbol=clean_ticker, period="daily", - start_date=s_fmt, end_date=e_fmt, - adjust="qfq" - ) - - try: - df_remote = fetch_data() - except (RequestException, Exception) as e: - if "Proxy" in str(e) or "proxy" in str(e): - logger.warning(f"⚠️ Proxy error detected: {e}. Retrying with proxy disabled...") - with temporary_no_proxy(): - df_remote = fetch_data() - else: - raise e - - if df_remote is not None and not df_remote.empty: - df_remote = df_remote.rename(columns={ - '日期': 'date', '开盘': 'open', '收盘': 'close', - '最高': 'high', '最低': 'low', '成交量': 'volume', - '涨跌幅': 'change_pct' - }) - # 确保日期格式正确 - df_remote['date'] = pd.to_datetime(df_remote['date']).dt.strftime('%Y-%m-%d') - - # 只有在获取到有意义的数据时才保存 - self.db.save_stock_prices(clean_ticker, df_remote) # 保存时使用清洗后的 clean_ticker - - # 重新查询数据库返回结果,保证一致性 - return self.db.get_stock_prices(clean_ticker, start_date, end_date) - else: - logger.warning(f"⚠️ Akshare returned empty data for {clean_ticker}") - - except KeyError as e: - # Akshare 有时在某些股票无数据时会抛出 KeyError - logger.warning(f"⚠️ Akshare data missing for {clean_ticker}: {e}") - except (RequestException, ConnectionError) as e: - logger.error(f"❌ Network error during Akshare sync for {clean_ticker}: {e}") - except sqlite3.Error as e: - logger.error(f"❌ Database error during Akshare sync for {clean_ticker}: {e}") - except Exception as e: - logger.error(f"❌ Unexpected error during Akshare sync for {clean_ticker}: {e}") - - return df_db - - -def get_stock_analysis(ticker: str, db: DatabaseManager) -> str: - """ - 生成指定股票的分析摘要报告。 - - Args: - ticker: 股票代码 - db: 数据库管理器实例 - - Returns: - Markdown 格式的分析报告,包含价格走势和关键指标。 - """ - tools = StockTools(db) - df = tools.get_stock_price(ticker) - - if df.empty: - return f"❌ 未能获取 {ticker} 的股价数据。" - - latest = df.iloc[-1] - change = ((latest['close'] - df.iloc[0]['close']) / df.iloc[0]['close']) * 100 - - report = [ - f"## 📊 {ticker} 分析报告", - f"- **查询时段**: {df.iloc[0]['date']} -> {latest['date']}", - f"- **当前价**: ¥{latest['close']:.2f}", - f"- **时段涨跌**: {change:+.2f}%", - f"- **最高/最低**: ¥{df['high'].max():.2f} / ¥{df['low'].min():.2f}", - "\n### 最近交易概览", - "```", - df.tail(5)[['date', 'close', 'change_pct', 'volume']].to_string(index=False), - "```" - ] - return "\n".join(report) diff --git a/skills/alphaear-signal-tracker/tests/test_tracker.py b/skills/alphaear-signal-tracker/tests/test_tracker.py deleted file mode 100644 index 7617ac4..0000000 --- a/skills/alphaear-signal-tracker/tests/test_tracker.py +++ /dev/null @@ -1,22 +0,0 @@ -import sys -import os -import unittest - -# Add skill root to path -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) - -try: - from scripts.fin_agent import FinAgent - from scripts.utils.database_manager import DatabaseManager -except ImportError as e: - print(f"Import Error: {e}") - sys.exit(1) - -class TestTracker(unittest.TestCase): - def test_init(self): - print("Testing FinAgent...") - # FinAgent Init might be complex. Checking import is a good start. - pass - -if __name__ == '__main__': - unittest.main() diff --git a/skills/alphaear-stock/SKILL.md b/skills/alphaear-stock/SKILL.md deleted file mode 100644 index bf2b582..0000000 --- a/skills/alphaear-stock/SKILL.md +++ /dev/null @@ -1,28 +0,0 @@ ---- -name: alphaear-stock -description: Search A-Share/HK/US finance stock tickers and retrieve finance stock price history. Use when user asks about finance stock codes, recent price changes, or specific company finance stock info. ---- - -# AlphaEar Stock Skill - -## Overview - -Search A-Share/HK/US stock tickers and retrieve historical price data (OHLCV). - -## Capabilities - -### 1. Stock Search & Data - -Use `scripts/stock_tools.py` via `StockTools`. - -- **Search**: `search_ticker(query)` - - Fuzzy search by code or name (e.g., "Moutai", "600519"). - - Returns: List of `{code, name}`. -- **Get Price**: `get_stock_price(ticker, start_date, end_date)` - - Returns DataFrame with OHLCV data. - - Dates format: "YYYY-MM-DD". - -## Dependencies - -- `pandas`, `requests`, `akshare`, `yfinance` -- `scripts/database_manager.py` (stock tables) diff --git a/skills/alphaear-stock/scripts/__init__.py b/skills/alphaear-stock/scripts/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/alphaear-stock/scripts/database_manager.py b/skills/alphaear-stock/scripts/database_manager.py deleted file mode 100644 index eb5d451..0000000 --- a/skills/alphaear-stock/scripts/database_manager.py +++ /dev/null @@ -1,119 +0,0 @@ -import sqlite3 -from pathlib import Path -from typing import List, Dict, Optional -import pandas as pd -from loguru import logger - -class DatabaseManager: - """ - AlphaEar Stock Database Manager - Reduced version for alphaear-stock skill - """ - - def __init__(self, db_path: str = "data/signal_flux.db"): - self.db_path = Path(db_path) - self.db_path.parent.mkdir(parents=True, exist_ok=True) - self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False) - self.conn.row_factory = sqlite3.Row - self._init_db() - logger.debug(f"💾 Stock Database initialized at {self.db_path}") - - def _init_db(self): - """Initialize stock-related tables""" - cursor = self.conn.cursor() - - # Stock Prices Table - cursor.execute(""" - CREATE TABLE IF NOT EXISTS stock_prices ( - ticker TEXT, - date TEXT, - open REAL, - close REAL, - high REAL, - low REAL, - volume REAL, - change_pct REAL, - PRIMARY KEY (ticker, date) - ) - """) - - # Stock List Table - cursor.execute(""" - CREATE TABLE IF NOT EXISTS stock_list ( - code TEXT PRIMARY KEY, - name TEXT - ) - """) - - cursor.execute("CREATE INDEX IF NOT EXISTS idx_stock_prices_ticker_date ON stock_prices(ticker, date)") - self.conn.commit() - - # --- Stock Operations --- - - def save_stock_list(self, df: pd.DataFrame): - cursor = self.conn.cursor() - try: - cursor.execute("DELETE FROM stock_list") - data = df[['code', 'name']].to_dict('records') - cursor.executemany( - "INSERT INTO stock_list (code, name) VALUES (:code, :name)", - data - ) - self.conn.commit() - except Exception as e: - logger.error(f"Error saving stock list: {e}") - - def search_stock(self, query: str, limit: int = 5) -> List[Dict]: - cursor = self.conn.cursor() - wild = f"%{query}%" - cursor.execute(""" - SELECT code, name FROM stock_list - WHERE code LIKE ? OR name LIKE ? - LIMIT ? - """, (wild, wild, limit)) - return [dict(row) for row in cursor.fetchall()] - - def get_stock_by_code(self, code: str) -> Optional[Dict[str, str]]: - if not code: return None - clean = "".join([c for c in str(code).strip() if c.isdigit()]) - if not clean: return None - - cursor = self.conn.cursor() - cursor.execute("SELECT code, name FROM stock_list WHERE code = ? LIMIT 1", (clean,)) - row = cursor.fetchone() - return dict(row) if row else None - - def save_stock_prices(self, ticker: str, df: pd.DataFrame): - if df.empty: return - cursor = self.conn.cursor() - try: - for _, row in df.iterrows(): - cursor.execute(""" - INSERT OR REPLACE INTO stock_prices - (ticker, date, open, close, high, low, volume, change_pct) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, ( - ticker, row['date'], row['open'], row['close'], - row['high'], row['low'], row['volume'], row['change_pct'] - )) - self.conn.commit() - except Exception as e: - logger.error(f"Error saving prices for {ticker}: {e}") - - def get_stock_prices(self, ticker: str, start_date: str, end_date: str) -> pd.DataFrame: - cursor = self.conn.cursor() - cursor.execute(""" - SELECT * FROM stock_prices - WHERE ticker = ? AND date >= ? AND date <= ? - ORDER BY date - """, (ticker, start_date, end_date)) - - rows = cursor.fetchall() - if not rows: return pd.DataFrame() - - columns = ['ticker', 'date', 'open', 'close', 'high', 'low', 'volume', 'change_pct'] - return pd.DataFrame([dict(row) for row in rows], columns=columns) - - def close(self): - if self.conn: - self.conn.close() diff --git a/skills/alphaear-stock/scripts/stock_tools.py b/skills/alphaear-stock/scripts/stock_tools.py deleted file mode 100644 index bcb8636..0000000 --- a/skills/alphaear-stock/scripts/stock_tools.py +++ /dev/null @@ -1,419 +0,0 @@ -from datetime import datetime, timedelta -from typing import List, Dict, Optional -import akshare as ak -import yfinance as yf -import pandas as pd -import re -import sqlite3 -import requests as _requests -from requests.exceptions import RequestException -from loguru import logger -from .database_manager import DatabaseManager -import os -from contextlib import contextmanager - -class EastMoneyDirect: - """东方财富 HTTP 直接调用 —— 作为 akshare 的零依赖降级方案。 - - 仅使用 requests,无需 API Key,国内网络直连。 - """ - - KLINE_URL = "https://push2his.eastmoney.com/api/qt/stock/kline/get" - LIST_URL = "https://push2.eastmoney.com/api/qt/clist/get" - UT = "fa5fd1943c7b386f172d6893dbfba10b" - - @staticmethod - def _secid(ticker: str) -> str: - """将纯数字 ticker 转为东方财富 secid 格式。 - - A股: 6开头 -> 1.{ticker}(上交所) | 其他 -> 0.{ticker}(深交所) - 港股: 5位数字 -> 116.{ticker} - """ - if len(ticker) == 5: - return f"116.{ticker}" - if ticker.startswith(('6', '9')): - return f"1.{ticker}" - return f"0.{ticker}" - - @classmethod - def fetch_kline(cls, ticker: str, start_date: str, end_date: str) -> pd.DataFrame: - """获取 K 线数据,返回与 akshare 对齐的 DataFrame。 - - Args: - ticker: 纯数字股票代码 - start_date: YYYYMMDD - end_date: YYYYMMDD - """ - params = { - 'secid': cls._secid(ticker), - 'fields1': 'f1,f2,f3,f4,f5,f6', - 'fields2': 'f51,f52,f53,f54,f55,f56,f57,f58,f59,f60,f61', - 'klt': '101', # 日K - 'fqt': '1', # 前复权 - 'beg': start_date, - 'end': end_date, - 'lmt': '1000', - 'ut': cls.UT, - } - resp = _requests.get(cls.KLINE_URL, params=params, timeout=10) - resp.raise_for_status() - data = resp.json().get('data') - if not data or not data.get('klines'): - return pd.DataFrame() - - # kline 格式: "日期,开盘,收盘,最高,最低,成交量,成交额,振幅,涨跌幅,涨跌额,换手率" - rows = [k.split(',') for k in data['klines']] - df = pd.DataFrame(rows, columns=[ - '日期', '开盘', '收盘', '最高', '最低', '成交量', - '成交额', '振幅', '涨跌幅', '涨跌额', '换手率' - ]) - # 转为数值类型 - for col in ['开盘', '收盘', '最高', '最低', '成交量', '涨跌幅']: - df[col] = pd.to_numeric(df[col], errors='coerce') - - return df - - @classmethod - def fetch_stock_list(cls, market: str = 'a') -> pd.DataFrame: - """获取股票列表。 - - Args: - market: 'a' for A股, 'hk' for 港股 - """ - if market == 'a': - fs = 'm:0+t:6,m:0+t:80,m:1+t:2,m:1+t:23' - else: - fs = 'm:128+t:3,m:128+t:4,m:128+t:1,m:128+t:2' - - all_items = [] - page = 1 - while True: - params = { - 'pn': str(page), 'pz': '5000', 'po': '1', 'np': '1', - 'fltt': '2', 'invt': '2', 'fid': 'f12', - 'fs': fs, 'fields': 'f12,f14', - 'ut': cls.UT, - } - resp = _requests.get(cls.LIST_URL, params=params, timeout=15) - resp.raise_for_status() - data = resp.json().get('data', {}) - diff = data.get('diff', []) - if not diff: - break - for item in diff: - all_items.append({'code': item.get('f12', ''), 'name': item.get('f14', '')}) - total = data.get('total', 0) - if page * 5000 >= total: - break - page += 1 - - return pd.DataFrame(all_items) - - -@contextmanager -def temporary_no_proxy(): - """Context manager to temporarily unset proxy environment variables.""" - proxies = {k: os.environ.get(k) for k in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY']} - for k in proxies: - if k in os.environ: - del os.environ[k] - try: - yield - finally: - for k, v in proxies.items(): - if v is not None: - os.environ[k] = v - -class StockTools: - """金融分析股票工具 - 结合高性能数据库缓存与增量更新""" - - def __init__(self, db: DatabaseManager, auto_update: bool = True): - """ - 初始化股票工具 - - Args: - db: 数据库管理器 - auto_update: 是否在列表为空时自动更新,默认 True - """ - self.db = db - if auto_update: - self._check_and_update_stock_list() - - def _check_and_update_stock_list(self, force: bool = False): - """检查并更新股票列表。仅在列表为空或 force=True 时从网络拉取。""" - # 直接查询表中记录数 - cursor = self.db.conn.cursor() - cursor.execute("SELECT COUNT(*) FROM stock_list") - count = cursor.fetchone()[0] - - if count > 0 and not force: - logger.info(f"ℹ️ Stock list already cached ({count} stocks)") - return - - logger.info("📡 Updating A-share and HK-share stock list...") - - df_combined = None - - # === 主路径: akshare === - try: - def fetch_data_ak(): - df_a = ak.stock_zh_a_spot_em() - df_a = df_a[['代码', '名称']].copy() - df_a.columns = ['code', 'name'] - - df_hk = ak.stock_hk_spot_em() - df_hk = df_hk[['代码', '名称']].copy() - df_hk.columns = ['code', 'name'] - - return pd.concat([df_a, df_hk], ignore_index=True) - - try: - df_combined = fetch_data_ak() - except (RequestException, Exception) as e: - if "Proxy" in str(e) or "proxy" in str(e): - logger.warning(f"⚠️ Proxy error detected: {e}. Retrying with proxy disabled...") - with temporary_no_proxy(): - df_combined = fetch_data_ak() - else: - raise e - logger.info(f"✅ akshare: fetched {len(df_combined)} stocks.") - except Exception as e: - logger.warning(f"⚠️ akshare stock list failed: {e}. Trying EastMoney direct...") - - # === 降级路径: 东方财富直接 HTTP === - if df_combined is None or df_combined.empty: - try: - df_a = EastMoneyDirect.fetch_stock_list('a') - df_hk = EastMoneyDirect.fetch_stock_list('hk') - df_combined = pd.concat([df_a, df_hk], ignore_index=True) - logger.info(f"✅ EastMoney direct: fetched {len(df_combined)} stocks.") - except Exception as e2: - logger.error(f"❌ All stock list sources failed. akshare + EastMoney: {e2}") - return - - if df_combined is not None and not df_combined.empty: - self.db.save_stock_list(df_combined) - logger.info(f"✅ Cached {len(df_combined)} stocks to database.") - - - def search_ticker(self, query: str, limit: int = 5) -> List[Dict]: - """ - 模糊搜索 A 股股票代码或名称,支持常见缩写。 - """ - # 清洗后缀 (如 CATL.SZ -> CATL, 000001.SZ -> 000001) - clean_query = re.sub(r'\.(SZ|SH|HK|US)$', '', query, flags=re.IGNORECASE) - - # 常见缩写映射 - aliases = { - "CATL": "宁德时代", - "BYD": "比亚迪", - "TSLA": "特斯拉", - "Moutai": "贵州茅台", - "Tencent": "腾讯", - "Alibaba": "阿里巴巴", - "Meituan": "美团", - } - - search_query = aliases.get(clean_query.upper(), clean_query) - - # Robustness: if regex-like ticker code is embedded in query (e.g. "300364 中文在线"), try to extract it - if not search_query.isdigit(): - # Extract explicit 5-6 digit codes - match = re.search(r'\b(\d{5,6})\b', clean_query) - if match: - search_query = match.group(1) - - res = self.db.search_stock(search_query, limit) - if not res and search_query.isalpha(): - # Robustness: mock search hit for alphabetic US tickers - return [{"code": search_query.upper(), "name": search_query.upper()}] - return res - - def get_stock_price( - self, - ticker: str, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - force_sync: bool = False, - ) -> pd.DataFrame: - """ - 获取指定股票的历史价格数据。优先从本地缓存读取,缺失时自动从网络补齐。 - - Args: - ticker: 股票代码,如 "600519"(贵州茅台)或 "000001"(平安银行)。 - start_date: 开始日期,格式 "YYYY-MM-DD"。默认为 90 天前。 - end_date: 结束日期,格式 "YYYY-MM-DD"。默认为今天。 - - Returns: - 包含 date, open, close, high, low, volume, change_pct 列的 DataFrame。 - """ - now = datetime.now() - if not end_date: - end_date = now.strftime('%Y-%m-%d') - if not start_date: - start_date = (now - timedelta(days=90)).strftime('%Y-%m-%d') - - df_db = self.db.get_stock_prices(ticker, start_date, end_date) - - need_update = False - if df_db.empty: - need_update = True - else: - db_latest = pd.to_datetime(df_db['date'].max()) - req_latest = pd.to_datetime(end_date) - if (req_latest - db_latest).days > 2: - need_update = True - - if force_sync: - need_update = True - - if need_update: - logger.info(f"📡 Data stale or missing for {ticker}, syncing from network...") - - is_us_stock = bool(re.search(r'[a-zA-Z]', ticker)) and not bool(re.search(r'\d{5,6}', ticker)) - - if is_us_stock: - clean_ticker = ticker.upper() - else: - # 清洗 ticker,确保只包含数字(Akshare A 股接口通常只需要数字代码) - clean_ticker = "".join(filter(str.isdigit, ticker)) - if not clean_ticker: - logger.warning(f"⚠️ Unsupported ticker format: {ticker}") - return df_db - - try: - s_fmt = start_date.replace("-", "") - e_fmt = end_date.replace("-", "") - - df_remote = None - - def fetch_data_akshare(): - """主路径: akshare""" - if is_us_stock: - return _fetch_data_yfinance() - if len(clean_ticker) == 5: - return ak.stock_hk_hist( - symbol=clean_ticker, period="daily", - start_date=s_fmt, end_date=e_fmt, - adjust="qfq" - ) - else: - return ak.stock_zh_a_hist( - symbol=clean_ticker, period="daily", - start_date=s_fmt, end_date=e_fmt, - adjust="qfq" - ) - - def _fetch_data_yfinance(): - """美股路径: yfinance""" - yf_ticker = yf.Ticker(clean_ticker) - end_dt = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=1) - df_us = yf_ticker.history(start=start_date, end=end_dt.strftime("%Y-%m-%d")) - if df_us.empty: - return pd.DataFrame() - - df_us = df_us.reset_index() - date_col = 'Date' if 'Date' in df_us.columns else df_us.columns[0] - df_us = df_us.rename(columns={ - 'Open': 'open', 'Close': 'close', - 'High': 'high', 'Low': 'low', 'Volume': 'volume' - }) - - if pd.api.types.is_datetime64_any_dtype(df_us[date_col]): - df_us['date'] = df_us[date_col].dt.strftime('%Y-%m-%d') - else: - df_us['date'] = pd.to_datetime(df_us[date_col]).dt.strftime('%Y-%m-%d') - - df_us['change_pct'] = df_us['close'].pct_change() * 100 - df_us['change_pct'] = df_us['change_pct'].fillna(0) - - return df_us[['date', 'open', 'close', 'high', 'low', 'volume', 'change_pct']] - - def fetch_data_eastmoney(): - """降级路径: 东方财富直接 HTTP""" - logger.info(f"📡 Trying EastMoney direct for {clean_ticker}...") - return EastMoneyDirect.fetch_kline(clean_ticker, s_fmt, e_fmt) - - # === 多源尝试: akshare → 东方财富直接 === - try: - try: - df_remote = fetch_data_akshare() - except (RequestException, Exception) as e: - if "Proxy" in str(e) or "proxy" in str(e): - logger.warning(f"⚠️ Proxy error detected: {e}. Retrying with proxy disabled...") - with temporary_no_proxy(): - df_remote = fetch_data_akshare() - else: - raise e - except Exception as e: - logger.warning(f"⚠️ akshare failed for {clean_ticker}: {e}") - if not is_us_stock: - try: - df_remote = fetch_data_eastmoney() - except Exception as e2: - logger.warning(f"⚠️ EastMoney direct also failed for {clean_ticker}: {e2}") - raise e # 抛出原始错误 - - if df_remote is not None and not df_remote.empty: - if not is_us_stock: - df_remote = df_remote.rename(columns={ - '日期': 'date', '开盘': 'open', '收盘': 'close', - '最高': 'high', '最低': 'low', '成交量': 'volume', - '涨跌幅': 'change_pct' - }) - # 确保日期格式正确 - df_remote['date'] = pd.to_datetime(df_remote['date']).dt.strftime('%Y-%m-%d') - - # 只有在获取到有意义的数据时才保存 - self.db.save_stock_prices(clean_ticker, df_remote) # 保存时使用清洗后的 clean_ticker - - # 重新查询数据库返回结果,保证一致性 - return self.db.get_stock_prices(clean_ticker, start_date, end_date) - else: - logger.warning(f"⚠️ Akshare returned empty data for {clean_ticker}") - - except KeyError as e: - # Akshare 有时在某些股票无数据时会抛出 KeyError - logger.warning(f"⚠️ Akshare data missing for {clean_ticker}: {e}") - except (RequestException, ConnectionError) as e: - logger.error(f"❌ Network error during Akshare sync for {clean_ticker}: {e}") - except sqlite3.Error as e: - logger.error(f"❌ Database error during Akshare sync for {clean_ticker}: {e}") - except Exception as e: - logger.error(f"❌ Unexpected error during Akshare sync for {clean_ticker}: {e}") - - return df_db - - -def get_stock_analysis(ticker: str, db: DatabaseManager) -> str: - """ - 生成指定股票的分析摘要报告。 - - Args: - ticker: 股票代码 - db: 数据库管理器实例 - - Returns: - Markdown 格式的分析报告,包含价格走势和关键指标。 - """ - tools = StockTools(db) - df = tools.get_stock_price(ticker) - - if df.empty: - return f"❌ 未能获取 {ticker} 的股价数据。" - - latest = df.iloc[-1] - change = ((latest['close'] - df.iloc[0]['close']) / df.iloc[0]['close']) * 100 - - report = [ - f"## 📊 {ticker} 分析报告", - f"- **查询时段**: {df.iloc[0]['date']} -> {latest['date']}", - f"- **当前价**: ¥{latest['close']:.2f}", - f"- **时段涨跌**: {change:+.2f}%", - f"- **最高/最低**: ¥{df['high'].max():.2f} / ¥{df['low'].min():.2f}", - "\n### 最近交易概览", - "```", - df.tail(5)[['date', 'close', 'change_pct', 'volume']].to_string(index=False), - "```" - ] - return "\n".join(report) diff --git a/skills/alphaear-stock/tests/test_stock.py b/skills/alphaear-stock/tests/test_stock.py deleted file mode 100644 index 3f548df..0000000 --- a/skills/alphaear-stock/tests/test_stock.py +++ /dev/null @@ -1,24 +0,0 @@ -import sys -import os -import unittest - -# Add skill root to path -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) - -try: - from scripts.stock_tools import StockTools - from scripts.database_manager import DatabaseManager -except ImportError as e: - print(f"Import Error: {e}") - sys.exit(1) - -class TestStock(unittest.TestCase): - def test_init(self): - print("Testing StockTools Iteration...") - db = DatabaseManager(":memory:") - tools = StockTools(db) - self.assertIsNotNone(tools) - print("StockTools Initialized.") - -if __name__ == '__main__': - unittest.main() diff --git a/skills/ecommerce-astro/SKILL.md b/skills/ecommerce-astro/SKILL.md new file mode 100644 index 0000000..ae4730d --- /dev/null +++ b/skills/ecommerce-astro/SKILL.md @@ -0,0 +1,620 @@ +--- +name: ecommerce-astro +description: | + Full-featured e-commerce site builder with Astro 6, React, Supabase backend. + Creates online stores with optional multi-vendor marketplace, Thai language support, + inventory tracking, and order management. Use when: building e-commerce sites, + marketplaces, online stores, or Thai e-commerce stores. +--- + +# E-commerce Astro - E-commerce Site Builder + +**Category:** `fullstack` +**Tech Stack:** Astro 6 + React + Supabase + Tailwind v4 + +--- + +## 🎯 Purpose + +Create complete e-commerce websites with these core features: + +- ✅ **Product catalog** - Browse, filter, search products from Supabase +- ✅ **Inventory management** - Stock tracking with low-stock alerts +- ✅ **Order management** - Cart, checkout, order tracking, status updates +- ✅ **Thai language support** - Bilingual Thai/English with i18n routing +- ✅ **Review system** - Verified purchase reviews with ratings +- ✅ **Responsive design** - Mobile-first with React components + +### Optional Features (Enable/Disable) + +| Feature | Default | Description | +|---------|---------|-------------| +| `multi_vendor` | `false` | Multi-vendor marketplace with vendor dashboards | +| `payso_payment` | `false` | PaySo Thai payment gateway (stub for now) | +| `vendor_payouts` | `false` | Automated payout tracking (requires multi_vendor) | + +--- + +## 🚀 Quick Start + +```bash +# Generate e-commerce site (interactive mode) +python3 skills/ecommerce-astro/scripts/create_ecommerce.py \ + --name "My Store" \ + --output "./my-store" + +# With options +python3 skills/ecommerce-astro/scripts/create_ecommerce.py \ + --name "My Store" \ + --output "./my-store" \ + --multi-vendor true \ + --languages "th" +``` + +--- + +## 📋 Pre-Flight Questions + +Before running the script, gather these details: + +1. **Store Name:** (e.g., "Deal Plus Tech Store") +2. **Store Slug:** (e.g., "deal-plus-tech-store") +3. **Supabase Project URL:** From supabase.com dashboard +4. **Supabase Anon Key:** Public key for client-side +5. **Supabase Service Role Key:** For admin/server-side operations +6. **Multi-Vendor Mode:** Enable/disable vendor system (true/false) +7. **Languages:** Thai only (th), English only (en), or bilingual (th,en) + +--- + +## 📁 Generated Project Structure (Base) + +``` +store-name/ +├── astro.config.mjs +├── package.json +├── Dockerfile +├── docker-compose.yml +├── .env.example +├── .gitignore +│ +├── supabase/ +│ └── migrations/ +│ └── 001_initial_schema.sql +│ +├── src/ +│ ├── components/ +│ │ ├── cart/ +│ │ │ ├── CartBadge.tsx # Floating cart button +│ │ │ ├── CartButton.tsx # Header cart icon +│ │ │ ├── CartDrawer.tsx # Slide-out cart panel +│ │ │ ├── CartItems.tsx # Cart item list +│ │ │ └── CartSummary.tsx # Price breakdown +│ │ ├── checkout/ +│ │ │ └── CheckoutForm.tsx # Checkout form +│ │ ├── product/ +│ │ │ ├── ProductCard.astro # Product grid card +│ │ │ ├── ProductFilters.tsx # Category/price filters +│ │ │ ├── ProductGallery.tsx # Image gallery +│ │ │ ├── ProductVariants.tsx # Size/color variants +│ │ │ └── StockBadge.tsx # Inventory status +│ │ ├── review/ +│ │ │ ├── ReviewList.tsx # Product reviews +│ │ │ └── StarRating.tsx # Star rating display +│ │ └── layout/ +│ │ ├── Header.astro # Site header +│ │ └── Footer.astro # Site footer +│ │ +│ ├── layouts/ +│ │ └── Layout.astro # Base layout +│ │ +│ ├── lib/ +│ │ ├── supabase.ts # Supabase client (SSR-safe) +│ │ ├── auth.ts # JWT auth helpers +│ │ ├── utils.ts # Utility functions +│ │ └── types.ts # TypeScript types +│ │ +│ ├── stores/ +│ │ ├── cart.ts # Zustand cart (SSR-safe) +│ │ ├── auth.ts # Auth state +│ │ └── vendor.ts # Vendor state (if multi_vendor) +│ │ +│ ├── pages/ +│ │ ├── index.astro # Homepage +│ │ ├── products/ +│ │ │ ├── index.astro # Product listing +│ │ │ └── [slug].astro # Product detail +│ │ ├── cart.astro # Full cart page +│ │ ├── checkout.astro # Checkout page +│ │ ├── search.astro # Search page +│ │ ├── auth/ +│ │ │ ├── login.astro # Login page +│ │ │ └── register.astro # Register page +│ │ ├── account/ +│ │ │ ├── index.astro # Account dashboard +│ │ │ └── orders/ +│ │ │ ├── index.astro # Order history +│ │ │ └── [id].astro # Order detail +│ │ ├── vendor/ # Only if multi_vendor=true +│ │ │ ├── dashboard.astro # Vendor dashboard +│ │ │ ├── products/ +│ │ │ ├── orders.astro +│ │ │ └── settings.astro +│ │ ├── admin/ # Only if multi_vendor=true +│ │ │ ├── dashboard.astro +│ │ │ ├── vendors.astro +│ │ │ ├── users.astro +│ │ │ ├── orders.astro +│ │ │ └── categories.astro +│ │ └── api/ +│ │ ├── auth/ +│ │ ├── products/ +│ │ ├── orders/ +│ │ └── payments/ +│ │ +│ ├── i18n/ +│ │ ├── th.json # Thai translations +│ │ └── en.json # English translations +│ │ +│ └── styles/ +│ └── global.css # Global styles +│ +└── public/ + └── images/ +``` + +--- + +## 🗄️ Database Schema (Supabase PostgreSQL) + +The migration creates these tables. Schema can be customized per project. + +### Core Tables (Always Included) + +| Table | Purpose | +|-------|---------| +| `users` | Customer/admin accounts (id, email, password_hash, name, role, avatar_url) | +| `categories` | Product categories (hierarchical with parent_id) | +| `products` | Product catalog (id, vendor_id, category_id, name, slug, description, price, images JSONB, inventory, status, track_inventory, featured) | +| `reviews` | Product reviews (product_id, user_id, rating, comment, status) | +| `orders` | Customer orders (id, order_number, user_id, status, payment_status, total, shipping_address JSONB) | +| `order_items` | Line items per order (order_id, product_id, quantity, unit_price) | + +### Multi-Vendor Tables (Only if multi_vendor=true) + +| Table | Purpose | +|-------|---------| +| `vendor_profiles` | Store info (user_id, store_name, store_slug, store_description, status) | +| `product_variants` | Size/color variants (product_id, name, sku, price, inventory) | + +### Indexes + +```sql +-- Core indexes +CREATE INDEX idx_products_category ON products(category_id); +CREATE INDEX idx_products_vendor ON products(vendor_id); +CREATE INDEX idx_products_slug ON products(slug); +CREATE INDEX idx_products_status ON products(status); +CREATE INDEX idx_orders_user ON orders(user_id); +CREATE INDEX idx_orders_status ON orders(status); +CREATE INDEX idx_reviews_product ON reviews(product_id); +``` + +### Row Level Security (RLS) + +```sql +-- Products: Public read, vendor write own +-- Orders: User read own, vendor read own orders +-- Vendors: Admin manages vendor_profiles +``` + +### Key Lessons Learned + +1. **Images stored as JSONB** - Parse with `typeof images === 'string' ? JSON.parse(images) : images` +2. **Use service_role key for SSR** - Anonymous key blocked by RLS during server-side rendering +3. **Format:** `Authorization: Bearer {key}` header for Supabase REST API + +--- + +## 💳 Payment Integration + +Payment is stubbed by default. To enable real payments: + +1. Add PaySo credentials to `.env` +2. Create `/api/payments/create.ts` endpoint +3. Create `/api/webhooks/payso.ts` handler +4. Update `checkout.astro` to call payment API + +### Stub Implementation (Default) + +```typescript +// lib/payso.ts (stub) +export async function createPayment(order: Order) { + // TODO: Implement PaySo integration + console.log('Payment stub for order:', order.id); + return { success: true, paymentUrl: '/checkout/success' }; +} +``` + +--- + +## 🔐 Authentication + +### User Roles + +| Role | Permissions | +|------|-------------| +| `customer` | Browse, cart, checkout, orders | +| `vendor` | Products, orders (requires multi_vendor=true) | +| `admin` | All management (requires multi_vendor=true) | + +### Auth Flow + +1. Register with email/password +2. Login → JWT token stored in httpOnly cookie +3. Protected routes check session cookie +4. Role-based access for vendor/admin pages + +### SSR Authentication + +```typescript +// Check auth in Astro pages +const token = Astro.cookies.get('session')?.value; +if (!token) return Astro.redirect('/login'); +``` + +--- + +## 🛒 Cart & Checkout + +### Cart Features + +- **Floating Cart Button** - Fixed position bottom-right, blue circular button +- **Cart Drawer** - Slide-out panel with HeadlessUI +- **Persistent** - Zustand with localStorage (SSR-safe with getStorage) +- **Guest cart** - localStorage only +- **Logged-in cart** - Synced to database (optional) + +### Cart Store (SSR-Safe) + +```typescript +// stores/cart.ts +export const useCartStore = create()( + persist( + (set, get) => ({ ... }), + { + name: 'cart-storage', + partialize: (state) => ({ items: state.items }), + getStorage: () => { + if (typeof window === 'undefined') { + return { getItem: () => null, setItem: () => {}, removeItem: () => {} }; + } + return localStorage; + }, + } + ) +); +``` + +### Checkout Flow + +1. Cart Review → 2. Shipping Info → 3. Payment → 4. Confirmation + +--- + +## 📦 Vendor Dashboard (Only if multi_vendor=true) + +When `multi_vendor=true`, these pages are generated: + +### Vendor Features + +- **Dashboard** - Stats (products, orders, sales) +- **Products** - Add/edit/archive products +- **Orders** - View and manage orders +- **Settings** - Store profile + +### Vendor Onboarding Flow + +1. Register as customer +2. Apply for vendor status (`/vendors/apply`) +3. Admin approves → Vendor profile created +4. Access `/vendor/dashboard` + +### Hiding Vendor Pages + +When `multi_vendor=false`: +- No vendor registration link +- No `/vendor/*` routes +- No admin vendor management pages +- Products belong to "store" (no vendor_id) + +--- + +## 🌐 Internationalization + +### Thai/English Support + +- **URL Structure:** `/th/products`, `/en/products` (if bilingual) +- **Fallback:** Missing translation → English +- **Default:** Thai-only if `--languages th` + +### Translation Keys + +```json +// i18n/th.json +{ + "common": { + "addToCart": "เพิ่มลงตะกร้า", + "checkout": "ชำระเงิน", + "login": "เข้าสู่ระบบ" + }, + "product": { + "outOfStock": "สินค้าหมด", + "inStock": "มีสินค้า" + } +} +``` + +--- + +## 🔧 Environment Variables + +```bash +# Supabase (Required) +SUPABASE_URL=https://xxx.supabase.co +SUPABASE_ANON_KEY=eyJxxx # Public key (client-side) +SUPABASE_SERVICE_ROLE_KEY=eyJxxx # Admin key (server-side only!) + +# JWT (Required for auth) +JWT_SECRET=your-super-secret-jwt-key-min-32-chars + +# Site +SITE_URL=https://yourdomain.com +SITE_NAME=My Store + +# PaySo (Optional - stub by default) +PAYSOLO_MERCHANT_ID=your-merchant-id +PAYSOLO_API_KEY=your-api-key +PAYSOLO_SECRET_KEY=your-secret-key +PAYSOLO_CALLBACK_URL=https://yourdomain.com/api/webhooks/payso +``` + +--- + +## 🐳 Docker Deployment + +### Dockerfile (Astro SSR Mode) + +```dockerfile +FROM node:20-alpine +WORKDIR /app + +# Build-time env vars (needed for npm run build) +ENV PUBLIC_SUPABASE_URL=https://xxx.supabase.co +ENV PUBLIC_SUPABASE_ANON_KEY=eyJxxx +ENV SUPABASE_SERVICE_ROLE_KEY=eyJxxx +ENV SITE_URL=https://yourdomain.com +ENV JWT_SECRET=your-32-char-min-secret-key + +COPY package*.json ./ +RUN npm install +COPY . . +RUN npm run build + +EXPOSE 4321 +ENV HOST=0.0.0.0 +ENV PORT=4321 + +CMD ["npm", "run", "start"] +``` + +### docker-compose.yml + +```yaml +services: + web: + build: . + ports: + - "4321:4321" + env_file: + - .env +``` + +### Key Points + +1. **Build-time env vars** - Set `ENV` before `npm run build` so Astro can access them +2. **Service role key** - Only in Dockerfile ENV, not in client code +3. **Port 4321** - Astro default, map to 80 or your preferred port + +--- + +## 🚀 Deployment to Easypanel + +```bash +# 1. Generate site locally +python3 skills/ecommerce-astro/scripts/create_ecommerce.py \ + --name "my-store" \ + --output "./my-store" + +# 2. Push to Gitea +cd my-store +git init +git add . +git commit -m "Initial e-commerce site" +git remote add origin https://git.moreminimore.com/user/my-store.git +git push -u origin main + +# 3. Deploy to Easypanel +# Use easypanel-deploy skill or dashboard +``` + +--- + +## ✅ Success Criteria + +- [ ] Astro dev server runs without errors +- [ ] Supabase tables created successfully +- [ ] Products display from Supabase (images as JSONB) +- [ ] Cart adds/removes items (SSR-safe Zustand) +- [ ] Checkout creates order +- [ ] User registration/login works +- [ ] (If multi_vendor=true) Vendor dashboard accessible +- [ ] (If bilingual) Language switching works +- [ ] Docker build succeeds +- [ ] Deploys to Easypanel + +--- + +## 📚 Dependencies + +```json +{ + "astro": "^6.1.4", + "@astrojs/react": "^4.2.0", + "@astrojs/node": "^9.1.0", + "@astrojs/sitemap": "^3.2.0", + "@supabase/supabase-js": "^2.47.0", + "@supabase/ssr": "^0.6.1", + "@tailwindcss/vite": "^4.2.1", + "tailwindcss": "^4.2.1", + "zustand": "^5.0.0", + "react": "^19.0.0", + "react-dom": "^19.0.0", + "jose": "^6.0.0", + "@headlessui/react": "^2.0.0", + "lucide-react": "^0.400.0" +} +``` + +**Note:** Astro 6 is required for CSRF protection and other security features. + +--- + +## 🔗 Related Skills + +- **thai-frontend-dev** - Base Astro setup with PDPA compliance, cookie consent +- **easypanel-deploy** - Deploy to Easypanel +- **gitea-sync** - Sync code to Gitea + +--- + +## 📝 Example Usage + +### Single Vendor Store (Thai Only) + +```bash +python3 skills/ecommerce-astro/scripts/create_ecommerce.py \ + --name "ร้านค้าออนไลน์" \ + --slug "online-store" \ + --output "./thai-store" +# multi_vendor=false by default +``` + +### Multi-Vendor Marketplace (Bilingual) + +```bash +python3 skills/ecommerce-astro/scripts/create_ecommerce.py \ + --name "ThaiMart" \ + --slug "thaimart" \ + --multi-vendor true \ + --languages "th,en" \ + --output "./thaimart" +``` + +--- + +**Note:** After generation: +1. Run the Supabase migration in your dashboard +2. Update `.env` with your Supabase credentials +3. Add sample products to test + +--- + +## 🔧 Troubleshooting + +### SSR Error: `Cannot read properties of undefined (reading 'value')` + +**Cause:** Zustand persist middleware tries to access localStorage during SSR. + +**Fix:** Add `getStorage` to handle server-side: + +```typescript +getStorage: () => { + if (typeof window === 'undefined') { + return { getItem: () => null, setItem: () => {}, removeItem: () => {} }; + } + return localStorage; +} +``` + +### RLS Policy Blocks Read + +**Cause:** Anon key doesn't bypass RLS during SSR. + +**Fix:** Use service role key for server-side fetches: + +```typescript +headers: { + 'apikey': import.meta.env.SUPABASE_SERVICE_ROLE_KEY, + 'Authorization': `Bearer ${import.meta.env.SUPABASE_SERVICE_ROLE_KEY}` +} +``` + +### Images Show as `[` or Empty + +**Cause:** Images stored as JSONB string, not array. + +**Fix:** Parse before use: + +```typescript +const images = typeof product.images === 'string' + ? JSON.parse(product.images || '[]') + : (product.images || []); +``` + +### URLSearchParams Error + +**Cause:** Spread operator with undefined in template literal. + +**Fix:** Use string concatenation instead: + +```typescript +// Bad +href={`/products?${new URLSearchParams({...category && {category}})}`} + +// Good +href={`/products?sort=${sort}`} +``` + +### Cross-site POST form submissions are forbidden + +**Cause:** Astro 6 has built-in CSRF protection that blocks native form POST from different origins. + +**Fix:** Use client-side fetch instead of native form submission: + +```typescript +// In your .astro page +
+ + + +
+ + +``` diff --git a/skills/ecommerce-astro/scripts/.env.example b/skills/ecommerce-astro/scripts/.env.example new file mode 100644 index 0000000..e72ec3f --- /dev/null +++ b/skills/ecommerce-astro/scripts/.env.example @@ -0,0 +1,17 @@ +# Supabase Configuration +SUPABASE_URL=https://your-project.supabase.co +SUPABASE_ANON_KEY=your-anon-key +SUPABASE_SERVICE_ROLE_KEY=your-service-role-key + +# PaySo Payment Gateway (Thai Payment) +PAYSOLO_MERCHANT_ID=your-merchant-id +PAYSOLO_API_KEY=your-api-key +PAYSOLO_SECRET_KEY=your-secret-key +PAYSOLO_CALLBACK_URL=https://yourdomain.com/api/webhooks/payso + +# JWT Authentication +JWT_SECRET=your-super-secret-jwt-key-min-32-chars-here + +# Site Configuration +SITE_URL=https://yourdomain.com +SITE_NAME=My Store diff --git a/skills/ecommerce-astro/scripts/create_ecommerce.py b/skills/ecommerce-astro/scripts/create_ecommerce.py new file mode 100755 index 0000000..ee273ba --- /dev/null +++ b/skills/ecommerce-astro/scripts/create_ecommerce.py @@ -0,0 +1,2034 @@ +#!/usr/bin/env python3 +import os +import sys +import argparse +import shutil +from pathlib import Path + + +SCRIPT_DIR = Path(__file__).parent + + +def pkg_json(name): + return ( + '''{ + "name": "''' + + name + + """", + "type": "module", + "version": "1.0.0", + "scripts": { + "dev": "astro dev", + "build": "astro build", + "preview": "astro preview", + "astro": "astro" + }, + "dependencies": { + "astro": "^5.17.1", + "@astrojs/react": "^4.2.0", + "@astrojs/node": "^9.1.0", + "@astrojs/sitemap": "^3.2.0", + "@supabase/supabase-js": "^2.47.0", + "@supabase/ssr": "^0.6.1", + "@tailwindcss/vite": "^4.2.1", + "tailwindcss": "^4.2.1", + "zustand": "^5.0.0", + "react": "^19.0.0", + "react-dom": "^19.0.0", + "jose": "^6.0.0" + }, + "devDependencies": { + "@types/react": "^19.0.0", + "@types/react-dom": "^19.0.0", + "typescript": "^5.7.0" + } +}""" + ) + + +ASTRO_CONFIG = """import {{ defineConfig }} from 'astro/config'; +import react from '@astrojs/react'; +import node from '@astrojs/node'; +import sitemap from '@astrojs/sitemap'; +import tailwindcss from '@tailwindcss/vite'; + +export default defineConfig({{ + site: '{site_url}', + output: 'hybrid', + adapter: node({{ mode: 'standalone' }}), + i18n: {{ + locales: [{locales}], + defaultLocale: '{default_locale}', + routing: {{ + prefixDefaultLocale: false, + fallbackType: 'rewrite', + }}, + }}, + integrations: [ + react(), + sitemap({{ + i18n: {{ defaultLocale: '{default_locale}' }}, + }}), + ], + vite: {{ + plugins: [tailwindcss()], + }}, +}}); +""" + +TSCONFIG = """{ + "extends": "astro/tsconfigs/strict", + "compilerOptions": { + "jsx": "react-jsx", + "jsxImportSource": "react", + "baseUrl": ".", + "paths": { + "@/*": ["src/*"] + } + } +}""" + +SUPABASE_TS = """import {{ createClient }} from '@supabase/supabase-js'; +import type {{ Database }} from './types'; + +const supabaseUrl = import.meta.env.SUPABASE_URL; +const supabaseAnonKey = import.meta.env.SUPABASE_ANON_KEY; + +export const supabase = createClient(supabaseUrl, supabaseAnonKey); + +export type {{ Database }}; +""" + +AUTH_TS = r"""import { supabase } from './supabase'; +import { SignJWT, jwtVerify } from 'jose'; +import type { User, VendorProfile } from './types'; + +const JWT_SECRET = new TextEncoder().encode( + import.meta.env.JWT_SECRET || 'default-secret-change-me' +); + +export interface AuthUser { + id: string; + email: string; + name: string | null; + role: 'customer' | 'vendor' | 'admin'; + vendor_id?: string; +} + +export async function createSessionToken(user: AuthUser): Promise { + return new SignJWT({ + sub: user.id, + email: user.email, + name: user.name, + role: user.role, + vendor_id: user.vendor_id, + }) + .setProtectedHeader({ alg: 'HS256' }) + .setIssuedAt() + .setExpirationTime('7d') + .sign(JWT_SECRET); +} + +export async function verifySessionToken(token: string): Promise { + try { + const { payload } = await jwtVerify(token, JWT_SECRET); + return { + id: payload.sub as string, + email: payload.email as string, + name: payload.name as string | null, + role: payload.role as 'customer' | 'vendor' | 'admin', + vendor_id: payload.vendor_id as string | undefined, + }; + } catch { + return null; + } +} + +export async function registerUser( + email: string, + password: string, + name: string +): Promise<{ user: User | null; error: string | null }> { + const { data, error } = await supabase.auth.signUp({ + email, + password, + options: { + data: { name }, + }, + }); + + if (error) return { user: null, error: error.message }; + + const user = data.user; + if (!user) return { user: null, error: 'Registration failed' }; + + return { + user: { + id: user.id, + email: user.email || email, + name, + role: 'customer', + }, + error: null, + }; +} + +export async function loginUser( + email: string, + password: string +): Promise<{ user: AuthUser | null; token: string | null; error: string | null }> { + const { data, error } = await supabase.auth.signInWithPassword({ + email, + password, + }); + + if (error) return { user: null, token: null, error: error.message }; + + const user = data.user; + if (!user) return { user: null, token: null, error: 'Login failed' }; + + const { data: profileData } = await supabase + .from('vendor_profiles') + .select('id') + .eq('user_id', user.id) + .single(); + + const authUser: AuthUser = { + id: user.id, + email: user.email || email, + name: user.user_metadata?.name || null, + role: profileData ? 'vendor' : 'customer', + vendor_id: profileData?.id, + }; + + const token = await createSessionToken(authUser); + + return { user: authUser, token, error: null }; +} + +export async function logoutUser(): Promise { + await supabase.auth.signOut(); +} +""" + +PAYSOS_TS = r"""export interface PaySoPayment { + merchant_id: string; + order_id: string; + amount: number; + currency: string; + description: string; + callback_url: string; + return_url: string; + customer_name?: string; + customer_email?: string; + customer_phone?: string; +} + +export interface PaySoResponse { + code: string; + message: string; + data: { + payment_url: string; + qr_code?: string; + qr_image?: string; + transaction_id: string; + }; +} + +export interface PaySoWebhookPayload { + transaction_id: string; + order_id: string; + amount: number; + status: 'pending' | 'success' | 'failed'; + timestamp: string; + signature: string; +} + +export async function createPaySoPayment(payment: PaySoPayment): Promise { + const response = await fetch('https://api.paysogateway.com/v1/payment', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${import.meta.env.PAYSOLO_API_KEY}`, + }, + body: JSON.stringify({ + merchant_id: import.meta.env.PAYSOLO_MERCHANT_ID, + order_id: payment.order_id, + amount: payment.amount, + currency: payment.currency || 'THB', + description: payment.description, + callback_url: payment.callback_url, + return_url: payment.return_url, + customer: { + name: payment.customer_name, + email: payment.customer_email, + phone: payment.customer_phone, + }, + }), + }); + + if (!response.ok) { + throw new Error(`PaySo API error: ${response.status}`); + } + + return response.json(); +} + +export async function verifyPaySoSignature( + payload: PaySoWebhookPayload, + signature: string +): Promise { + const crypto = await import('crypto'); + const secret = import.meta.env.PAYSOLO_SECRET_KEY; + const data = JSON.stringify({ + transaction_id: payload.transaction_id, + order_id: payload.order_id, + amount: payload.amount, + status: payload.status, + }); + const expectedSignature = crypto + .createHmac('sha256', secret) + .update(data) + .digest('hex'); + return signature === expectedSignature; +} +""" + +TYPES_TS = r"""export interface User { + id: string; + email: string; + name: string | null; + phone: string | null; + role: 'customer' | 'vendor' | 'admin'; + avatar_url: string | null; + created_at: string; + updated_at: string; +} + +export interface VendorProfile { + id: string; + user_id: string; + store_name: string; + store_slug: string; + store_description: string | null; + store_logo: string | null; + bank_account: string | null; + bank_name: string | null; + payout_status: 'pending' | 'approved' | 'rejected'; + total_earnings: number; + approved_at: string | null; + created_at: string; +} + +export interface Category { + id: string; + name: string; + slug: string; + description: string | null; + image_url: string | null; + parent_id: string | null; + sort_order: number; +} + +export interface Product { + id: string; + vendor_id: string; + category_id: string | null; + name: string; + slug: string; + description: string | null; + price: number; + compare_at_price: number | null; + cost_price: number | null; + sku: string | null; + barcode: string | null; + inventory: number; + low_stock_threshold: number; + track_inventory: boolean; + allow_backorder: boolean; + weight: number | null; + images: string[]; + metadata: Record; + status: 'draft' | 'active' | 'archived'; + featured: boolean; + created_at: string; + updated_at: string; +} + +export interface ProductVariant { + id: string; + product_id: string; + name: string; + sku: string | null; + price: number | null; + inventory: number; + attributes: Record; + image_url: string | null; +} + +export interface Review { + id: string; + product_id: string; + user_id: string; + order_id: string | null; + rating: number; + title: string | null; + comment: string | null; + images: string[]; + verified_purchase: boolean; + status: 'pending' | 'approved' | 'rejected'; + created_at: string; +} + +export interface Order { + id: string; + order_number: string; + user_id: string; + vendor_id: string | null; + status: 'pending' | 'confirmed' | 'processing' | 'shipped' | 'delivered' | 'cancelled' | 'refunded'; + payment_status: 'unpaid' | 'paid' | 'failed' | 'refunded'; + subtotal: number; + tax: number; + shipping_cost: number; + total: number; + currency: string; + payment_method: string | null; + payment_provider: string | null; + payment_ref: string | null; + shipping_name: string | null; + shipping_phone: string | null; + shipping_address: string | null; + shipping_city: string | null; + shipping_postal: string | null; + shipping_country: string; + notes: string | null; + created_at: string; + updated_at: string; +} + +export interface OrderItem { + id: string; + order_id: string; + product_id: string; + variant_id: string | null; + vendor_id: string | null; + quantity: number; + unit_price: number; + total_price: number; +} + +export interface CartItem { + id: string; + product: Product; + variant: ProductVariant | null; + quantity: number; +} +""" + +UTILS_TS = r"""export function formatPrice(amount: number, currency = 'THB'): string { + return new Intl.NumberFormat('th-TH', { + style: 'currency', + currency, + }).format(amount); +} + +export function generateSlug(text: string): string { + return text + .toLowerCase() + .replace(/[^\w\s-]/g, '') + .replace(/[\s_-]+/g, '-') + .replace(/^-+|-+$/g, ''); +} + +export function generateOrderNumber(): string { + const date = new Date(); + const dateStr = date.toISOString().slice(0, 10).replace(/-/g, ''); + const random = Math.random().toString(36).substring(2, 10).toUpperCase(); + return `ORD-${dateStr}-${random}`; +} + +export function cn(...classes: (string | undefined | null | false)[]): string { + return classes.filter(Boolean).join(' '); +} + +export function debounce any>( + fn: T, + delay: number +): (...args: Parameters) => void { + let timeoutId: ReturnType; + return (...args: Parameters) => { + clearTimeout(timeoutId); + timeoutId = setTimeout(() => fn(...args), delay); + }; +} +""" + +CART_STORE = r"""import { create } from 'zustand'; +import { persist } from 'zustand/middleware'; +import type { CartItem, Product, ProductVariant } from '../lib/types'; + +interface CartStore { + items: CartItem[]; + addItem: (product: Product, variant?: ProductVariant, quantity?: number) => void; + removeItem: (productId: string, variantId?: string) => void; + updateQuantity: (productId: string, variantId: string | undefined, quantity: number) => void; + clearCart: () => void; + getTotal: () => number; + getItemCount: () => number; +} + +export const useCartStore = create()( + persist( + (set, get) => ({ + items: [], + addItem: (product, variant, quantity = 1) => { + set((state) => { + const existingIndex = state.items.findIndex( + (item) => + item.product.id === product.id && + item.variant?.id === variant?.id + ); + + if (existingIndex >= 0) { + const newItems = [...state.items]; + newItems[existingIndex].quantity += quantity; + return { items: newItems }; + } + + return { + items: [ + ...state.items, + { id: crypto.randomUUID(), product, variant: variant || null, quantity }, + ], + }; + }); + }, + removeItem: (productId, variantId) => { + set((state) => ({ + items: state.items.filter( + (item) => + !(item.product.id === productId && item.variant?.id === variantId) + ), + })); + }, + updateQuantity: (productId, variantId, quantity) => { + if (quantity <= 0) { + get().removeItem(productId, variantId); + return; + } + set((state) => ({ + items: state.items.map((item) => + item.product.id === productId && item.variant?.id === variantId + ? { ...item, quantity } + : item + ), + })); + }, + clearCart: () => set({ items: [] }), + getTotal: () => { + return get().items.reduce( + (total, item) => total + (item.variant?.price || item.product.price) * item.quantity, + 0 + ); + }, + getItemCount: () => { + return get().items.reduce((count, item) => count + item.quantity, 0); + }, + }), + { + name: 'ecommerce-cart', + } + ) +); +""" + +AUTH_STORE = r"""import { create } from 'zustand'; +import { persist } from 'zustand/middleware'; +import type { AuthUser } from '../lib/auth'; + +interface AuthStore { + user: AuthUser | null; + token: string | null; + isLoading: boolean; + setAuth: (user: AuthUser, token: string) => void; + logout: () => void; + setLoading: (loading: boolean) => void; +} + +export const useAuthStore = create()( + persist( + (set) => ({ + user: null, + token: null, + isLoading: false, + setAuth: (user, token) => set({ user, token }), + logout: () => set({ user: null, token: null }), + setLoading: (isLoading) => set({ isLoading }), + }), + { + name: 'ecommerce-auth', + } + ) +); +""" + +VENDOR_STORE = r"""import { create } from 'zustand'; +import { supabase } from '../lib/supabase'; +import type { Product, Order, VendorProfile } from '../lib/types'; + +interface VendorStore { + profile: VendorProfile | null; + products: Product[]; + orders: Order[]; + isLoading: boolean; + fetchProfile: (userId: string) => Promise; + fetchProducts: (vendorId: string) => Promise; + fetchOrders: (vendorId: string) => Promise; + createProduct: (product: Partial) => Promise; + updateProduct: (id: string, updates: Partial) => Promise; + updateOrderStatus: (orderId: string, status: string) => Promise; +} + +export const useVendorStore = create((set, get) => ({ + profile: null, + products: [], + orders: [], + isLoading: false, + fetchProfile: async (userId) => { + set({ isLoading: true }); + const { data } = await supabase + .from('vendor_profiles') + .select('*') + .eq('user_id', userId) + .single(); + set({ profile: data, isLoading: false }); + }, + fetchProducts: async (vendorId) => { + set({ isLoading: true }); + const { data } = await supabase + .from('products') + .select('*') + .eq('vendor_id', vendorId) + .order('created_at', { ascending: false }); + set({ products: data || [], isLoading: false }); + }, + fetchOrders: async (vendorId) => { + set({ isLoading: true }); + const { data } = await supabase + .from('orders') + .select('*') + .eq('vendor_id', vendorId) + .order('created_at', { ascending: false }); + set({ orders: data || [], isLoading: false }); + }, + createProduct: async (product) => { + const { data, error } = await supabase + .from('products') + .insert(product) + .select() + .single(); + if (error) return null; + set((state) => ({ products: [data, ...state.products] })); + return data; + }, + updateProduct: async (id, updates) => { + await supabase.from('products').update(updates).eq('id', id); + set((state) => ({ + products: state.products.map((p) => + p.id === id ? { ...p, ...updates } : p + ), + })); + }, + updateOrderStatus: async (orderId, status) => { + await supabase.from('orders').update({ status }).eq('id', orderId); + set((state) => ({ + orders: state.orders.map((o) => + o.id === orderId ? { ...o, status: status as Order['status'] } : o + ), + })); + }, +})); +""" + +BASE_LAYOUT = """--- +interface Props {{ + title: string; + description?: string; +}} + +const {{ title, description = '{site_name}' }} = Astro.props; +--- + + + + + + + + + {{title}} | {site_name} + + + + + + + +""" + +TH_JSON = """{{ + "common": {{ + "home": "หน้าแรก", + "products": "สินค้า", + "cart": "ตะกร้า", + "checkout": "ชำระเงิน", + "login": "เข้าสู่ระบบ", + "register": "ลงทะเบียน", + "logout": "ออกจากระบบ" + }}, + "product": {{ + "addToCart": "เพิ่มลงตะกร้า", + "outOfStock": "สินค้าหมด", + "inStock": "มีสินค้า" + }}, + "cart": {{ + "empty": "ตะกร้าว่าง", + "total": "รวม" + }} +}}""" + +EN_JSON = """{{ + "common": {{ + "home": "Home", + "products": "Products", + "cart": "Cart", + "checkout": "Checkout", + "login": "Login", + "register": "Register", + "logout": "Logout" + }}, + "product": {{ + "addToCart": "Add to Cart", + "outOfStock": "Out of Stock", + "inStock": "In Stock" + }}, + "cart": {{ + "empty": "Your cart is empty", + "total": "Total" + }} +}}""" + +GLOBAL_CSS = """@import "tailwindcss"; + +@theme { + --font-sans: "Inter", system-ui, sans-serif; + --color-primary: #2563eb; + --color-secondary: #64748b; +}""" + +CART_BUTTON = r"""import { useCartStore } from '../../stores/cart'; + +export function CartButton() { + const getItemCount = useCartStore((state) => state.getItemCount); + const count = getItemCount(); + + return ( + + + + + {count > 0 && ( + + {count > 9 ? '9+' : count} + + )} + + ); +}""" + +CART_DRAWER = r"""import { useState } from 'react'; +import { useCartStore } from '../../stores/cart'; +import { formatPrice } from '../../lib/utils'; +import { CartItem } from './CartItem'; + +export function CartDrawer() { + const [isOpen, setIsOpen] = useState(false); + const items = useCartStore((state) => state.items); + const getTotal = useCartStore((state) => state.getTotal); + + return ( + <> + + + {isOpen && ( +
+
setIsOpen(false)} /> +
+
+
+

ตะกร้าสินค้า

+ +
+ +
+ {items.length === 0 ? ( +

ตะกร้าว่างเปล่า

+ ) : ( +
+ {items.map((item) => ( + + ))} +
+ )} +
+ + {items.length > 0 && ( +
+
+ รวม + {formatPrice(getTotal())} +
+ + ชำระเงิน + +
+ )} +
+
+
+ )} + + ); +}""" + +CART_ITEM_COMPONENT = r"""import { useCartStore } from '../../stores/cart'; +import { formatPrice } from '../../lib/utils'; +import type { CartItem as CartItemType } from '../../lib/types'; + +interface Props { + item: CartItemType; +} + +export function CartItem({ item }: Props) { + const updateQuantity = useCartStore((state) => state.updateQuantity); + const removeItem = useCartStore((state) => state.removeItem); + const price = item.variant?.price || item.product.price; + + return ( +
+ {item.product.name} +
+

{item.product.name}

+ {item.variant && ( +

{item.variant.name}

+ )} +

+ {formatPrice(price)} +

+
+
+ +
+ + {item.quantity} + +
+
+
+ ); +}""" + +PRODUCT_CARD = r"""import { Link } from '@astrojs/react/components'; +import { formatPrice } from '../../lib/utils'; +import type { Product } from '../../lib/types'; + +interface Props { + product: Product; +} + +export function ProductCard({ product }: Props) { + const isOutOfStock = product.inventory <= 0 && !product.allow_backorder; + const isLowStock = product.inventory > 0 && product.inventory <= product.low_stock_threshold; + + return ( + +
+
+ {product.name} + {isOutOfStock && ( +
+ + สินค้าหมด + +
+ )} + {product.featured && !isOutOfStock && ( + + แนะนำ + + )} +
+
+

+ {product.name} +

+
+ + {formatPrice(product.price)} + + {product.compare_at_price && ( + + {formatPrice(product.compare_at_price)} + + )} +
+ {isLowStock && ( +

สินค้าใกล้หมด ({product.inventory} ชิ้น)

+ )} +
+
+ + ); +}""" + +CHECKOUT_FORM = r"""import { useState } from 'react'; +import { useCartStore } from '../../stores/cart'; +import { useAuthStore } from '../../stores/auth'; +import { formatPrice, generateOrderNumber } from '../../lib/utils'; + +export function CheckoutForm() { + const items = useCartStore((state) => state.items); + const getTotal = useCartStore((state) => state.getTotal); + const clearCart = useCartStore((state) => state.clearCart); + const user = useAuthStore((state) => state.user); + + const [formData, setFormData] = useState({ + name: user?.name || '', + phone: '', + address: '', + city: '', + postal: '', + paymentMethod: 'qr', + }); + const [isSubmitting, setIsSubmitting] = useState(false); + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + setIsSubmitting(true); + + try { + const response = await fetch('/api/checkout/create-order', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + ...formData, + items: items.map((item) => ({ + product_id: item.product.id, + variant_id: item.variant?.id, + quantity: item.quantity, + unit_price: item.variant?.price || item.product.price, + })), + order_number: generateOrderNumber(), + }), + }); + + const data = await response.json(); + + if (data.payment_url) { + window.location.href = data.payment_url; + } else { + clearCart(); + window.location.href = `/orders/${data.order_id}`; + } + } catch (error) { + console.error('Checkout error:', error); + } finally { + setIsSubmitting(false); + } + }; + + return ( +
+
+

ข้อมูลจัดส่ง

+
+
+ + setFormData({ ...formData, name: e.target.value })} + className="w-full px-4 py-2 border rounded-lg focus:ring-2 focus:ring-primary focus:border-primary" + /> +
+
+ + setFormData({ ...formData, phone: e.target.value })} + className="w-full px-4 py-2 border rounded-lg focus:ring-2 focus:ring-primary focus:border-primary" + /> +
+
+ + +``` + +### Custom Plugin + +```javascript +// tailwind.config.js +const plugin = require('tailwindcss/plugin') + +export default { + plugins: [ + plugin(function({ addUtilities, addComponents, theme }) { + // Add utilities + addUtilities({ + '.text-shadow': { + textShadow: '2px 2px 4px rgba(0, 0, 0, 0.1)', + }, + '.text-shadow-lg': { + textShadow: '4px 4px 8px rgba(0, 0, 0, 0.2)', + }, + }) + + // Add components + addComponents({ + '.card-custom': { + backgroundColor: theme('colors.white'), + borderRadius: theme('borderRadius.lg'), + padding: theme('spacing.6'), + boxShadow: theme('boxShadow.md'), + }, + }) + }), + ], +} +``` + +## Configuration Examples + +### Complete Tailwind Config + +```javascript +// tailwind.config.ts +import type { Config } from 'tailwindcss' + +const config: Config = { + darkMode: ["class"], + content: [ + './pages/**/*.{ts,tsx}', + './components/**/*.{ts,tsx}', + './app/**/*.{ts,tsx}', + ], + theme: { + container: { + center: true, + padding: "2rem", + screens: { + "2xl": "1400px", + }, + }, + extend: { + colors: { + border: "hsl(var(--border))", + input: "hsl(var(--input))", + ring: "hsl(var(--ring))", + background: "hsl(var(--background))", + foreground: "hsl(var(--foreground))", + primary: { + DEFAULT: "hsl(var(--primary))", + foreground: "hsl(var(--primary-foreground))", + }, + brand: { + 50: '#f0f9ff', + 500: '#3b82f6', + 900: '#1e3a8a', + }, + }, + fontFamily: { + sans: ['Inter', 'system-ui', 'sans-serif'], + display: ['Playfair Display', 'serif'], + }, + spacing: { + '18': '4.5rem', + '88': '22rem', + '128': '32rem', + }, + borderRadius: { + lg: "var(--radius)", + md: "calc(var(--radius) - 2px)", + sm: "calc(var(--radius) - 4px)", + }, + keyframes: { + "slide-in": { + "0%": { transform: "translateX(-100%)" }, + "100%": { transform: "translateX(0)" }, + }, + }, + animation: { + "slide-in": "slide-in 0.5s ease-out", + }, + }, + }, + plugins: [require("tailwindcss-animate")], +} + +export default config +``` + +## Dark Mode Configuration + +```javascript +// tailwind.config.js +export default { + darkMode: ["class"], // or "media" for automatic + // ... +} +``` + +**Usage:** +```html + + +
+ Responds to .dark class +
+ + + +
+ Responds to system preference automatically +
+``` + +## Content Configuration + +Specify files to scan for classes: + +```javascript +// tailwind.config.js +export default { + content: [ + "./src/**/*.{js,jsx,ts,tsx}", + "./app/**/*.{js,jsx,ts,tsx}", + "./components/**/*.{js,jsx,ts,tsx}", + "./pages/**/*.{js,jsx,ts,tsx}", + ], + // ... +} +``` + +### Safelist + +Preserve dynamic classes: + +```javascript +export default { + safelist: [ + 'bg-red-500', + 'bg-green-500', + 'bg-blue-500', + { + pattern: /bg-(red|green|blue)-(100|500|900)/, + }, + ], +} +``` + +## Best Practices + +1. **Use @theme for simple customizations**: Prefer CSS-based customization +2. **Extract components sparingly**: Use @apply only for truly repeated patterns +3. **Leverage design tokens**: Define custom tokens in @theme +4. **Layer organization**: Keep base, components, and utilities separate +5. **Plugin for complex logic**: Use plugins for advanced customizations +6. **Test dark mode**: Ensure custom colors work in both themes +7. **Document custom utilities**: Add comments explaining custom classes +8. **Semantic naming**: Use descriptive names (primary not blue) diff --git a/skills/website-creator/ui-styling/references/tailwind-responsive.md b/skills/website-creator/ui-styling/references/tailwind-responsive.md new file mode 100644 index 0000000..f252e18 --- /dev/null +++ b/skills/website-creator/ui-styling/references/tailwind-responsive.md @@ -0,0 +1,382 @@ +# Tailwind CSS Responsive Design + +Mobile-first breakpoints, responsive utilities, and adaptive layouts. + +## Mobile-First Approach + +Tailwind uses mobile-first responsive design. Base styles apply to all screen sizes, then use breakpoint prefixes to override at larger sizes. + +```html + +
+
Item 1
+
Item 2
+
Item 3
+
Item 4
+
+``` + +## Breakpoint System + +**Default breakpoints:** + +| Prefix | Min Width | CSS Media Query | +|--------|-----------|-----------------| +| `sm:` | 640px | `@media (min-width: 640px)` | +| `md:` | 768px | `@media (min-width: 768px)` | +| `lg:` | 1024px | `@media (min-width: 1024px)` | +| `xl:` | 1280px | `@media (min-width: 1280px)` | +| `2xl:` | 1536px | `@media (min-width: 1536px)` | + +## Responsive Patterns + +### Layout Changes + +```html + +
+
Left
+
Right
+
+ + +
+
Item 1
+
Item 2
+
Item 3
+
+``` + +### Visibility + +```html + + + + +
+ Mobile only content +
+ + +
Mobile menu
+ +``` + +### Typography + +```html + +

+ Heading scales with screen size +

+ +

+ Body text scales appropriately +

+``` + +### Spacing + +```html + +
+ More padding on larger screens +
+ + +
+
Item 1
+
Item 2
+
+``` + +### Width + +```html + +
+ Responsive width +
+ + +
+ Centered with responsive max width +
+``` + +## Common Responsive Layouts + +### Sidebar Layout + +```html +
+ + + + +
+ Main content +
+
+``` + +### Card Grid + +```html +
+
Card 1
+
Card 2
+
Card 3
+
Card 4
+
+``` + +### Hero Section + +```html +
+
+
+
+

+ Hero Title +

+

+ Hero description +

+ +
+
+ +
+
+
+
+``` + +### Navigation + +```html + +``` + +## Max-Width Queries + +Apply styles only below certain breakpoint using `max-*:` prefix: + +```html + +
+ Centered on mobile/tablet, left-aligned on desktop +
+ + +
+ Hidden only on mobile +
+``` + +Available: `max-sm:` `max-md:` `max-lg:` `max-xl:` `max-2xl:` + +## Range Queries + +Apply styles between breakpoints: + +```html + +
+ Visible only on tablets +
+ + +
+ 2 columns on tablet, 4 on extra large +
+``` + +## Container Queries + +Style elements based on parent container width: + +```html +
+
+ Responds to parent width, not viewport +
+
+``` + +Container query breakpoints: `@sm:` `@md:` `@lg:` `@xl:` `@2xl:` + +## Custom Breakpoints + +Define custom breakpoints in theme: + +```css +@theme { + --breakpoint-3xl: 120rem; /* 1920px */ + --breakpoint-tablet: 48rem; /* 768px */ +} +``` + +```html +
+ Uses custom breakpoints +
+``` + +## Responsive State Variants + +Combine responsive with hover/focus: + +```html + + + + + + Link + +``` + +## Best Practices + +### 1. Mobile-First Design + +Start with mobile styles, add complexity at larger breakpoints: + +```html + +
+ + +
+``` + +### 2. Consistent Breakpoint Usage + +Use same breakpoints across related elements: + +```html +
+ Spacing scales with layout +
+``` + +### 3. Test at Breakpoint Boundaries + +Test at exact breakpoint widths (640px, 768px, 1024px, etc.) to catch edge cases. + +### 4. Use Container for Content Width + +```html +
+
+ Content with consistent max width +
+
+``` + +### 5. Progressive Enhancement + +Ensure core functionality works on mobile, enhance for larger screens: + +```html + +
+ +
+ Content +
+
+``` + +### 6. Avoid Too Many Breakpoints + +Use 2-3 breakpoints per element for maintainability: + +```html + +
+ + +
+``` + +## Common Responsive Utilities + +### Responsive Display + +```html +
+ Changes display type per breakpoint +
+``` + +### Responsive Position + +```html +
+ Positioned differently per breakpoint +
+``` + +### Responsive Order + +```html +
+
First on desktop
+
First on mobile
+
+``` + +### Responsive Overflow + +```html +
+ Scrollable on mobile, expanded on desktop +
+``` + +## Testing Checklist + +- [ ] Test at 320px (small mobile) +- [ ] Test at 640px (mobile breakpoint) +- [ ] Test at 768px (tablet breakpoint) +- [ ] Test at 1024px (desktop breakpoint) +- [ ] Test at 1280px (large desktop breakpoint) +- [ ] Test landscape orientation +- [ ] Verify touch targets (min 44x44px) +- [ ] Check text readability at all sizes +- [ ] Verify navigation works on mobile +- [ ] Test with browser zoom diff --git a/skills/website-creator/ui-styling/references/tailwind-utilities.md b/skills/website-creator/ui-styling/references/tailwind-utilities.md new file mode 100644 index 0000000..7b7b123 --- /dev/null +++ b/skills/website-creator/ui-styling/references/tailwind-utilities.md @@ -0,0 +1,455 @@ +# Tailwind CSS Utility Reference + +Core utility classes for layout, spacing, typography, colors, borders, and shadows. + +## Layout Utilities + +### Display + +```html +
Block
+
Inline Block
+
Inline
+
Flexbox
+
Inline Flex
+
Grid
+
Inline Grid
+ +``` + +### Flexbox + +**Container:** +```html +
Row (default)
+
Column
+
Reverse row
+
Reverse column
+``` + +**Justify (main axis):** +```html +
Start
+
Center
+
End
+
Space between
+
Space around
+
Space evenly
+``` + +**Align (cross axis):** +```html +
Start
+
Center
+
End
+
Baseline
+
Stretch
+``` + +**Gap:** +```html +
All sides
+
X and Y
+``` + +**Wrap:** +```html +
Wrap
+
No wrap
+``` + +### Grid + +**Columns:** +```html +
1 column
+
2 columns
+
3 columns
+
4 columns
+
12 columns
+
Custom
+``` + +**Rows:** +```html +
3 rows
+
Custom
+``` + +**Span:** +```html +
Span 2 columns
+
Span 3 rows
+``` + +**Gap:** +```html +
All sides
+
X and Y
+``` + +### Positioning + +```html +
Static (default)
+
Relative
+
Absolute
+
Fixed
+
Sticky
+ + +
Top right
+
All sides 0
+
Left/right 4
+
Top/bottom 8
+``` + +### Z-Index + +```html +
z-index: 0
+
z-index: 10
+
z-index: 20
+
z-index: 50
+``` + +## Spacing Utilities + +### Padding + +```html +
All sides
+
Left and right
+
Top and bottom
+
Top
+
Right
+
Bottom
+
Left
+``` + +### Margin + +```html +
All sides
+
Center horizontally
+
Top and bottom
+
Top
+
Negative top
+
Push to right
+``` + +### Space Between + +```html +
Horizontal spacing
+
Vertical spacing
+``` + +### Spacing Scale + +- `0`: 0px +- `px`: 1px +- `0.5`: 0.125rem (2px) +- `1`: 0.25rem (4px) +- `2`: 0.5rem (8px) +- `3`: 0.75rem (12px) +- `4`: 1rem (16px) +- `6`: 1.5rem (24px) +- `8`: 2rem (32px) +- `12`: 3rem (48px) +- `16`: 4rem (64px) +- `24`: 6rem (96px) + +## Typography + +### Font Size + +```html +

Extra small (12px)

+

Small (14px)

+

Base (16px)

+

Large (18px)

+

XL (20px)

+

2XL (24px)

+

3XL (30px)

+

4XL (36px)

+

5XL (48px)

+``` + +### Font Weight + +```html +

Thin (100)

+

Light (300)

+

Normal (400)

+

Medium (500)

+

Semibold (600)

+

Bold (700)

+

Black (900)

+``` + +### Text Alignment + +```html +

Left

+

Center

+

Right

+

Justify

+``` + +### Line Height + +```html +

1

+

1.25

+

1.5

+

1.75

+

2

+``` + +### Combined Font Utilities + +```html +

+ Font size 4xl with tight line height +

+``` + +### Text Transform + +```html +

UPPERCASE

+

lowercase

+

Capitalize

+

Normal

+``` + +### Text Decoration + +```html +

Underline

+

Line through

+

No underline

+``` + +### Text Overflow + +```html +

Truncate with ellipsis...

+

Clamp to 3 lines...

+

Ellipsis

+``` + +## Colors + +### Text Colors + +```html +

Black

+

White

+

Gray 500

+

Red 600

+

Blue 500

+

Green 600

+``` + +### Background Colors + +```html +
White
+
Gray 100
+
Blue 500
+
Red 600
+``` + +### Color Scale + +Each color has 11 shades (50-950): +- `50`: Lightest +- `100-400`: Light variations +- `500`: Base color +- `600-800`: Dark variations +- `950`: Darkest + +### Opacity Modifiers + +```html +
75% opacity
+
30% opacity
+
87% opacity
+``` + +### Gradients + +```html +
+ Left to right gradient +
+
+ With via color +
+``` + +Directions: `to-t | to-tr | to-r | to-br | to-b | to-bl | to-l | to-tl` + +## Borders + +### Border Width + +```html +
1px all sides
+
2px all sides
+
Top only
+
Right 4px
+
Bottom 2px
+
Left only
+
No border
+``` + +### Border Color + +```html +
Gray
+
Blue
+
Red with opacity
+``` + +### Border Radius + +```html +
0.25rem
+
0.375rem
+
0.5rem
+
0.75rem
+
1rem
+
9999px
+ + +
Top corners
+
Bottom right
+``` + +### Border Style + +```html +
Solid
+
Dashed
+
Dotted
+``` + +## Shadows + +```html +
Small
+
Default
+
Medium
+
Large
+
Extra large
+
2XL
+
No shadow
+``` + +### Colored Shadows + +```html +
Blue shadow
+``` + +## Width & Height + +### Width + +```html +
100%
+
50%
+
33.333%
+
16rem
+
500px
+
100vw
+ + +
min-width: 0
+
max-width: 28rem
+
max-width: 1280px
+``` + +### Height + +```html +
100%
+
100vh
+
16rem
+
500px
+ + +
min-height: 100vh
+
max-height: 24rem
+``` + +## Arbitrary Values + +Use square brackets for custom values: + +```html + +
Custom padding
+
Custom position
+ + +
Hex color
+
RGB
+ + +
Custom width
+
Custom font size
+ + +
CSS var
+ + +
Custom grid
+``` + +## Aspect Ratio + +```html +
1:1
+
16:9
+
4:3
+``` + +## Overflow + +```html +
Auto scroll
+
Hidden
+
Always scroll
+
Horizontal scroll
+
No vertical scroll
+``` + +## Opacity + +```html +
0%
+
50%
+
75%
+
100%
+``` + +## Cursor + +```html +
Pointer
+
Wait
+
Not allowed
+
Default
+``` + +## User Select + +```html +
No select
+
Text selectable
+
Select all
+``` diff --git a/skills/website-creator/ui-styling/scripts/.coverage b/skills/website-creator/ui-styling/scripts/.coverage new file mode 100644 index 0000000..4382142 Binary files /dev/null and b/skills/website-creator/ui-styling/scripts/.coverage differ diff --git a/skills/website-creator/ui-styling/scripts/requirements.txt b/skills/website-creator/ui-styling/scripts/requirements.txt new file mode 100644 index 0000000..75f72ca --- /dev/null +++ b/skills/website-creator/ui-styling/scripts/requirements.txt @@ -0,0 +1,17 @@ +# UI Styling Skill Dependencies +# Python 3.10+ required + +# No Python package dependencies - uses only standard library + +# Testing dependencies (dev) +pytest>=8.0.0 +pytest-cov>=4.1.0 +pytest-mock>=3.12.0 + +# Note: This skill works with shadcn/ui and Tailwind CSS +# Requires Node.js and package managers: +# - Node.js 18+: https://nodejs.org/ +# - npm (comes with Node.js) +# +# shadcn/ui CLI is installed per-project: +# npx shadcn-ui@latest init diff --git a/skills/website-creator/ui-styling/scripts/shadcn_add.py b/skills/website-creator/ui-styling/scripts/shadcn_add.py new file mode 100644 index 0000000..e2a9799 --- /dev/null +++ b/skills/website-creator/ui-styling/scripts/shadcn_add.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python3 +""" +shadcn/ui Component Installer + +Add shadcn/ui components to project with automatic dependency handling. +Wraps shadcn CLI for programmatic component installation. +""" + +import argparse +import json +import subprocess +import sys +from pathlib import Path +from typing import List, Optional + + +class ShadcnInstaller: + """Handle shadcn/ui component installation.""" + + def __init__(self, project_root: Optional[Path] = None, dry_run: bool = False): + """ + Initialize installer. + + Args: + project_root: Project root directory (default: current directory) + dry_run: If True, show actions without executing + """ + self.project_root = project_root or Path.cwd() + self.dry_run = dry_run + self.components_json = self.project_root / "components.json" + + def check_shadcn_config(self) -> bool: + """ + Check if shadcn is initialized in project. + + Returns: + True if components.json exists + """ + return self.components_json.exists() + + def get_installed_components(self) -> List[str]: + """ + Get list of already installed components. + + Returns: + List of installed component names + """ + if not self.check_shadcn_config(): + return [] + + try: + with open(self.components_json) as f: + config = json.load(f) + + components_dir = self.project_root / config.get("aliases", {}).get( + "components", "components" + ).replace("@/", "") + ui_dir = components_dir / "ui" + + if not ui_dir.exists(): + return [] + + return [f.stem for f in ui_dir.glob("*.tsx") if f.is_file()] + except (json.JSONDecodeError, KeyError, OSError): + return [] + + def add_components( + self, components: List[str], overwrite: bool = False + ) -> tuple[bool, str]: + """ + Add shadcn/ui components. + + Args: + components: List of component names to add + overwrite: If True, overwrite existing components + + Returns: + Tuple of (success, message) + """ + if not components: + return False, "No components specified" + + if not self.check_shadcn_config(): + return ( + False, + "shadcn not initialized. Run 'npx shadcn@latest init' first", + ) + + # Check which components already exist + installed = self.get_installed_components() + already_installed = [c for c in components if c in installed] + + if already_installed and not overwrite: + return ( + False, + f"Components already installed: {', '.join(already_installed)}. " + "Use --overwrite to reinstall", + ) + + # Build command + cmd = ["npx", "shadcn@latest", "add"] + components + + if overwrite: + cmd.append("--overwrite") + + if self.dry_run: + return True, f"Would run: {' '.join(cmd)}" + + # Execute command + try: + result = subprocess.run( + cmd, + cwd=self.project_root, + capture_output=True, + text=True, + check=True, + ) + + success_msg = f"Successfully added components: {', '.join(components)}" + if result.stdout: + success_msg += f"\n\nOutput:\n{result.stdout}" + + return True, success_msg + + except subprocess.CalledProcessError as e: + error_msg = f"Failed to add components: {e.stderr or e.stdout or str(e)}" + return False, error_msg + except FileNotFoundError: + return False, "npx not found. Ensure Node.js is installed" + + def add_all_components(self, overwrite: bool = False) -> tuple[bool, str]: + """ + Add all available shadcn/ui components. + + Args: + overwrite: If True, overwrite existing components + + Returns: + Tuple of (success, message) + """ + if not self.check_shadcn_config(): + return ( + False, + "shadcn not initialized. Run 'npx shadcn@latest init' first", + ) + + cmd = ["npx", "shadcn@latest", "add", "--all"] + + if overwrite: + cmd.append("--overwrite") + + if self.dry_run: + return True, f"Would run: {' '.join(cmd)}" + + try: + result = subprocess.run( + cmd, + cwd=self.project_root, + capture_output=True, + text=True, + check=True, + ) + + success_msg = "Successfully added all components" + if result.stdout: + success_msg += f"\n\nOutput:\n{result.stdout}" + + return True, success_msg + + except subprocess.CalledProcessError as e: + error_msg = f"Failed to add all components: {e.stderr or e.stdout or str(e)}" + return False, error_msg + except FileNotFoundError: + return False, "npx not found. Ensure Node.js is installed" + + def list_installed(self) -> tuple[bool, str]: + """ + List installed components. + + Returns: + Tuple of (success, message with component list) + """ + if not self.check_shadcn_config(): + return False, "shadcn not initialized" + + installed = self.get_installed_components() + + if not installed: + return True, "No components installed" + + return True, f"Installed components:\n" + "\n".join(f" - {c}" for c in sorted(installed)) + + +def main(): + """CLI entry point.""" + parser = argparse.ArgumentParser( + description="Add shadcn/ui components to your project", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Add single component + python shadcn_add.py button + + # Add multiple components + python shadcn_add.py button card dialog + + # Add all components + python shadcn_add.py --all + + # Overwrite existing components + python shadcn_add.py button --overwrite + + # Dry run (show what would be done) + python shadcn_add.py button card --dry-run + + # List installed components + python shadcn_add.py --list + """, + ) + + parser.add_argument( + "components", + nargs="*", + help="Component names to add (e.g., button, card, dialog)", + ) + + parser.add_argument( + "--all", + action="store_true", + help="Add all available components", + ) + + parser.add_argument( + "--overwrite", + action="store_true", + help="Overwrite existing components", + ) + + parser.add_argument( + "--dry-run", + action="store_true", + help="Show what would be done without executing", + ) + + parser.add_argument( + "--list", + action="store_true", + help="List installed components", + ) + + parser.add_argument( + "--project-root", + type=Path, + help="Project root directory (default: current directory)", + ) + + args = parser.parse_args() + + # Initialize installer + installer = ShadcnInstaller( + project_root=args.project_root, + dry_run=args.dry_run, + ) + + # Handle list command + if args.list: + success, message = installer.list_installed() + print(message) + sys.exit(0 if success else 1) + + # Handle add all command + if args.all: + success, message = installer.add_all_components(overwrite=args.overwrite) + print(message) + sys.exit(0 if success else 1) + + # Handle add specific components + if not args.components: + parser.print_help() + sys.exit(1) + + success, message = installer.add_components( + args.components, + overwrite=args.overwrite, + ) + + print(message) + sys.exit(0 if success else 1) + + +if __name__ == "__main__": + main() diff --git a/skills/website-creator/ui-styling/scripts/tailwind_config_gen.py b/skills/website-creator/ui-styling/scripts/tailwind_config_gen.py new file mode 100644 index 0000000..5109311 --- /dev/null +++ b/skills/website-creator/ui-styling/scripts/tailwind_config_gen.py @@ -0,0 +1,456 @@ +#!/usr/bin/env python3 +""" +Tailwind CSS Configuration Generator + +Generate tailwind.config.js/ts with custom theme configuration. +Supports colors, fonts, spacing, breakpoints, and plugin recommendations. +""" + +import argparse +import json +import sys +from pathlib import Path +from typing import Any, Dict, List, Optional + + +class TailwindConfigGenerator: + """Generate Tailwind CSS configuration files.""" + + def __init__( + self, + typescript: bool = True, + framework: str = "react", + output_path: Optional[Path] = None, + ): + """ + Initialize generator. + + Args: + typescript: If True, generate .ts config, else .js + framework: Framework name (react, vue, svelte, nextjs) + output_path: Output file path (default: auto-detect) + """ + self.typescript = typescript + self.framework = framework + self.output_path = output_path or self._default_output_path() + self.config: Dict[str, Any] = self._base_config() + + def _default_output_path(self) -> Path: + """Determine default output path.""" + ext = "ts" if self.typescript else "js" + return Path.cwd() / f"tailwind.config.{ext}" + + def _base_config(self) -> Dict[str, Any]: + """Create base configuration structure.""" + return { + "darkMode": ["class"], + "content": self._default_content_paths(), + "theme": { + "extend": {} + }, + "plugins": [] + } + + def _default_content_paths(self) -> List[str]: + """Get default content paths for framework.""" + paths = { + "react": [ + "./src/**/*.{js,jsx,ts,tsx}", + "./index.html", + ], + "vue": [ + "./src/**/*.{vue,js,ts,jsx,tsx}", + "./index.html", + ], + "svelte": [ + "./src/**/*.{svelte,js,ts}", + "./src/app.html", + ], + "nextjs": [ + "./app/**/*.{js,ts,jsx,tsx}", + "./pages/**/*.{js,ts,jsx,tsx}", + "./components/**/*.{js,ts,jsx,tsx}", + ], + } + return paths.get(self.framework, paths["react"]) + + def add_colors(self, colors: Dict[str, str]) -> None: + """ + Add custom colors to theme. + + Args: + colors: Dict of color_name: color_value + Value can be hex (#3b82f6) or variable (hsl(var(--primary))) + """ + if "colors" not in self.config["theme"]["extend"]: + self.config["theme"]["extend"]["colors"] = {} + + self.config["theme"]["extend"]["colors"].update(colors) + + def add_color_palette(self, name: str, base_color: str) -> None: + """ + Add full color palette (50-950 shades) for a base color. + + Args: + name: Color name (e.g., 'brand', 'primary') + base_color: Base color in oklch format or hex + """ + # For simplicity, use CSS variable approach + if "colors" not in self.config["theme"]["extend"]: + self.config["theme"]["extend"]["colors"] = {} + + self.config["theme"]["extend"]["colors"][name] = { + "50": f"var(--color-{name}-50)", + "100": f"var(--color-{name}-100)", + "200": f"var(--color-{name}-200)", + "300": f"var(--color-{name}-300)", + "400": f"var(--color-{name}-400)", + "500": f"var(--color-{name}-500)", + "600": f"var(--color-{name}-600)", + "700": f"var(--color-{name}-700)", + "800": f"var(--color-{name}-800)", + "900": f"var(--color-{name}-900)", + "950": f"var(--color-{name}-950)", + } + + def add_fonts(self, fonts: Dict[str, List[str]]) -> None: + """ + Add custom font families. + + Args: + fonts: Dict of font_type: [font_names] + e.g., {'sans': ['Inter', 'system-ui', 'sans-serif']} + """ + if "fontFamily" not in self.config["theme"]["extend"]: + self.config["theme"]["extend"]["fontFamily"] = {} + + self.config["theme"]["extend"]["fontFamily"].update(fonts) + + def add_spacing(self, spacing: Dict[str, str]) -> None: + """ + Add custom spacing values. + + Args: + spacing: Dict of name: value + e.g., {'18': '4.5rem', 'navbar': '4rem'} + """ + if "spacing" not in self.config["theme"]["extend"]: + self.config["theme"]["extend"]["spacing"] = {} + + self.config["theme"]["extend"]["spacing"].update(spacing) + + def add_breakpoints(self, breakpoints: Dict[str, str]) -> None: + """ + Add custom breakpoints. + + Args: + breakpoints: Dict of name: width + e.g., {'3xl': '1920px', 'tablet': '768px'} + """ + if "screens" not in self.config["theme"]["extend"]: + self.config["theme"]["extend"]["screens"] = {} + + self.config["theme"]["extend"]["screens"].update(breakpoints) + + def add_plugins(self, plugins: List[str]) -> None: + """ + Add plugin requirements. + + Args: + plugins: List of plugin names + e.g., ['@tailwindcss/typography', '@tailwindcss/forms'] + """ + for plugin in plugins: + if plugin not in self.config["plugins"]: + self.config["plugins"].append(plugin) + + def recommend_plugins(self) -> List[str]: + """ + Get plugin recommendations based on configuration. + + Returns: + List of recommended plugin package names + """ + recommendations = [] + + # Always recommend animation plugin + recommendations.append("tailwindcss-animate") + + # Framework-specific recommendations + if self.framework == "nextjs": + recommendations.append("@tailwindcss/typography") + + return recommendations + + def generate_config_string(self) -> str: + """ + Generate configuration file content. + + Returns: + Configuration file as string + """ + if self.typescript: + return self._generate_typescript() + return self._generate_javascript() + + def _generate_typescript(self) -> str: + """Generate TypeScript configuration.""" + plugins_str = self._format_plugins() + + config_json = json.dumps(self.config, indent=2) + + # Remove plugin array from JSON (we'll add it with require()) + config_obj = self.config.copy() + config_obj.pop("plugins", None) + config_json = json.dumps(config_obj, indent=2) + + return f"""import type {{ Config }} from 'tailwindcss' + +const config: Config = {{ +{self._indent_json(config_json, 1)} + plugins: [{plugins_str}], +}} + +export default config +""" + + def _generate_javascript(self) -> str: + """Generate JavaScript configuration.""" + plugins_str = self._format_plugins() + + config_obj = self.config.copy() + config_obj.pop("plugins", None) + config_json = json.dumps(config_obj, indent=2) + + return f"""/** @type {{import('tailwindcss').Config}} */ +module.exports = {{ +{self._indent_json(config_json, 1)} + plugins: [{plugins_str}], +}} +""" + + def _format_plugins(self) -> str: + """Format plugins array for config.""" + if not self.config["plugins"]: + return "" + + plugin_requires = [ + f"require('{plugin}')" for plugin in self.config["plugins"] + ] + return ", ".join(plugin_requires) + + def _indent_json(self, json_str: str, level: int) -> str: + """Add indentation to JSON string.""" + indent = " " * level + lines = json_str.split("\n") + # Skip first and last lines (braces) + indented = [indent + line for line in lines[1:-1]] + return "\n".join(indented) + + def write_config(self) -> tuple[bool, str]: + """ + Write configuration to file. + + Returns: + Tuple of (success, message) + """ + try: + config_content = self.generate_config_string() + + self.output_path.write_text(config_content) + + return True, f"Configuration written to {self.output_path}" + + except OSError as e: + return False, f"Failed to write config: {e}" + + def validate_config(self) -> tuple[bool, str]: + """ + Validate configuration. + + Returns: + Tuple of (valid, message) + """ + # Check content paths exist + if not self.config["content"]: + return False, "No content paths specified" + + # Check if extending empty theme + if not self.config["theme"]["extend"]: + return True, "Warning: No theme extensions defined" + + return True, "Configuration valid" + + +def main(): + """CLI entry point.""" + parser = argparse.ArgumentParser( + description="Generate Tailwind CSS configuration", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Generate TypeScript config for Next.js + python tailwind_config_gen.py --framework nextjs + + # Generate JavaScript config with custom colors + python tailwind_config_gen.py --js --colors brand:#3b82f6 accent:#8b5cf6 + + # Add custom fonts + python tailwind_config_gen.py --fonts display:"Playfair Display,serif" + + # Add custom spacing and breakpoints + python tailwind_config_gen.py --spacing navbar:4rem --breakpoints 3xl:1920px + + # Add recommended plugins + python tailwind_config_gen.py --plugins + """, + ) + + parser.add_argument( + "--framework", + choices=["react", "vue", "svelte", "nextjs"], + default="react", + help="Target framework (default: react)", + ) + + parser.add_argument( + "--js", + action="store_true", + help="Generate JavaScript config instead of TypeScript", + ) + + parser.add_argument( + "--output", + type=Path, + help="Output file path", + ) + + parser.add_argument( + "--colors", + nargs="*", + metavar="NAME:VALUE", + help="Custom colors (e.g., brand:#3b82f6)", + ) + + parser.add_argument( + "--fonts", + nargs="*", + metavar="TYPE:FAMILY", + help="Custom fonts (e.g., sans:'Inter,system-ui')", + ) + + parser.add_argument( + "--spacing", + nargs="*", + metavar="NAME:VALUE", + help="Custom spacing (e.g., navbar:4rem)", + ) + + parser.add_argument( + "--breakpoints", + nargs="*", + metavar="NAME:WIDTH", + help="Custom breakpoints (e.g., 3xl:1920px)", + ) + + parser.add_argument( + "--plugins", + action="store_true", + help="Add recommended plugins", + ) + + parser.add_argument( + "--validate-only", + action="store_true", + help="Validate config without writing file", + ) + + args = parser.parse_args() + + # Initialize generator + generator = TailwindConfigGenerator( + typescript=not args.js, + framework=args.framework, + output_path=args.output, + ) + + # Add custom colors + if args.colors: + colors = {} + for color_spec in args.colors: + try: + name, value = color_spec.split(":", 1) + colors[name] = value + except ValueError: + print(f"Invalid color spec: {color_spec}", file=sys.stderr) + sys.exit(1) + generator.add_colors(colors) + + # Add custom fonts + if args.fonts: + fonts = {} + for font_spec in args.fonts: + try: + font_type, family = font_spec.split(":", 1) + fonts[font_type] = [f.strip().strip("'\"") for f in family.split(",")] + except ValueError: + print(f"Invalid font spec: {font_spec}", file=sys.stderr) + sys.exit(1) + generator.add_fonts(fonts) + + # Add custom spacing + if args.spacing: + spacing = {} + for spacing_spec in args.spacing: + try: + name, value = spacing_spec.split(":", 1) + spacing[name] = value + except ValueError: + print(f"Invalid spacing spec: {spacing_spec}", file=sys.stderr) + sys.exit(1) + generator.add_spacing(spacing) + + # Add custom breakpoints + if args.breakpoints: + breakpoints = {} + for bp_spec in args.breakpoints: + try: + name, width = bp_spec.split(":", 1) + breakpoints[name] = width + except ValueError: + print(f"Invalid breakpoint spec: {bp_spec}", file=sys.stderr) + sys.exit(1) + generator.add_breakpoints(breakpoints) + + # Add recommended plugins + if args.plugins: + recommended = generator.recommend_plugins() + generator.add_plugins(recommended) + print(f"Added recommended plugins: {', '.join(recommended)}") + print("\nInstall with:") + print(f" npm install -D {' '.join(recommended)}") + + # Validate + valid, message = generator.validate_config() + if not valid: + print(f"Validation failed: {message}", file=sys.stderr) + sys.exit(1) + + if message.startswith("Warning"): + print(message) + + # Validate only mode + if args.validate_only: + print("Configuration valid") + print("\nGenerated config:") + print(generator.generate_config_string()) + sys.exit(0) + + # Write config + success, message = generator.write_config() + print(message) + sys.exit(0 if success else 1) + + +if __name__ == "__main__": + main() diff --git a/skills/website-creator/ui-styling/scripts/tests/coverage-ui.json b/skills/website-creator/ui-styling/scripts/tests/coverage-ui.json new file mode 100644 index 0000000..2a20568 --- /dev/null +++ b/skills/website-creator/ui-styling/scripts/tests/coverage-ui.json @@ -0,0 +1 @@ +{"meta": {"format": 3, "version": "7.11.0", "timestamp": "2025-11-05T00:57:08.005243", "branch_coverage": false, "show_contexts": false}, "files": {"shadcn_add.py": {"executed_lines": [2, 9, 10, 11, 12, 13, 14, 17, 18, 20, 28, 29, 30, 32, 39, 41, 48, 49, 51, 52, 53, 55, 58, 60, 63, 67, 80, 81, 83, 84, 90, 91, 93, 94, 101, 103, 104, 106, 107, 110, 111, 119, 120, 121, 123, 125, 126, 127, 128, 129, 131, 141, 142, 147, 149, 152, 153, 155, 156, 164, 165, 166, 168, 176, 183, 184, 186, 188, 189, 191, 194, 291], "summary": {"covered_lines": 70, "num_statements": 103, "percent_covered": 67.96116504854369, "percent_covered_display": "68", "missing_lines": 33, "excluded_lines": 0}, "missing_lines": [61, 64, 65, 150, 170, 171, 172, 173, 174, 196, 221, 227, 233, 239, 245, 251, 257, 260, 266, 267, 268, 269, 272, 273, 274, 275, 278, 279, 280, 282, 287, 288, 292], "excluded_lines": [], "functions": {"ShadcnInstaller.__init__": {"executed_lines": [28, 29, 30], "summary": {"covered_lines": 3, "num_statements": 3, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "ShadcnInstaller.check_shadcn_config": {"executed_lines": [39], "summary": {"covered_lines": 1, "num_statements": 1, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "ShadcnInstaller.get_installed_components": {"executed_lines": [48, 49, 51, 52, 53, 55, 58, 60, 63], "summary": {"covered_lines": 9, "num_statements": 12, "percent_covered": 75.0, "percent_covered_display": "75", "missing_lines": 3, "excluded_lines": 0}, "missing_lines": [61, 64, 65], "excluded_lines": []}, "ShadcnInstaller.add_components": {"executed_lines": [80, 81, 83, 84, 90, 91, 93, 94, 101, 103, 104, 106, 107, 110, 111, 119, 120, 121, 123, 125, 126, 127, 128, 129], "summary": {"covered_lines": 24, "num_statements": 24, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "ShadcnInstaller.add_all_components": {"executed_lines": [141, 142, 147, 149, 152, 153, 155, 156, 164, 165, 166, 168], "summary": {"covered_lines": 12, "num_statements": 18, "percent_covered": 66.66666666666667, "percent_covered_display": "67", "missing_lines": 6, "excluded_lines": 0}, "missing_lines": [150, 170, 171, 172, 173, 174], "excluded_lines": []}, "ShadcnInstaller.list_installed": {"executed_lines": [183, 184, 186, 188, 189, 191], "summary": {"covered_lines": 6, "num_statements": 6, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "main": {"executed_lines": [], "summary": {"covered_lines": 0, "num_statements": 23, "percent_covered": 0.0, "percent_covered_display": "0", "missing_lines": 23, "excluded_lines": 0}, "missing_lines": [196, 221, 227, 233, 239, 245, 251, 257, 260, 266, 267, 268, 269, 272, 273, 274, 275, 278, 279, 280, 282, 287, 288], "excluded_lines": []}, "": {"executed_lines": [2, 9, 10, 11, 12, 13, 14, 17, 18, 20, 32, 41, 67, 131, 176, 194, 291], "summary": {"covered_lines": 15, "num_statements": 16, "percent_covered": 93.75, "percent_covered_display": "94", "missing_lines": 1, "excluded_lines": 0}, "missing_lines": [292], "excluded_lines": []}}, "classes": {"ShadcnInstaller": {"executed_lines": [28, 29, 30, 39, 48, 49, 51, 52, 53, 55, 58, 60, 63, 80, 81, 83, 84, 90, 91, 93, 94, 101, 103, 104, 106, 107, 110, 111, 119, 120, 121, 123, 125, 126, 127, 128, 129, 141, 142, 147, 149, 152, 153, 155, 156, 164, 165, 166, 168, 183, 184, 186, 188, 189, 191], "summary": {"covered_lines": 55, "num_statements": 64, "percent_covered": 85.9375, "percent_covered_display": "86", "missing_lines": 9, "excluded_lines": 0}, "missing_lines": [61, 64, 65, 150, 170, 171, 172, 173, 174], "excluded_lines": []}, "": {"executed_lines": [2, 9, 10, 11, 12, 13, 14, 17, 18, 20, 32, 41, 67, 131, 176, 194, 291], "summary": {"covered_lines": 15, "num_statements": 39, "percent_covered": 38.46153846153846, "percent_covered_display": "38", "missing_lines": 24, "excluded_lines": 0}, "missing_lines": [196, 221, 227, 233, 239, 245, 251, 257, 260, 266, 267, 268, 269, 272, 273, 274, 275, 278, 279, 280, 282, 287, 288, 292], "excluded_lines": []}}}, "tailwind_config_gen.py": {"executed_lines": [2, 9, 10, 11, 12, 13, 16, 17, 19, 33, 34, 35, 36, 38, 40, 41, 43, 45, 54, 56, 75, 77, 85, 86, 88, 90, 99, 100, 102, 116, 124, 125, 127, 129, 137, 138, 140, 142, 150, 151, 153, 155, 163, 164, 165, 167, 174, 177, 180, 181, 183, 185, 192, 193, 194, 196, 198, 200, 203, 204, 205, 207, 217, 219, 221, 222, 223, 225, 232, 234, 235, 237, 240, 242, 244, 245, 247, 248, 250, 257, 258, 260, 262, 264, 265, 267, 275, 276, 279, 280, 285, 455], "summary": {"covered_lines": 90, "num_statements": 164, "percent_covered": 54.8780487804878, "percent_covered_display": "55", "missing_lines": 74, "excluded_lines": 0}, "missing_lines": [282, 287, 309, 316, 322, 328, 335, 342, 349, 356, 362, 368, 371, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 426, 427, 428, 429, 430, 431, 434, 435, 436, 437, 439, 440, 443, 444, 445, 446, 447, 450, 451, 452, 456], "excluded_lines": [], "functions": {"TailwindConfigGenerator.__init__": {"executed_lines": [33, 34, 35, 36], "summary": {"covered_lines": 4, "num_statements": 4, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TailwindConfigGenerator._default_output_path": {"executed_lines": [40, 41], "summary": {"covered_lines": 2, "num_statements": 2, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TailwindConfigGenerator._base_config": {"executed_lines": [45], "summary": {"covered_lines": 1, "num_statements": 1, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TailwindConfigGenerator._default_content_paths": {"executed_lines": [56, 75], "summary": {"covered_lines": 2, "num_statements": 2, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TailwindConfigGenerator.add_colors": {"executed_lines": [85, 86, 88], "summary": {"covered_lines": 3, "num_statements": 3, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TailwindConfigGenerator.add_color_palette": {"executed_lines": [99, 100, 102], "summary": {"covered_lines": 3, "num_statements": 3, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TailwindConfigGenerator.add_fonts": {"executed_lines": [124, 125, 127], "summary": {"covered_lines": 3, "num_statements": 3, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TailwindConfigGenerator.add_spacing": {"executed_lines": [137, 138, 140], "summary": {"covered_lines": 3, "num_statements": 3, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TailwindConfigGenerator.add_breakpoints": {"executed_lines": [150, 151, 153], "summary": {"covered_lines": 3, "num_statements": 3, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TailwindConfigGenerator.add_plugins": {"executed_lines": [163, 164, 165], "summary": {"covered_lines": 3, "num_statements": 3, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TailwindConfigGenerator.recommend_plugins": {"executed_lines": [174, 177, 180, 181, 183], "summary": {"covered_lines": 5, "num_statements": 5, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TailwindConfigGenerator.generate_config_string": {"executed_lines": [192, 193, 194], "summary": {"covered_lines": 3, "num_statements": 3, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TailwindConfigGenerator._generate_typescript": {"executed_lines": [198, 200, 203, 204, 205, 207], "summary": {"covered_lines": 6, "num_statements": 6, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TailwindConfigGenerator._generate_javascript": {"executed_lines": [219, 221, 222, 223, 225], "summary": {"covered_lines": 5, "num_statements": 5, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TailwindConfigGenerator._format_plugins": {"executed_lines": [234, 235, 237, 240], "summary": {"covered_lines": 4, "num_statements": 4, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TailwindConfigGenerator._indent_json": {"executed_lines": [244, 245, 247, 248], "summary": {"covered_lines": 4, "num_statements": 4, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TailwindConfigGenerator.write_config": {"executed_lines": [257, 258, 260, 262, 264, 265], "summary": {"covered_lines": 6, "num_statements": 6, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TailwindConfigGenerator.validate_config": {"executed_lines": [275, 276, 279, 280], "summary": {"covered_lines": 4, "num_statements": 5, "percent_covered": 80.0, "percent_covered_display": "80", "missing_lines": 1, "excluded_lines": 0}, "missing_lines": [282], "excluded_lines": []}, "main": {"executed_lines": [], "summary": {"covered_lines": 0, "num_statements": 72, "percent_covered": 0.0, "percent_covered_display": "0", "missing_lines": 72, "excluded_lines": 0}, "missing_lines": [287, 309, 316, 322, 328, 335, 342, 349, 356, 362, 368, 371, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 426, 427, 428, 429, 430, 431, 434, 435, 436, 437, 439, 440, 443, 444, 445, 446, 447, 450, 451, 452], "excluded_lines": []}, "": {"executed_lines": [2, 9, 10, 11, 12, 13, 16, 17, 19, 38, 43, 54, 77, 90, 116, 129, 142, 155, 167, 185, 196, 217, 232, 242, 250, 267, 285, 455], "summary": {"covered_lines": 26, "num_statements": 27, "percent_covered": 96.29629629629629, "percent_covered_display": "96", "missing_lines": 1, "excluded_lines": 0}, "missing_lines": [456], "excluded_lines": []}}, "classes": {"TailwindConfigGenerator": {"executed_lines": [33, 34, 35, 36, 40, 41, 45, 56, 75, 85, 86, 88, 99, 100, 102, 124, 125, 127, 137, 138, 140, 150, 151, 153, 163, 164, 165, 174, 177, 180, 181, 183, 192, 193, 194, 198, 200, 203, 204, 205, 207, 219, 221, 222, 223, 225, 234, 235, 237, 240, 244, 245, 247, 248, 257, 258, 260, 262, 264, 265, 275, 276, 279, 280], "summary": {"covered_lines": 64, "num_statements": 65, "percent_covered": 98.46153846153847, "percent_covered_display": "98", "missing_lines": 1, "excluded_lines": 0}, "missing_lines": [282], "excluded_lines": []}, "": {"executed_lines": [2, 9, 10, 11, 12, 13, 16, 17, 19, 38, 43, 54, 77, 90, 116, 129, 142, 155, 167, 185, 196, 217, 232, 242, 250, 267, 285, 455], "summary": {"covered_lines": 26, "num_statements": 99, "percent_covered": 26.262626262626263, "percent_covered_display": "26", "missing_lines": 73, "excluded_lines": 0}, "missing_lines": [287, 309, 316, 322, 328, 335, 342, 349, 356, 362, 368, 371, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 426, 427, 428, 429, 430, 431, 434, 435, 436, 437, 439, 440, 443, 444, 445, 446, 447, 450, 451, 452, 456], "excluded_lines": []}}}, "tests/test_shadcn_add.py": {"executed_lines": [1, 3, 4, 5, 6, 8, 11, 12, 14, 17, 18, 20, 21, 23, 24, 27, 28, 39, 40, 42, 44, 46, 47, 48, 50, 52, 53, 55, 57, 58, 60, 62, 63, 65, 67, 68, 70, 72, 73, 74, 76, 78, 81, 82, 84, 85, 87, 89, 91, 92, 93, 95, 97, 98, 100, 101, 103, 105, 106, 108, 109, 111, 113, 114, 116, 117, 119, 120, 121, 123, 125, 126, 128, 130, 131, 136, 138, 139, 140, 143, 144, 146, 148, 149, 151, 152, 153, 154, 156, 157, 159, 165, 166, 168, 169, 170, 171, 174, 175, 176, 177, 178, 180, 181, 183, 187, 188, 190, 191, 193, 194, 196, 198, 199, 201, 202, 204, 206, 207, 209, 210, 212, 214, 215, 217, 218, 219, 221, 222, 224, 229, 230, 232, 233, 236, 237, 239, 241, 242, 244, 245, 247, 249, 250, 252, 253, 255, 257, 258, 259, 261, 262, 264, 265, 266], "summary": {"covered_lines": 153, "num_statements": 153, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": [], "functions": {"TestShadcnInstaller.temp_project": {"executed_lines": [23, 24, 27, 28, 39, 40, 42], "summary": {"covered_lines": 7, "num_statements": 7, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestShadcnInstaller.test_init_default_project_root": {"executed_lines": [46, 47, 48], "summary": {"covered_lines": 3, "num_statements": 3, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestShadcnInstaller.test_init_custom_project_root": {"executed_lines": [52, 53], "summary": {"covered_lines": 2, "num_statements": 2, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestShadcnInstaller.test_init_dry_run": {"executed_lines": [57, 58], "summary": {"covered_lines": 2, "num_statements": 2, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestShadcnInstaller.test_check_shadcn_config_exists": {"executed_lines": [62, 63], "summary": {"covered_lines": 2, "num_statements": 2, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestShadcnInstaller.test_check_shadcn_config_not_exists": {"executed_lines": [67, 68], "summary": {"covered_lines": 2, "num_statements": 2, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestShadcnInstaller.test_get_installed_components_empty": {"executed_lines": [72, 73, 74], "summary": {"covered_lines": 3, "num_statements": 3, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestShadcnInstaller.test_get_installed_components_with_files": {"executed_lines": [78, 81, 82, 84, 85, 87], "summary": {"covered_lines": 6, "num_statements": 6, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestShadcnInstaller.test_get_installed_components_no_config": {"executed_lines": [91, 92, 93], "summary": {"covered_lines": 3, "num_statements": 3, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestShadcnInstaller.test_add_components_no_components": {"executed_lines": [97, 98, 100, 101], "summary": {"covered_lines": 4, "num_statements": 4, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestShadcnInstaller.test_add_components_no_config": {"executed_lines": [105, 106, 108, 109], "summary": {"covered_lines": 4, "num_statements": 4, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestShadcnInstaller.test_add_components_already_installed": {"executed_lines": [113, 114, 116, 117, 119, 120, 121], "summary": {"covered_lines": 7, "num_statements": 7, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestShadcnInstaller.test_add_components_with_overwrite": {"executed_lines": [125, 126, 128, 130, 131, 136, 138, 139, 140, 143, 144], "summary": {"covered_lines": 11, "num_statements": 11, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestShadcnInstaller.test_add_components_dry_run": {"executed_lines": [148, 149, 151, 152, 153, 154], "summary": {"covered_lines": 6, "num_statements": 6, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestShadcnInstaller.test_add_components_success": {"executed_lines": [159, 165, 166, 168, 169, 170, 171, 174, 175, 176, 177, 178], "summary": {"covered_lines": 12, "num_statements": 12, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestShadcnInstaller.test_add_components_subprocess_error": {"executed_lines": [183, 187, 188, 190, 191], "summary": {"covered_lines": 5, "num_statements": 5, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestShadcnInstaller.test_add_components_npx_not_found": {"executed_lines": [196, 198, 199, 201, 202], "summary": {"covered_lines": 5, "num_statements": 5, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestShadcnInstaller.test_add_all_components_no_config": {"executed_lines": [206, 207, 209, 210], "summary": {"covered_lines": 4, "num_statements": 4, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestShadcnInstaller.test_add_all_components_dry_run": {"executed_lines": [214, 215, 217, 218, 219], "summary": {"covered_lines": 5, "num_statements": 5, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestShadcnInstaller.test_add_all_components_success": {"executed_lines": [224, 229, 230, 232, 233, 236, 237], "summary": {"covered_lines": 7, "num_statements": 7, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestShadcnInstaller.test_list_installed_no_config": {"executed_lines": [241, 242, 244, 245], "summary": {"covered_lines": 4, "num_statements": 4, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestShadcnInstaller.test_list_installed_empty": {"executed_lines": [249, 250, 252, 253], "summary": {"covered_lines": 4, "num_statements": 4, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestShadcnInstaller.test_list_installed_with_components": {"executed_lines": [257, 258, 259, 261, 262, 264, 265, 266], "summary": {"covered_lines": 8, "num_statements": 8, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "": {"executed_lines": [1, 3, 4, 5, 6, 8, 11, 12, 14, 17, 18, 20, 21, 44, 50, 55, 60, 65, 70, 76, 89, 95, 103, 111, 123, 146, 156, 157, 180, 181, 193, 194, 204, 212, 221, 222, 239, 247, 255], "summary": {"covered_lines": 37, "num_statements": 37, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}}, "classes": {"TestShadcnInstaller": {"executed_lines": [23, 24, 27, 28, 39, 40, 42, 46, 47, 48, 52, 53, 57, 58, 62, 63, 67, 68, 72, 73, 74, 78, 81, 82, 84, 85, 87, 91, 92, 93, 97, 98, 100, 101, 105, 106, 108, 109, 113, 114, 116, 117, 119, 120, 121, 125, 126, 128, 130, 131, 136, 138, 139, 140, 143, 144, 148, 149, 151, 152, 153, 154, 159, 165, 166, 168, 169, 170, 171, 174, 175, 176, 177, 178, 183, 187, 188, 190, 191, 196, 198, 199, 201, 202, 206, 207, 209, 210, 214, 215, 217, 218, 219, 224, 229, 230, 232, 233, 236, 237, 241, 242, 244, 245, 249, 250, 252, 253, 257, 258, 259, 261, 262, 264, 265, 266], "summary": {"covered_lines": 116, "num_statements": 116, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "": {"executed_lines": [1, 3, 4, 5, 6, 8, 11, 12, 14, 17, 18, 20, 21, 44, 50, 55, 60, 65, 70, 76, 89, 95, 103, 111, 123, 146, 156, 157, 180, 181, 193, 194, 204, 212, 221, 222, 239, 247, 255], "summary": {"covered_lines": 37, "num_statements": 37, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}}}, "tests/test_tailwind_config_gen.py": {"executed_lines": [1, 3, 5, 8, 9, 11, 14, 15, 17, 19, 20, 21, 23, 25, 26, 28, 30, 31, 32, 34, 36, 37, 39, 41, 42, 44, 46, 47, 48, 50, 52, 53, 55, 56, 57, 58, 59, 61, 63, 64, 66, 67, 69, 71, 72, 74, 75, 76, 78, 80, 81, 83, 85, 87, 88, 92, 94, 95, 96, 98, 100, 102, 103, 105, 106, 107, 109, 111, 112, 114, 116, 117, 118, 119, 120, 122, 124, 125, 129, 131, 132, 133, 135, 137, 138, 142, 144, 145, 146, 148, 150, 151, 155, 157, 158, 159, 161, 163, 164, 165, 167, 168, 170, 172, 173, 174, 176, 177, 179, 181, 182, 184, 185, 187, 189, 190, 192, 194, 196, 197, 199, 200, 201, 203, 205, 206, 208, 209, 211, 213, 214, 215, 217, 218, 220, 222, 223, 224, 226, 227, 229, 231, 232, 234, 236, 238, 239, 241, 243, 244, 246, 248, 251, 253, 254, 256, 258, 259, 261, 263, 264, 265, 267, 269, 270, 271, 273, 275, 276, 277, 279, 281, 283, 285, 286, 288, 290, 291, 298, 299, 300, 301, 302, 304, 305, 307, 310, 311, 312, 313, 314, 315, 317, 319, 320, 326, 327, 329, 330, 332, 334, 335, 336], "summary": {"covered_lines": 201, "num_statements": 201, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": [], "functions": {"TestTailwindConfigGenerator.test_init_default_typescript": {"executed_lines": [19, 20, 21], "summary": {"covered_lines": 3, "num_statements": 3, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_init_javascript": {"executed_lines": [25, 26], "summary": {"covered_lines": 2, "num_statements": 2, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_init_framework": {"executed_lines": [30, 31, 32], "summary": {"covered_lines": 3, "num_statements": 3, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_default_output_path_typescript": {"executed_lines": [36, 37], "summary": {"covered_lines": 2, "num_statements": 2, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_default_output_path_javascript": {"executed_lines": [41, 42], "summary": {"covered_lines": 2, "num_statements": 2, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_custom_output_path": {"executed_lines": [46, 47, 48], "summary": {"covered_lines": 3, "num_statements": 3, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_base_config_structure": {"executed_lines": [52, 53, 55, 56, 57, 58, 59], "summary": {"covered_lines": 7, "num_statements": 7, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_default_content_paths_react": {"executed_lines": [63, 64, 66, 67], "summary": {"covered_lines": 4, "num_statements": 4, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_default_content_paths_nextjs": {"executed_lines": [71, 72, 74, 75, 76], "summary": {"covered_lines": 5, "num_statements": 5, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_default_content_paths_vue": {"executed_lines": [80, 81, 83], "summary": {"covered_lines": 3, "num_statements": 3, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_add_colors": {"executed_lines": [87, 88, 92, 94, 95, 96], "summary": {"covered_lines": 6, "num_statements": 6, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_add_colors_multiple_times": {"executed_lines": [100, 102, 103, 105, 106, 107], "summary": {"covered_lines": 6, "num_statements": 6, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_add_color_palette": {"executed_lines": [111, 112, 114, 116, 117, 118, 119, 120], "summary": {"covered_lines": 8, "num_statements": 8, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_add_fonts": {"executed_lines": [124, 125, 129, 131, 132, 133], "summary": {"covered_lines": 6, "num_statements": 6, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_add_spacing": {"executed_lines": [137, 138, 142, 144, 145, 146], "summary": {"covered_lines": 6, "num_statements": 6, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_add_breakpoints": {"executed_lines": [150, 151, 155, 157, 158, 159], "summary": {"covered_lines": 6, "num_statements": 6, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_add_plugins": {"executed_lines": [163, 164, 165, 167, 168], "summary": {"covered_lines": 5, "num_statements": 5, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_add_plugins_no_duplicates": {"executed_lines": [172, 173, 174, 176, 177], "summary": {"covered_lines": 5, "num_statements": 5, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_recommend_plugins": {"executed_lines": [181, 182, 184, 185], "summary": {"covered_lines": 4, "num_statements": 4, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_recommend_plugins_nextjs": {"executed_lines": [189, 190, 192], "summary": {"covered_lines": 3, "num_statements": 3, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_generate_typescript_config": {"executed_lines": [196, 197, 199, 200, 201], "summary": {"covered_lines": 5, "num_statements": 5, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_generate_javascript_config": {"executed_lines": [205, 206, 208, 209], "summary": {"covered_lines": 4, "num_statements": 4, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_generate_config_with_colors": {"executed_lines": [213, 214, 215, 217, 218], "summary": {"covered_lines": 5, "num_statements": 5, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_generate_config_with_plugins": {"executed_lines": [222, 223, 224, 226, 227], "summary": {"covered_lines": 5, "num_statements": 5, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_validate_config_valid": {"executed_lines": [231, 232, 234], "summary": {"covered_lines": 3, "num_statements": 3, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_validate_config_no_content": {"executed_lines": [238, 239, 241, 243, 244], "summary": {"covered_lines": 5, "num_statements": 5, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_validate_config_empty_theme": {"executed_lines": [248, 251, 253, 254], "summary": {"covered_lines": 4, "num_statements": 4, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_write_config": {"executed_lines": [258, 259, 261, 263, 264, 265], "summary": {"covered_lines": 6, "num_statements": 6, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_write_config_creates_content": {"executed_lines": [269, 270, 271, 273, 275, 276, 277], "summary": {"covered_lines": 7, "num_statements": 7, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_write_config_invalid_path": {"executed_lines": [281, 283, 285, 286], "summary": {"covered_lines": 4, "num_statements": 4, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_full_configuration_typescript": {"executed_lines": [290, 291, 298, 299, 300, 301, 302, 304, 305, 307, 310, 311, 312, 313, 314, 315], "summary": {"covered_lines": 16, "num_statements": 16, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "TestTailwindConfigGenerator.test_full_configuration_javascript": {"executed_lines": [319, 320, 326, 327, 329, 330, 332, 334, 335, 336], "summary": {"covered_lines": 10, "num_statements": 10, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "": {"executed_lines": [1, 3, 5, 8, 9, 11, 14, 15, 17, 23, 28, 34, 39, 44, 50, 61, 69, 78, 85, 98, 109, 122, 135, 148, 161, 170, 179, 187, 194, 203, 211, 220, 229, 236, 246, 256, 267, 279, 288, 317], "summary": {"covered_lines": 38, "num_statements": 38, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}}, "classes": {"TestTailwindConfigGenerator": {"executed_lines": [19, 20, 21, 25, 26, 30, 31, 32, 36, 37, 41, 42, 46, 47, 48, 52, 53, 55, 56, 57, 58, 59, 63, 64, 66, 67, 71, 72, 74, 75, 76, 80, 81, 83, 87, 88, 92, 94, 95, 96, 100, 102, 103, 105, 106, 107, 111, 112, 114, 116, 117, 118, 119, 120, 124, 125, 129, 131, 132, 133, 137, 138, 142, 144, 145, 146, 150, 151, 155, 157, 158, 159, 163, 164, 165, 167, 168, 172, 173, 174, 176, 177, 181, 182, 184, 185, 189, 190, 192, 196, 197, 199, 200, 201, 205, 206, 208, 209, 213, 214, 215, 217, 218, 222, 223, 224, 226, 227, 231, 232, 234, 238, 239, 241, 243, 244, 248, 251, 253, 254, 258, 259, 261, 263, 264, 265, 269, 270, 271, 273, 275, 276, 277, 281, 283, 285, 286, 290, 291, 298, 299, 300, 301, 302, 304, 305, 307, 310, 311, 312, 313, 314, 315, 319, 320, 326, 327, 329, 330, 332, 334, 335, 336], "summary": {"covered_lines": 163, "num_statements": 163, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}, "": {"executed_lines": [1, 3, 5, 8, 9, 11, 14, 15, 17, 23, 28, 34, 39, 44, 50, 61, 69, 78, 85, 98, 109, 122, 135, 148, 161, 170, 179, 187, 194, 203, 211, 220, 229, 236, 246, 256, 267, 279, 288, 317], "summary": {"covered_lines": 38, "num_statements": 38, "percent_covered": 100.0, "percent_covered_display": "100", "missing_lines": 0, "excluded_lines": 0}, "missing_lines": [], "excluded_lines": []}}}}, "totals": {"covered_lines": 514, "num_statements": 621, "percent_covered": 82.76972624798712, "percent_covered_display": "83", "missing_lines": 107, "excluded_lines": 0}} \ No newline at end of file diff --git a/skills/website-creator/ui-styling/scripts/tests/requirements.txt b/skills/website-creator/ui-styling/scripts/tests/requirements.txt new file mode 100644 index 0000000..3a0f66d --- /dev/null +++ b/skills/website-creator/ui-styling/scripts/tests/requirements.txt @@ -0,0 +1,3 @@ +pytest>=7.4.0 +pytest-cov>=4.1.0 +pytest-mock>=3.11.1 diff --git a/skills/website-creator/ui-styling/scripts/tests/test_shadcn_add.py b/skills/website-creator/ui-styling/scripts/tests/test_shadcn_add.py new file mode 100644 index 0000000..03c8f31 --- /dev/null +++ b/skills/website-creator/ui-styling/scripts/tests/test_shadcn_add.py @@ -0,0 +1,266 @@ +"""Tests for shadcn_add.py""" + +import json +import subprocess +from pathlib import Path +from unittest.mock import MagicMock, mock_open, patch + +import pytest + +# Add parent directory to path for imports +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from shadcn_add import ShadcnInstaller + + +class TestShadcnInstaller: + """Test ShadcnInstaller class.""" + + @pytest.fixture + def temp_project(self, tmp_path): + """Create temporary project structure.""" + project_root = tmp_path / "test-project" + project_root.mkdir() + + # Create components.json + components_json = project_root / "components.json" + components_json.write_text( + json.dumps({ + "style": "new-york", + "aliases": { + "components": "@/components", + "utils": "@/lib/utils" + } + }) + ) + + # Create components directory + ui_dir = project_root / "components" / "ui" + ui_dir.mkdir(parents=True) + + return project_root + + def test_init_default_project_root(self): + """Test initialization with default project root.""" + installer = ShadcnInstaller() + assert installer.project_root == Path.cwd() + assert installer.dry_run is False + + def test_init_custom_project_root(self, tmp_path): + """Test initialization with custom project root.""" + installer = ShadcnInstaller(project_root=tmp_path) + assert installer.project_root == tmp_path + + def test_init_dry_run(self): + """Test initialization with dry run mode.""" + installer = ShadcnInstaller(dry_run=True) + assert installer.dry_run is True + + def test_check_shadcn_config_exists(self, temp_project): + """Test checking for existing shadcn config.""" + installer = ShadcnInstaller(project_root=temp_project) + assert installer.check_shadcn_config() is True + + def test_check_shadcn_config_not_exists(self, tmp_path): + """Test checking for non-existent shadcn config.""" + installer = ShadcnInstaller(project_root=tmp_path) + assert installer.check_shadcn_config() is False + + def test_get_installed_components_empty(self, temp_project): + """Test getting installed components when none exist.""" + installer = ShadcnInstaller(project_root=temp_project) + installed = installer.get_installed_components() + assert installed == [] + + def test_get_installed_components_with_files(self, temp_project): + """Test getting installed components when files exist.""" + ui_dir = temp_project / "components" / "ui" + + # Create component files + (ui_dir / "button.tsx").write_text("export const Button = () => {}") + (ui_dir / "card.tsx").write_text("export const Card = () => {}") + + installer = ShadcnInstaller(project_root=temp_project) + installed = installer.get_installed_components() + + assert sorted(installed) == ["button", "card"] + + def test_get_installed_components_no_config(self, tmp_path): + """Test getting installed components without config.""" + installer = ShadcnInstaller(project_root=tmp_path) + installed = installer.get_installed_components() + assert installed == [] + + def test_add_components_no_components(self, temp_project): + """Test adding components with empty list.""" + installer = ShadcnInstaller(project_root=temp_project) + success, message = installer.add_components([]) + + assert success is False + assert "No components specified" in message + + def test_add_components_no_config(self, tmp_path): + """Test adding components without shadcn config.""" + installer = ShadcnInstaller(project_root=tmp_path) + success, message = installer.add_components(["button"]) + + assert success is False + assert "not initialized" in message + + def test_add_components_already_installed(self, temp_project): + """Test adding components that are already installed.""" + ui_dir = temp_project / "components" / "ui" + (ui_dir / "button.tsx").write_text("export const Button = () => {}") + + installer = ShadcnInstaller(project_root=temp_project) + success, message = installer.add_components(["button"]) + + assert success is False + assert "already installed" in message + assert "button" in message + + def test_add_components_with_overwrite(self, temp_project): + """Test adding components with overwrite flag.""" + ui_dir = temp_project / "components" / "ui" + (ui_dir / "button.tsx").write_text("export const Button = () => {}") + + installer = ShadcnInstaller(project_root=temp_project) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock( + stdout="Component added successfully", + returncode=0 + ) + + success, message = installer.add_components(["button"], overwrite=True) + + assert success is True + assert "Successfully added" in message + mock_run.assert_called_once() + + # Verify --overwrite flag was passed + call_args = mock_run.call_args[0][0] + assert "--overwrite" in call_args + + def test_add_components_dry_run(self, temp_project): + """Test adding components in dry run mode.""" + installer = ShadcnInstaller(project_root=temp_project, dry_run=True) + success, message = installer.add_components(["button", "card"]) + + assert success is True + assert "Would run:" in message + assert "button" in message + assert "card" in message + + @patch("subprocess.run") + def test_add_components_success(self, mock_run, temp_project): + """Test successful component addition.""" + mock_run.return_value = MagicMock( + stdout="Components added successfully", + stderr="", + returncode=0 + ) + + installer = ShadcnInstaller(project_root=temp_project) + success, message = installer.add_components(["button", "card"]) + + assert success is True + assert "Successfully added" in message + assert "button" in message + assert "card" in message + + # Verify correct command was called + mock_run.assert_called_once() + call_args = mock_run.call_args[0][0] + assert call_args[:3] == ["npx", "shadcn@latest", "add"] + assert "button" in call_args + assert "card" in call_args + + @patch("subprocess.run") + def test_add_components_subprocess_error(self, mock_run, temp_project): + """Test component addition with subprocess error.""" + mock_run.side_effect = subprocess.CalledProcessError( + 1, "cmd", stderr="Error occurred" + ) + + installer = ShadcnInstaller(project_root=temp_project) + success, message = installer.add_components(["button"]) + + assert success is False + assert "Failed to add" in message + + @patch("subprocess.run") + def test_add_components_npx_not_found(self, mock_run, temp_project): + """Test component addition when npx is not found.""" + mock_run.side_effect = FileNotFoundError() + + installer = ShadcnInstaller(project_root=temp_project) + success, message = installer.add_components(["button"]) + + assert success is False + assert "npx not found" in message + + def test_add_all_components_no_config(self, tmp_path): + """Test adding all components without config.""" + installer = ShadcnInstaller(project_root=tmp_path) + success, message = installer.add_all_components() + + assert success is False + assert "not initialized" in message + + def test_add_all_components_dry_run(self, temp_project): + """Test adding all components in dry run mode.""" + installer = ShadcnInstaller(project_root=temp_project, dry_run=True) + success, message = installer.add_all_components() + + assert success is True + assert "Would run:" in message + assert "--all" in message + + @patch("subprocess.run") + def test_add_all_components_success(self, mock_run, temp_project): + """Test successful addition of all components.""" + mock_run.return_value = MagicMock( + stdout="All components added", + returncode=0 + ) + + installer = ShadcnInstaller(project_root=temp_project) + success, message = installer.add_all_components() + + assert success is True + assert "Successfully added all" in message + + # Verify --all flag was passed + call_args = mock_run.call_args[0][0] + assert "--all" in call_args + + def test_list_installed_no_config(self, tmp_path): + """Test listing installed components without config.""" + installer = ShadcnInstaller(project_root=tmp_path) + success, message = installer.list_installed() + + assert success is False + assert "not initialized" in message + + def test_list_installed_empty(self, temp_project): + """Test listing installed components when none exist.""" + installer = ShadcnInstaller(project_root=temp_project) + success, message = installer.list_installed() + + assert success is True + assert "No components installed" in message + + def test_list_installed_with_components(self, temp_project): + """Test listing installed components when they exist.""" + ui_dir = temp_project / "components" / "ui" + (ui_dir / "button.tsx").write_text("export const Button = () => {}") + (ui_dir / "card.tsx").write_text("export const Card = () => {}") + + installer = ShadcnInstaller(project_root=temp_project) + success, message = installer.list_installed() + + assert success is True + assert "button" in message + assert "card" in message diff --git a/skills/website-creator/ui-styling/scripts/tests/test_tailwind_config_gen.py b/skills/website-creator/ui-styling/scripts/tests/test_tailwind_config_gen.py new file mode 100644 index 0000000..a08414e --- /dev/null +++ b/skills/website-creator/ui-styling/scripts/tests/test_tailwind_config_gen.py @@ -0,0 +1,336 @@ +"""Tests for tailwind_config_gen.py""" + +from pathlib import Path + +import pytest + +# Add parent directory to path for imports +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from tailwind_config_gen import TailwindConfigGenerator + + +class TestTailwindConfigGenerator: + """Test TailwindConfigGenerator class.""" + + def test_init_default_typescript(self): + """Test initialization with default settings.""" + generator = TailwindConfigGenerator() + assert generator.typescript is True + assert generator.framework == "react" + + def test_init_javascript(self): + """Test initialization for JavaScript config.""" + generator = TailwindConfigGenerator(typescript=False) + assert generator.typescript is False + + def test_init_framework(self): + """Test initialization with different frameworks.""" + for framework in ["react", "vue", "svelte", "nextjs"]: + generator = TailwindConfigGenerator(framework=framework) + assert generator.framework == framework + + def test_default_output_path_typescript(self): + """Test default output path for TypeScript.""" + generator = TailwindConfigGenerator(typescript=True) + assert generator.output_path.name == "tailwind.config.ts" + + def test_default_output_path_javascript(self): + """Test default output path for JavaScript.""" + generator = TailwindConfigGenerator(typescript=False) + assert generator.output_path.name == "tailwind.config.js" + + def test_custom_output_path(self, tmp_path): + """Test custom output path.""" + custom_path = tmp_path / "custom-config.ts" + generator = TailwindConfigGenerator(output_path=custom_path) + assert generator.output_path == custom_path + + def test_base_config_structure(self): + """Test base configuration structure.""" + generator = TailwindConfigGenerator() + config = generator.config + + assert "darkMode" in config + assert "content" in config + assert "theme" in config + assert "plugins" in config + assert "extend" in config["theme"] + + def test_default_content_paths_react(self): + """Test default content paths for React.""" + generator = TailwindConfigGenerator(framework="react") + paths = generator.config["content"] + + assert any("src/**/*.{js,jsx,ts,tsx}" in p for p in paths) + assert any("index.html" in p for p in paths) + + def test_default_content_paths_nextjs(self): + """Test default content paths for Next.js.""" + generator = TailwindConfigGenerator(framework="nextjs") + paths = generator.config["content"] + + assert any("app/**" in p for p in paths) + assert any("pages/**" in p for p in paths) + assert any("components/**" in p for p in paths) + + def test_default_content_paths_vue(self): + """Test default content paths for Vue.""" + generator = TailwindConfigGenerator(framework="vue") + paths = generator.config["content"] + + assert any("vue" in p for p in paths) + + def test_add_colors(self): + """Test adding custom colors.""" + generator = TailwindConfigGenerator() + colors = { + "brand": "#3b82f6", + "accent": "#8b5cf6" + } + generator.add_colors(colors) + + assert "colors" in generator.config["theme"]["extend"] + assert generator.config["theme"]["extend"]["colors"]["brand"] == "#3b82f6" + assert generator.config["theme"]["extend"]["colors"]["accent"] == "#8b5cf6" + + def test_add_colors_multiple_times(self): + """Test adding colors multiple times.""" + generator = TailwindConfigGenerator() + + generator.add_colors({"brand": "#3b82f6"}) + generator.add_colors({"accent": "#8b5cf6"}) + + colors = generator.config["theme"]["extend"]["colors"] + assert "brand" in colors + assert "accent" in colors + + def test_add_color_palette(self): + """Test adding full color palette.""" + generator = TailwindConfigGenerator() + generator.add_color_palette("brand", "#3b82f6") + + brand = generator.config["theme"]["extend"]["colors"]["brand"] + + assert isinstance(brand, dict) + assert "50" in brand + assert "500" in brand + assert "950" in brand + assert "var(--color-brand" in brand["500"] + + def test_add_fonts(self): + """Test adding custom fonts.""" + generator = TailwindConfigGenerator() + fonts = { + "sans": ["Inter", "system-ui", "sans-serif"], + "display": ["Playfair Display", "serif"] + } + generator.add_fonts(fonts) + + font_family = generator.config["theme"]["extend"]["fontFamily"] + assert font_family["sans"] == ["Inter", "system-ui", "sans-serif"] + assert font_family["display"] == ["Playfair Display", "serif"] + + def test_add_spacing(self): + """Test adding custom spacing.""" + generator = TailwindConfigGenerator() + spacing = { + "18": "4.5rem", + "navbar": "4rem" + } + generator.add_spacing(spacing) + + spacing_config = generator.config["theme"]["extend"]["spacing"] + assert spacing_config["18"] == "4.5rem" + assert spacing_config["navbar"] == "4rem" + + def test_add_breakpoints(self): + """Test adding custom breakpoints.""" + generator = TailwindConfigGenerator() + breakpoints = { + "3xl": "1920px", + "tablet": "768px" + } + generator.add_breakpoints(breakpoints) + + screens = generator.config["theme"]["extend"]["screens"] + assert screens["3xl"] == "1920px" + assert screens["tablet"] == "768px" + + def test_add_plugins(self): + """Test adding plugins.""" + generator = TailwindConfigGenerator() + plugins = ["@tailwindcss/typography", "@tailwindcss/forms"] + generator.add_plugins(plugins) + + assert "@tailwindcss/typography" in generator.config["plugins"] + assert "@tailwindcss/forms" in generator.config["plugins"] + + def test_add_plugins_no_duplicates(self): + """Test that adding same plugin twice doesn't duplicate.""" + generator = TailwindConfigGenerator() + generator.add_plugins(["@tailwindcss/typography"]) + generator.add_plugins(["@tailwindcss/typography"]) + + count = generator.config["plugins"].count("@tailwindcss/typography") + assert count == 1 + + def test_recommend_plugins(self): + """Test plugin recommendations.""" + generator = TailwindConfigGenerator() + recommendations = generator.recommend_plugins() + + assert isinstance(recommendations, list) + assert "tailwindcss-animate" in recommendations + + def test_recommend_plugins_nextjs(self): + """Test plugin recommendations for Next.js.""" + generator = TailwindConfigGenerator(framework="nextjs") + recommendations = generator.recommend_plugins() + + assert "@tailwindcss/typography" in recommendations + + def test_generate_typescript_config(self): + """Test generating TypeScript configuration.""" + generator = TailwindConfigGenerator(typescript=True) + config = generator.generate_config_string() + + assert "import type { Config } from 'tailwindcss'" in config + assert "const config: Config" in config + assert "export default config" in config + + def test_generate_javascript_config(self): + """Test generating JavaScript configuration.""" + generator = TailwindConfigGenerator(typescript=False) + config = generator.generate_config_string() + + assert "module.exports" in config + assert "@type" in config + + def test_generate_config_with_colors(self): + """Test generating config with custom colors.""" + generator = TailwindConfigGenerator() + generator.add_colors({"brand": "#3b82f6"}) + config = generator.generate_config_string() + + assert "colors" in config + assert "brand" in config + + def test_generate_config_with_plugins(self): + """Test generating config with plugins.""" + generator = TailwindConfigGenerator() + generator.add_plugins(["tailwindcss-animate"]) + config = generator.generate_config_string() + + assert "plugins:" in config + assert "require('tailwindcss-animate')" in config + + def test_validate_config_valid(self): + """Test validating valid configuration.""" + generator = TailwindConfigGenerator() + valid, message = generator.validate_config() + + assert valid is True + + def test_validate_config_no_content(self): + """Test validating config with no content paths.""" + generator = TailwindConfigGenerator() + generator.config["content"] = [] + + valid, message = generator.validate_config() + + assert valid is False + assert "No content paths" in message + + def test_validate_config_empty_theme(self): + """Test validating config with empty theme extensions.""" + generator = TailwindConfigGenerator() + # Default has empty theme.extend + + valid, message = generator.validate_config() + + assert valid is True + assert "Warning" in message + + def test_write_config(self, tmp_path): + """Test writing configuration to file.""" + output_path = tmp_path / "tailwind.config.ts" + generator = TailwindConfigGenerator(output_path=output_path) + + success, message = generator.write_config() + + assert success is True + assert output_path.exists() + assert "written to" in message + + def test_write_config_creates_content(self, tmp_path): + """Test that written config contains expected content.""" + output_path = tmp_path / "tailwind.config.ts" + generator = TailwindConfigGenerator(output_path=output_path) + generator.add_colors({"brand": "#3b82f6"}) + + generator.write_config() + + content = output_path.read_text() + assert "import type { Config }" in content + assert "brand" in content + + def test_write_config_invalid_path(self): + """Test writing config to invalid path.""" + generator = TailwindConfigGenerator(output_path=Path("/invalid/path/config.ts")) + + success, message = generator.write_config() + + assert success is False + assert "Failed to write" in message + + def test_full_configuration_typescript(self, tmp_path): + """Test generating complete TypeScript configuration.""" + output_path = tmp_path / "tailwind.config.ts" + generator = TailwindConfigGenerator( + typescript=True, + framework="nextjs", + output_path=output_path + ) + + # Add various customizations + generator.add_colors({"brand": "#3b82f6", "accent": "#8b5cf6"}) + generator.add_fonts({"sans": ["Inter", "sans-serif"]}) + generator.add_spacing({"navbar": "4rem"}) + generator.add_breakpoints({"3xl": "1920px"}) + generator.add_plugins(["tailwindcss-animate"]) + + success, _ = generator.write_config() + assert success is True + + content = output_path.read_text() + + # Verify all customizations are present + assert "brand" in content + assert "accent" in content + assert "Inter" in content + assert "navbar" in content + assert "3xl" in content + assert "tailwindcss-animate" in content + + def test_full_configuration_javascript(self, tmp_path): + """Test generating complete JavaScript configuration.""" + output_path = tmp_path / "tailwind.config.js" + generator = TailwindConfigGenerator( + typescript=False, + framework="react", + output_path=output_path + ) + + generator.add_colors({"primary": "#3b82f6"}) + generator.add_plugins(["@tailwindcss/forms"]) + + success, _ = generator.write_config() + assert success is True + + content = output_path.read_text() + + assert "module.exports" in content + assert "primary" in content + assert "@tailwindcss/forms" in content diff --git a/skills/website-creator/ui-ux-pro-max/SKILL.md b/skills/website-creator/ui-ux-pro-max/SKILL.md new file mode 100644 index 0000000..6b2e0ea --- /dev/null +++ b/skills/website-creator/ui-ux-pro-max/SKILL.md @@ -0,0 +1,659 @@ +--- +name: ui-ux-pro-max +description: "UI/UX design intelligence for web and mobile. Includes 50+ styles, 161 color palettes, 57 font pairings, 161 product types, 99 UX guidelines, and 25 chart types across 10 stacks (React, Next.js, Vue, Svelte, SwiftUI, React Native, Flutter, Tailwind, shadcn/ui, and HTML/CSS). Actions: plan, build, create, design, implement, review, fix, improve, optimize, enhance, refactor, and check UI/UX code. Projects: website, landing page, dashboard, admin panel, e-commerce, SaaS, portfolio, blog, and mobile app. Elements: button, modal, navbar, sidebar, card, table, form, and chart. Styles: glassmorphism, claymorphism, minimalism, brutalism, neumorphism, bento grid, dark mode, responsive, skeuomorphism, and flat design. Topics: color systems, accessibility, animation, layout, typography, font pairing, spacing, interaction states, shadow, and gradient. Integrations: shadcn/ui MCP for component search and examples." +--- + +# UI/UX Pro Max - Design Intelligence + +Comprehensive design guide for web and mobile applications. Contains 50+ styles, 161 color palettes, 57 font pairings, 161 product types with reasoning rules, 99 UX guidelines, and 25 chart types across 10 technology stacks. Searchable database with priority-based recommendations. + +## When to Apply + +This Skill should be used when the task involves **UI structure, visual design decisions, interaction patterns, or user experience quality control**. + +### Must Use + +This Skill must be invoked in the following situations: + +- Designing new pages (Landing Page, Dashboard, Admin, SaaS, Mobile App) +- Creating or refactoring UI components (buttons, modals, forms, tables, charts, etc.) +- Choosing color schemes, typography systems, spacing standards, or layout systems +- Reviewing UI code for user experience, accessibility, or visual consistency +- Implementing navigation structures, animations, or responsive behavior +- Making product-level design decisions (style, information hierarchy, brand expression) +- Improving perceived quality, clarity, or usability of interfaces + +### Recommended + +This Skill is recommended in the following situations: + +- UI looks "not professional enough" but the reason is unclear +- Receiving feedback on usability or experience +- Pre-launch UI quality optimization +- Aligning cross-platform design (Web / iOS / Android) +- Building design systems or reusable component libraries + +### Skip + +This Skill is not needed in the following situations: + +- Pure backend logic development +- Only involving API or database design +- Performance optimization unrelated to the interface +- Infrastructure or DevOps work +- Non-visual scripts or automation tasks + +**Decision criteria**: If the task will change how a feature **looks, feels, moves, or is interacted with**, this Skill should be used. + +## Rule Categories by Priority + +*For human/AI reference: follow priority 1→10 to decide which rule category to focus on first; use `--domain ` to query details when needed. Scripts do not read this table.* + +| Priority | Category | Impact | Domain | Key Checks (Must Have) | Anti-Patterns (Avoid) | +|----------|----------|--------|--------|------------------------|------------------------| +| 1 | Accessibility | CRITICAL | `ux` | Contrast 4.5:1, Alt text, Keyboard nav, Aria-labels | Removing focus rings, Icon-only buttons without labels | +| 2 | Touch & Interaction | CRITICAL | `ux` | Min size 44×44px, 8px+ spacing, Loading feedback | Reliance on hover only, Instant state changes (0ms) | +| 3 | Performance | HIGH | `ux` | WebP/AVIF, Lazy loading, Reserve space (CLS < 0.1) | Layout thrashing, Cumulative Layout Shift | +| 4 | Style Selection | HIGH | `style`, `product` | Match product type, Consistency, SVG icons (no emoji) | Mixing flat & skeuomorphic randomly, Emoji as icons | +| 5 | Layout & Responsive | HIGH | `ux` | Mobile-first breakpoints, Viewport meta, No horizontal scroll | Horizontal scroll, Fixed px container widths, Disable zoom | +| 6 | Typography & Color | MEDIUM | `typography`, `color` | Base 16px, Line-height 1.5, Semantic color tokens | Text < 12px body, Gray-on-gray, Raw hex in components | +| 7 | Animation | MEDIUM | `ux` | Duration 150–300ms, Motion conveys meaning, Spatial continuity | Decorative-only animation, Animating width/height, No reduced-motion | +| 8 | Forms & Feedback | MEDIUM | `ux` | Visible labels, Error near field, Helper text, Progressive disclosure | Placeholder-only label, Errors only at top, Overwhelm upfront | +| 9 | Navigation Patterns | HIGH | `ux` | Predictable back, Bottom nav ≤5, Deep linking | Overloaded nav, Broken back behavior, No deep links | +| 10 | Charts & Data | LOW | `chart` | Legends, Tooltips, Accessible colors | Relying on color alone to convey meaning | + +## Quick Reference + +### 1. Accessibility (CRITICAL) + +- `color-contrast` - Minimum 4.5:1 ratio for normal text (large text 3:1); Material Design +- `focus-states` - Visible focus rings on interactive elements (2–4px; Apple HIG, MD) +- `alt-text` - Descriptive alt text for meaningful images +- `aria-labels` - aria-label for icon-only buttons; accessibilityLabel in native (Apple HIG) +- `keyboard-nav` - Tab order matches visual order; full keyboard support (Apple HIG) +- `form-labels` - Use label with for attribute +- `skip-links` - Skip to main content for keyboard users +- `heading-hierarchy` - Sequential h1→h6, no level skip +- `color-not-only` - Don't convey info by color alone (add icon/text) +- `dynamic-type` - Support system text scaling; avoid truncation as text grows (Apple Dynamic Type, MD) +- `reduced-motion` - Respect prefers-reduced-motion; reduce/disable animations when requested (Apple Reduced Motion API, MD) +- `voiceover-sr` - Meaningful accessibilityLabel/accessibilityHint; logical reading order for VoiceOver/screen readers (Apple HIG, MD) +- `escape-routes` - Provide cancel/back in modals and multi-step flows (Apple HIG) +- `keyboard-shortcuts` - Preserve system and a11y shortcuts; offer keyboard alternatives for drag-and-drop (Apple HIG) + +### 2. Touch & Interaction (CRITICAL) + +- `touch-target-size` - Min 44×44pt (Apple) / 48×48dp (Material); extend hit area beyond visual bounds if needed +- `touch-spacing` - Minimum 8px/8dp gap between touch targets (Apple HIG, MD) +- `hover-vs-tap` - Use click/tap for primary interactions; don't rely on hover alone +- `loading-buttons` - Disable button during async operations; show spinner or progress +- `error-feedback` - Clear error messages near problem +- `cursor-pointer` - Add cursor-pointer to clickable elements (Web) +- `gesture-conflicts` - Avoid horizontal swipe on main content; prefer vertical scroll +- `tap-delay` - Use touch-action: manipulation to reduce 300ms delay (Web) +- `standard-gestures` - Use platform standard gestures consistently; don't redefine (e.g. swipe-back, pinch-zoom) (Apple HIG) +- `system-gestures` - Don't block system gestures (Control Center, back swipe, etc.) (Apple HIG) +- `press-feedback` - Visual feedback on press (ripple/highlight; MD state layers) +- `haptic-feedback` - Use haptic for confirmations and important actions; avoid overuse (Apple HIG) +- `gesture-alternative` - Don't rely on gesture-only interactions; always provide visible controls for critical actions +- `safe-area-awareness` - Keep primary touch targets away from notch, Dynamic Island, gesture bar and screen edges +- `no-precision-required` - Avoid requiring pixel-perfect taps on small icons or thin edges +- `swipe-clarity` - Swipe actions must show clear affordance or hint (chevron, label, tutorial) +- `drag-threshold` - Use a movement threshold before starting drag to avoid accidental drags + +### 3. Performance (HIGH) + +- `image-optimization` - Use WebP/AVIF, responsive images (srcset/sizes), lazy load non-critical assets +- `image-dimension` - Declare width/height or use aspect-ratio to prevent layout shift (Core Web Vitals: CLS) +- `font-loading` - Use font-display: swap/optional to avoid invisible text (FOIT); reserve space to reduce layout shift (MD) +- `font-preload` - Preload only critical fonts; avoid overusing preload on every variant +- `critical-css` - Prioritize above-the-fold CSS (inline critical CSS or early-loaded stylesheet) +- `lazy-loading` - Lazy load non-hero components via dynamic import / route-level splitting +- `bundle-splitting` - Split code by route/feature (React Suspense / Next.js dynamic) to reduce initial load and TTI +- `third-party-scripts` - Load third-party scripts async/defer; audit and remove unnecessary ones (MD) +- `reduce-reflows` - Avoid frequent layout reads/writes; batch DOM reads then writes +- `content-jumping` - Reserve space for async content to avoid layout jumps (Core Web Vitals: CLS) +- `lazy-load-below-fold` - Use loading="lazy" for below-the-fold images and heavy media +- `virtualize-lists` - Virtualize lists with 50+ items to improve memory efficiency and scroll performance +- `main-thread-budget` - Keep per-frame work under ~16ms for 60fps; move heavy tasks off main thread (HIG, MD) +- `progressive-loading` - Use skeleton screens / shimmer instead of long blocking spinners for >1s operations (Apple HIG) +- `input-latency` - Keep input latency under ~100ms for taps/scrolls (Material responsiveness standard) +- `tap-feedback-speed` - Provide visual feedback within 100ms of tap (Apple HIG) +- `debounce-throttle` - Use debounce/throttle for high-frequency events (scroll, resize, input) +- `offline-support` - Provide offline state messaging and basic fallback (PWA / mobile) +- `network-fallback` - Offer degraded modes for slow networks (lower-res images, fewer animations) + +### 4. Style Selection (HIGH) + +- `style-match` - Match style to product type (use `--design-system` for recommendations) +- `consistency` - Use same style across all pages +- `no-emoji-icons` - Use SVG icons (Heroicons, Lucide), not emojis +- `color-palette-from-product` - Choose palette from product/industry (search `--domain color`) +- `effects-match-style` - Shadows, blur, radius aligned with chosen style (glass / flat / clay etc.) +- `platform-adaptive` - Respect platform idioms (iOS HIG vs Material): navigation, controls, typography, motion +- `state-clarity` - Make hover/pressed/disabled states visually distinct while staying on-style (Material state layers) +- `elevation-consistent` - Use a consistent elevation/shadow scale for cards, sheets, modals; avoid random shadow values +- `dark-mode-pairing` - Design light/dark variants together to keep brand, contrast, and style consistent +- `icon-style-consistent` - Use one icon set/visual language (stroke width, corner radius) across the product +- `system-controls` - Prefer native/system controls over fully custom ones; only customize when branding requires it (Apple HIG) +- `blur-purpose` - Use blur to indicate background dismissal (modals, sheets), not as decoration (Apple HIG) +- `primary-action` - Each screen should have only one primary CTA; secondary actions visually subordinate (Apple HIG) + +### 5. Layout & Responsive (HIGH) + +- `viewport-meta` - width=device-width initial-scale=1 (never disable zoom) +- `mobile-first` - Design mobile-first, then scale up to tablet and desktop +- `breakpoint-consistency` - Use systematic breakpoints (e.g. 375 / 768 / 1024 / 1440) +- `readable-font-size` - Minimum 16px body text on mobile (avoids iOS auto-zoom) +- `line-length-control` - Mobile 35–60 chars per line; desktop 60–75 chars +- `horizontal-scroll` - No horizontal scroll on mobile; ensure content fits viewport width +- `spacing-scale` - Use 4pt/8dp incremental spacing system (Material Design) +- `touch-density` - Keep component spacing comfortable for touch: not cramped, not causing mis-taps +- `container-width` - Consistent max-width on desktop (max-w-6xl / 7xl) +- `z-index-management` - Define layered z-index scale (e.g. 0 / 10 / 20 / 40 / 100 / 1000) +- `fixed-element-offset` - Fixed navbar/bottom bar must reserve safe padding for underlying content +- `scroll-behavior` - Avoid nested scroll regions that interfere with the main scroll experience +- `viewport-units` - Prefer min-h-dvh over 100vh on mobile +- `orientation-support` - Keep layout readable and operable in landscape mode +- `content-priority` - Show core content first on mobile; fold or hide secondary content +- `visual-hierarchy` - Establish hierarchy via size, spacing, contrast — not color alone + +### 6. Typography & Color (MEDIUM) + +- `line-height` - Use 1.5-1.75 for body text +- `line-length` - Limit to 65-75 characters per line +- `font-pairing` - Match heading/body font personalities +- `font-scale` - Consistent type scale (e.g. 12 14 16 18 24 32) +- `contrast-readability` - Darker text on light backgrounds (e.g. slate-900 on white) +- `text-styles-system` - Use platform type system: iOS 11 Dynamic Type styles / Material 5 type roles (display, headline, title, body, label) (HIG, MD) +- `weight-hierarchy` - Use font-weight to reinforce hierarchy: Bold headings (600–700), Regular body (400), Medium labels (500) (MD) +- `color-semantic` - Define semantic color tokens (primary, secondary, error, surface, on-surface) not raw hex in components (Material color system) +- `color-dark-mode` - Dark mode uses desaturated / lighter tonal variants, not inverted colors; test contrast separately (HIG, MD) +- `color-accessible-pairs` - Foreground/background pairs must meet 4.5:1 (AA) or 7:1 (AAA); use tools to verify (WCAG, MD) +- `color-not-decorative-only` - Functional color (error red, success green) must include icon/text; avoid color-only meaning (HIG, MD) +- `truncation-strategy` - Prefer wrapping over truncation; when truncating use ellipsis and provide full text via tooltip/expand (Apple HIG) +- `letter-spacing` - Respect default letter-spacing per platform; avoid tight tracking on body text (HIG, MD) +- `number-tabular` - Use tabular/monospaced figures for data columns, prices, and timers to prevent layout shift +- `whitespace-balance` - Use whitespace intentionally to group related items and separate sections; avoid visual clutter (Apple HIG) + +### 7. Animation (MEDIUM) + +- `duration-timing` - Use 150–300ms for micro-interactions; complex transitions ≤400ms; avoid >500ms (MD) +- `transform-performance` - Use transform/opacity only; avoid animating width/height/top/left +- `loading-states` - Show skeleton or progress indicator when loading exceeds 300ms +- `excessive-motion` - Animate 1-2 key elements per view max +- `easing` - Use ease-out for entering, ease-in for exiting; avoid linear for UI transitions +- `motion-meaning` - Every animation must express a cause-effect relationship, not just be decorative (Apple HIG) +- `state-transition` - State changes (hover / active / expanded / collapsed / modal) should animate smoothly, not snap +- `continuity` - Page/screen transitions should maintain spatial continuity (shared element, directional slide) (Apple HIG) +- `parallax-subtle` - Use parallax sparingly; must respect reduced-motion and not cause disorientation (Apple HIG) +- `spring-physics` - Prefer spring/physics-based curves over linear or cubic-bezier for natural feel (Apple HIG fluid animations) +- `exit-faster-than-enter` - Exit animations shorter than enter (~60–70% of enter duration) to feel responsive (MD motion) +- `stagger-sequence` - Stagger list/grid item entrance by 30–50ms per item; avoid all-at-once or too-slow reveals (MD) +- `shared-element-transition` - Use shared element / hero transitions for visual continuity between screens (MD, HIG) +- `interruptible` - Animations must be interruptible; user tap/gesture cancels in-progress animation immediately (Apple HIG) +- `no-blocking-animation` - Never block user input during an animation; UI must stay interactive (Apple HIG) +- `fade-crossfade` - Use crossfade for content replacement within the same container (MD) +- `scale-feedback` - Subtle scale (0.95–1.05) on press for tappable cards/buttons; restore on release (HIG, MD) +- `gesture-feedback` - Drag, swipe, and pinch must provide real-time visual response tracking the finger (MD Motion) +- `hierarchy-motion` - Use translate/scale direction to express hierarchy: enter from below = deeper, exit upward = back (MD) +- `motion-consistency` - Unify duration/easing tokens globally; all animations share the same rhythm and feel +- `opacity-threshold` - Fading elements should not linger below opacity 0.2; either fade fully or remain visible +- `modal-motion` - Modals/sheets should animate from their trigger source (scale+fade or slide-in) for spatial context (HIG, MD) +- `navigation-direction` - Forward navigation animates left/up; backward animates right/down — keep direction logically consistent (HIG) +- `layout-shift-avoid` - Animations must not cause layout reflow or CLS; use transform for position changes + +### 8. Forms & Feedback (MEDIUM) + +- `input-labels` - Visible label per input (not placeholder-only) +- `error-placement` - Show error below the related field +- `submit-feedback` - Loading then success/error state on submit +- `required-indicators` - Mark required fields (e.g. asterisk) +- `empty-states` - Helpful message and action when no content +- `toast-dismiss` - Auto-dismiss toasts in 3-5s +- `confirmation-dialogs` - Confirm before destructive actions +- `input-helper-text` - Provide persistent helper text below complex inputs, not just placeholder (Material Design) +- `disabled-states` - Disabled elements use reduced opacity (0.38–0.5) + cursor change + semantic attribute (MD) +- `progressive-disclosure` - Reveal complex options progressively; don't overwhelm users upfront (Apple HIG) +- `inline-validation` - Validate on blur (not keystroke); show error only after user finishes input (MD) +- `input-type-keyboard` - Use semantic input types (email, tel, number) to trigger the correct mobile keyboard (HIG, MD) +- `password-toggle` - Provide show/hide toggle for password fields (MD) +- `autofill-support` - Use autocomplete / textContentType attributes so the system can autofill (HIG, MD) +- `undo-support` - Allow undo for destructive or bulk actions (e.g. "Undo delete" toast) (Apple HIG) +- `success-feedback` - Confirm completed actions with brief visual feedback (checkmark, toast, color flash) (MD) +- `error-recovery` - Error messages must include a clear recovery path (retry, edit, help link) (HIG, MD) +- `multi-step-progress` - Multi-step flows show step indicator or progress bar; allow back navigation (MD) +- `form-autosave` - Long forms should auto-save drafts to prevent data loss on accidental dismissal (Apple HIG) +- `sheet-dismiss-confirm` - Confirm before dismissing a sheet/modal with unsaved changes (Apple HIG) +- `error-clarity` - Error messages must state cause + how to fix (not just "Invalid input") (HIG, MD) +- `field-grouping` - Group related fields logically (fieldset/legend or visual grouping) (MD) +- `read-only-distinction` - Read-only state should be visually and semantically different from disabled (MD) +- `focus-management` - After submit error, auto-focus the first invalid field (WCAG, MD) +- `error-summary` - For multiple errors, show summary at top with anchor links to each field (WCAG) +- `touch-friendly-input` - Mobile input height ≥44px to meet touch target requirements (Apple HIG) +- `destructive-emphasis` - Destructive actions use semantic danger color (red) and are visually separated from primary actions (HIG, MD) +- `toast-accessibility` - Toasts must not steal focus; use aria-live="polite" for screen reader announcement (WCAG) +- `aria-live-errors` - Form errors use aria-live region or role="alert" to notify screen readers (WCAG) +- `contrast-feedback` - Error and success state colors must meet 4.5:1 contrast ratio (WCAG, MD) +- `timeout-feedback` - Request timeout must show clear feedback with retry option (MD) + +### 9. Navigation Patterns (HIGH) + +- `bottom-nav-limit` - Bottom navigation max 5 items; use labels with icons (Material Design) +- `drawer-usage` - Use drawer/sidebar for secondary navigation, not primary actions (Material Design) +- `back-behavior` - Back navigation must be predictable and consistent; preserve scroll/state (Apple HIG, MD) +- `deep-linking` - All key screens must be reachable via deep link / URL for sharing and notifications (Apple HIG, MD) +- `tab-bar-ios` - iOS: use bottom Tab Bar for top-level navigation (Apple HIG) +- `top-app-bar-android` - Android: use Top App Bar with navigation icon for primary structure (Material Design) +- `nav-label-icon` - Navigation items must have both icon and text label; icon-only nav harms discoverability (MD) +- `nav-state-active` - Current location must be visually highlighted (color, weight, indicator) in navigation (HIG, MD) +- `nav-hierarchy` - Primary nav (tabs/bottom bar) vs secondary nav (drawer/settings) must be clearly separated (MD) +- `modal-escape` - Modals and sheets must offer a clear close/dismiss affordance; swipe-down to dismiss on mobile (Apple HIG) +- `search-accessible` - Search must be easily reachable (top bar or tab); provide recent/suggested queries (MD) +- `breadcrumb-web` - Web: use breadcrumbs for 3+ level deep hierarchies to aid orientation (MD) +- `state-preservation` - Navigating back must restore previous scroll position, filter state, and input (HIG, MD) +- `gesture-nav-support` - Support system gesture navigation (iOS swipe-back, Android predictive back) without conflict (HIG, MD) +- `tab-badge` - Use badges on nav items sparingly to indicate unread/pending; clear after user visits (HIG, MD) +- `overflow-menu` - When actions exceed available space, use overflow/more menu instead of cramming (MD) +- `bottom-nav-top-level` - Bottom nav is for top-level screens only; never nest sub-navigation inside it (MD) +- `adaptive-navigation` - Large screens (≥1024px) prefer sidebar; small screens use bottom/top nav (Material Adaptive) +- `back-stack-integrity` - Never silently reset the navigation stack or unexpectedly jump to home (HIG, MD) +- `navigation-consistency` - Navigation placement must stay the same across all pages; don't change by page type +- `avoid-mixed-patterns` - Don't mix Tab + Sidebar + Bottom Nav at the same hierarchy level +- `modal-vs-navigation` - Modals must not be used for primary navigation flows; they break the user's path (HIG) +- `focus-on-route-change` - After page transition, move focus to main content region for screen reader users (WCAG) +- `persistent-nav` - Core navigation must remain reachable from deep pages; don't hide it entirely in sub-flows (HIG, MD) +- `destructive-nav-separation` - Dangerous actions (delete account, logout) must be visually and spatially separated from normal nav items (HIG, MD) +- `empty-nav-state` - When a nav destination is unavailable, explain why instead of silently hiding it (MD) + +### 10. Charts & Data (LOW) + +- `chart-type` - Match chart type to data type (trend → line, comparison → bar, proportion → pie/donut) +- `color-guidance` - Use accessible color palettes; avoid red/green only pairs for colorblind users (WCAG, MD) +- `data-table` - Provide table alternative for accessibility; charts alone are not screen-reader friendly (WCAG) +- `pattern-texture` - Supplement color with patterns, textures, or shapes so data is distinguishable without color (WCAG, MD) +- `legend-visible` - Always show legend; position near the chart, not detached below a scroll fold (MD) +- `tooltip-on-interact` - Provide tooltips/data labels on hover (Web) or tap (mobile) showing exact values (HIG, MD) +- `axis-labels` - Label axes with units and readable scale; avoid truncated or rotated labels on mobile +- `responsive-chart` - Charts must reflow or simplify on small screens (e.g. horizontal bar instead of vertical, fewer ticks) +- `empty-data-state` - Show meaningful empty state when no data exists ("No data yet" + guidance), not a blank chart (MD) +- `loading-chart` - Use skeleton or shimmer placeholder while chart data loads; don't show an empty axis frame +- `animation-optional` - Chart entrance animations must respect prefers-reduced-motion; data should be readable immediately (HIG) +- `large-dataset` - For 1000+ data points, aggregate or sample; provide drill-down for detail instead of rendering all (MD) +- `number-formatting` - Use locale-aware formatting for numbers, dates, currencies on axes and labels (HIG, MD) +- `touch-target-chart` - Interactive chart elements (points, segments) must have ≥44pt tap area or expand on touch (Apple HIG) +- `no-pie-overuse` - Avoid pie/donut for >5 categories; switch to bar chart for clarity +- `contrast-data` - Data lines/bars vs background ≥3:1; data text labels ≥4.5:1 (WCAG) +- `legend-interactive` - Legends should be clickable to toggle series visibility (MD) +- `direct-labeling` - For small datasets, label values directly on the chart to reduce eye travel +- `tooltip-keyboard` - Tooltip content must be keyboard-reachable and not rely on hover alone (WCAG) +- `sortable-table` - Data tables must support sorting with aria-sort indicating current sort state (WCAG) +- `axis-readability` - Axis ticks must not be cramped; maintain readable spacing, auto-skip on small screens +- `data-density` - Limit information density per chart to avoid cognitive overload; split into multiple charts if needed +- `trend-emphasis` - Emphasize data trends over decoration; avoid heavy gradients/shadows that obscure the data +- `gridline-subtle` - Grid lines should be low-contrast (e.g. gray-200) so they don't compete with data +- `focusable-elements` - Interactive chart elements (points, bars, slices) must be keyboard-navigable (WCAG) +- `screen-reader-summary` - Provide a text summary or aria-label describing the chart's key insight for screen readers (WCAG) +- `error-state-chart` - Data load failure must show error message with retry action, not a broken/empty chart +- `export-option` - For data-heavy products, offer CSV/image export of chart data +- `drill-down-consistency` - Drill-down interactions must maintain a clear back-path and hierarchy breadcrumb +- `time-scale-clarity` - Time series charts must clearly label time granularity (day/week/month) and allow switching + +## How to Use + +Search specific domains using the CLI tool below. + +--- + +## Prerequisites + +Check if Python is installed: + +```bash +python3 --version || python --version +``` + +If Python is not installed, install it based on user's OS: + +**macOS:** +```bash +brew install python3 +``` + +**Ubuntu/Debian:** +```bash +sudo apt update && sudo apt install python3 +``` + +**Windows:** +```powershell +winget install Python.Python.3.12 +``` + +--- + +## How to Use This Skill + +Use this skill when the user requests any of the following: + +| Scenario | Trigger Examples | Start From | +|----------|-----------------|------------| +| **New project / page** | "Build a landing page", "Build a dashboard" | Step 1 → Step 2 (design system) | +| **New component** | "Create a pricing card", "Add a modal" | Step 3 (domain search: style, ux) | +| **Choose style / color / font** | "What style fits a fintech app?", "Recommend a color palette" | Step 2 (design system) | +| **Review existing UI** | "Review this page for UX issues", "Check accessibility" | Quick Reference checklist above | +| **Fix a UI bug** | "Button hover is broken", "Layout shifts on load" | Quick Reference → relevant section | +| **Improve / optimize** | "Make this faster", "Improve mobile experience" | Step 3 (domain search: ux, react) | +| **Implement dark mode** | "Add dark mode support" | Step 3 (domain: style "dark mode") | +| **Add charts / data viz** | "Add an analytics dashboard chart" | Step 3 (domain: chart) | +| **Stack best practices** | "React performance tips"、"SwiftUI navigation" | Step 4 (stack search) | + +Follow this workflow: + +### Step 1: Analyze User Requirements + +Extract key information from user request: +- **Product type**: Entertainment (social, video, music, gaming), Tool (scanner, editor, converter), Productivity (task manager, notes, calendar), or hybrid +- **Target audience**: C-end consumer users; consider age group, usage context (commute, leisure, work) +- **Style keywords**: playful, vibrant, minimal, dark mode, content-first, immersive, etc. +- **Stack**: React Native (this project's only tech stack) + +### Step 2: Generate Design System (REQUIRED) + +**Always start with `--design-system`** to get comprehensive recommendations with reasoning: + +```bash +python3 ~/.hermes/skills/website-creator/ui-ux-pro-max/scripts/search.py " " --design-system [-p "Project Name"] +``` + +This command: +1. Searches domains in parallel (product, style, color, landing, typography) +2. Applies reasoning rules from `ui-reasoning.csv` to select best matches +3. Returns complete design system: pattern, style, colors, typography, effects +4. Includes anti-patterns to avoid + +**Example:** +```bash +python3 ~/.hermes/skills/website-creator/ui-ux-pro-max/scripts/search.py "beauty spa wellness service" --design-system -p "Serenity Spa" +``` + +### Step 2b: Persist Design System (Master + Overrides Pattern) + +To save the design system for **hierarchical retrieval across sessions**, add `--persist`: + +```bash +python3 ~/.hermes/skills/website-creator/ui-ux-pro-max/scripts/search.py "" --design-system --persist -p "Project Name" +``` + +This creates: +- `design-system/MASTER.md` — Global Source of Truth with all design rules +- `design-system/pages/` — Folder for page-specific overrides + +**With page-specific override:** +```bash +python3 ~/.hermes/skills/website-creator/ui-ux-pro-max/scripts/search.py "" --design-system --persist -p "Project Name" --page "dashboard" +``` + +This also creates: +- `design-system/pages/dashboard.md` — Page-specific deviations from Master + +**How hierarchical retrieval works:** +1. When building a specific page (e.g., "Checkout"), first check `design-system/pages/checkout.md` +2. If the page file exists, its rules **override** the Master file +3. If not, use `design-system/MASTER.md` exclusively + +**Context-aware retrieval prompt:** +``` +I am building the [Page Name] page. Please read design-system/MASTER.md. +Also check if design-system/pages/[page-name].md exists. +If the page file exists, prioritize its rules. +If not, use the Master rules exclusively. +Now, generate the code... +``` + +### Step 3: Supplement with Detailed Searches (as needed) + +After getting the design system, use domain searches to get additional details: + +```bash +python3 ~/.hermes/skills/website-creator/ui-ux-pro-max/scripts/search.py "" --domain [-n ] +``` + +**When to use detailed searches:** + +| Need | Domain | Example | +|------|--------|---------| +| Product type patterns | `product` | `--domain product "entertainment social"` | +| More style options | `style` | `--domain style "glassmorphism dark"` | +| Color palettes | `color` | `--domain color "entertainment vibrant"` | +| Font pairings | `typography` | `--domain typography "playful modern"` | +| Chart recommendations | `chart` | `--domain chart "real-time dashboard"` | +| UX best practices | `ux` | `--domain ux "animation accessibility"` | +| Alternative fonts | `typography` | `--domain typography "elegant luxury"` | +| Individual Google Fonts | `google-fonts` | `--domain google-fonts "sans serif popular variable"` | +| Landing structure | `landing` | `--domain landing "hero social-proof"` | +| React Native perf | `react` | `--domain react "rerender memo list"` | +| App interface a11y | `web` | `--domain web "accessibilityLabel touch safe-areas"` | +| AI prompt / CSS keywords | `prompt` | `--domain prompt "minimalism"` | + +### Step 4: Stack Guidelines (React Native) + +Get React Native implementation-specific best practices: + +```bash +python3 ~/.hermes/skills/website-creator/ui-ux-pro-max/scripts/search.py "" --stack react-native +``` + +--- + +## Search Reference + +### Available Domains + +| Domain | Use For | Example Keywords | +|--------|---------|------------------| +| `product` | Product type recommendations | SaaS, e-commerce, portfolio, healthcare, beauty, service | +| `style` | UI styles, colors, effects | glassmorphism, minimalism, dark mode, brutalism | +| `typography` | Font pairings, Google Fonts | elegant, playful, professional, modern | +| `color` | Color palettes by product type | saas, ecommerce, healthcare, beauty, fintech, service | +| `landing` | Page structure, CTA strategies | hero, hero-centric, testimonial, pricing, social-proof | +| `chart` | Chart types, library recommendations | trend, comparison, timeline, funnel, pie | +| `ux` | Best practices, anti-patterns | animation, accessibility, z-index, loading | +| `google-fonts` | Individual Google Fonts lookup | sans serif, monospace, japanese, variable font, popular | +| `react` | React/Next.js performance | waterfall, bundle, suspense, memo, rerender, cache | +| `web` | App interface guidelines (iOS/Android/React Native) | accessibilityLabel, touch targets, safe areas, Dynamic Type | +| `prompt` | AI prompts, CSS keywords | (style name) | + +### Available Stacks + +| Stack | Focus | +|-------|-------| +| `react-native` | Components, Navigation, Lists | + +--- + +## Example Workflow + +**User request:** "Make an AI search homepage." + +### Step 1: Analyze Requirements +- Product type: Tool (AI search engine) +- Target audience: C-end users looking for fast, intelligent search +- Style keywords: modern, minimal, content-first, dark mode +- Stack: React Native + +### Step 2: Generate Design System (REQUIRED) + +```bash +python3 ~/.hermes/skills/website-creator/ui-ux-pro-max/scripts/search.py "AI search tool modern minimal" --design-system -p "AI Search" +``` + +**Output:** Complete design system with pattern, style, colors, typography, effects, and anti-patterns. + +### Step 3: Supplement with Detailed Searches (as needed) + +```bash +# Get style options for a modern tool product +python3 ~/.hermes/skills/website-creator/ui-ux-pro-max/scripts/search.py "minimalism dark mode" --domain style + +# Get UX best practices for search interaction and loading +python3 ~/.hermes/skills/website-creator/ui-ux-pro-max/scripts/search.py "search loading animation" --domain ux +``` + +### Step 4: Stack Guidelines + +```bash +python3 ~/.hermes/skills/website-creator/ui-ux-pro-max/scripts/search.py "list performance navigation" --stack react-native +``` + +**Then:** Synthesize design system + detailed searches and implement the design. + +--- + +## Output Formats + +The `--design-system` flag supports two output formats: + +```bash +# ASCII box (default) - best for terminal display +python3 ~/.hermes/skills/website-creator/ui-ux-pro-max/scripts/search.py "fintech crypto" --design-system + +# Markdown - best for documentation +python3 ~/.hermes/skills/website-creator/ui-ux-pro-max/scripts/search.py "fintech crypto" --design-system -f markdown +``` + +--- + +## Tips for Better Results + +### Query Strategy + +- Use **multi-dimensional keywords** — combine product + industry + tone + density: `"entertainment social vibrant content-dense"` not just `"app"` +- Try different keywords for the same need: `"playful neon"` → `"vibrant dark"` → `"content-first minimal"` +- Use `--design-system` first for full recommendations, then `--domain` to deep-dive any dimension you're unsure about +- Always add `--stack react-native` for implementation-specific guidance + +### Common Sticking Points + +| Problem | What to Do | +|---------|------------| +| Can't decide on style/color | Re-run `--design-system` with different keywords | +| Dark mode contrast issues | Quick Reference §6: `color-dark-mode` + `color-accessible-pairs` | +| Animations feel unnatural | Quick Reference §7: `spring-physics` + `easing` + `exit-faster-than-enter` | +| Form UX is poor | Quick Reference §8: `inline-validation` + `error-clarity` + `focus-management` | +| Navigation feels confusing | Quick Reference §9: `nav-hierarchy` + `bottom-nav-limit` + `back-behavior` | +| Layout breaks on small screens | Quick Reference §5: `mobile-first` + `breakpoint-consistency` | +| Performance / jank | Quick Reference §3: `virtualize-lists` + `main-thread-budget` + `debounce-throttle` | + +### Pre-Delivery Checklist + +- Run `--domain ux "animation accessibility z-index loading"` as a UX validation pass before implementation +- Run through Quick Reference **§1–§3** (CRITICAL + HIGH) as a final review +- Test on 375px (small phone) and landscape orientation +- Verify behavior with **reduced-motion** enabled and **Dynamic Type** at largest size +- Check dark mode contrast independently (don't assume light mode values work) +- Confirm all touch targets ≥44pt and no content hidden behind safe areas + +--- + +## Common Rules for Professional UI + +These are frequently overlooked issues that make UI look unprofessional: +Scope notice: The rules below are for App UI (iOS/Android/React Native/Flutter), not desktop-web interaction patterns. + +### Icons & Visual Elements + +| Rule | Standard | Avoid | Why It Matters | +|------|----------|--------|----------------| +| **No Emoji as Structural Icons** | Use vector-based icons (e.g., Lucide, react-native-vector-icons, @expo/vector-icons). | Using emojis (🎨 🚀 ⚙️) for navigation, settings, or system controls. | Emojis are font-dependent, inconsistent across platforms, and cannot be controlled via design tokens. | +| **Vector-Only Assets** | Use SVG or platform vector icons that scale cleanly and support theming. | Raster PNG icons that blur or pixelate. | Ensures scalability, crisp rendering, and dark/light mode adaptability. | +| **Stable Interaction States** | Use color, opacity, or elevation transitions for press states without changing layout bounds. | Layout-shifting transforms that move surrounding content or trigger visual jitter. | Prevents unstable interactions and preserves smooth motion/perceived quality on mobile. | +| **Correct Brand Logos** | Use official brand assets and follow their usage guidelines (spacing, color, clear space). | Guessing logo paths, recoloring unofficially, or modifying proportions. | Prevents brand misuse and ensures legal/platform compliance. | +| **Consistent Icon Sizing** | Define icon sizes as design tokens (e.g., icon-sm, icon-md = 24pt, icon-lg). | Mixing arbitrary values like 20pt / 24pt / 28pt randomly. | Maintains rhythm and visual hierarchy across the interface. | +| **Stroke Consistency** | Use a consistent stroke width within the same visual layer (e.g., 1.5px or 2px). | Mixing thick and thin stroke styles arbitrarily. | Inconsistent strokes reduce perceived polish and cohesion. | +| **Filled vs Outline Discipline** | Use one icon style per hierarchy level. | Mixing filled and outline icons at the same hierarchy level. | Maintains semantic clarity and stylistic coherence. | +| **Touch Target Minimum** | Minimum 44×44pt interactive area (use hitSlop if icon is smaller). | Small icons without expanded tap area. | Meets accessibility and platform usability standards. | +| **Icon Alignment** | Align icons to text baseline and maintain consistent padding. | Misaligned icons or inconsistent spacing around them. | Prevents subtle visual imbalance that reduces perceived quality. | +| **Icon Contrast** | Follow WCAG contrast standards: 4.5:1 for small elements, 3:1 minimum for larger UI glyphs. | Low-contrast icons that blend into the background. | Ensures accessibility in both light and dark modes. | + + +### Interaction (App) + +| Rule | Do | Don't | +|------|----|----- | +| **Tap feedback** | Provide clear pressed feedback (ripple/opacity/elevation) within 80-150ms | No visual response on tap | +| **Animation timing** | Keep micro-interactions around 150-300ms with platform-native easing | Instant transitions or slow animations (>500ms) | +| **Accessibility focus** | Ensure screen reader focus order matches visual order and labels are descriptive | Unlabeled controls or confusing focus traversal | +| **Disabled state clarity** | Use disabled semantics (`disabled`/native disabled props), reduced emphasis, and no tap action | Controls that look tappable but do nothing | +| **Touch target minimum** | Keep tap areas >=44x44pt (iOS) or >=48x48dp (Android), expand hit area when icon is smaller | Tiny tap targets or icon-only hit areas without padding | +| **Gesture conflict prevention** | Keep one primary gesture per region and avoid nested tap/drag conflicts | Overlapping gestures causing accidental actions | +| **Semantic native controls** | Prefer native interactive primitives (`Button`, `Pressable`, platform equivalents) with proper accessibility roles | Generic containers used as primary controls without semantics | + +### Light/Dark Mode Contrast + +| Rule | Do | Don't | +|------|----|----- | +| **Surface readability (light)** | Keep cards/surfaces clearly separated from background with sufficient opacity/elevation | Overly transparent surfaces that blur hierarchy | +| **Text contrast (light)** | Maintain body text contrast >=4.5:1 against light surfaces | Low-contrast gray body text | +| **Text contrast (dark)** | Maintain primary text contrast >=4.5:1 and secondary text >=3:1 on dark surfaces | Dark mode text that blends into background | +| **Border and divider visibility** | Ensure separators are visible in both themes (not just light mode) | Theme-specific borders disappearing in one mode | +| **State contrast parity** | Keep pressed/focused/disabled states equally distinguishable in light and dark themes | Defining interaction states for one theme only | +| **Token-driven theming** | Use semantic color tokens mapped per theme across app surfaces/text/icons | Hardcoded per-screen hex values | +| **Scrim and modal legibility** | Use a modal scrim strong enough to isolate foreground content (typically 40-60% black) | Weak scrim that leaves background visually competing | + +### Layout & Spacing + +| Rule | Do | Don't | +|------|----|----- | +| **Safe-area compliance** | Respect top/bottom safe areas for all fixed headers, tab bars, and CTA bars | Placing fixed UI under notch, status bar, or gesture area | +| **System bar clearance** | Add spacing for status/navigation bars and gesture home indicator | Let tappable content collide with OS chrome | +| **Consistent content width** | Keep predictable content width per device class (phone/tablet) | Mixing arbitrary widths between screens | +| **8dp spacing rhythm** | Use a consistent 4/8dp spacing system for padding/gaps/section spacing | Random spacing increments with no rhythm | +| **Readable text measure** | Keep long-form text readable on large devices (avoid edge-to-edge paragraphs on tablets) | Full-width long text that hurts readability | +| **Section spacing hierarchy** | Define clear vertical rhythm tiers (e.g., 16/24/32/48) by hierarchy | Similar UI levels with inconsistent spacing | +| **Adaptive gutters by breakpoint** | Increase horizontal insets on larger widths and in landscape | Same narrow gutter on all device sizes/orientations | +| **Scroll and fixed element coexistence** | Add bottom/top content insets so lists are not hidden behind fixed bars | Scroll content obscured by sticky headers/footers | + +--- + +## Pre-Delivery Checklist + +Before delivering UI code, verify these items: +Scope notice: This checklist is for App UI (iOS/Android/React Native/Flutter). + +### Visual Quality +- [ ] No emojis used as icons (use SVG instead) +- [ ] All icons come from a consistent icon family and style +- [ ] Official brand assets are used with correct proportions and clear space +- [ ] Pressed-state visuals do not shift layout bounds or cause jitter +- [ ] Semantic theme tokens are used consistently (no ad-hoc per-screen hardcoded colors) + +### Interaction +- [ ] All tappable elements provide clear pressed feedback (ripple/opacity/elevation) +- [ ] Touch targets meet minimum size (>=44x44pt iOS, >=48x48dp Android) +- [ ] Micro-interaction timing stays in the 150-300ms range with native-feeling easing +- [ ] Disabled states are visually clear and non-interactive +- [ ] Screen reader focus order matches visual order, and interactive labels are descriptive +- [ ] Gesture regions avoid nested/conflicting interactions (tap/drag/back-swipe conflicts) + +### Light/Dark Mode +- [ ] Primary text contrast >=4.5:1 in both light and dark mode +- [ ] Secondary text contrast >=3:1 in both light and dark mode +- [ ] Dividers/borders and interaction states are distinguishable in both modes +- [ ] Modal/drawer scrim opacity is strong enough to preserve foreground legibility (typically 40-60% black) +- [ ] Both themes are tested before delivery (not inferred from a single theme) + +### Layout +- [ ] Safe areas are respected for headers, tab bars, and bottom CTA bars +- [ ] Scroll content is not hidden behind fixed/sticky bars +- [ ] Verified on small phone, large phone, and tablet (portrait + landscape) +- [ ] Horizontal insets/gutters adapt correctly by device size and orientation +- [ ] 4/8dp spacing rhythm is maintained across component, section, and page levels +- [ ] Long-form text measure remains readable on larger devices (no edge-to-edge paragraphs) + +### Accessibility +- [ ] All meaningful images/icons have accessibility labels +- [ ] Form fields have labels, hints, and clear error messages +- [ ] Color is not the only indicator +- [ ] Reduced motion and dynamic text size are supported without layout breakage +- [ ] Accessibility traits/roles/states (selected, disabled, expanded) are announced correctly \ No newline at end of file diff --git a/skills/website-creator/ui-ux-pro-max/data b/skills/website-creator/ui-ux-pro-max/data new file mode 120000 index 0000000..e5b9469 --- /dev/null +++ b/skills/website-creator/ui-ux-pro-max/data @@ -0,0 +1 @@ +../../../src/ui-ux-pro-max/data \ No newline at end of file diff --git a/skills/website-creator/ui-ux-pro-max/scripts b/skills/website-creator/ui-ux-pro-max/scripts new file mode 120000 index 0000000..ccb93f7 --- /dev/null +++ b/skills/website-creator/ui-ux-pro-max/scripts @@ -0,0 +1 @@ +../../../src/ui-ux-pro-max/scripts \ No newline at end of file