initial commit
This commit is contained in:
313
backend/main.py
Normal file
313
backend/main.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import os
|
||||
import uvicorn
|
||||
import threading
|
||||
import asyncio
|
||||
from fastapi import FastAPI, HTTPException, Depends, Body, Header, BackgroundTasks
|
||||
import fastapi
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
|
||||
from database import init_db, get_db, TradeSetting, Order, News, Stock, Watchlist
|
||||
from config import load_config, save_config, get_kis_config
|
||||
from kis_api import kis
|
||||
from trader import trader
|
||||
from news_ai import news_bot
|
||||
from telegram_notifier import notifier
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("MAIN")
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Get absolute path to the directory where main.py is located
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
# Root project directory (one level up from backend)
|
||||
PROJECT_ROOT = os.path.dirname(BASE_DIR)
|
||||
# Frontend directory
|
||||
FRONTEND_DIR = os.path.join(PROJECT_ROOT, "frontend")
|
||||
|
||||
# Create frontend directory if it doesn't exist (for safety)
|
||||
if not os.path.exists(FRONTEND_DIR):
|
||||
os.makedirs(FRONTEND_DIR)
|
||||
|
||||
# Mount static files
|
||||
app.mount("/static", StaticFiles(directory=FRONTEND_DIR), name="static")
|
||||
|
||||
from websocket_manager import ws_manager
|
||||
from fastapi import WebSocket
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await ws_manager.connect_frontend(websocket)
|
||||
try:
|
||||
while True:
|
||||
# Keep alive or handle frontend formatting
|
||||
data = await websocket.receive_text()
|
||||
# If frontend sends "subscribe:CODE", we could forward to KIS
|
||||
if data.startswith("sub:"):
|
||||
code = data.split(":")[1]
|
||||
await ws_manager.subscribe_stock(code)
|
||||
except:
|
||||
ws_manager.disconnect_frontend(websocket)
|
||||
|
||||
@app.on_event("startup")
|
||||
def startup_event():
|
||||
init_db()
|
||||
# Start Background Threads
|
||||
trader.start()
|
||||
news_bot.start()
|
||||
|
||||
# Start KIS WebSocket Loop
|
||||
# We need a way to run async loop in bg.
|
||||
# Uvicorn runs in asyncio loop. We can create task.
|
||||
asyncio.create_task(ws_manager.start_kis_socket())
|
||||
|
||||
notifier.send_message("🚀 KisStock AI 시스템이 시작되었습니다.")
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
def shutdown_event():
|
||||
notifier.send_message("🛑 KisStock AI 시스템이 종료됩니다.")
|
||||
trader.stop()
|
||||
news_bot.stop()
|
||||
|
||||
|
||||
# --- Pages ---
|
||||
@app.get("/")
|
||||
async def read_index():
|
||||
return FileResponse(os.path.join(FRONTEND_DIR, "index.html"))
|
||||
|
||||
@app.get("/news")
|
||||
async def read_news():
|
||||
return FileResponse(os.path.join(FRONTEND_DIR, "news.html"))
|
||||
|
||||
@app.get("/stocks")
|
||||
async def read_stocks():
|
||||
return FileResponse(os.path.join(FRONTEND_DIR, "stocks.html"))
|
||||
|
||||
@app.get("/settings")
|
||||
async def read_settings():
|
||||
return FileResponse(os.path.join(FRONTEND_DIR, "settings.html"))
|
||||
|
||||
@app.get("/trade")
|
||||
async def read_trade():
|
||||
return FileResponse(os.path.join(FRONTEND_DIR, "trade.html"))
|
||||
|
||||
# --- API ---
|
||||
from master_loader import master_loader
|
||||
from database import Watchlist
|
||||
|
||||
@app.get("/api/sync/status")
|
||||
def get_sync_status():
|
||||
return master_loader.get_status()
|
||||
|
||||
@app.post("/api/sync/master")
|
||||
def sync_master_data(background_tasks: fastapi.BackgroundTasks):
|
||||
# Run in background
|
||||
background_tasks.add_task(master_loader.download_and_parse_domestic)
|
||||
background_tasks.add_task(master_loader.download_and_parse_overseas)
|
||||
return {"status": "started", "message": "Master data sync started in background"}
|
||||
|
||||
@app.get("/api/stocks")
|
||||
def search_stocks(keyword: str = "", market: str = "", page: int = 1, db: Session = Depends(get_db)):
|
||||
query = db.query(Stock)
|
||||
if keyword:
|
||||
query = query.filter(Stock.name.contains(keyword) | Stock.code.contains(keyword))
|
||||
if market:
|
||||
query = query.filter(Stock.market == market)
|
||||
|
||||
limit = 50
|
||||
offset = (page - 1) * limit
|
||||
items = query.limit(limit).offset(offset).all()
|
||||
return {"items": items}
|
||||
|
||||
@app.get("/api/watchlist")
|
||||
def get_watchlist(db: Session = Depends(get_db)):
|
||||
return db.query(Watchlist).order_by(Watchlist.created_at.desc()).all()
|
||||
|
||||
@app.post("/api/watchlist")
|
||||
def add_watchlist(code: str = Body(...), name: str = Body(...), market: str = Body(...), db: Session = Depends(get_db)):
|
||||
exists = db.query(Watchlist).filter(Watchlist.code == code).first()
|
||||
if exists: return {"status": "exists"}
|
||||
item = Watchlist(code=code, name=name, market=market)
|
||||
db.add(item)
|
||||
db.commit()
|
||||
return {"status": "added"}
|
||||
|
||||
@app.delete("/api/watchlist/{code}")
|
||||
def delete_watchlist(code: str, db: Session = Depends(get_db)):
|
||||
db.query(Watchlist).filter(Watchlist.code == code).delete()
|
||||
db.commit()
|
||||
return {"status": "deleted"}
|
||||
|
||||
@app.get("/api/settings")
|
||||
def get_settings():
|
||||
return load_config()
|
||||
|
||||
@app.post("/api/settings")
|
||||
def update_settings(settings: dict = Body(...)):
|
||||
save_config(settings)
|
||||
return {"status": "ok"}
|
||||
|
||||
from database import AccountBalance, Holding
|
||||
import datetime
|
||||
|
||||
@app.get("/api/balance")
|
||||
def get_my_balance(source: str = "db", db: Session = Depends(get_db)):
|
||||
"""
|
||||
Return persisted balance and holdings from DB.
|
||||
Structure similar to KIS API to minimize frontend changes, or simplified.
|
||||
"""
|
||||
# 1. Balance Summary
|
||||
acc = db.query(AccountBalance).first()
|
||||
output2 = []
|
||||
if acc:
|
||||
output2.append({
|
||||
"tot_evlu_amt": acc.total_eval,
|
||||
"dnca_tot_amt": acc.deposit,
|
||||
"evlu_pfls_smtl_amt": acc.total_profit
|
||||
})
|
||||
|
||||
# 2. Holdings (Domestic)
|
||||
holdings = db.query(Holding).filter(Holding.market == 'DOMESTIC').all()
|
||||
output1 = []
|
||||
for h in holdings:
|
||||
output1.append({
|
||||
"pdno": h.code,
|
||||
"prdt_name": h.name,
|
||||
"hldg_qty": h.quantity,
|
||||
"prpr": h.current_price,
|
||||
"pchs_avg_pric": h.price,
|
||||
"evlu_pfls_rt": h.profit_rate
|
||||
})
|
||||
|
||||
return {
|
||||
"rt_cd": "0",
|
||||
"msg1": "Success from DB",
|
||||
"output1": output1,
|
||||
"output2": output2
|
||||
}
|
||||
|
||||
@app.get("/api/balance/overseas")
|
||||
def get_my_overseas_balance(db: Session = Depends(get_db)):
|
||||
# Persisted Overseas Holdings
|
||||
holdings = db.query(Holding).filter(Holding.market == 'NASD').all()
|
||||
output1 = []
|
||||
for h in holdings:
|
||||
output1.append({
|
||||
"ovrs_pdno": h.code,
|
||||
"ovrs_item_name": h.name,
|
||||
"ovrs_cblc_qty": h.quantity,
|
||||
"now_pric2": h.current_price,
|
||||
"frcr_pchs_amt1": h.price,
|
||||
"evlu_pfls_rt": h.profit_rate
|
||||
})
|
||||
|
||||
return {
|
||||
"rt_cd": "0",
|
||||
"output1": output1
|
||||
}
|
||||
|
||||
@app.post("/api/balance/refresh")
|
||||
def force_refresh_balance(background_tasks: BackgroundTasks):
|
||||
trader.refresh_assets() # Run synchronously to return fresh data immediately? Or BG?
|
||||
# User perception: "Loading..." -> Show data.
|
||||
# If we run BG, frontend needs to poll.
|
||||
# Let's run Sync for "Refresh" button (unless it takes too long).
|
||||
# KIS API is reasonably fast (milliseconds).
|
||||
# trader.refresh_assets()
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/api/price/{code}")
|
||||
def get_stock_price(code: str):
|
||||
price = kis.get_current_price(code)
|
||||
if price:
|
||||
return price
|
||||
raise HTTPException(status_code=404, detail="Stock info not found")
|
||||
|
||||
@app.post("/api/order")
|
||||
def place_order_api(
|
||||
code: str = Body(...),
|
||||
type: str = Body(...),
|
||||
qty: int = Body(...),
|
||||
price: int = Body(...),
|
||||
market: str = Body("DOMESTIC")
|
||||
):
|
||||
if market in ["NASD", "NYSE", "AMEX"]:
|
||||
res = kis.place_overseas_order(code, type, qty, price, market)
|
||||
else:
|
||||
res = kis.place_order(code, type, qty, price)
|
||||
|
||||
if res and res.get('rt_cd') == '0':
|
||||
return res
|
||||
raise HTTPException(status_code=400, detail=f"Order Failed: {res}")
|
||||
|
||||
@app.get("/api/orders")
|
||||
def get_db_orders(db: Session = Depends(get_db)):
|
||||
orders = db.query(Order).order_by(Order.created_at.desc()).limit(50).all()
|
||||
return orders
|
||||
|
||||
@app.get("/api/orders/daily")
|
||||
def get_daily_orders_api():
|
||||
res = kis.get_daily_orders()
|
||||
if res:
|
||||
return res
|
||||
raise HTTPException(status_code=500, detail="Failed to fetch daily orders")
|
||||
|
||||
@app.get("/api/orders/cancelable")
|
||||
def get_cancelable_orders_api():
|
||||
res = kis.get_cancelable_orders()
|
||||
if res:
|
||||
return res
|
||||
raise HTTPException(status_code=500, detail="Failed to fetch cancelable orders")
|
||||
|
||||
@app.post("/api/order/cancel")
|
||||
def cancel_order_api(
|
||||
org_no: str = Body(...),
|
||||
order_no: str = Body(...),
|
||||
qty: int = Body(...),
|
||||
is_buy: bool = Body(...),
|
||||
price: int = Body(0)
|
||||
):
|
||||
res = kis.cancel_order(org_no, order_no, qty, is_buy, price)
|
||||
if res and res.get('rt_cd') == '0':
|
||||
return res
|
||||
raise HTTPException(status_code=400, detail=f"Cancel Failed: {res}")
|
||||
|
||||
|
||||
@app.get("/api/news")
|
||||
def get_news(db: Session = Depends(get_db)):
|
||||
news = db.query(News).order_by(News.created_at.desc()).limit(20).all()
|
||||
return news
|
||||
|
||||
@app.get("/api/trade_settings")
|
||||
def get_trade_settings(db: Session = Depends(get_db)):
|
||||
return db.query(TradeSetting).all()
|
||||
|
||||
@app.post("/api/trade_settings")
|
||||
def set_trade_setting(
|
||||
code: str = Body(...),
|
||||
target_price: Optional[float] = Body(None),
|
||||
stop_loss_price: Optional[float] = Body(None),
|
||||
is_active: bool = Body(True),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
setting = db.query(TradeSetting).filter(TradeSetting.code == code).first()
|
||||
if not setting:
|
||||
setting = TradeSetting(code=code)
|
||||
db.add(setting)
|
||||
|
||||
if target_price is not None:
|
||||
setting.target_price = target_price
|
||||
if stop_loss_price is not None:
|
||||
setting.stop_loss_price = stop_loss_price
|
||||
setting.is_active = is_active
|
||||
|
||||
db.commit()
|
||||
return {"status": "ok"}
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|
||||
Reference in New Issue
Block a user