initial commit

This commit is contained in:
2026-02-04 00:16:34 +09:00
commit ae11528dd9
867 changed files with 209640 additions and 0 deletions

18
backend/config.py Normal file
View File

@@ -0,0 +1,18 @@
import os
import yaml
CONFIG_FILE = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "settings.yaml")
def load_config():
if not os.path.exists(CONFIG_FILE):
return {}
with open(CONFIG_FILE, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
def save_config(config_data):
with open(CONFIG_FILE, 'w', encoding='utf-8') as f:
yaml.dump(config_data, f, allow_unicode=True)
def get_kis_config():
config = load_config()
return config.get('kis', {})

5285
backend/data/nasdaq.txt Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

112
backend/database.py Normal file
View File

@@ -0,0 +1,112 @@
from sqlalchemy import create_engine, Column, Integer, String, Float, DateTime, Boolean, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
import datetime
import os
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DB_URL = f"sqlite:///{os.path.join(BASE_DIR, 'kis_stock.db')}"
engine = create_engine(DB_URL, connect_args={"check_same_thread": False})
# Enable WAL mode for better concurrency
from sqlalchemy import event
@event.listens_for(engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.close()
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
class Stock(Base):
__tablename__ = "stocks"
code = Column(String, primary_key=True, index=True)
name = Column(String, index=True)
name_eng = Column(String, nullable=True) # English Name
market = Column(String) # KOSPI, KOSDAQ, NASD, NYSE, AMEX
sector = Column(String, nullable=True)
industry = Column(String, nullable=True) # Detailed Industry
type = Column(String, default="DOMESTIC") # DOMESTIC, OVERSEAS
financial_status = Column(String, nullable=True) # 'N', 'D', 'E' etc from Nasdaq
is_etf = Column(Boolean, default=False)
current_price = Column(Float, default=0.0)
class Watchlist(Base):
__tablename__ = "watchlist"
id = Column(Integer, primary_key=True, index=True)
code = Column(String, index=True)
name = Column(String) # Cache name for display
market = Column(String)
is_monitoring = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.datetime.now)
class Order(Base):
__tablename__ = "orders"
id = Column(Integer, primary_key=True, index=True)
order_id = Column(String, nullable=True) # KIS Order ID
code = Column(String, index=True)
type = Column(String) # BUY, SELL
price = Column(Float)
quantity = Column(Integer)
status = Column(String) # PENDING, FILLED, CANCELLED
created_at = Column(DateTime, default=datetime.datetime.now)
class TradeSetting(Base):
__tablename__ = "trade_settings"
code = Column(String, primary_key=True)
target_price = Column(Float, nullable=True)
stop_loss_price = Column(Float, nullable=True)
trailing_stop_percent = Column(Float, nullable=True)
is_active = Column(Boolean, default=False)
class News(Base):
__tablename__ = "news"
id = Column(Integer, primary_key=True, index=True)
title = Column(String)
link = Column(String, unique=True)
pub_date = Column(String)
analysis_result = Column(Text)
impact_score = Column(Integer)
related_sector = Column(String)
created_at = Column(DateTime, default=datetime.datetime.now)
class StockPrice(Base):
__tablename__ = "stock_prices"
id = Column(Integer, primary_key=True, index=True)
code = Column(String, index=True)
price = Column(Float)
change = Column(Float)
volume = Column(Integer)
created_at = Column(DateTime, default=datetime.datetime.now)
class AccountBalance(Base):
__tablename__ = "account_balance"
id = Column(Integer, primary_key=True)
total_eval = Column(Float, default=0.0) # 총평가금액
deposit = Column(Float, default=0.0) # 예수금
total_profit = Column(Float, default=0.0) # 평가손익
updated_at = Column(DateTime, default=datetime.datetime.now)
class Holding(Base):
__tablename__ = "holdings"
id = Column(Integer, primary_key=True, index=True)
code = Column(String, index=True)
name = Column(String)
quantity = Column(Integer)
price = Column(Float) # 매입평단가
current_price = Column(Float) # 현재가
profit_rate = Column(Float)
market = Column(String) # DOMESTIC, NASD, etc.
updated_at = Column(DateTime, default=datetime.datetime.now)
def init_db():
Base.metadata.create_all(bind=engine)
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()

379
backend/kis_api.py Normal file
View File

@@ -0,0 +1,379 @@
import requests
import json
import datetime
import os
import time
import copy
from config import get_kis_config
import logging
# Basic Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("KIS_API")
class KisApi:
def __init__(self):
self.config = get_kis_config()
self.app_key = self.config.get("app_key")
self.app_secret = self.config.get("app_secret")
self.account_no = str(self.config.get("account_no", "")).replace("-", "").strip() # 8 digits
self.account_prod = str(self.config.get("account_prod", "01")).strip() # 2 digits
self.is_paper = self.config.get("is_paper", True)
self.htsid = self.config.get("htsid", "")
logger.info(f"Initialized KIS API: Account={self.account_no}, Prod={self.account_prod}, Paper={self.is_paper}")
if self.is_paper:
self.base_url = "https://openapivts.koreainvestment.com:29443"
else:
self.base_url = "https://openapi.koreainvestment.com:9443"
self.token_file = "kis_token.tmp"
self.access_token = None
self.token_expired = None
self.last_req_time = 0
self._auth()
def _auth(self):
# Clean outdated token file if needed or load existing
if os.path.exists(self.token_file):
with open(self.token_file, 'r') as f:
data = json.load(f)
expired = datetime.datetime.strptime(data['expired'], "%Y-%m-%d %H:%M:%S")
if expired > datetime.datetime.now():
self.access_token = data['token']
self.token_expired = expired
logger.info("Loaded credentials from cache.")
return
# Request new token
url = f"{self.base_url}/oauth2/tokenP"
headers = {"content-type": "application/json"}
body = {
"grant_type": "client_credentials",
"appkey": self.app_key,
"appsecret": self.app_secret
}
res = requests.post(url, headers=headers, data=json.dumps(body))
if res.status_code == 200:
data = res.json()
self.access_token = data['access_token']
self.token_expired = datetime.datetime.strptime(data['access_token_token_expired'], "%Y-%m-%d %H:%M:%S")
# Save to file
with open(self.token_file, 'w') as f:
json.dump({
"token": self.access_token,
"expired": data['access_token_token_expired']
}, f)
logger.info("Issued new access token.")
else:
logger.error(f"Auth Failed: {res.text}")
raise Exception("Authentication Failed")
def get_websocket_key(self):
"""
Get Approval Key for WebSocket
"""
url = f"{self.base_url}/oauth2/Approval"
headers = {"content-type": "application/json"}
body = {
"grant_type": "client_credentials",
"appkey": self.app_key,
"secretkey": self.app_secret
}
res = requests.post(url, headers=headers, data=json.dumps(body))
if res.status_code == 200:
return res.json().get("approval_key")
else:
logger.error(f"WS Key Failed: {res.text}")
return None
def _get_header(self, tr_id=None):
header = {
"Content-Type": "application/json",
"authorization": f"Bearer {self.access_token}",
"appkey": self.app_key,
"appsecret": self.app_secret,
"tr_id": tr_id,
"custtype": "P"
}
return header
def _request(self, method, path, tr_id=None, x_tr_id_buy=None, x_tr_id_sell=None, **kwargs):
"""
Centralized request handler with auto-token refresh logic and 500ms throttling.
"""
# Throttling
now = time.time()
diff = now - self.last_req_time
if diff < 0.5:
time.sleep(0.5 - diff)
self.last_req_time = time.time()
url = f"{self.base_url}{path}"
# Determine TR ID
curr_tr_id = tr_id
if x_tr_id_buy and x_tr_id_sell:
pass
# Prepare headers
headers = self._get_header(curr_tr_id)
# Execute Request
try:
if method.upper() == "GET":
res = requests.get(url, headers=headers, **kwargs)
else:
res = requests.post(url, headers=headers, **kwargs)
# Check for Token Expiration (EGW00123)
if res.status_code == 200:
data = res.json()
if isinstance(data, dict):
msg_cd = data.get('msg_cd', '')
if msg_cd == 'EGW00123': # Expired Token
logger.warning("Token expired (EGW00123). Refreshing token and retrying...")
# Remove token file
if os.path.exists(self.token_file):
os.remove(self.token_file)
# Re-auth
self._auth()
# Update headers with new token
headers = self._get_header(curr_tr_id)
# Retry
if method.upper() == "GET":
res = requests.get(url, headers=headers, **kwargs)
else:
res = requests.post(url, headers=headers, **kwargs)
return res.json()
return data
else:
logger.error(f"API Request Failed [{res.status_code}]: {res.text}")
return None
except Exception as e:
logger.error(f"Request Exception: {e}")
return None
def get_current_price(self, code):
"""
inquire-price
"""
tr_id = "FHKST01010100"
params = {
"FID_COND_MRKT_DIV_CODE": "J",
"FID_INPUT_ISCD": code
}
res = self._request("GET", "/uapi/domestic-stock/v1/quotations/inquire-price", tr_id=tr_id, params=params)
return res.get('output', {}) if res else None
def get_balance(self, source=None):
"""
inquire-balance
Cached for 2 seconds to prevent rate limit (EGW00201)
"""
# Check cache
now = time.time()
if hasattr(self, '_balance_cache') and self._balance_cache:
last_time, data = self._balance_cache
if now - last_time < 2.0: # 2 Seconds TTL
return data
log_source = f" [Source: {source}]" if source else ""
logger.info(f"get_balance{log_source}: Account={self.account_no}, Paper={self.is_paper}")
tr_id = "VTTC8434R" if self.is_paper else "TTTC8434R"
params = {
"CANO": self.account_no,
"ACNT_PRDT_CD": self.account_prod,
"AFHR_FLPR_YN": "N",
"OFL_YN": "",
"INQR_DVSN": "02", # 02: By Stock
"UNPR_DVSN": "01",
"FUND_STTL_ICLD_YN": "N",
"FNCG_AMT_AUTO_RDPT_YN": "N",
"PRCS_DVSN": "00",
"CTX_AREA_FK100": "",
"CTX_AREA_NK100": ""
}
res = self._request("GET", "/uapi/domestic-stock/v1/trading/inquire-balance", tr_id=tr_id, params=params)
# Update Cache if success
if res and res.get('rt_cd') == '0':
self._balance_cache = (now, res)
return res
def get_overseas_balance(self, exchange="NASD"):
"""
overseas-stock inquire-balance
"""
# For overseas, we need to know the exchange code often, but inquire-balance might return all if configured?
# Typically requires an exchange code in some params or TR IDs specific to exchange?
# Looking at docs: TTTS3012R is for "Overseas Stock Balance"
tr_id = "VTTS3012R" if self.is_paper else "TTTS3012R"
# Paper env TR ID is tricky, usually V... but let's assume VTTS3012R or JTTT3012R?
# Common pattern: Real 'T...' -> Paper 'V...'
params = {
"CANO": self.account_no,
"ACNT_PRDT_CD": self.account_prod,
"OVRS_EXCG_CD": exchange, # NASD, NYSE, AMEX, HKS, TSE, etc.
"TR_CRCY_CD": "USD", # Transaction Currency
"CTX_AREA_FK200": "",
"CTX_AREA_NK200": ""
}
return self._request("GET", "/uapi/overseas-stock/v1/trading/inquire-balance", tr_id=tr_id, params=params)
def place_order(self, code, type, qty, price):
"""
order-cash
type: 'buy' or 'sell'
"""
if self.is_paper:
tr_id = "VTTC0012U" if type == 'buy' else "VTTC0011U"
else:
tr_id = "TTTC0012U" if type == 'buy' else "TTTC0011U"
# 00: Limit (Specified Price), 01: Market Price
ord_dvsn = "00" if int(price) > 0 else "01"
body = {
"CANO": self.account_no,
"ACNT_PRDT_CD": self.account_prod,
"PDNO": code,
"ORD_DVSN": ord_dvsn,
"ORD_QTY": str(qty),
"ORD_UNPR": str(price),
"EXCG_ID_DVSN_CD": "KRX", # KRX for Exchange
"SLL_TYPE": "01", # 01: Normal Sell
"CNDT_PRIC": ""
}
return self._request("POST", "/uapi/domestic-stock/v1/trading/order-cash", tr_id=tr_id, data=json.dumps(body))
def place_overseas_order(self, code, type, qty, price, market="NASD"):
"""
overseas-stock order
"""
# Checks for Paper vs Real and Buy vs Sell
# TR_ID might vary by country. Assuming US (NASD, NYSE, AMEX).
if self.is_paper:
tr_id = "VTTT1002U" if type == 'buy' else "VTTT1006U"
else:
# US Real: JTTT1002U (Buy), JTTT1006U (Sell)
# Note: This TR ID is for US Night Market (Main).
tr_id = "JTTT1002U" if type == 'buy' else "JTTT1006U"
# Price '0' or empty is usually Market Price, but overseas api often requires specific handling.
# US Market Price Order usually uses ord_dvsn="00" and price="0" or empty?
# KIS Docs: US Limit="00", Market="00"?? No, usually "00" is Limit.
# "32": LOO, "33": LOC, "34": MOO, "35": MOC...
# Let's stick to Limit ("00") for now. If price is 0, user might mean Market, but US API requires price for limit.
# If user sends 0, let's try to assume "00" (Limit) with price 0 (might fail) or valid Market ("01"? No).
# Safe bet: US Market Order is "00" (Limit) with Price 0? No.
# Use "00" (Limit) as default. If price is 0, we can't easily do Market on US via standard "01" like domestic.
ord_dvsn = "00" # Limit
body = {
"CANO": self.account_no,
"ACNT_PRDT_CD": self.account_prod,
"OVRS_EXCG_CD": market,
"PDNO": code,
"ORD_QTY": str(qty),
"OVRS_ORD_UNPR": str(price),
"ORD_SVR_DVSN_CD": "0",
"ORD_DVSN": ord_dvsn
}
return self._request("POST", "/uapi/overseas-stock/v1/trading/order", tr_id=tr_id, data=json.dumps(body))
def get_daily_orders(self, start_date=None, end_date=None, expanded_code=None):
"""
inquire-daily-ccld
"""
if self.is_paper:
tr_id = "VTTC0081R" # 3-month inner
else:
tr_id = "TTTC0081R" # 3-month inner
if not start_date:
start_date = datetime.datetime.now().strftime("%Y%m%d")
if not end_date:
end_date = datetime.datetime.now().strftime("%Y%m%d")
params = {
"CANO": self.account_no,
"ACNT_PRDT_CD": self.account_prod,
"INQR_STRT_DT": start_date,
"INQR_END_DT": end_date,
"SLL_BUY_DVSN_CD": "00", # All
"PDNO": "",
"CCLD_DVSN": "00", # All (Executed + Unexecuted)
"INQR_DVSN": "00", # Reverse Order
"INQR_DVSN_3": "00", # All
"ORD_GNO_BRNO": "",
"ODNO": "",
"INQR_DVSN_1": "",
"CTX_AREA_FK100": "",
"CTX_AREA_NK100": ""
}
return self._request("GET", "/uapi/domestic-stock/v1/trading/inquire-daily-ccld", tr_id=tr_id, params=params)
def get_cancelable_orders(self):
"""
inquire-psbl-rvsecncl
"""
tr_id = "VTTC0084R" if self.is_paper else "TTTC0084R"
params = {
"CANO": self.account_no,
"ACNT_PRDT_CD": self.account_prod,
"INQR_DVSN_1": "0", # 0: Order No order
"INQR_DVSN_2": "0", # 0: All
"CTX_AREA_FK100": "",
"CTX_AREA_NK100": ""
}
return self._request("GET", "/uapi/domestic-stock/v1/trading/inquire-psbl-rvsecncl", tr_id=tr_id, params=params)
def cancel_order(self, org_no, order_no, qty, is_buy, price="0", total=True):
"""
order-rvsecncl
"""
if self.is_paper:
tr_id = "VTTC0013U"
else:
tr_id = "TTTC0013U"
rvse_cncl_dvsn_cd = "02"
qty_all = "Y" if total else "N"
body = {
"CANO": self.account_no,
"ACNT_PRDT_CD": self.account_prod,
"KRX_FWDG_ORD_ORGNO": org_no,
"ORGN_ODNO": order_no,
"ORD_DVSN": "00",
"RVSE_CNCL_DVSN_CD": rvse_cncl_dvsn_cd,
"ORD_QTY": str(qty),
"ORD_UNPR": str(price),
"QTY_ALL_ORD_YN": qty_all,
"EXCG_ID_DVSN_CD": "KRX"
}
return self._request("POST", "/uapi/domestic-stock/v1/trading/order-rvsecncl", tr_id=tr_id, data=json.dumps(body))
# Singleton Instance
kis = KisApi()

313
backend/main.py Normal file
View File

@@ -0,0 +1,313 @@
import os
import uvicorn
import threading
import asyncio
from fastapi import FastAPI, HTTPException, Depends, Body, Header, BackgroundTasks
import fastapi
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, JSONResponse
from sqlalchemy.orm import Session
from typing import List, Optional
from database import init_db, get_db, TradeSetting, Order, News, Stock, Watchlist
from config import load_config, save_config, get_kis_config
from kis_api import kis
from trader import trader
from news_ai import news_bot
from telegram_notifier import notifier
import logging
logger = logging.getLogger("MAIN")
app = FastAPI()
# Get absolute path to the directory where main.py is located
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# Root project directory (one level up from backend)
PROJECT_ROOT = os.path.dirname(BASE_DIR)
# Frontend directory
FRONTEND_DIR = os.path.join(PROJECT_ROOT, "frontend")
# Create frontend directory if it doesn't exist (for safety)
if not os.path.exists(FRONTEND_DIR):
os.makedirs(FRONTEND_DIR)
# Mount static files
app.mount("/static", StaticFiles(directory=FRONTEND_DIR), name="static")
from websocket_manager import ws_manager
from fastapi import WebSocket
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await ws_manager.connect_frontend(websocket)
try:
while True:
# Keep alive or handle frontend formatting
data = await websocket.receive_text()
# If frontend sends "subscribe:CODE", we could forward to KIS
if data.startswith("sub:"):
code = data.split(":")[1]
await ws_manager.subscribe_stock(code)
except:
ws_manager.disconnect_frontend(websocket)
@app.on_event("startup")
def startup_event():
init_db()
# Start Background Threads
trader.start()
news_bot.start()
# Start KIS WebSocket Loop
# We need a way to run async loop in bg.
# Uvicorn runs in asyncio loop. We can create task.
asyncio.create_task(ws_manager.start_kis_socket())
notifier.send_message("🚀 KisStock AI 시스템이 시작되었습니다.")
@app.on_event("shutdown")
def shutdown_event():
notifier.send_message("🛑 KisStock AI 시스템이 종료됩니다.")
trader.stop()
news_bot.stop()
# --- Pages ---
@app.get("/")
async def read_index():
return FileResponse(os.path.join(FRONTEND_DIR, "index.html"))
@app.get("/news")
async def read_news():
return FileResponse(os.path.join(FRONTEND_DIR, "news.html"))
@app.get("/stocks")
async def read_stocks():
return FileResponse(os.path.join(FRONTEND_DIR, "stocks.html"))
@app.get("/settings")
async def read_settings():
return FileResponse(os.path.join(FRONTEND_DIR, "settings.html"))
@app.get("/trade")
async def read_trade():
return FileResponse(os.path.join(FRONTEND_DIR, "trade.html"))
# --- API ---
from master_loader import master_loader
from database import Watchlist
@app.get("/api/sync/status")
def get_sync_status():
return master_loader.get_status()
@app.post("/api/sync/master")
def sync_master_data(background_tasks: fastapi.BackgroundTasks):
# Run in background
background_tasks.add_task(master_loader.download_and_parse_domestic)
background_tasks.add_task(master_loader.download_and_parse_overseas)
return {"status": "started", "message": "Master data sync started in background"}
@app.get("/api/stocks")
def search_stocks(keyword: str = "", market: str = "", page: int = 1, db: Session = Depends(get_db)):
query = db.query(Stock)
if keyword:
query = query.filter(Stock.name.contains(keyword) | Stock.code.contains(keyword))
if market:
query = query.filter(Stock.market == market)
limit = 50
offset = (page - 1) * limit
items = query.limit(limit).offset(offset).all()
return {"items": items}
@app.get("/api/watchlist")
def get_watchlist(db: Session = Depends(get_db)):
return db.query(Watchlist).order_by(Watchlist.created_at.desc()).all()
@app.post("/api/watchlist")
def add_watchlist(code: str = Body(...), name: str = Body(...), market: str = Body(...), db: Session = Depends(get_db)):
exists = db.query(Watchlist).filter(Watchlist.code == code).first()
if exists: return {"status": "exists"}
item = Watchlist(code=code, name=name, market=market)
db.add(item)
db.commit()
return {"status": "added"}
@app.delete("/api/watchlist/{code}")
def delete_watchlist(code: str, db: Session = Depends(get_db)):
db.query(Watchlist).filter(Watchlist.code == code).delete()
db.commit()
return {"status": "deleted"}
@app.get("/api/settings")
def get_settings():
return load_config()
@app.post("/api/settings")
def update_settings(settings: dict = Body(...)):
save_config(settings)
return {"status": "ok"}
from database import AccountBalance, Holding
import datetime
@app.get("/api/balance")
def get_my_balance(source: str = "db", db: Session = Depends(get_db)):
"""
Return persisted balance and holdings from DB.
Structure similar to KIS API to minimize frontend changes, or simplified.
"""
# 1. Balance Summary
acc = db.query(AccountBalance).first()
output2 = []
if acc:
output2.append({
"tot_evlu_amt": acc.total_eval,
"dnca_tot_amt": acc.deposit,
"evlu_pfls_smtl_amt": acc.total_profit
})
# 2. Holdings (Domestic)
holdings = db.query(Holding).filter(Holding.market == 'DOMESTIC').all()
output1 = []
for h in holdings:
output1.append({
"pdno": h.code,
"prdt_name": h.name,
"hldg_qty": h.quantity,
"prpr": h.current_price,
"pchs_avg_pric": h.price,
"evlu_pfls_rt": h.profit_rate
})
return {
"rt_cd": "0",
"msg1": "Success from DB",
"output1": output1,
"output2": output2
}
@app.get("/api/balance/overseas")
def get_my_overseas_balance(db: Session = Depends(get_db)):
# Persisted Overseas Holdings
holdings = db.query(Holding).filter(Holding.market == 'NASD').all()
output1 = []
for h in holdings:
output1.append({
"ovrs_pdno": h.code,
"ovrs_item_name": h.name,
"ovrs_cblc_qty": h.quantity,
"now_pric2": h.current_price,
"frcr_pchs_amt1": h.price,
"evlu_pfls_rt": h.profit_rate
})
return {
"rt_cd": "0",
"output1": output1
}
@app.post("/api/balance/refresh")
def force_refresh_balance(background_tasks: BackgroundTasks):
trader.refresh_assets() # Run synchronously to return fresh data immediately? Or BG?
# User perception: "Loading..." -> Show data.
# If we run BG, frontend needs to poll.
# Let's run Sync for "Refresh" button (unless it takes too long).
# KIS API is reasonably fast (milliseconds).
# trader.refresh_assets()
return {"status": "ok"}
@app.get("/api/price/{code}")
def get_stock_price(code: str):
price = kis.get_current_price(code)
if price:
return price
raise HTTPException(status_code=404, detail="Stock info not found")
@app.post("/api/order")
def place_order_api(
code: str = Body(...),
type: str = Body(...),
qty: int = Body(...),
price: int = Body(...),
market: str = Body("DOMESTIC")
):
if market in ["NASD", "NYSE", "AMEX"]:
res = kis.place_overseas_order(code, type, qty, price, market)
else:
res = kis.place_order(code, type, qty, price)
if res and res.get('rt_cd') == '0':
return res
raise HTTPException(status_code=400, detail=f"Order Failed: {res}")
@app.get("/api/orders")
def get_db_orders(db: Session = Depends(get_db)):
orders = db.query(Order).order_by(Order.created_at.desc()).limit(50).all()
return orders
@app.get("/api/orders/daily")
def get_daily_orders_api():
res = kis.get_daily_orders()
if res:
return res
raise HTTPException(status_code=500, detail="Failed to fetch daily orders")
@app.get("/api/orders/cancelable")
def get_cancelable_orders_api():
res = kis.get_cancelable_orders()
if res:
return res
raise HTTPException(status_code=500, detail="Failed to fetch cancelable orders")
@app.post("/api/order/cancel")
def cancel_order_api(
org_no: str = Body(...),
order_no: str = Body(...),
qty: int = Body(...),
is_buy: bool = Body(...),
price: int = Body(0)
):
res = kis.cancel_order(org_no, order_no, qty, is_buy, price)
if res and res.get('rt_cd') == '0':
return res
raise HTTPException(status_code=400, detail=f"Cancel Failed: {res}")
@app.get("/api/news")
def get_news(db: Session = Depends(get_db)):
news = db.query(News).order_by(News.created_at.desc()).limit(20).all()
return news
@app.get("/api/trade_settings")
def get_trade_settings(db: Session = Depends(get_db)):
return db.query(TradeSetting).all()
@app.post("/api/trade_settings")
def set_trade_setting(
code: str = Body(...),
target_price: Optional[float] = Body(None),
stop_loss_price: Optional[float] = Body(None),
is_active: bool = Body(True),
db: Session = Depends(get_db)
):
setting = db.query(TradeSetting).filter(TradeSetting.code == code).first()
if not setting:
setting = TradeSetting(code=code)
db.add(setting)
if target_price is not None:
setting.target_price = target_price
if stop_loss_price is not None:
setting.stop_loss_price = stop_loss_price
setting.is_active = is_active
db.commit()
return {"status": "ok"}
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)

