314 lines
9.6 KiB
Python
314 lines
9.6 KiB
Python
import os
|
|
import uvicorn
|
|
import threading
|
|
import asyncio
|
|
from fastapi import FastAPI, HTTPException, Depends, Body, Header, BackgroundTasks
|
|
import fastapi
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.responses import FileResponse, JSONResponse
|
|
from sqlalchemy.orm import Session
|
|
from typing import List, Optional
|
|
|
|
from database import init_db, get_db, TradeSetting, Order, News, Stock, Watchlist
|
|
from config import load_config, save_config, get_kis_config
|
|
from kis_api import kis
|
|
from trader import trader
|
|
from news_ai import news_bot
|
|
from telegram_notifier import notifier
|
|
import logging
|
|
|
|
logger = logging.getLogger("MAIN")
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
# Get absolute path to the directory where main.py is located
|
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
# Root project directory (one level up from backend)
|
|
PROJECT_ROOT = os.path.dirname(BASE_DIR)
|
|
# Frontend directory
|
|
FRONTEND_DIR = os.path.join(PROJECT_ROOT, "frontend")
|
|
|
|
# Create frontend directory if it doesn't exist (for safety)
|
|
if not os.path.exists(FRONTEND_DIR):
|
|
os.makedirs(FRONTEND_DIR)
|
|
|
|
# Mount static files
|
|
app.mount("/static", StaticFiles(directory=FRONTEND_DIR), name="static")
|
|
|
|
from websocket_manager import ws_manager
|
|
from fastapi import WebSocket
|
|
|
|
@app.websocket("/ws")
|
|
async def websocket_endpoint(websocket: WebSocket):
|
|
await ws_manager.connect_frontend(websocket)
|
|
try:
|
|
while True:
|
|
# Keep alive or handle frontend formatting
|
|
data = await websocket.receive_text()
|
|
# If frontend sends "subscribe:CODE", we could forward to KIS
|
|
if data.startswith("sub:"):
|
|
code = data.split(":")[1]
|
|
await ws_manager.subscribe_stock(code)
|
|
except:
|
|
ws_manager.disconnect_frontend(websocket)
|
|
|
|
@app.on_event("startup")
|
|
def startup_event():
|
|
init_db()
|
|
# Start Background Threads
|
|
trader.start()
|
|
news_bot.start()
|
|
|
|
# Start KIS WebSocket Loop
|
|
# We need a way to run async loop in bg.
|
|
# Uvicorn runs in asyncio loop. We can create task.
|
|
asyncio.create_task(ws_manager.start_kis_socket())
|
|
|
|
notifier.send_message("🚀 KisStock AI 시스템이 시작되었습니다.")
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
def shutdown_event():
|
|
notifier.send_message("🛑 KisStock AI 시스템이 종료됩니다.")
|
|
trader.stop()
|
|
news_bot.stop()
|
|
|
|
|
|
# --- Pages ---
|
|
@app.get("/")
|
|
async def read_index():
|
|
return FileResponse(os.path.join(FRONTEND_DIR, "index.html"))
|
|
|
|
@app.get("/news")
|
|
async def read_news():
|
|
return FileResponse(os.path.join(FRONTEND_DIR, "news.html"))
|
|
|
|
@app.get("/stocks")
|
|
async def read_stocks():
|
|
return FileResponse(os.path.join(FRONTEND_DIR, "stocks.html"))
|
|
|
|
@app.get("/settings")
|
|
async def read_settings():
|
|
return FileResponse(os.path.join(FRONTEND_DIR, "settings.html"))
|
|
|
|
@app.get("/trade")
|
|
async def read_trade():
|
|
return FileResponse(os.path.join(FRONTEND_DIR, "trade.html"))
|
|
|
|
# --- API ---
|
|
from master_loader import master_loader
|
|
from database import Watchlist
|
|
|
|
@app.get("/api/sync/status")
|
|
def get_sync_status():
|
|
return master_loader.get_status()
|
|
|
|
@app.post("/api/sync/master")
|
|
def sync_master_data(background_tasks: fastapi.BackgroundTasks):
|
|
# Run in background
|
|
background_tasks.add_task(master_loader.download_and_parse_domestic)
|
|
background_tasks.add_task(master_loader.download_and_parse_overseas)
|
|
return {"status": "started", "message": "Master data sync started in background"}
|
|
|
|
@app.get("/api/stocks")
|
|
def search_stocks(keyword: str = "", market: str = "", page: int = 1, db: Session = Depends(get_db)):
|
|
query = db.query(Stock)
|
|
if keyword:
|
|
query = query.filter(Stock.name.contains(keyword) | Stock.code.contains(keyword))
|
|
if market:
|
|
query = query.filter(Stock.market == market)
|
|
|
|
limit = 50
|
|
offset = (page - 1) * limit
|
|
items = query.limit(limit).offset(offset).all()
|
|
return {"items": items}
|
|
|
|
@app.get("/api/watchlist")
|
|
def get_watchlist(db: Session = Depends(get_db)):
|
|
return db.query(Watchlist).order_by(Watchlist.created_at.desc()).all()
|
|
|
|
@app.post("/api/watchlist")
|
|
def add_watchlist(code: str = Body(...), name: str = Body(...), market: str = Body(...), db: Session = Depends(get_db)):
|
|
exists = db.query(Watchlist).filter(Watchlist.code == code).first()
|
|
if exists: return {"status": "exists"}
|
|
item = Watchlist(code=code, name=name, market=market)
|
|
db.add(item)
|
|
db.commit()
|
|
return {"status": "added"}
|
|
|
|
@app.delete("/api/watchlist/{code}")
|
|
def delete_watchlist(code: str, db: Session = Depends(get_db)):
|
|
db.query(Watchlist).filter(Watchlist.code == code).delete()
|
|
db.commit()
|
|
return {"status": "deleted"}
|
|
|
|
@app.get("/api/settings")
|
|
def get_settings():
|
|
return load_config()
|
|
|
|
@app.post("/api/settings")
|
|
def update_settings(settings: dict = Body(...)):
|
|
save_config(settings)
|
|
return {"status": "ok"}
|
|
|
|
from database import AccountBalance, Holding
|
|
import datetime
|
|
|
|
@app.get("/api/balance")
|
|
def get_my_balance(source: str = "db", db: Session = Depends(get_db)):
|
|
"""
|
|
Return persisted balance and holdings from DB.
|
|
Structure similar to KIS API to minimize frontend changes, or simplified.
|
|
"""
|
|
# 1. Balance Summary
|
|
acc = db.query(AccountBalance).first()
|
|
output2 = []
|
|
if acc:
|
|
output2.append({
|
|
"tot_evlu_amt": acc.total_eval,
|
|
"dnca_tot_amt": acc.deposit,
|
|
"evlu_pfls_smtl_amt": acc.total_profit
|
|
})
|
|
|
|
# 2. Holdings (Domestic)
|
|
holdings = db.query(Holding).filter(Holding.market == 'DOMESTIC').all()
|
|
output1 = []
|
|
for h in holdings:
|
|
output1.append({
|
|
"pdno": h.code,
|
|
"prdt_name": h.name,
|
|
"hldg_qty": h.quantity,
|
|
"prpr": h.current_price,
|
|
"pchs_avg_pric": h.price,
|
|
"evlu_pfls_rt": h.profit_rate
|
|
})
|
|
|
|
return {
|
|
"rt_cd": "0",
|
|
"msg1": "Success from DB",
|
|
"output1": output1,
|
|
"output2": output2
|
|
}
|
|
|
|
@app.get("/api/balance/overseas")
|
|
def get_my_overseas_balance(db: Session = Depends(get_db)):
|
|
# Persisted Overseas Holdings
|
|
holdings = db.query(Holding).filter(Holding.market == 'NASD').all()
|
|
output1 = []
|
|
for h in holdings:
|
|
output1.append({
|
|
"ovrs_pdno": h.code,
|
|
"ovrs_item_name": h.name,
|
|
"ovrs_cblc_qty": h.quantity,
|
|
"now_pric2": h.current_price,
|
|
"frcr_pchs_amt1": h.price,
|
|
"evlu_pfls_rt": h.profit_rate
|
|
})
|
|
|
|
return {
|
|
"rt_cd": "0",
|
|
"output1": output1
|
|
}
|
|
|
|
@app.post("/api/balance/refresh")
|
|
def force_refresh_balance(background_tasks: BackgroundTasks):
|
|
trader.refresh_assets() # Run synchronously to return fresh data immediately? Or BG?
|
|
# User perception: "Loading..." -> Show data.
|
|
# If we run BG, frontend needs to poll.
|
|
# Let's run Sync for "Refresh" button (unless it takes too long).
|
|
# KIS API is reasonably fast (milliseconds).
|
|
# trader.refresh_assets()
|
|
return {"status": "ok"}
|
|
|
|
@app.get("/api/price/{code}")
|
|
def get_stock_price(code: str):
|
|
price = kis.get_current_price(code)
|
|
if price:
|
|
return price
|
|
raise HTTPException(status_code=404, detail="Stock info not found")
|
|
|
|
@app.post("/api/order")
|
|
def place_order_api(
|
|
code: str = Body(...),
|
|
type: str = Body(...),
|
|
qty: int = Body(...),
|
|
price: int = Body(...),
|
|
market: str = Body("DOMESTIC")
|
|
):
|
|
if market in ["NASD", "NYSE", "AMEX"]:
|
|
res = kis.place_overseas_order(code, type, qty, price, market)
|
|
else:
|
|
res = kis.place_order(code, type, qty, price)
|
|
|
|
if res and res.get('rt_cd') == '0':
|
|
return res
|
|
raise HTTPException(status_code=400, detail=f"Order Failed: {res}")
|
|
|
|
@app.get("/api/orders")
|
|
def get_db_orders(db: Session = Depends(get_db)):
|
|
orders = db.query(Order).order_by(Order.created_at.desc()).limit(50).all()
|
|
return orders
|
|
|
|
@app.get("/api/orders/daily")
|
|
def get_daily_orders_api():
|
|
res = kis.get_daily_orders()
|
|
if res:
|
|
return res
|
|
raise HTTPException(status_code=500, detail="Failed to fetch daily orders")
|
|
|
|
@app.get("/api/orders/cancelable")
|
|
def get_cancelable_orders_api():
|
|
res = kis.get_cancelable_orders()
|
|
if res:
|
|
return res
|
|
raise HTTPException(status_code=500, detail="Failed to fetch cancelable orders")
|
|
|
|
@app.post("/api/order/cancel")
|
|
def cancel_order_api(
|
|
org_no: str = Body(...),
|
|
order_no: str = Body(...),
|
|
qty: int = Body(...),
|
|
is_buy: bool = Body(...),
|
|
price: int = Body(0)
|
|
):
|
|
res = kis.cancel_order(org_no, order_no, qty, is_buy, price)
|
|
if res and res.get('rt_cd') == '0':
|
|
return res
|
|
raise HTTPException(status_code=400, detail=f"Cancel Failed: {res}")
|
|
|
|
|
|
@app.get("/api/news")
|
|
def get_news(db: Session = Depends(get_db)):
|
|
news = db.query(News).order_by(News.created_at.desc()).limit(20).all()
|
|
return news
|
|
|
|
@app.get("/api/trade_settings")
|
|
def get_trade_settings(db: Session = Depends(get_db)):
|
|
return db.query(TradeSetting).all()
|
|
|
|
@app.post("/api/trade_settings")
|
|
def set_trade_setting(
|
|
code: str = Body(...),
|
|
target_price: Optional[float] = Body(None),
|
|
stop_loss_price: Optional[float] = Body(None),
|
|
is_active: bool = Body(True),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
setting = db.query(TradeSetting).filter(TradeSetting.code == code).first()
|
|
if not setting:
|
|
setting = TradeSetting(code=code)
|
|
db.add(setting)
|
|
|
|
if target_price is not None:
|
|
setting.target_price = target_price
|
|
if stop_loss_price is not None:
|
|
setting.stop_loss_price = stop_loss_price
|
|
setting.is_active = is_active
|
|
|
|
db.commit()
|
|
return {"status": "ok"}
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|