import asyncio import json import websockets import logging from typing import Dict, Set, Callable, Optional from datetime import datetime from app.services.kis_auth import kis_auth from app.core.crypto import aes_cbc_base64_dec from app.db.database import SessionLocal # from app.db.crud import update_stock_price # TODO: Implement CRUD logger = logging.getLogger(__name__) class RealtimeManager: """ Manages KIS WebSocket Connection. Handles: Connection, Subscription, Decryption, PINGPONG. """ WS_URL_REAL = "ws://ops.koreainvestment.com:21000" def __init__(self): self.ws: Optional[websockets.WebSocketClientProtocol] = None self.approval_key: Optional[str] = None # Reference Counting: Code -> Set of Sources # e.g. "005930": {"HOLDING", "FRONTEND_DASHBOARD"} self.subscriptions: Dict[str, Set[str]] = {} self.running = False self.data_map: Dict[str, Dict] = {} # Realtime Data Cache (Code -> DataDict) # Used by Scheduler to persist data periodically self.price_cache: Dict[str, Dict] = {} async def add_subscription(self, code: str, source: str): """ Request subscription. Increments reference count for the code. """ if code not in self.subscriptions: self.subscriptions[code] = set() if not self.subscriptions[code]: # First subscriber, Send WS Command await self._send_subscribe(code, "1") # 1=Register self.subscriptions[code].add(source) logger.info(f"Subscribed {code} by {source}. RefCount: {len(self.subscriptions[code])}") async def remove_subscription(self, code: str, source: str): """ Remove subscription. Decrements reference count. """ if code in self.subscriptions and source in self.subscriptions[code]: self.subscriptions[code].remove(source) logger.info(f"Unsubscribed {code} by {source}. RefCount: {len(self.subscriptions[code])}") if not self.subscriptions[code]: # No more subscribers, Send WS Unsubscribe await self._send_subscribe(code, "2") # 2=Unregister del self.subscriptions[code] async def _send_subscribe(self, code: str, tr_type: str): if not self.ws or not self.approval_key: return # Will resubscribe on connect payload = { "header": { "approval_key": self.approval_key, "custtype": "P", "tr_type": tr_type, "content-type": "utf-8" }, "body": { "input": { "tr_id": "H0STCNT0", "tr_key": code } } } await self.ws.send(json.dumps(payload)) async def _resubscribe_all(self): for code in list(self.subscriptions.keys()): await self._send_subscribe(code, "1") async def _listen(self): async for message in self.ws: try: if isinstance(message, bytes): message = message.decode('utf-8') first_char = message[0] if first_char in ['0', '1']: # Real Data parts = message.split('|') if len(parts) < 4: continue tr_id = parts[1] raw_data = parts[3] if tr_id == "H0STCNT0": await self._parse_domestic_price(raw_data) elif first_char == '{': data = json.loads(message) if data.get('header', {}).get('tr_id') == "PINGPONG": await self.ws.send(message) except Exception as e: logger.error(f"WS Error: {e}") async def _parse_domestic_price(self, raw_data: str): # Format: MKSC_SHRN_ISCD^EXEC_TIME^CURRENT_PRICE^... fields = raw_data.split('^') if len(fields) < 3: return code = fields[0] curr_price = fields[2] change = fields[4] change_rate = fields[5] # Create lightweight update object (Dict) update_data = { "code": code, "price": curr_price, "change": change, "rate": change_rate, "timestamp": datetime.now().isoformat() } # Update Cache self.price_cache[code] = update_data # logger.debug(f"Price Update: {code} {curr_price}") realtime_manager = RealtimeManager()