253
backend/master_loader.py Normal file
View File

@@ -0,0 +1,253 @@
import os
import requests
import zipfile
import io
import pandas as pd
from database import SessionLocal, Stock, engine
from sqlalchemy.orm import Session
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("MASTER_LOADER")
class MasterLoader:
def __init__(self):
self.base_dir = os.path.dirname(os.path.abspath(__file__))
self.tmp_dir = os.path.join(self.base_dir, "tmp_master")
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)
self.sync_status = {"status": "idle", "message": ""}
def get_status(self):
return self.sync_status
def _set_status(self, status, message):
self.sync_status = {"status": status, "message": message}
logger.info(f"Sync Status: {status} - {message}")
def download_and_parse_domestic(self):
self._set_status("running", "Downloading Domestic Master...")
urls = {
"kospi": "https://new.real.download.dws.co.kr/common/master/kospi_code.mst.zip",
"kosdaq": "https://new.real.download.dws.co.kr/common/master/kosdaq_code.mst.zip"
}
db = SessionLocal()
try:
for market, url in urls.items():
logger.info(f"Downloading {market} master data from {url}...")
try:
res = requests.get(url)
if res.status_code != 200:
logger.error(f"Failed to download {market} master")
self._set_status("error", f"Failed to download {market}")
continue
with zipfile.ZipFile(io.BytesIO(res.content)) as z:
filename = f"{market}_code.mst"
z.extract(filename, self.tmp_dir)
file_path = os.path.join(self.tmp_dir, filename)
self._parse_domestic_file(file_path, market.upper(), db)
except Exception as e:
logger.error(f"Error processing {market}: {e}")
self._set_status("error", f"Error processing {market}: {e}")
db.commit()
if self.sync_status['status'] != 'error':
self._set_status("running", "Domestic Sync Complete")
finally:
db.close()
def _parse_domestic_file(self, file_path, market_name, db: Session):
with open(file_path, 'r', encoding='cp949') as f:
lines = f.readlines()
logger.info(f"Parsing {len(lines)} lines for {market_name}...")
batch = []
for line in lines:
code = line[0:9].strip()
name = line[21:61].strip()
if not code or not name:
continue
batch.append({
"code": code,
"name": name,
"market": market_name,
"type": "DOMESTIC"
})
if len(batch) >= 1000:
self._upsert_batch(db, batch)
batch = []
if batch:
self._upsert_batch(db, batch)
def download_and_parse_overseas(self):
if self.sync_status['status'] == 'error': return
self._set_status("running", "Downloading Overseas Master...")
# NASDAQ from text file
urls = {
"NASD": "https://www.nasdaqtrader.com/dynamic/symdir/nasdaqlisted.txt",
# "NYSE": "https://new.real.download.dws.co.kr/common/master/usa_nys.mst.zip",
# "AMEX": "https://new.real.download.dws.co.kr/common/master/usa_ams.mst.zip"
}
db = SessionLocal()
error_count = 0
try:
for market, url in urls.items():
logger.info(f"Downloading {market} master data from {url}...")
try:
res = requests.get(url)
logger.info(f"HTTP Status: {res.status_code}")
if res.status_code != 200:
logger.error(f"Download failed for {market}. Status: {res.status_code}")
error_count += 1
continue
if url.endswith('.txt'):
self._parse_nasdaq_txt(res.text, market, db)
else:
with zipfile.ZipFile(io.BytesIO(res.content)) as z:
target_file = None
for f in z.namelist():
if f.endswith(".mst"):
target_file = f
break
if target_file:
z.extract(target_file, self.tmp_dir)
file_path = os.path.join(self.tmp_dir, target_file)
self._parse_overseas_file(file_path, market, db)
except Exception as e:
logger.error(f"Error processing {market}: {e}")
error_count += 1
db.commit()
if error_count == len(urls):
self._set_status("error", "All overseas downloads failed.")
elif error_count > 0:
self._set_status("warning", f"Overseas Sync Partial ({error_count} failed).")
else:
self._set_status("done", "All Sync Complete.")
finally:
db.close()
def _parse_nasdaq_txt(self, content, market_name, db: Session):
# Format: Symbol|Security Name|Market Category|Test Issue|Financial Status|Round Lot Size|ETF|NextShares
lines = content.splitlines()
logger.info(f"Parsing {len(lines)} lines for {market_name} (TXT)...")
batch = []
parsed_count = 0
for line in lines:
try:
if not line or line.startswith('Symbol|') or line.startswith('File Creation Time'):
continue
parts = line.split('|')
if len(parts) < 7: continue
symbol = parts[0]
name = parts[1]
# market_category = parts[2]
financial_status = parts[4] # N=Normal, D=Deficient, E=Delinquent, Q=Bankrupt, G=Deficient and Bankrupt
etf_flag = parts[6] # Y/N
is_etf = (etf_flag == 'Y')
batch.append({
"code": symbol,
"name": name,
"name_eng": name,
"market": market_name,
"type": "OVERSEAS",
"financial_status": financial_status,
"is_etf": is_etf
})
if len(batch) >= 1000:
self._upsert_batch(db, batch)
parsed_count += len(batch)
batch = []
except Exception as e:
# logger.error(f"Parse error: {e}")
continue
if batch:
self._upsert_batch(db, batch)
parsed_count += len(batch)
logger.info(f"Parsed and Upserted {parsed_count} items for {market_name}")
def _parse_overseas_file(self, file_path, market_name, db: Session):
with open(file_path, 'r', encoding='cp949', errors='ignore') as f:
lines = f.readlines()
logger.info(f"Parsing {len(lines)} lines for {market_name}... (File: {os.path.basename(file_path)})")
batch = []
parsed_count = 0
for line in lines:
try:
b_line = line.encode('cp949')
symbol = b_line[0:16].decode('cp949').strip()
name_eng = b_line[16:80].decode('cp949').strip()
if not symbol: continue
batch.append({
"code": symbol,
"name": name_eng,
"name_eng": name_eng,
"market": market_name,
"type": "OVERSEAS"
})
if len(batch) >= 1000:
self._upsert_batch(db, batch)
parsed_count += len(batch)
batch = []
except Exception as e:
continue
if batch:
self._upsert_batch(db, batch)
parsed_count += len(batch)
logger.info(f"Parsed and Upserted {parsed_count} items for {market_name}")
def _upsert_batch(self, db: Session, batch):
for item in batch:
existing = db.query(Stock).filter(Stock.code == item['code']).first()
if existing:
existing.name = item['name']
existing.market = item['market']
existing.type = item['type']
if 'name_eng' in item: existing.name_eng = item['name_eng']
if 'financial_status' in item: existing.financial_status = item['financial_status']
if 'is_etf' in item: existing.is_etf = item['is_etf']
else:
stock = Stock(**item)
db.add(stock)
db.commit()
master_loader = MasterLoader()
if __name__ == "__main__":
print("Starting sync...")
master_loader.download_and_parse_domestic()
print("Domestic Done. Starting Overseas...")
master_loader.download_and_parse_overseas()
print("Sync Complete.")

