Files
KisStock/backend/app/services/ai_factory.py

61 lines
2.3 KiB
Python

from typing import Optional, Dict, Any
from abc import ABC, abstractmethod
import httpx
class BaseAIProvider(ABC):
def __init__(self, api_key: str, model_name: str, base_url: str = None):
self.api_key = api_key
self.model_name = model_name
self.base_url = base_url
@abstractmethod
async def generate_content(self, prompt: str, system_instruction: str = None) -> str:
pass
class GeminiProvider(BaseAIProvider):
async def generate_content(self, prompt: str, system_instruction: str = None) -> str:
# Placeholder for Gemini API Implementation
# https://generativelanguage.googleapis.com/v1beta/models/...
return f"Gemini Response to: {prompt}"
class OpenAIProvider(BaseAIProvider):
async def generate_content(self, prompt: str, system_instruction: str = None) -> str:
# Placeholder for OpenAI API
return f"OpenAI Response to: {prompt}"
class OllamaProvider(BaseAIProvider):
"""
Ollama (Local LLM), compatible with OpenAI client usually, or direct /api/generate
"""
async def generate_content(self, prompt: str, system_instruction: str = None) -> str:
# Placeholder for Ollama API
url = f"{self.base_url}/api/generate"
payload = {
"model": self.model_name,
"prompt": prompt,
"stream": False
}
if system_instruction:
payload["system"] = system_instruction
try:
async with httpx.AsyncClient() as client:
resp = await client.post(url, json=payload, timeout=60.0)
resp.raise_for_status()
data = resp.json()
return data.get("response", "")
except Exception as e:
return f"Error: {e}"
class AIFactory:
@staticmethod
def get_provider(provider_type: str, api_key: str, model_name: str, base_url: str = None) -> BaseAIProvider:
if provider_type.lower() == "gemini":
return GeminiProvider(api_key, model_name, base_url)
elif provider_type.lower() == "openai":
return OpenAIProvider(api_key, model_name, base_url)
elif provider_type.lower() == "ollama":
return OllamaProvider(api_key, model_name, base_url)
else:
raise ValueError(f"Unknown Provider: {provider_type}")