56 lines
2.4 KiB
Python
56 lines
2.4 KiB
Python
from sqlalchemy import select
|
|
from app.db.database import SessionLocal
|
|
from app.db.models import AiConfig, ApiSettings
|
|
from app.services.ai_factory import AIFactory, BaseAIProvider
|
|
|
|
class AIOrchestrator:
|
|
def __init__(self):
|
|
pass
|
|
|
|
async def _get_provider_by_id(self, config_id: str) -> BaseAIProvider:
|
|
async with SessionLocal() as session:
|
|
stmt = select(AiConfig).where(AiConfig.id == config_id)
|
|
result = await session.execute(stmt)
|
|
config = result.scalar_one_or_none()
|
|
|
|
if not config:
|
|
raise ValueError("AI Config not found")
|
|
|
|
# Note: API Keys might need to be stored securely or passed from ENV/Settings.
|
|
# For now assuming API Key is managed externally or stored in config (not implemented in DB schema for security).
|
|
# Or we look up ApiSettings or a secure vault.
|
|
# Simplified: Use a placeholder or ENV.
|
|
api_key = "place_holder"
|
|
|
|
return AIFactory.get_provider(config.providerType, api_key, config.modelName, config.baseUrl)
|
|
|
|
async def get_preferred_provider(self, purpose: str) -> BaseAIProvider:
|
|
"""
|
|
purpose: 'news', 'stock', 'judgement', 'buy', 'sell'
|
|
"""
|
|
async with SessionLocal() as session:
|
|
stmt = select(ApiSettings).where(ApiSettings.id == 1)
|
|
result = await session.execute(stmt)
|
|
settings = result.scalar_one_or_none()
|
|
|
|
if not settings:
|
|
raise ValueError("Settings not initialized")
|
|
|
|
config_id = None
|
|
if purpose == 'news': config_id = settings.preferredNewsAiId
|
|
elif purpose == 'stock': config_id = settings.preferredStockAiId
|
|
elif purpose == 'judgement': config_id = settings.preferredNewsJudgementAiId
|
|
elif purpose == 'buy': config_id = settings.preferredAutoBuyAiId
|
|
elif purpose == 'sell': config_id = settings.preferredAutoSellAiId
|
|
|
|
if not config_id:
|
|
raise ValueError(f"No preferred AI configured for {purpose}")
|
|
|
|
return await self._get_provider_by_id(config_id)
|
|
|
|
async def analyze_text(self, text: str, purpose="news") -> str:
|
|
provider = await self.get_preferred_provider(purpose)
|
|
return await provider.generate_content(text)
|
|
|
|
ai_orchestrator = AIOrchestrator()
|