37
backend/migrate_db.py Normal file
View File

@@ -0,0 +1,37 @@
import sqlite3
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DB_PATH = os.path.join(os.path.dirname(BASE_DIR), 'kis_stock.db')
def migrate():
print(f"Migrating database at {DB_PATH}...")
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
# Check if columns exist, if not add them
try:
cursor.execute("ALTER TABLE stocks ADD COLUMN name_eng VARCHAR")
print("Added name_eng")
except Exception as e:
print(f"Skipping name_eng: {e}")
try:
cursor.execute("ALTER TABLE stocks ADD COLUMN industry VARCHAR")
print("Added industry")
except Exception as e:
print(f"Skipping industry: {e}")
try:
cursor.execute("ALTER TABLE stocks ADD COLUMN type VARCHAR DEFAULT 'DOMESTIC'")
print("Added type")
except Exception as e:
print(f"Skipping type: {e}")
conn.commit()
conn.close()
print("Migration complete.")
if __name__ == "__main__":
migrate()

57
backend/migrate_db_v2.py Normal file
View File

@@ -0,0 +1,57 @@
import sqlite3
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DB_PATH = os.path.join(os.path.dirname(BASE_DIR), 'kis_stock.db')
def migrate():
print(f"Migrating database v2 at {DB_PATH}...")
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
# Financial Status (Health)
try:
cursor.execute("ALTER TABLE stocks ADD COLUMN financial_status VARCHAR")
print("Added financial_status")
except Exception as e:
print(f"Skipping financial_status: {e}")
# ETF Facet
try:
cursor.execute("ALTER TABLE stocks ADD COLUMN is_etf BOOLEAN DEFAULT 0")
print("Added is_etf")
except Exception as e:
print(f"Skipping is_etf: {e}")
# Current Price (Cache)
try:
cursor.execute("ALTER TABLE stocks ADD COLUMN current_price FLOAT DEFAULT 0")
print("Added current_price")
except Exception as e:
print(f"Skipping current_price: {e}")
# Stock Price History Table
try:
cursor.execute("""
CREATE TABLE IF NOT EXISTS stock_prices (
id INTEGER PRIMARY KEY AUTOINCREMENT,
code VARCHAR NOT NULL,
price FLOAT,
change FLOAT,
volume INTEGER,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
""")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_stock_prices_code ON stock_prices (code)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_stock_prices_created_at ON stock_prices (created_at)")
print("Created stock_prices table")
except Exception as e:
print(f"Error creating stock_prices: {e}")
conn.commit()
conn.close()
print("Migration v2 complete.")
if __name__ == "__main__":
migrate()

