import os import uvicorn import threading import asyncio from fastapi import FastAPI, HTTPException, Depends, Body, Header, BackgroundTasks import fastapi from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse, JSONResponse from sqlalchemy.orm import Session from typing import List, Optional from database import init_db, get_db, TradeSetting, Order, News, Stock, Watchlist from config import load_config, save_config, get_kis_config from kis_api import kis from trader import trader from news_ai import news_bot from telegram_notifier import notifier import logging logger = logging.getLogger("MAIN") app = FastAPI() # Get absolute path to the directory where main.py is located BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # Root project directory (one level up from backend) PROJECT_ROOT = os.path.dirname(BASE_DIR) # Frontend directory FRONTEND_DIR = os.path.join(PROJECT_ROOT, "frontend") # Create frontend directory if it doesn't exist (for safety) if not os.path.exists(FRONTEND_DIR): os.makedirs(FRONTEND_DIR) # Mount static files app.mount("/static", StaticFiles(directory=FRONTEND_DIR), name="static") from websocket_manager import ws_manager from fastapi import WebSocket @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await ws_manager.connect_frontend(websocket) try: while True: # Keep alive or handle frontend formatting data = await websocket.receive_text() # If frontend sends "subscribe:CODE", we could forward to KIS if data.startswith("sub:"): code = data.split(":")[1] await ws_manager.subscribe_stock(code) except: ws_manager.disconnect_frontend(websocket) @app.on_event("startup") def startup_event(): init_db() # Start Background Threads trader.start() news_bot.start() # Start KIS WebSocket Loop # We need a way to run async loop in bg. # Uvicorn runs in asyncio loop. We can create task. asyncio.create_task(ws_manager.start_kis_socket()) notifier.send_message("🚀 KisStock AI 시스템이 시작되었습니다.") @app.on_event("shutdown") def shutdown_event(): notifier.send_message("🛑 KisStock AI 시스템이 종료됩니다.") trader.stop() news_bot.stop() # --- Pages --- @app.get("/") async def read_index(): return FileResponse(os.path.join(FRONTEND_DIR, "index.html")) @app.get("/news") async def read_news(): return FileResponse(os.path.join(FRONTEND_DIR, "news.html")) @app.get("/stocks") async def read_stocks(): return FileResponse(os.path.join(FRONTEND_DIR, "stocks.html")) @app.get("/settings") async def read_settings(): return FileResponse(os.path.join(FRONTEND_DIR, "settings.html")) @app.get("/trade") async def read_trade(): return FileResponse(os.path.join(FRONTEND_DIR, "trade.html")) # --- API --- from master_loader import master_loader from database import Watchlist @app.get("/api/sync/status") def get_sync_status(): return master_loader.get_status() @app.post("/api/sync/master") def sync_master_data(background_tasks: fastapi.BackgroundTasks): # Run in background background_tasks.add_task(master_loader.download_and_parse_domestic) background_tasks.add_task(master_loader.download_and_parse_overseas) return {"status": "started", "message": "Master data sync started in background"} @app.get("/api/stocks") def search_stocks(keyword: str = "", market: str = "", page: int = 1, db: Session = Depends(get_db)): query = db.query(Stock) if keyword: query = query.filter(Stock.name.contains(keyword) | Stock.code.contains(keyword)) if market: query = query.filter(Stock.market == market) limit = 50 offset = (page - 1) * limit items = query.limit(limit).offset(offset).all() return {"items": items} @app.get("/api/watchlist") def get_watchlist(db: Session = Depends(get_db)): return db.query(Watchlist).order_by(Watchlist.created_at.desc()).all() @app.post("/api/watchlist") def add_watchlist(code: str = Body(...), name: str = Body(...), market: str = Body(...), db: Session = Depends(get_db)): exists = db.query(Watchlist).filter(Watchlist.code == code).first() if exists: return {"status": "exists"} item = Watchlist(code=code, name=name, market=market) db.add(item) db.commit() return {"status": "added"} @app.delete("/api/watchlist/{code}") def delete_watchlist(code: str, db: Session = Depends(get_db)): db.query(Watchlist).filter(Watchlist.code == code).delete() db.commit() return {"status": "deleted"} @app.get("/api/settings") def get_settings(): return load_config() @app.post("/api/settings") def update_settings(settings: dict = Body(...)): save_config(settings) return {"status": "ok"} from database import AccountBalance, Holding import datetime @app.get("/api/balance") def get_my_balance(source: str = "db", db: Session = Depends(get_db)): """ Return persisted balance and holdings from DB. Structure similar to KIS API to minimize frontend changes, or simplified. """ # 1. Balance Summary acc = db.query(AccountBalance).first() output2 = [] if acc: output2.append({ "tot_evlu_amt": acc.total_eval, "dnca_tot_amt": acc.deposit, "evlu_pfls_smtl_amt": acc.total_profit }) # 2. Holdings (Domestic) holdings = db.query(Holding).filter(Holding.market == 'DOMESTIC').all() output1 = [] for h in holdings: output1.append({ "pdno": h.code, "prdt_name": h.name, "hldg_qty": h.quantity, "prpr": h.current_price, "pchs_avg_pric": h.price, "evlu_pfls_rt": h.profit_rate }) return { "rt_cd": "0", "msg1": "Success from DB", "output1": output1, "output2": output2 } @app.get("/api/balance/overseas") def get_my_overseas_balance(db: Session = Depends(get_db)): # Persisted Overseas Holdings holdings = db.query(Holding).filter(Holding.market == 'NASD').all() output1 = [] for h in holdings: output1.append({ "ovrs_pdno": h.code, "ovrs_item_name": h.name, "ovrs_cblc_qty": h.quantity, "now_pric2": h.current_price, "frcr_pchs_amt1": h.price, "evlu_pfls_rt": h.profit_rate }) return { "rt_cd": "0", "output1": output1 } @app.post("/api/balance/refresh") def force_refresh_balance(background_tasks: BackgroundTasks): trader.refresh_assets() # Run synchronously to return fresh data immediately? Or BG? # User perception: "Loading..." -> Show data. # If we run BG, frontend needs to poll. # Let's run Sync for "Refresh" button (unless it takes too long). # KIS API is reasonably fast (milliseconds). # trader.refresh_assets() return {"status": "ok"} @app.get("/api/price/{code}") def get_stock_price(code: str): price = kis.get_current_price(code) if price: return price raise HTTPException(status_code=404, detail="Stock info not found") @app.post("/api/order") def place_order_api( code: str = Body(...), type: str = Body(...), qty: int = Body(...), price: int = Body(...), market: str = Body("DOMESTIC") ): if market in ["NASD", "NYSE", "AMEX"]: res = kis.place_overseas_order(code, type, qty, price, market) else: res = kis.place_order(code, type, qty, price) if res and res.get('rt_cd') == '0': return res raise HTTPException(status_code=400, detail=f"Order Failed: {res}") @app.get("/api/orders") def get_db_orders(db: Session = Depends(get_db)): orders = db.query(Order).order_by(Order.created_at.desc()).limit(50).all() return orders @app.get("/api/orders/daily") def get_daily_orders_api(): res = kis.get_daily_orders() if res: return res raise HTTPException(status_code=500, detail="Failed to fetch daily orders") @app.get("/api/orders/cancelable") def get_cancelable_orders_api(): res = kis.get_cancelable_orders() if res: return res raise HTTPException(status_code=500, detail="Failed to fetch cancelable orders") @app.post("/api/order/cancel") def cancel_order_api( org_no: str = Body(...), order_no: str = Body(...), qty: int = Body(...), is_buy: bool = Body(...), price: int = Body(0) ): res = kis.cancel_order(org_no, order_no, qty, is_buy, price) if res and res.get('rt_cd') == '0': return res raise HTTPException(status_code=400, detail=f"Cancel Failed: {res}") @app.get("/api/news") def get_news(db: Session = Depends(get_db)): news = db.query(News).order_by(News.created_at.desc()).limit(20).all() return news @app.get("/api/trade_settings") def get_trade_settings(db: Session = Depends(get_db)): return db.query(TradeSetting).all() @app.post("/api/trade_settings") def set_trade_setting( code: str = Body(...), target_price: Optional[float] = Body(None), stop_loss_price: Optional[float] = Body(None), is_active: bool = Body(True), db: Session = Depends(get_db) ): setting = db.query(TradeSetting).filter(TradeSetting.code == code).first() if not setting: setting = TradeSetting(code=code) db.add(setting) if target_price is not None: setting.target_price = target_price if stop_loss_price is not None: setting.stop_loss_price = stop_loss_price setting.is_active = is_active db.commit() return {"status": "ok"} if __name__ == "__main__": uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)