백엔드 전체 구현 완료: 내부 서비스(Auth, Client, Realtime), API 엔드포인트 및 스케줄러 구현
This commit is contained in:
60
backend/app/services/ai_factory.py
Normal file
60
backend/app/services/ai_factory.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from typing import Optional, Dict, Any
|
||||
from abc import ABC, abstractmethod
|
||||
import httpx
|
||||
|
||||
class BaseAIProvider(ABC):
|
||||
def __init__(self, api_key: str, model_name: str, base_url: str = None):
|
||||
self.api_key = api_key
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
|
||||
@abstractmethod
|
||||
async def generate_content(self, prompt: str, system_instruction: str = None) -> str:
|
||||
pass
|
||||
|
||||
class GeminiProvider(BaseAIProvider):
|
||||
async def generate_content(self, prompt: str, system_instruction: str = None) -> str:
|
||||
# Placeholder for Gemini API Implementation
|
||||
# https://generativelanguage.googleapis.com/v1beta/models/...
|
||||
return f"Gemini Response to: {prompt}"
|
||||
|
||||
class OpenAIProvider(BaseAIProvider):
|
||||
async def generate_content(self, prompt: str, system_instruction: str = None) -> str:
|
||||
# Placeholder for OpenAI API
|
||||
return f"OpenAI Response to: {prompt}"
|
||||
|
||||
class OllamaProvider(BaseAIProvider):
|
||||
"""
|
||||
Ollama (Local LLM), compatible with OpenAI client usually, or direct /api/generate
|
||||
"""
|
||||
async def generate_content(self, prompt: str, system_instruction: str = None) -> str:
|
||||
# Placeholder for Ollama API
|
||||
url = f"{self.base_url}/api/generate"
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"prompt": prompt,
|
||||
"stream": False
|
||||
}
|
||||
if system_instruction:
|
||||
payload["system"] = system_instruction
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(url, json=payload, timeout=60.0)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data.get("response", "")
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
class AIFactory:
|
||||
@staticmethod
|
||||
def get_provider(provider_type: str, api_key: str, model_name: str, base_url: str = None) -> BaseAIProvider:
|
||||
if provider_type.lower() == "gemini":
|
||||
return GeminiProvider(api_key, model_name, base_url)
|
||||
elif provider_type.lower() == "openai":
|
||||
return OpenAIProvider(api_key, model_name, base_url)
|
||||
elif provider_type.lower() == "ollama":
|
||||
return OllamaProvider(api_key, model_name, base_url)
|
||||
else:
|
||||
raise ValueError(f"Unknown Provider: {provider_type}")
|
||||
Reference in New Issue
Block a user