18
backend/migrate_db_v3.py Normal file
View File

@@ -0,0 +1,18 @@
from sqlalchemy import create_engine, text
import os
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DB_URL = f"sqlite:///{os.path.join(BASE_DIR, 'kis_stock.db')}"
engine = create_engine(DB_URL)
def migrate():
with engine.connect() as conn:
try:
conn.execute(text("ALTER TABLE watchlist ADD COLUMN is_monitoring BOOLEAN DEFAULT 1"))
print("Added is_monitoring column to watchlist")
except Exception as e:
print(f"Column might already exist: {e}")
if __name__ == "__main__":
migrate()

147
backend/news_ai.py Normal file
View File

@@ -0,0 +1,147 @@
import time
import threading
import logging
import requests
import json
from sqlalchemy.orm import Session
from database import SessionLocal, News, Stock
from config import get_kis_config, load_config
logger = logging.getLogger("NEWS_AI")
class NewsBot:
def __init__(self):
self.is_running = False
self.thread = None
self.config = load_config()
self.naver_id = self.config.get('naver', {}).get('client_id', '')
self.naver_secret = self.config.get('naver', {}).get('client_secret', '')
self.google_key = self.config.get('google', {}).get('api_key', '')
def start(self):
if self.is_running:
return
self.is_running = True
self.thread = threading.Thread(target=self._run_loop, daemon=True)
self.thread.start()
logger.info("News Bot Started")
def stop(self):
self.is_running = False
if self.thread:
self.thread.join()
logger.info("News Bot Stopped")
def _run_loop(self):
while self.is_running:
try:
# Reload config to check current settings
self.config = load_config()
if self.config.get('preferences', {}).get('enable_news', False):
self._fetch_and_analyze()
else:
logger.info("News collection is disabled.")
except Exception as e:
logger.error(f"Error in news loop: {e}")
# Sleep 10 minutes (600 seconds)
for _ in range(600):
if not self.is_running: break
time.sleep(1)
def _fetch_and_analyze(self):
logger.info("Fetching News...")
if not self.naver_id or not self.naver_secret:
logger.warning("Naver API Credentials missing.")
return
# 1. Fetch News (Naver)
# Search for generic economy terms or specific watchlist
query = "주식 시장" # General Stock Market
url = "https://openapi.naver.com/v1/search/news.json"
headers = {
"X-Naver-Client-Id": self.naver_id,
"X-Naver-Client-Secret": self.naver_secret
}
params = {"query": query, "display": 10, "sort": "date"}
res = requests.get(url, headers=headers, params=params)
if res.status_code != 200:
logger.error(f"Naver News Failed: {res.text}")
return
items = res.json().get('items', [])
db = SessionLocal()
try:
for item in items:
title = item['title']
link = item['originallink'] or item['link']
pub_date = item['pubDate']
# Check duplication
if db.query(News).filter(News.link == link).first():
continue
# 2. AI Analysis (Google Gemini)
analysis = self._analyze_with_ai(title, item['description'])
# Save to DB
news = News(
title=title,
link=link,
pub_date=pub_date,
analysis_result=analysis.get('summary', ''),
impact_score=analysis.get('score', 0),
related_sector=analysis.get('sector', '')
)
db.add(news)
db.commit()
logger.info(f"Processed {len(items)} news items.")
finally:
db.close()
def _analyze_with_ai(self, title, description):
if not self.google_key:
return {"summary": "No API Key", "score": 0, "sector": ""}
logger.info(f"Analyzing: {title[:30]}...")
# Prompt
prompt = f"""
Analyze the following news for stock market impact.
Title: {title}
Description: {description}
Return JSON format:
{{
"summary": "One line summary of impact",
"score": Integer between -10 (Negative) to 10 (Positive),
"sector": "Related Industry/Sector or 'None'"
}}
"""
url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key={self.google_key}"
headers = {"Content-Type": "application/json"}
body = {
"contents": [{
"parts": [{"text": prompt}]
}]
}
try:
res = requests.post(url, headers=headers, data=json.dumps(body))
if res.status_code == 200:
result = res.json()
text = result['candidates'][0]['content']['parts'][0]['text']
# Clean markdown json if any
text = text.replace("```json", "").replace("```", "").strip()
return json.loads(text)
else:
logger.error(f"Gemini API Error: {res.text}")
except Exception as e:
logger.error(f"AI Analysis Exception: {e}")
return {"summary": "Error", "score": 0, "sector": ""}
news_bot = NewsBot()

