61 lines
2.3 KiB
Python
61 lines
2.3 KiB
Python
from typing import Optional, Dict, Any
|
|
from abc import ABC, abstractmethod
|
|
import httpx
|
|
|
|
class BaseAIProvider(ABC):
|
|
def __init__(self, api_key: str, model_name: str, base_url: str = None):
|
|
self.api_key = api_key
|
|
self.model_name = model_name
|
|
self.base_url = base_url
|
|
|
|
@abstractmethod
|
|
async def generate_content(self, prompt: str, system_instruction: str = None) -> str:
|
|
pass
|
|
|
|
class GeminiProvider(BaseAIProvider):
|
|
async def generate_content(self, prompt: str, system_instruction: str = None) -> str:
|
|
# Placeholder for Gemini API Implementation
|
|
# https://generativelanguage.googleapis.com/v1beta/models/...
|
|
return f"Gemini Response to: {prompt}"
|
|
|
|
class OpenAIProvider(BaseAIProvider):
|
|
async def generate_content(self, prompt: str, system_instruction: str = None) -> str:
|
|
# Placeholder for OpenAI API
|
|
return f"OpenAI Response to: {prompt}"
|
|
|
|
class OllamaProvider(BaseAIProvider):
|
|
"""
|
|
Ollama (Local LLM), compatible with OpenAI client usually, or direct /api/generate
|
|
"""
|
|
async def generate_content(self, prompt: str, system_instruction: str = None) -> str:
|
|
# Placeholder for Ollama API
|
|
url = f"{self.base_url}/api/generate"
|
|
payload = {
|
|
"model": self.model_name,
|
|
"prompt": prompt,
|
|
"stream": False
|
|
}
|
|
if system_instruction:
|
|
payload["system"] = system_instruction
|
|
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
resp = await client.post(url, json=payload, timeout=60.0)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
return data.get("response", "")
|
|
except Exception as e:
|
|
return f"Error: {e}"
|
|
|
|
class AIFactory:
|
|
@staticmethod
|
|
def get_provider(provider_type: str, api_key: str, model_name: str, base_url: str = None) -> BaseAIProvider:
|
|
if provider_type.lower() == "gemini":
|
|
return GeminiProvider(api_key, model_name, base_url)
|
|
elif provider_type.lower() == "openai":
|
|
return OpenAIProvider(api_key, model_name, base_url)
|
|
elif provider_type.lower() == "ollama":
|
|
return OllamaProvider(api_key, model_name, base_url)
|
|
else:
|
|
raise ValueError(f"Unknown Provider: {provider_type}")
|