View File

@@ -0,0 +1,43 @@
import requests
import logging
from config import load_config
logger = logging.getLogger("TELEGRAM")
class TelegramNotifier:
def __init__(self):
self.config = load_config()
self.bot_token = self.config.get('telegram', {}).get('bot_token', '')
self.chat_id = self.config.get('telegram', {}).get('chat_id', '')
def reload_config(self):
self.config = load_config()
self.bot_token = self.config.get('telegram', {}).get('bot_token', '')
self.chat_id = self.config.get('telegram', {}).get('chat_id', '')
def send_message(self, text):
# Reload to ensure we have latest from settings
self.reload_config()
# Check if enabled
if not self.config.get('preferences', {}).get('enable_telegram', True):
return
if not self.bot_token or not self.chat_id:
logger.warning("Telegram credentials missing.")
return
url = f"https://api.telegram.org/bot{self.bot_token}/sendMessage"
payload = {
"chat_id": self.chat_id,
"text": text
}
try:
res = requests.post(url, json=payload, timeout=5)
if res.status_code != 200:
logger.error(f"Telegram Send Failed: {res.text}")
except Exception as e:
logger.error(f"Telegram Error: {e}")
notifier = TelegramNotifier()

View File

@@ -0,0 +1,29 @@
import sys
import os
# Ensure current dir is in path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from database import init_db, SessionLocal, AccountBalance, Holding, engine
from trader import trader
import logging
logging.basicConfig(level=logging.INFO)
def test_db_migration():
print("Initializing DB...")
init_db()
# Check tables
from sqlalchemy import inspect
inspector = inspect(engine)
tables = inspector.get_table_names()
print(f"Tables: {tables}")
if "account_balance" in tables and "holdings" in tables:
print("PASS: New tables created.")
else:
print("FAIL: Tables missing.")
if __name__ == "__main__":
test_db_migration()

View File

@@ -0,0 +1,43 @@
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from telegram_notifier import notifier
from config import save_config, load_config
def test_telegram_toggle():
print("Testing Telegram Toggle...")
original_config = load_config()
try:
# 1. Enable
print("1. Testing ENABLED...")
cfg = load_config()
if 'preferences' not in cfg: cfg['preferences'] = {}
cfg['preferences']['enable_telegram'] = True
save_config(cfg)
# We can't easily mock requests.post here without importing mock,
# but we can check if it attempts to read credentials.
# Ideally, we'd check if it returns early.
# For this environment, let's just ensure no crash.
notifier.send_message("Test Message (Should Send)")
# 2. Disable
print("2. Testing DISABLED...")
cfg['preferences']['enable_telegram'] = False
save_config(cfg)
# This should return early and NOT log "Telegram credentials missing" if implemented right.
notifier.send_message("Test Message (Should NOT Send)")
print("Toggle logic executed without error.")
finally:
# Restore
save_config(original_config)
print("Original config restored.")
if __name__ == "__main__":
test_telegram_toggle()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

214
backend/trader.py Normal file
View File

@@ -0,0 +1,214 @@
import time
import threading
import logging
from sqlalchemy.orm import Session
from database import SessionLocal, TradeSetting, Order, Stock, AccountBalance, Holding
import datetime
from kis_api import kis
from telegram_notifier import notifier
logger = logging.getLogger("TRADER")
class TradingBot:
def __init__(self):
self.is_running = False
self.thread = None
self.holdings = {} # Local cache for holdings: {code: {qty: int, price: float}}
self.last_chart_update = 0
def refresh_assets(self):
"""
Fetch Balance and Holdings from KIS and save to DB
"""
logger.info("Syncing Assets to Database...")
db = SessionLocal()
try:
# 1. Domestic Balance
balance = kis.get_balance(source="Automated_Sync")
if balance and 'output2' in balance and balance['output2']:
summary = balance['output2'][0]
# Upsert AccountBalance
fn_status = db.query(AccountBalance).first()
if not fn_status:
fn_status = AccountBalance()
db.add(fn_status)
fn_status.total_eval = float(summary['tot_evlu_amt'])
fn_status.deposit = float(summary['dnca_tot_amt'])
fn_status.total_profit = float(summary['evlu_pfls_smtl_amt'])
fn_status.updated_at = datetime.datetime.now()
# 2. Holdings (Domestic)
# Clear existing DOMESTIC
db.query(Holding).filter(Holding.market == 'DOMESTIC').delete()
if balance and 'output1' in balance:
self.holdings = {} # Keep memory cache for trading logic
for item in balance['output1']:
code = item['pdno']
qty = int(item['hldg_qty'])
if qty > 0:
buy_price = float(item['pchs_avg_pric'])
current_price = float(item['prpr'])
profit_rate = float(item['evlu_pfls_rt'])
# Save to DB
db.add(Holding(
code=code,
name=item['prdt_name'],
quantity=qty,
price=buy_price,
current_price=current_price,
profit_rate=profit_rate,
market="DOMESTIC"
))
# Memory Cache for Trade Logic
self.holdings[code] = {'qty': qty, 'price': buy_price}
# 3. Overseas Balance (NASD default)
# TODO: Multi-market support if needed
overseas = kis.get_overseas_balance(exchange="NASD")
# Clear existing NASD
db.query(Holding).filter(Holding.market == 'NASD').delete()
if overseas and 'output1' in overseas:
for item in overseas['output1']:
qty = float(item['ovrs_cblc_qty']) # Overseas can be fractional? KIS is usually int but check.
if qty > 0:
code = item['ovrs_pdno']
# name = item.get('ovrs_item_name') or item.get('prdt_name')
# KIS overseas output keys vary.
db.add(Holding(
code=code,
name=item.get('ovrs_item_name', code),
quantity=int(qty),
price=float(item.get('frcr_pchs_amt1', 0)), # Avg Price? Check API
current_price=float(item.get('now_pric2', 0)),
profit_rate=float(item.get('evlu_pfls_rt', 0)),
market="NASD"
))
db.commit()
logger.info("Assets Synced Successfully.")
except Exception as e:
logger.error(f"Failed to sync assets: {e}")
db.rollback()
finally:
db.close()
def start(self):
if self.is_running:
return
self.is_running = True
self.refresh_assets() # Fetch on start
self.thread = threading.Thread(target=self._run_loop, daemon=True)
self.thread.start()
logger.info("Trading Bot Started")
def stop(self):
self.is_running = False
if self.thread:
self.thread.join()
logger.info("Trading Bot Stopped")
def _run_loop(self):
while self.is_running:
try:
self._process_cycle()
except Exception as e:
logger.error(f"Error in trading loop: {e}")
# Sleep 1 second to avoid hammering
time.sleep(1)
def _process_cycle(self):
db = SessionLocal()
try:
# Get active trade settings
settings = db.query(TradeSetting).filter(TradeSetting.is_active == True).all()
for setting in settings:
self._check_and_trade(db, setting)
finally:
db.close()
def _check_and_trade(self, db: Session, setting: TradeSetting):
code = setting.code
# Get Current Price
# Optimization: Ideally read from a shared cache from WebSocket
# For now, we still poll price or should use WS logic?
# User said "Websocket... automatic decision".
# But trader.py is isolated.
# For simplicity in this step (removing balance poll), we keep price fetch but remove balance poll.
price_data = kis.get_current_price(code)
if not price_data:
return
current_price = float(price_data.get('stck_prpr', 0))
if current_price == 0:
return
# Check holdings from Cache
if code not in self.holdings:
return # No holdings, nothing to sell (if logic is Sell)
holding = self.holdings[code]
holding_qty = holding['qty']
# SELL Logic
if holding_qty > 0:
# Stop Loss
if setting.stop_loss_price and current_price <= setting.stop_loss_price:
logger.info(f"Stop Loss Triggered for {code}. Price: {current_price}, SL: {setting.stop_loss_price}")
self._place_order(db, code, 'sell', holding_qty, 0) # 0 means Market Price
return
# Target Profit
if setting.target_price and current_price >= setting.target_price:
logger.info(f"Target Price Triggered for {code}. Price: {current_price}, TP: {setting.target_price}")
self._place_order(db, code, 'sell', holding_qty, 0)
return
def _place_order(self, db: Session, code: str, type: str, qty: int, price: int):
logger.info(f"Placing Order: {code} {type} {qty} @ {price}")
res = kis.place_order(code, type, qty, price)
status = "FAILED"
order_id = ""
if res and res.get('rt_cd') == '0':
status = "PENDING"
order_id = res.get('output', {}).get('ODNO', '')
logger.info(f"Order Success: {order_id}")
notifier.send_message(f"🔔 주문 전송 완료\n[{type.upper()}] {code}\n수량: {qty}\n가격: {price if price > 0 else '시장가'}")
# Optimistic Update or Refresh?
# User said "If execution happens, update list".
# We should schedule a refresh.
time.sleep(1) # Wait for execution
self.refresh_assets()
else:
logger.error(f"Order Failed: {res}")
notifier.send_message(f"⚠️ 주문 실패\n[{type.upper()}] {code}\n이유: {res}")
# Record to DB
new_order = Order(
code=code,
order_id=order_id,
type=type.upper(),
price=price,
quantity=qty,
status=status
)
db.add(new_order)
db.commit()
trader = TradingBot()

View File

@@ -0,0 +1,182 @@
import asyncio
import websockets
import json
import logging
import datetime
from typing import List, Set
from database import SessionLocal, Stock, StockPrice
from kis_api import kis
logger = logging.getLogger("WEBSOCKET")
class KisWebSocketManager:
def __init__(self):
self.active_frontend_connections: List[any] = []
self.subscribed_codes: Set[str] = set()
self.running = False
self.approval_key = None
self.msg_queue = asyncio.Queue() # For outgoing subscription requests
# KIS Environment
if kis.is_paper:
self.url = "ws://ops.koreainvestment.com:31000"
else:
self.url = "ws://ops.koreainvestment.com:21000"
# ... (connect/disconnect/broadcast remains same)
# ... (connect/disconnect/broadcast remains same)
# ... (_handle_realtime_data remains same)
async def subscribe_stock(self, code):
if code in self.subscribed_codes:
return
self.subscribed_codes.add(code)
await self.msg_queue.put(code)
logger.info(f"Queued Subscription for {code}")
async def connect_frontend(self, websocket):
await websocket.accept()
self.active_frontend_connections.append(websocket)
logger.info(f"Frontend Client Connected. Total: {len(self.active_frontend_connections)}")
def disconnect_frontend(self, websocket):
if websocket in self.active_frontend_connections:
self.active_frontend_connections.remove(websocket)
logger.info("Frontend Client Disconnected")
async def broadcast_to_frontend(self, message: dict):
# Broadcast to all connected frontend clients
for connection in self.active_frontend_connections:
try:
await connection.send_json(message)
except Exception as e:
logger.error(f"Broadcast error: {e}")
self.disconnect_frontend(connection)
async def start_kis_socket(self):
self.running = True
logger.info(f"Starting KIS WebSocket Service... Target: {self.url}")
while self.running:
# 1. Ensure Approval Key
if not self.approval_key:
self.approval_key = kis.get_websocket_key()
if not self.approval_key:
logger.error("Failed to get WebSocket Approval Key. Retrying in 10s...")
await asyncio.sleep(10)
continue
logger.info(f"Got WS Key: {self.approval_key[:10]}...")
# 2. Connect
try:
# KIS doesn't use standard ping frames often, handle manually or disable auto-ping
async with websockets.connect(self.url, ping_interval=None, open_timeout=20) as ws:
logger.info("Connected to KIS WebSocket Server")
# Process initial subscriptions
for code in self.subscribed_codes:
await self._send_subscription(ws, code)
while self.running:
try:
# 1. Check for incoming data
msg = await asyncio.wait_for(ws.recv(), timeout=0.1)
# PING/PONG (String starting with 0 or 1 usually means data)
if msg[0] in ['0', '1']:
await self._handle_realtime_data(msg)
else:
# JSON Message (System, PINGPONG)
try:
data = json.loads(msg)
if data.get('header', {}).get('tr_id') == 'PINGPONG':
await ws.send(msg) # Echo back
continue
except:
pass
except asyncio.TimeoutError:
pass
except websockets.ConnectionClosed:
logger.warning("KIS WS Closed. Reconnecting...")
break
except Exception as e:
logger.error(f"WS Connection Error: {e}")
# If auth failed (maybe expired key?), clear key to force refresh
# simplified check: if "Approval key" error in exception message?
# For now just retry.
await asyncio.sleep(5)
async def _handle_realtime_data(self, msg: str):
# Format: 0|TR_ID|DATA_CNT|Code^Time^Price...
try:
parts = msg.split('|')
if len(parts) < 4: return
tr_id = parts[1]
data_part = parts[3]
if tr_id == "H0STCNT0": # Domestic Stock Price
# Data format: Code^Time^CurrentPrice^Sign^Change...
# Actually, data_part is delimiter separated.
values = data_part.split('^')
code = values[0]
price = values[2]
change = values[4]
rate = values[5]
# Broadcast
payload = {
"type": "PRICE",
"code": code,
"price": price,
"change": change,
"rate": rate,
"timestamp": datetime.datetime.now().isoformat()
}
await self.broadcast_to_frontend(payload)
# Update DB (Optional? Too frequent writes maybe bad)
# Let's save only significant updates or throttle?
# For now just log/broadcast.
except Exception as e:
logger.error(f"Data Parse Error: {e} | Msg: {msg[:50]}")
async def subscribe_stock(self, code):
if code in self.subscribed_codes:
return
self.subscribed_codes.add(code)
# If socket is active, send subscription (Implementation complexity: need access to active 'ws' object)
# Will handle by restarting connection or using a queue?
# Better: just set it in set, and the main loop will pick it up on reconnect,
# BUT for immediate sub, we need a way to signal the running loop.
# For MVP, let's assume we subscribe on startup or bulk.
# Real-time dynamic sub needs a queue.
logger.info(f"Subscribed to {code} (Pending next reconnect/sweep)")
async def _send_subscription(self, ws, code):
# Domestic Stock Realtime Price: H0STCNT0
body = {
"header": {
"approval_key": self.approval_key,
"custtype": "P",
"tr_type": "1", # 1: Register, 2: Unregister
"content-type": "utf-8"
},
"body": {
"input": {
"tr_id": "H0STCNT0",
"tr_key": code
}
}
}
await ws.send(json.dumps(body))
logger.info(f"Sent Subscription Request for {code}")
ws_manager = KisWebSocketManager()