initial commit
This commit is contained in:
18
backend/config.py
Normal file
18
backend/config.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import os
|
||||
import yaml
|
||||
|
||||
CONFIG_FILE = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "settings.yaml")
|
||||
|
||||
def load_config():
|
||||
if not os.path.exists(CONFIG_FILE):
|
||||
return {}
|
||||
with open(CONFIG_FILE, 'r', encoding='utf-8') as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
def save_config(config_data):
|
||||
with open(CONFIG_FILE, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(config_data, f, allow_unicode=True)
|
||||
|
||||
def get_kis_config():
|
||||
config = load_config()
|
||||
return config.get('kis', {})
|
||||
5285
backend/data/nasdaq.txt
Normal file
5285
backend/data/nasdaq.txt
Normal file
File diff suppressed because it is too large
Load Diff
6979
backend/data/nasdaq_screener_1769693338554.csv
Normal file
6979
backend/data/nasdaq_screener_1769693338554.csv
Normal file
File diff suppressed because it is too large
Load Diff
112
backend/database.py
Normal file
112
backend/database.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from sqlalchemy import create_engine, Column, Integer, String, Float, DateTime, Boolean, Text
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
import datetime
|
||||
import os
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
DB_URL = f"sqlite:///{os.path.join(BASE_DIR, 'kis_stock.db')}"
|
||||
|
||||
engine = create_engine(DB_URL, connect_args={"check_same_thread": False})
|
||||
|
||||
# Enable WAL mode for better concurrency
|
||||
from sqlalchemy import event
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.close()
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base = declarative_base()
|
||||
|
||||
class Stock(Base):
|
||||
__tablename__ = "stocks"
|
||||
code = Column(String, primary_key=True, index=True)
|
||||
name = Column(String, index=True)
|
||||
name_eng = Column(String, nullable=True) # English Name
|
||||
market = Column(String) # KOSPI, KOSDAQ, NASD, NYSE, AMEX
|
||||
sector = Column(String, nullable=True)
|
||||
industry = Column(String, nullable=True) # Detailed Industry
|
||||
type = Column(String, default="DOMESTIC") # DOMESTIC, OVERSEAS
|
||||
financial_status = Column(String, nullable=True) # 'N', 'D', 'E' etc from Nasdaq
|
||||
is_etf = Column(Boolean, default=False)
|
||||
current_price = Column(Float, default=0.0)
|
||||
|
||||
class Watchlist(Base):
|
||||
__tablename__ = "watchlist"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
code = Column(String, index=True)
|
||||
name = Column(String) # Cache name for display
|
||||
market = Column(String)
|
||||
is_monitoring = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.datetime.now)
|
||||
|
||||
class Order(Base):
|
||||
__tablename__ = "orders"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
order_id = Column(String, nullable=True) # KIS Order ID
|
||||
code = Column(String, index=True)
|
||||
type = Column(String) # BUY, SELL
|
||||
price = Column(Float)
|
||||
quantity = Column(Integer)
|
||||
status = Column(String) # PENDING, FILLED, CANCELLED
|
||||
created_at = Column(DateTime, default=datetime.datetime.now)
|
||||
|
||||
class TradeSetting(Base):
|
||||
__tablename__ = "trade_settings"
|
||||
code = Column(String, primary_key=True)
|
||||
target_price = Column(Float, nullable=True)
|
||||
stop_loss_price = Column(Float, nullable=True)
|
||||
trailing_stop_percent = Column(Float, nullable=True)
|
||||
is_active = Column(Boolean, default=False)
|
||||
|
||||
class News(Base):
|
||||
__tablename__ = "news"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
title = Column(String)
|
||||
link = Column(String, unique=True)
|
||||
pub_date = Column(String)
|
||||
analysis_result = Column(Text)
|
||||
impact_score = Column(Integer)
|
||||
related_sector = Column(String)
|
||||
created_at = Column(DateTime, default=datetime.datetime.now)
|
||||
|
||||
class StockPrice(Base):
|
||||
__tablename__ = "stock_prices"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
code = Column(String, index=True)
|
||||
price = Column(Float)
|
||||
change = Column(Float)
|
||||
volume = Column(Integer)
|
||||
created_at = Column(DateTime, default=datetime.datetime.now)
|
||||
|
||||
class AccountBalance(Base):
|
||||
__tablename__ = "account_balance"
|
||||
id = Column(Integer, primary_key=True)
|
||||
total_eval = Column(Float, default=0.0) # 총평가금액
|
||||
deposit = Column(Float, default=0.0) # 예수금
|
||||
total_profit = Column(Float, default=0.0) # 평가손익
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now)
|
||||
|
||||
class Holding(Base):
|
||||
__tablename__ = "holdings"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
code = Column(String, index=True)
|
||||
name = Column(String)
|
||||
quantity = Column(Integer)
|
||||
price = Column(Float) # 매입평단가
|
||||
current_price = Column(Float) # 현재가
|
||||
profit_rate = Column(Float)
|
||||
market = Column(String) # DOMESTIC, NASD, etc.
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now)
|
||||
|
||||
def init_db():
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
379
backend/kis_api.py
Normal file
379
backend/kis_api.py
Normal file
@@ -0,0 +1,379 @@
|
||||
import requests
|
||||
import json
|
||||
import datetime
|
||||
import os
|
||||
import time
|
||||
import copy
|
||||
from config import get_kis_config
|
||||
import logging
|
||||
|
||||
# Basic Logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("KIS_API")
|
||||
|
||||
class KisApi:
|
||||
def __init__(self):
|
||||
self.config = get_kis_config()
|
||||
self.app_key = self.config.get("app_key")
|
||||
self.app_secret = self.config.get("app_secret")
|
||||
self.account_no = str(self.config.get("account_no", "")).replace("-", "").strip() # 8 digits
|
||||
self.account_prod = str(self.config.get("account_prod", "01")).strip() # 2 digits
|
||||
self.is_paper = self.config.get("is_paper", True)
|
||||
self.htsid = self.config.get("htsid", "")
|
||||
|
||||
logger.info(f"Initialized KIS API: Account={self.account_no}, Prod={self.account_prod}, Paper={self.is_paper}")
|
||||
|
||||
if self.is_paper:
|
||||
self.base_url = "https://openapivts.koreainvestment.com:29443"
|
||||
else:
|
||||
self.base_url = "https://openapi.koreainvestment.com:9443"
|
||||
|
||||
self.token_file = "kis_token.tmp"
|
||||
self.access_token = None
|
||||
self.token_expired = None
|
||||
self.last_req_time = 0
|
||||
|
||||
self._auth()
|
||||
|
||||
def _auth(self):
|
||||
# Clean outdated token file if needed or load existing
|
||||
if os.path.exists(self.token_file):
|
||||
with open(self.token_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
expired = datetime.datetime.strptime(data['expired'], "%Y-%m-%d %H:%M:%S")
|
||||
if expired > datetime.datetime.now():
|
||||
self.access_token = data['token']
|
||||
self.token_expired = expired
|
||||
logger.info("Loaded credentials from cache.")
|
||||
return
|
||||
|
||||
# Request new token
|
||||
url = f"{self.base_url}/oauth2/tokenP"
|
||||
headers = {"content-type": "application/json"}
|
||||
body = {
|
||||
"grant_type": "client_credentials",
|
||||
"appkey": self.app_key,
|
||||
"appsecret": self.app_secret
|
||||
}
|
||||
|
||||
res = requests.post(url, headers=headers, data=json.dumps(body))
|
||||
if res.status_code == 200:
|
||||
data = res.json()
|
||||
self.access_token = data['access_token']
|
||||
self.token_expired = datetime.datetime.strptime(data['access_token_token_expired'], "%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# Save to file
|
||||
with open(self.token_file, 'w') as f:
|
||||
json.dump({
|
||||
"token": self.access_token,
|
||||
"expired": data['access_token_token_expired']
|
||||
}, f)
|
||||
logger.info("Issued new access token.")
|
||||
else:
|
||||
logger.error(f"Auth Failed: {res.text}")
|
||||
raise Exception("Authentication Failed")
|
||||
|
||||
def get_websocket_key(self):
|
||||
"""
|
||||
Get Approval Key for WebSocket
|
||||
"""
|
||||
url = f"{self.base_url}/oauth2/Approval"
|
||||
headers = {"content-type": "application/json"}
|
||||
body = {
|
||||
"grant_type": "client_credentials",
|
||||
"appkey": self.app_key,
|
||||
"secretkey": self.app_secret
|
||||
}
|
||||
res = requests.post(url, headers=headers, data=json.dumps(body))
|
||||
if res.status_code == 200:
|
||||
return res.json().get("approval_key")
|
||||
else:
|
||||
logger.error(f"WS Key Failed: {res.text}")
|
||||
return None
|
||||
|
||||
def _get_header(self, tr_id=None):
|
||||
header = {
|
||||
"Content-Type": "application/json",
|
||||
"authorization": f"Bearer {self.access_token}",
|
||||
"appkey": self.app_key,
|
||||
"appsecret": self.app_secret,
|
||||
"tr_id": tr_id,
|
||||
"custtype": "P"
|
||||
}
|
||||
return header
|
||||
|
||||
def _request(self, method, path, tr_id=None, x_tr_id_buy=None, x_tr_id_sell=None, **kwargs):
|
||||
"""
|
||||
Centralized request handler with auto-token refresh logic and 500ms throttling.
|
||||
"""
|
||||
# Throttling
|
||||
now = time.time()
|
||||
diff = now - self.last_req_time
|
||||
if diff < 0.5:
|
||||
time.sleep(0.5 - diff)
|
||||
self.last_req_time = time.time()
|
||||
|
||||
url = f"{self.base_url}{path}"
|
||||
|
||||
# Determine TR ID
|
||||
curr_tr_id = tr_id
|
||||
if x_tr_id_buy and x_tr_id_sell:
|
||||
pass
|
||||
|
||||
# Prepare headers
|
||||
headers = self._get_header(curr_tr_id)
|
||||
|
||||
# Execute Request
|
||||
try:
|
||||
if method.upper() == "GET":
|
||||
res = requests.get(url, headers=headers, **kwargs)
|
||||
else:
|
||||
res = requests.post(url, headers=headers, **kwargs)
|
||||
|
||||
# Check for Token Expiration (EGW00123)
|
||||
if res.status_code == 200:
|
||||
data = res.json()
|
||||
if isinstance(data, dict):
|
||||
msg_cd = data.get('msg_cd', '')
|
||||
if msg_cd == 'EGW00123': # Expired Token
|
||||
logger.warning("Token expired (EGW00123). Refreshing token and retrying...")
|
||||
|
||||
# Remove token file
|
||||
if os.path.exists(self.token_file):
|
||||
os.remove(self.token_file)
|
||||
|
||||
# Re-auth
|
||||
self._auth()
|
||||
|
||||
# Update headers with new token
|
||||
headers = self._get_header(curr_tr_id)
|
||||
|
||||
# Retry
|
||||
if method.upper() == "GET":
|
||||
res = requests.get(url, headers=headers, **kwargs)
|
||||
else:
|
||||
res = requests.post(url, headers=headers, **kwargs)
|
||||
|
||||
return res.json()
|
||||
return data
|
||||
else:
|
||||
logger.error(f"API Request Failed [{res.status_code}]: {res.text}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Request Exception: {e}")
|
||||
return None
|
||||
|
||||
def get_current_price(self, code):
|
||||
"""
|
||||
inquire-price
|
||||
"""
|
||||
tr_id = "FHKST01010100"
|
||||
params = {
|
||||
"FID_COND_MRKT_DIV_CODE": "J",
|
||||
"FID_INPUT_ISCD": code
|
||||
}
|
||||
res = self._request("GET", "/uapi/domestic-stock/v1/quotations/inquire-price", tr_id=tr_id, params=params)
|
||||
return res.get('output', {}) if res else None
|
||||
|
||||
def get_balance(self, source=None):
|
||||
"""
|
||||
inquire-balance
|
||||
Cached for 2 seconds to prevent rate limit (EGW00201)
|
||||
"""
|
||||
# Check cache
|
||||
now = time.time()
|
||||
if hasattr(self, '_balance_cache') and self._balance_cache:
|
||||
last_time, data = self._balance_cache
|
||||
if now - last_time < 2.0: # 2 Seconds TTL
|
||||
return data
|
||||
|
||||
log_source = f" [Source: {source}]" if source else ""
|
||||
logger.info(f"get_balance{log_source}: Account={self.account_no}, Paper={self.is_paper}")
|
||||
|
||||
tr_id = "VTTC8434R" if self.is_paper else "TTTC8434R"
|
||||
|
||||
params = {
|
||||
"CANO": self.account_no,
|
||||
"ACNT_PRDT_CD": self.account_prod,
|
||||
"AFHR_FLPR_YN": "N",
|
||||
"OFL_YN": "",
|
||||
"INQR_DVSN": "02", # 02: By Stock
|
||||
"UNPR_DVSN": "01",
|
||||
"FUND_STTL_ICLD_YN": "N",
|
||||
"FNCG_AMT_AUTO_RDPT_YN": "N",
|
||||
"PRCS_DVSN": "00",
|
||||
"CTX_AREA_FK100": "",
|
||||
"CTX_AREA_NK100": ""
|
||||
}
|
||||
res = self._request("GET", "/uapi/domestic-stock/v1/trading/inquire-balance", tr_id=tr_id, params=params)
|
||||
|
||||
# Update Cache if success
|
||||
if res and res.get('rt_cd') == '0':
|
||||
self._balance_cache = (now, res)
|
||||
|
||||
return res
|
||||
|
||||
def get_overseas_balance(self, exchange="NASD"):
|
||||
"""
|
||||
overseas-stock inquire-balance
|
||||
"""
|
||||
# For overseas, we need to know the exchange code often, but inquire-balance might return all if configured?
|
||||
# Typically requires an exchange code in some params or TR IDs specific to exchange?
|
||||
# Looking at docs: TTTS3012R is for "Overseas Stock Balance"
|
||||
|
||||
tr_id = "VTTS3012R" if self.is_paper else "TTTS3012R"
|
||||
# Paper env TR ID is tricky, usually V... but let's assume VTTS3012R or JTTT3012R?
|
||||
# Common pattern: Real 'T...' -> Paper 'V...'
|
||||
|
||||
params = {
|
||||
"CANO": self.account_no,
|
||||
"ACNT_PRDT_CD": self.account_prod,
|
||||
"OVRS_EXCG_CD": exchange, # NASD, NYSE, AMEX, HKS, TSE, etc.
|
||||
"TR_CRCY_CD": "USD", # Transaction Currency
|
||||
"CTX_AREA_FK200": "",
|
||||
"CTX_AREA_NK200": ""
|
||||
}
|
||||
|
||||
return self._request("GET", "/uapi/overseas-stock/v1/trading/inquire-balance", tr_id=tr_id, params=params)
|
||||
|
||||
def place_order(self, code, type, qty, price):
|
||||
"""
|
||||
order-cash
|
||||
type: 'buy' or 'sell'
|
||||
"""
|
||||
if self.is_paper:
|
||||
tr_id = "VTTC0012U" if type == 'buy' else "VTTC0011U"
|
||||
else:
|
||||
tr_id = "TTTC0012U" if type == 'buy' else "TTTC0011U"
|
||||
|
||||
# 00: Limit (Specified Price), 01: Market Price
|
||||
ord_dvsn = "00" if int(price) > 0 else "01"
|
||||
|
||||
body = {
|
||||
"CANO": self.account_no,
|
||||
"ACNT_PRDT_CD": self.account_prod,
|
||||
"PDNO": code,
|
||||
"ORD_DVSN": ord_dvsn,
|
||||
"ORD_QTY": str(qty),
|
||||
"ORD_UNPR": str(price),
|
||||
"EXCG_ID_DVSN_CD": "KRX", # KRX for Exchange
|
||||
"SLL_TYPE": "01", # 01: Normal Sell
|
||||
"CNDT_PRIC": ""
|
||||
}
|
||||
return self._request("POST", "/uapi/domestic-stock/v1/trading/order-cash", tr_id=tr_id, data=json.dumps(body))
|
||||
|
||||
def place_overseas_order(self, code, type, qty, price, market="NASD"):
|
||||
"""
|
||||
overseas-stock order
|
||||
"""
|
||||
# Checks for Paper vs Real and Buy vs Sell
|
||||
# TR_ID might vary by country. Assuming US (NASD, NYSE, AMEX).
|
||||
if self.is_paper:
|
||||
tr_id = "VTTT1002U" if type == 'buy' else "VTTT1006U"
|
||||
else:
|
||||
# US Real: JTTT1002U (Buy), JTTT1006U (Sell)
|
||||
# Note: This TR ID is for US Night Market (Main).
|
||||
tr_id = "JTTT1002U" if type == 'buy' else "JTTT1006U"
|
||||
|
||||
# Price '0' or empty is usually Market Price, but overseas api often requires specific handling.
|
||||
# US Market Price Order usually uses ord_dvsn="00" and price="0" or empty?
|
||||
# KIS Docs: US Limit="00", Market="00"?? No, usually "00" is Limit.
|
||||
# "32": LOO, "33": LOC, "34": MOO, "35": MOC...
|
||||
# Let's stick to Limit ("00") for now. If price is 0, user might mean Market, but US API requires price for limit.
|
||||
# If user sends 0, let's try to assume "00" (Limit) with price 0 (might fail) or valid Market ("01"? No).
|
||||
# Safe bet: US Market Order is "00" (Limit) with Price 0? No.
|
||||
# Use "00" (Limit) as default. If price is 0, we can't easily do Market on US via standard "01" like domestic.
|
||||
|
||||
ord_dvsn = "00" # Limit
|
||||
|
||||
body = {
|
||||
"CANO": self.account_no,
|
||||
"ACNT_PRDT_CD": self.account_prod,
|
||||
"OVRS_EXCG_CD": market,
|
||||
"PDNO": code,
|
||||
"ORD_QTY": str(qty),
|
||||
"OVRS_ORD_UNPR": str(price),
|
||||
"ORD_SVR_DVSN_CD": "0",
|
||||
"ORD_DVSN": ord_dvsn
|
||||
}
|
||||
|
||||
return self._request("POST", "/uapi/overseas-stock/v1/trading/order", tr_id=tr_id, data=json.dumps(body))
|
||||
|
||||
def get_daily_orders(self, start_date=None, end_date=None, expanded_code=None):
|
||||
"""
|
||||
inquire-daily-ccld
|
||||
"""
|
||||
if self.is_paper:
|
||||
tr_id = "VTTC0081R" # 3-month inner
|
||||
else:
|
||||
tr_id = "TTTC0081R" # 3-month inner
|
||||
|
||||
if not start_date:
|
||||
start_date = datetime.datetime.now().strftime("%Y%m%d")
|
||||
if not end_date:
|
||||
end_date = datetime.datetime.now().strftime("%Y%m%d")
|
||||
|
||||
params = {
|
||||
"CANO": self.account_no,
|
||||
"ACNT_PRDT_CD": self.account_prod,
|
||||
"INQR_STRT_DT": start_date,
|
||||
"INQR_END_DT": end_date,
|
||||
"SLL_BUY_DVSN_CD": "00", # All
|
||||
"PDNO": "",
|
||||
"CCLD_DVSN": "00", # All (Executed + Unexecuted)
|
||||
"INQR_DVSN": "00", # Reverse Order
|
||||
"INQR_DVSN_3": "00", # All
|
||||
"ORD_GNO_BRNO": "",
|
||||
"ODNO": "",
|
||||
"INQR_DVSN_1": "",
|
||||
"CTX_AREA_FK100": "",
|
||||
"CTX_AREA_NK100": ""
|
||||
}
|
||||
return self._request("GET", "/uapi/domestic-stock/v1/trading/inquire-daily-ccld", tr_id=tr_id, params=params)
|
||||
|
||||
def get_cancelable_orders(self):
|
||||
"""
|
||||
inquire-psbl-rvsecncl
|
||||
"""
|
||||
tr_id = "VTTC0084R" if self.is_paper else "TTTC0084R"
|
||||
|
||||
params = {
|
||||
"CANO": self.account_no,
|
||||
"ACNT_PRDT_CD": self.account_prod,
|
||||
"INQR_DVSN_1": "0", # 0: Order No order
|
||||
"INQR_DVSN_2": "0", # 0: All
|
||||
"CTX_AREA_FK100": "",
|
||||
"CTX_AREA_NK100": ""
|
||||
}
|
||||
return self._request("GET", "/uapi/domestic-stock/v1/trading/inquire-psbl-rvsecncl", tr_id=tr_id, params=params)
|
||||
|
||||
def cancel_order(self, org_no, order_no, qty, is_buy, price="0", total=True):
|
||||
"""
|
||||
order-rvsecncl
|
||||
"""
|
||||
if self.is_paper:
|
||||
tr_id = "VTTC0013U"
|
||||
else:
|
||||
tr_id = "TTTC0013U"
|
||||
|
||||
rvse_cncl_dvsn_cd = "02"
|
||||
qty_all = "Y" if total else "N"
|
||||
|
||||
body = {
|
||||
"CANO": self.account_no,
|
||||
"ACNT_PRDT_CD": self.account_prod,
|
||||
"KRX_FWDG_ORD_ORGNO": org_no,
|
||||
"ORGN_ODNO": order_no,
|
||||
"ORD_DVSN": "00",
|
||||
"RVSE_CNCL_DVSN_CD": rvse_cncl_dvsn_cd,
|
||||
"ORD_QTY": str(qty),
|
||||
"ORD_UNPR": str(price),
|
||||
"QTY_ALL_ORD_YN": qty_all,
|
||||
"EXCG_ID_DVSN_CD": "KRX"
|
||||
}
|
||||
return self._request("POST", "/uapi/domestic-stock/v1/trading/order-rvsecncl", tr_id=tr_id, data=json.dumps(body))
|
||||
|
||||
# Singleton Instance
|
||||
kis = KisApi()
|
||||
|
||||
313
backend/main.py
Normal file
313
backend/main.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import os
|
||||
import uvicorn
|
||||
import threading
|
||||
import asyncio
|
||||
from fastapi import FastAPI, HTTPException, Depends, Body, Header, BackgroundTasks
|
||||
import fastapi
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
|
||||
from database import init_db, get_db, TradeSetting, Order, News, Stock, Watchlist
|
||||
from config import load_config, save_config, get_kis_config
|
||||
from kis_api import kis
|
||||
from trader import trader
|
||||
from news_ai import news_bot
|
||||
from telegram_notifier import notifier
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("MAIN")
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Get absolute path to the directory where main.py is located
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
# Root project directory (one level up from backend)
|
||||
PROJECT_ROOT = os.path.dirname(BASE_DIR)
|
||||
# Frontend directory
|
||||
FRONTEND_DIR = os.path.join(PROJECT_ROOT, "frontend")
|
||||
|
||||
# Create frontend directory if it doesn't exist (for safety)
|
||||
if not os.path.exists(FRONTEND_DIR):
|
||||
os.makedirs(FRONTEND_DIR)
|
||||
|
||||
# Mount static files
|
||||
app.mount("/static", StaticFiles(directory=FRONTEND_DIR), name="static")
|
||||
|
||||
from websocket_manager import ws_manager
|
||||
from fastapi import WebSocket
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await ws_manager.connect_frontend(websocket)
|
||||
try:
|
||||
while True:
|
||||
# Keep alive or handle frontend formatting
|
||||
data = await websocket.receive_text()
|
||||
# If frontend sends "subscribe:CODE", we could forward to KIS
|
||||
if data.startswith("sub:"):
|
||||
code = data.split(":")[1]
|
||||
await ws_manager.subscribe_stock(code)
|
||||
except:
|
||||
ws_manager.disconnect_frontend(websocket)
|
||||
|
||||
@app.on_event("startup")
|
||||
def startup_event():
|
||||
init_db()
|
||||
# Start Background Threads
|
||||
trader.start()
|
||||
news_bot.start()
|
||||
|
||||
# Start KIS WebSocket Loop
|
||||
# We need a way to run async loop in bg.
|
||||
# Uvicorn runs in asyncio loop. We can create task.
|
||||
asyncio.create_task(ws_manager.start_kis_socket())
|
||||
|
||||
notifier.send_message("🚀 KisStock AI 시스템이 시작되었습니다.")
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
def shutdown_event():
|
||||
notifier.send_message("🛑 KisStock AI 시스템이 종료됩니다.")
|
||||
trader.stop()
|
||||
news_bot.stop()
|
||||
|
||||
|
||||
# --- Pages ---
|
||||
@app.get("/")
|
||||
async def read_index():
|
||||
return FileResponse(os.path.join(FRONTEND_DIR, "index.html"))
|
||||
|
||||
@app.get("/news")
|
||||
async def read_news():
|
||||
return FileResponse(os.path.join(FRONTEND_DIR, "news.html"))
|
||||
|
||||
@app.get("/stocks")
|
||||
async def read_stocks():
|
||||
return FileResponse(os.path.join(FRONTEND_DIR, "stocks.html"))
|
||||
|
||||
@app.get("/settings")
|
||||
async def read_settings():
|
||||
return FileResponse(os.path.join(FRONTEND_DIR, "settings.html"))
|
||||
|
||||
@app.get("/trade")
|
||||
async def read_trade():
|
||||
return FileResponse(os.path.join(FRONTEND_DIR, "trade.html"))
|
||||
|
||||
# --- API ---
|
||||
from master_loader import master_loader
|
||||
from database import Watchlist
|
||||
|
||||
@app.get("/api/sync/status")
|
||||
def get_sync_status():
|
||||
return master_loader.get_status()
|
||||
|
||||
@app.post("/api/sync/master")
|
||||
def sync_master_data(background_tasks: fastapi.BackgroundTasks):
|
||||
# Run in background
|
||||
background_tasks.add_task(master_loader.download_and_parse_domestic)
|
||||
background_tasks.add_task(master_loader.download_and_parse_overseas)
|
||||
return {"status": "started", "message": "Master data sync started in background"}
|
||||
|
||||
@app.get("/api/stocks")
|
||||
def search_stocks(keyword: str = "", market: str = "", page: int = 1, db: Session = Depends(get_db)):
|
||||
query = db.query(Stock)
|
||||
if keyword:
|
||||
query = query.filter(Stock.name.contains(keyword) | Stock.code.contains(keyword))
|
||||
if market:
|
||||
query = query.filter(Stock.market == market)
|
||||
|
||||
limit = 50
|
||||
offset = (page - 1) * limit
|
||||
items = query.limit(limit).offset(offset).all()
|
||||
return {"items": items}
|
||||
|
||||
@app.get("/api/watchlist")
|
||||
def get_watchlist(db: Session = Depends(get_db)):
|
||||
return db.query(Watchlist).order_by(Watchlist.created_at.desc()).all()
|
||||
|
||||
@app.post("/api/watchlist")
|
||||
def add_watchlist(code: str = Body(...), name: str = Body(...), market: str = Body(...), db: Session = Depends(get_db)):
|
||||
exists = db.query(Watchlist).filter(Watchlist.code == code).first()
|
||||
if exists: return {"status": "exists"}
|
||||
item = Watchlist(code=code, name=name, market=market)
|
||||
db.add(item)
|
||||
db.commit()
|
||||
return {"status": "added"}
|
||||
|
||||
@app.delete("/api/watchlist/{code}")
|
||||
def delete_watchlist(code: str, db: Session = Depends(get_db)):
|
||||
db.query(Watchlist).filter(Watchlist.code == code).delete()
|
||||
db.commit()
|
||||
return {"status": "deleted"}
|
||||
|
||||
@app.get("/api/settings")
|
||||
def get_settings():
|
||||
return load_config()
|
||||
|
||||
@app.post("/api/settings")
|
||||
def update_settings(settings: dict = Body(...)):
|
||||
save_config(settings)
|
||||
return {"status": "ok"}
|
||||
|
||||
from database import AccountBalance, Holding
|
||||
import datetime
|
||||
|
||||
@app.get("/api/balance")
|
||||
def get_my_balance(source: str = "db", db: Session = Depends(get_db)):
|
||||
"""
|
||||
Return persisted balance and holdings from DB.
|
||||
Structure similar to KIS API to minimize frontend changes, or simplified.
|
||||
"""
|
||||
# 1. Balance Summary
|
||||
acc = db.query(AccountBalance).first()
|
||||
output2 = []
|
||||
if acc:
|
||||
output2.append({
|
||||
"tot_evlu_amt": acc.total_eval,
|
||||
"dnca_tot_amt": acc.deposit,
|
||||
"evlu_pfls_smtl_amt": acc.total_profit
|
||||
})
|
||||
|
||||
# 2. Holdings (Domestic)
|
||||
holdings = db.query(Holding).filter(Holding.market == 'DOMESTIC').all()
|
||||
output1 = []
|
||||
for h in holdings:
|
||||
output1.append({
|
||||
"pdno": h.code,
|
||||
"prdt_name": h.name,
|
||||
"hldg_qty": h.quantity,
|
||||
"prpr": h.current_price,
|
||||
"pchs_avg_pric": h.price,
|
||||
"evlu_pfls_rt": h.profit_rate
|
||||
})
|
||||
|
||||
return {
|
||||
"rt_cd": "0",
|
||||
"msg1": "Success from DB",
|
||||
"output1": output1,
|
||||
"output2": output2
|
||||
}
|
||||
|
||||
@app.get("/api/balance/overseas")
|
||||
def get_my_overseas_balance(db: Session = Depends(get_db)):
|
||||
# Persisted Overseas Holdings
|
||||
holdings = db.query(Holding).filter(Holding.market == 'NASD').all()
|
||||
output1 = []
|
||||
for h in holdings:
|
||||
output1.append({
|
||||
"ovrs_pdno": h.code,
|
||||
"ovrs_item_name": h.name,
|
||||
"ovrs_cblc_qty": h.quantity,
|
||||
"now_pric2": h.current_price,
|
||||
"frcr_pchs_amt1": h.price,
|
||||
"evlu_pfls_rt": h.profit_rate
|
||||
})
|
||||
|
||||
return {
|
||||
"rt_cd": "0",
|
||||
"output1": output1
|
||||
}
|
||||
|
||||
@app.post("/api/balance/refresh")
|
||||
def force_refresh_balance(background_tasks: BackgroundTasks):
|
||||
trader.refresh_assets() # Run synchronously to return fresh data immediately? Or BG?
|
||||
# User perception: "Loading..." -> Show data.
|
||||
# If we run BG, frontend needs to poll.
|
||||
# Let's run Sync for "Refresh" button (unless it takes too long).
|
||||
# KIS API is reasonably fast (milliseconds).
|
||||
# trader.refresh_assets()
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/api/price/{code}")
|
||||
def get_stock_price(code: str):
|
||||
price = kis.get_current_price(code)
|
||||
if price:
|
||||
return price
|
||||
raise HTTPException(status_code=404, detail="Stock info not found")
|
||||
|
||||
@app.post("/api/order")
|
||||
def place_order_api(
|
||||
code: str = Body(...),
|
||||
type: str = Body(...),
|
||||
qty: int = Body(...),
|
||||
price: int = Body(...),
|
||||
market: str = Body("DOMESTIC")
|
||||
):
|
||||
if market in ["NASD", "NYSE", "AMEX"]:
|
||||
res = kis.place_overseas_order(code, type, qty, price, market)
|
||||
else:
|
||||
res = kis.place_order(code, type, qty, price)
|
||||
|
||||
if res and res.get('rt_cd') == '0':
|
||||
return res
|
||||
raise HTTPException(status_code=400, detail=f"Order Failed: {res}")
|
||||
|
||||
@app.get("/api/orders")
|
||||
def get_db_orders(db: Session = Depends(get_db)):
|
||||
orders = db.query(Order).order_by(Order.created_at.desc()).limit(50).all()
|
||||
return orders
|
||||
|
||||
@app.get("/api/orders/daily")
|
||||
def get_daily_orders_api():
|
||||
res = kis.get_daily_orders()
|
||||
if res:
|
||||
return res
|
||||
raise HTTPException(status_code=500, detail="Failed to fetch daily orders")
|
||||
|
||||
@app.get("/api/orders/cancelable")
|
||||
def get_cancelable_orders_api():
|
||||
res = kis.get_cancelable_orders()
|
||||
if res:
|
||||
return res
|
||||
raise HTTPException(status_code=500, detail="Failed to fetch cancelable orders")
|
||||
|
||||
@app.post("/api/order/cancel")
|
||||
def cancel_order_api(
|
||||
org_no: str = Body(...),
|
||||
order_no: str = Body(...),
|
||||
qty: int = Body(...),
|
||||
is_buy: bool = Body(...),
|
||||
price: int = Body(0)
|
||||
):
|
||||
res = kis.cancel_order(org_no, order_no, qty, is_buy, price)
|
||||
if res and res.get('rt_cd') == '0':
|
||||
return res
|
||||
raise HTTPException(status_code=400, detail=f"Cancel Failed: {res}")
|
||||
|
||||
|
||||
@app.get("/api/news")
|
||||
def get_news(db: Session = Depends(get_db)):
|
||||
news = db.query(News).order_by(News.created_at.desc()).limit(20).all()
|
||||
return news
|
||||
|
||||
@app.get("/api/trade_settings")
|
||||
def get_trade_settings(db: Session = Depends(get_db)):
|
||||
return db.query(TradeSetting).all()
|
||||
|
||||
@app.post("/api/trade_settings")
|
||||
def set_trade_setting(
|
||||
code: str = Body(...),
|
||||
target_price: Optional[float] = Body(None),
|
||||
stop_loss_price: Optional[float] = Body(None),
|
||||
is_active: bool = Body(True),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
setting = db.query(TradeSetting).filter(TradeSetting.code == code).first()
|
||||
if not setting:
|
||||
setting = TradeSetting(code=code)
|
||||
db.add(setting)
|
||||
|
||||
if target_price is not None:
|
||||
setting.target_price = target_price
|
||||
if stop_loss_price is not None:
|
||||
setting.stop_loss_price = stop_loss_price
|
||||
setting.is_active = is_active
|
||||
|
||||
db.commit()
|
||||
return {"status": "ok"}
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|
||||
253
backend/master_loader.py
Normal file
253
backend/master_loader.py
Normal file
@@ -0,0 +1,253 @@
|
||||
|
||||
import os
|
||||
import requests
|
||||
import zipfile
|
||||
import io
|
||||
import pandas as pd
|
||||
from database import SessionLocal, Stock, engine
|
||||
from sqlalchemy.orm import Session
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("MASTER_LOADER")
|
||||
|
||||
class MasterLoader:
|
||||
def __init__(self):
|
||||
self.base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
self.tmp_dir = os.path.join(self.base_dir, "tmp_master")
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
self.sync_status = {"status": "idle", "message": ""}
|
||||
|
||||
def get_status(self):
|
||||
return self.sync_status
|
||||
|
||||
def _set_status(self, status, message):
|
||||
self.sync_status = {"status": status, "message": message}
|
||||
logger.info(f"Sync Status: {status} - {message}")
|
||||
|
||||
def download_and_parse_domestic(self):
|
||||
self._set_status("running", "Downloading Domestic Master...")
|
||||
urls = {
|
||||
"kospi": "https://new.real.download.dws.co.kr/common/master/kospi_code.mst.zip",
|
||||
"kosdaq": "https://new.real.download.dws.co.kr/common/master/kosdaq_code.mst.zip"
|
||||
}
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
for market, url in urls.items():
|
||||
logger.info(f"Downloading {market} master data from {url}...")
|
||||
try:
|
||||
res = requests.get(url)
|
||||
if res.status_code != 200:
|
||||
logger.error(f"Failed to download {market} master")
|
||||
self._set_status("error", f"Failed to download {market}")
|
||||
continue
|
||||
|
||||
with zipfile.ZipFile(io.BytesIO(res.content)) as z:
|
||||
filename = f"{market}_code.mst"
|
||||
z.extract(filename, self.tmp_dir)
|
||||
|
||||
file_path = os.path.join(self.tmp_dir, filename)
|
||||
self._parse_domestic_file(file_path, market.upper(), db)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {market}: {e}")
|
||||
self._set_status("error", f"Error processing {market}: {e}")
|
||||
|
||||
db.commit()
|
||||
if self.sync_status['status'] != 'error':
|
||||
self._set_status("running", "Domestic Sync Complete")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _parse_domestic_file(self, file_path, market_name, db: Session):
|
||||
with open(file_path, 'r', encoding='cp949') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
logger.info(f"Parsing {len(lines)} lines for {market_name}...")
|
||||
|
||||
batch = []
|
||||
for line in lines:
|
||||
code = line[0:9].strip()
|
||||
name = line[21:61].strip()
|
||||
|
||||
if not code or not name:
|
||||
continue
|
||||
|
||||
batch.append({
|
||||
"code": code,
|
||||
"name": name,
|
||||
"market": market_name,
|
||||
"type": "DOMESTIC"
|
||||
})
|
||||
|
||||
if len(batch) >= 1000:
|
||||
self._upsert_batch(db, batch)
|
||||
batch = []
|
||||
|
||||
if batch:
|
||||
self._upsert_batch(db, batch)
|
||||
|
||||
def download_and_parse_overseas(self):
|
||||
if self.sync_status['status'] == 'error': return
|
||||
|
||||
self._set_status("running", "Downloading Overseas Master...")
|
||||
|
||||
# NASDAQ from text file
|
||||
urls = {
|
||||
"NASD": "https://www.nasdaqtrader.com/dynamic/symdir/nasdaqlisted.txt",
|
||||
# "NYSE": "https://new.real.download.dws.co.kr/common/master/usa_nys.mst.zip",
|
||||
# "AMEX": "https://new.real.download.dws.co.kr/common/master/usa_ams.mst.zip"
|
||||
}
|
||||
|
||||
db = SessionLocal()
|
||||
error_count = 0
|
||||
try:
|
||||
for market, url in urls.items():
|
||||
logger.info(f"Downloading {market} master data from {url}...")
|
||||
try:
|
||||
res = requests.get(url)
|
||||
logger.info(f"HTTP Status: {res.status_code}")
|
||||
if res.status_code != 200:
|
||||
logger.error(f"Download failed for {market}. Status: {res.status_code}")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
if url.endswith('.txt'):
|
||||
self._parse_nasdaq_txt(res.text, market, db)
|
||||
else:
|
||||
with zipfile.ZipFile(io.BytesIO(res.content)) as z:
|
||||
target_file = None
|
||||
for f in z.namelist():
|
||||
if f.endswith(".mst"):
|
||||
target_file = f
|
||||
break
|
||||
if target_file:
|
||||
z.extract(target_file, self.tmp_dir)
|
||||
file_path = os.path.join(self.tmp_dir, target_file)
|
||||
self._parse_overseas_file(file_path, market, db)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {market}: {e}")
|
||||
error_count += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
if error_count == len(urls):
|
||||
self._set_status("error", "All overseas downloads failed.")
|
||||
elif error_count > 0:
|
||||
self._set_status("warning", f"Overseas Sync Partial ({error_count} failed).")
|
||||
else:
|
||||
self._set_status("done", "All Sync Complete.")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _parse_nasdaq_txt(self, content, market_name, db: Session):
|
||||
# Format: Symbol|Security Name|Market Category|Test Issue|Financial Status|Round Lot Size|ETF|NextShares
|
||||
lines = content.splitlines()
|
||||
logger.info(f"Parsing {len(lines)} lines for {market_name} (TXT)...")
|
||||
|
||||
batch = []
|
||||
parsed_count = 0
|
||||
|
||||
for line in lines:
|
||||
try:
|
||||
if not line or line.startswith('Symbol|') or line.startswith('File Creation Time'):
|
||||
continue
|
||||
|
||||
parts = line.split('|')
|
||||
if len(parts) < 7: continue
|
||||
|
||||
symbol = parts[0]
|
||||
name = parts[1]
|
||||
# market_category = parts[2]
|
||||
financial_status = parts[4] # N=Normal, D=Deficient, E=Delinquent, Q=Bankrupt, G=Deficient and Bankrupt
|
||||
etf_flag = parts[6] # Y/N
|
||||
|
||||
is_etf = (etf_flag == 'Y')
|
||||
|
||||
batch.append({
|
||||
"code": symbol,
|
||||
"name": name,
|
||||
"name_eng": name,
|
||||
"market": market_name,
|
||||
"type": "OVERSEAS",
|
||||
"financial_status": financial_status,
|
||||
"is_etf": is_etf
|
||||
})
|
||||
|
||||
if len(batch) >= 1000:
|
||||
self._upsert_batch(db, batch)
|
||||
parsed_count += len(batch)
|
||||
batch = []
|
||||
except Exception as e:
|
||||
# logger.error(f"Parse error: {e}")
|
||||
continue
|
||||
|
||||
if batch:
|
||||
self._upsert_batch(db, batch)
|
||||
parsed_count += len(batch)
|
||||
|
||||
logger.info(f"Parsed and Upserted {parsed_count} items for {market_name}")
|
||||
|
||||
def _parse_overseas_file(self, file_path, market_name, db: Session):
|
||||
with open(file_path, 'r', encoding='cp949', errors='ignore') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
logger.info(f"Parsing {len(lines)} lines for {market_name}... (File: {os.path.basename(file_path)})")
|
||||
|
||||
batch = []
|
||||
parsed_count = 0
|
||||
for line in lines:
|
||||
try:
|
||||
b_line = line.encode('cp949')
|
||||
symbol = b_line[0:16].decode('cp949').strip()
|
||||
name_eng = b_line[16:80].decode('cp949').strip()
|
||||
|
||||
if not symbol: continue
|
||||
|
||||
batch.append({
|
||||
"code": symbol,
|
||||
"name": name_eng,
|
||||
"name_eng": name_eng,
|
||||
"market": market_name,
|
||||
"type": "OVERSEAS"
|
||||
})
|
||||
|
||||
if len(batch) >= 1000:
|
||||
self._upsert_batch(db, batch)
|
||||
parsed_count += len(batch)
|
||||
batch = []
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
if batch:
|
||||
self._upsert_batch(db, batch)
|
||||
parsed_count += len(batch)
|
||||
|
||||
logger.info(f"Parsed and Upserted {parsed_count} items for {market_name}")
|
||||
|
||||
def _upsert_batch(self, db: Session, batch):
|
||||
for item in batch:
|
||||
existing = db.query(Stock).filter(Stock.code == item['code']).first()
|
||||
if existing:
|
||||
existing.name = item['name']
|
||||
existing.market = item['market']
|
||||
existing.type = item['type']
|
||||
if 'name_eng' in item: existing.name_eng = item['name_eng']
|
||||
if 'financial_status' in item: existing.financial_status = item['financial_status']
|
||||
if 'is_etf' in item: existing.is_etf = item['is_etf']
|
||||
else:
|
||||
stock = Stock(**item)
|
||||
db.add(stock)
|
||||
db.commit()
|
||||
|
||||
master_loader = MasterLoader()
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Starting sync...")
|
||||
master_loader.download_and_parse_domestic()
|
||||
print("Domestic Done. Starting Overseas...")
|
||||
master_loader.download_and_parse_overseas()
|
||||
print("Sync Complete.")
|
||||
37
backend/migrate_db.py
Normal file
37
backend/migrate_db.py
Normal file
@@ -0,0 +1,37 @@
|
||||
|
||||
import sqlite3
|
||||
import os
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
DB_PATH = os.path.join(os.path.dirname(BASE_DIR), 'kis_stock.db')
|
||||
|
||||
def migrate():
|
||||
print(f"Migrating database at {DB_PATH}...")
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check if columns exist, if not add them
|
||||
try:
|
||||
cursor.execute("ALTER TABLE stocks ADD COLUMN name_eng VARCHAR")
|
||||
print("Added name_eng")
|
||||
except Exception as e:
|
||||
print(f"Skipping name_eng: {e}")
|
||||
|
||||
try:
|
||||
cursor.execute("ALTER TABLE stocks ADD COLUMN industry VARCHAR")
|
||||
print("Added industry")
|
||||
except Exception as e:
|
||||
print(f"Skipping industry: {e}")
|
||||
|
||||
try:
|
||||
cursor.execute("ALTER TABLE stocks ADD COLUMN type VARCHAR DEFAULT 'DOMESTIC'")
|
||||
print("Added type")
|
||||
except Exception as e:
|
||||
print(f"Skipping type: {e}")
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
print("Migration complete.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
migrate()
|
||||
57
backend/migrate_db_v2.py
Normal file
57
backend/migrate_db_v2.py
Normal file
@@ -0,0 +1,57 @@
|
||||
|
||||
import sqlite3
|
||||
import os
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
DB_PATH = os.path.join(os.path.dirname(BASE_DIR), 'kis_stock.db')
|
||||
|
||||
def migrate():
|
||||
print(f"Migrating database v2 at {DB_PATH}...")
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Financial Status (Health)
|
||||
try:
|
||||
cursor.execute("ALTER TABLE stocks ADD COLUMN financial_status VARCHAR")
|
||||
print("Added financial_status")
|
||||
except Exception as e:
|
||||
print(f"Skipping financial_status: {e}")
|
||||
|
||||
# ETF Facet
|
||||
try:
|
||||
cursor.execute("ALTER TABLE stocks ADD COLUMN is_etf BOOLEAN DEFAULT 0")
|
||||
print("Added is_etf")
|
||||
except Exception as e:
|
||||
print(f"Skipping is_etf: {e}")
|
||||
|
||||
# Current Price (Cache)
|
||||
try:
|
||||
cursor.execute("ALTER TABLE stocks ADD COLUMN current_price FLOAT DEFAULT 0")
|
||||
print("Added current_price")
|
||||
except Exception as e:
|
||||
print(f"Skipping current_price: {e}")
|
||||
|
||||
# Stock Price History Table
|
||||
try:
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS stock_prices (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
code VARCHAR NOT NULL,
|
||||
price FLOAT,
|
||||
change FLOAT,
|
||||
volume INTEGER,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_stock_prices_code ON stock_prices (code)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_stock_prices_created_at ON stock_prices (created_at)")
|
||||
print("Created stock_prices table")
|
||||
except Exception as e:
|
||||
print(f"Error creating stock_prices: {e}")
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
print("Migration v2 complete.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
migrate()
|
||||
18
backend/migrate_db_v3.py
Normal file
18
backend/migrate_db_v3.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from sqlalchemy import create_engine, text
|
||||
import os
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
DB_URL = f"sqlite:///{os.path.join(BASE_DIR, 'kis_stock.db')}"
|
||||
|
||||
engine = create_engine(DB_URL)
|
||||
|
||||
def migrate():
|
||||
with engine.connect() as conn:
|
||||
try:
|
||||
conn.execute(text("ALTER TABLE watchlist ADD COLUMN is_monitoring BOOLEAN DEFAULT 1"))
|
||||
print("Added is_monitoring column to watchlist")
|
||||
except Exception as e:
|
||||
print(f"Column might already exist: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
migrate()
|
||||
147
backend/news_ai.py
Normal file
147
backend/news_ai.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import time
|
||||
import threading
|
||||
import logging
|
||||
import requests
|
||||
import json
|
||||
from sqlalchemy.orm import Session
|
||||
from database import SessionLocal, News, Stock
|
||||
from config import get_kis_config, load_config
|
||||
|
||||
logger = logging.getLogger("NEWS_AI")
|
||||
|
||||
class NewsBot:
|
||||
def __init__(self):
|
||||
self.is_running = False
|
||||
self.thread = None
|
||||
self.config = load_config()
|
||||
self.naver_id = self.config.get('naver', {}).get('client_id', '')
|
||||
self.naver_secret = self.config.get('naver', {}).get('client_secret', '')
|
||||
self.google_key = self.config.get('google', {}).get('api_key', '')
|
||||
|
||||
def start(self):
|
||||
if self.is_running:
|
||||
return
|
||||
self.is_running = True
|
||||
self.thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self.thread.start()
|
||||
logger.info("News Bot Started")
|
||||
|
||||
def stop(self):
|
||||
self.is_running = False
|
||||
if self.thread:
|
||||
self.thread.join()
|
||||
logger.info("News Bot Stopped")
|
||||
|
||||
def _run_loop(self):
|
||||
while self.is_running:
|
||||
try:
|
||||
# Reload config to check current settings
|
||||
self.config = load_config()
|
||||
if self.config.get('preferences', {}).get('enable_news', False):
|
||||
self._fetch_and_analyze()
|
||||
else:
|
||||
logger.info("News collection is disabled.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in news loop: {e}")
|
||||
|
||||
# Sleep 10 minutes (600 seconds)
|
||||
for _ in range(600):
|
||||
if not self.is_running: break
|
||||
time.sleep(1)
|
||||
|
||||
def _fetch_and_analyze(self):
|
||||
logger.info("Fetching News...")
|
||||
if not self.naver_id or not self.naver_secret:
|
||||
logger.warning("Naver API Credentials missing.")
|
||||
return
|
||||
|
||||
# 1. Fetch News (Naver)
|
||||
# Search for generic economy terms or specific watchlist
|
||||
query = "주식 시장" # General Stock Market
|
||||
url = "https://openapi.naver.com/v1/search/news.json"
|
||||
headers = {
|
||||
"X-Naver-Client-Id": self.naver_id,
|
||||
"X-Naver-Client-Secret": self.naver_secret
|
||||
}
|
||||
params = {"query": query, "display": 10, "sort": "date"}
|
||||
|
||||
res = requests.get(url, headers=headers, params=params)
|
||||
if res.status_code != 200:
|
||||
logger.error(f"Naver News Failed: {res.text}")
|
||||
return
|
||||
|
||||
items = res.json().get('items', [])
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
for item in items:
|
||||
title = item['title']
|
||||
link = item['originallink'] or item['link']
|
||||
pub_date = item['pubDate']
|
||||
|
||||
# Check duplication
|
||||
if db.query(News).filter(News.link == link).first():
|
||||
continue
|
||||
|
||||
# 2. AI Analysis (Google Gemini)
|
||||
analysis = self._analyze_with_ai(title, item['description'])
|
||||
|
||||
# Save to DB
|
||||
news = News(
|
||||
title=title,
|
||||
link=link,
|
||||
pub_date=pub_date,
|
||||
analysis_result=analysis.get('summary', ''),
|
||||
impact_score=analysis.get('score', 0),
|
||||
related_sector=analysis.get('sector', '')
|
||||
)
|
||||
db.add(news)
|
||||
db.commit()
|
||||
logger.info(f"Processed {len(items)} news items.")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _analyze_with_ai(self, title, description):
|
||||
if not self.google_key:
|
||||
return {"summary": "No API Key", "score": 0, "sector": ""}
|
||||
|
||||
logger.info(f"Analyzing: {title[:30]}...")
|
||||
|
||||
# Prompt
|
||||
prompt = f"""
|
||||
Analyze the following news for stock market impact.
|
||||
Title: {title}
|
||||
Description: {description}
|
||||
|
||||
Return JSON format:
|
||||
{{
|
||||
"summary": "One line summary of impact",
|
||||
"score": Integer between -10 (Negative) to 10 (Positive),
|
||||
"sector": "Related Industry/Sector or 'None'"
|
||||
}}
|
||||
"""
|
||||
|
||||
url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key={self.google_key}"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
body = {
|
||||
"contents": [{
|
||||
"parts": [{"text": prompt}]
|
||||
}]
|
||||
}
|
||||
|
||||
try:
|
||||
res = requests.post(url, headers=headers, data=json.dumps(body))
|
||||
if res.status_code == 200:
|
||||
result = res.json()
|
||||
text = result['candidates'][0]['content']['parts'][0]['text']
|
||||
# Clean markdown json if any
|
||||
text = text.replace("```json", "").replace("```", "").strip()
|
||||
return json.loads(text)
|
||||
else:
|
||||
logger.error(f"Gemini API Error: {res.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"AI Analysis Exception: {e}")
|
||||
|
||||
return {"summary": "Error", "score": 0, "sector": ""}
|
||||
|
||||
news_bot = NewsBot()
|
||||
43
backend/telegram_notifier.py
Normal file
43
backend/telegram_notifier.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import requests
|
||||
import logging
|
||||
from config import load_config
|
||||
|
||||
logger = logging.getLogger("TELEGRAM")
|
||||
|
||||
class TelegramNotifier:
|
||||
def __init__(self):
|
||||
self.config = load_config()
|
||||
self.bot_token = self.config.get('telegram', {}).get('bot_token', '')
|
||||
self.chat_id = self.config.get('telegram', {}).get('chat_id', '')
|
||||
|
||||
def reload_config(self):
|
||||
self.config = load_config()
|
||||
self.bot_token = self.config.get('telegram', {}).get('bot_token', '')
|
||||
self.chat_id = self.config.get('telegram', {}).get('chat_id', '')
|
||||
|
||||
def send_message(self, text):
|
||||
# Reload to ensure we have latest from settings
|
||||
self.reload_config()
|
||||
|
||||
# Check if enabled
|
||||
if not self.config.get('preferences', {}).get('enable_telegram', True):
|
||||
return
|
||||
|
||||
if not self.bot_token or not self.chat_id:
|
||||
logger.warning("Telegram credentials missing.")
|
||||
return
|
||||
|
||||
url = f"https://api.telegram.org/bot{self.bot_token}/sendMessage"
|
||||
payload = {
|
||||
"chat_id": self.chat_id,
|
||||
"text": text
|
||||
}
|
||||
|
||||
try:
|
||||
res = requests.post(url, json=payload, timeout=5)
|
||||
if res.status_code != 200:
|
||||
logger.error(f"Telegram Send Failed: {res.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"Telegram Error: {e}")
|
||||
|
||||
notifier = TelegramNotifier()
|
||||
29
backend/test_db_verification.py
Normal file
29
backend/test_db_verification.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Ensure current dir is in path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from database import init_db, SessionLocal, AccountBalance, Holding, engine
|
||||
from trader import trader
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
def test_db_migration():
|
||||
print("Initializing DB...")
|
||||
init_db()
|
||||
|
||||
# Check tables
|
||||
from sqlalchemy import inspect
|
||||
inspector = inspect(engine)
|
||||
tables = inspector.get_table_names()
|
||||
print(f"Tables: {tables}")
|
||||
|
||||
if "account_balance" in tables and "holdings" in tables:
|
||||
print("PASS: New tables created.")
|
||||
else:
|
||||
print("FAIL: Tables missing.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_db_migration()
|
||||
43
backend/test_telegram_toggle.py
Normal file
43
backend/test_telegram_toggle.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from telegram_notifier import notifier
|
||||
from config import save_config, load_config
|
||||
|
||||
def test_telegram_toggle():
|
||||
print("Testing Telegram Toggle...")
|
||||
original_config = load_config()
|
||||
|
||||
try:
|
||||
# 1. Enable
|
||||
print("1. Testing ENABLED...")
|
||||
cfg = load_config()
|
||||
if 'preferences' not in cfg: cfg['preferences'] = {}
|
||||
cfg['preferences']['enable_telegram'] = True
|
||||
save_config(cfg)
|
||||
|
||||
# We can't easily mock requests.post here without importing mock,
|
||||
# but we can check if it attempts to read credentials.
|
||||
# Ideally, we'd check if it returns early.
|
||||
# For this environment, let's just ensure no crash.
|
||||
notifier.send_message("Test Message (Should Send)")
|
||||
|
||||
# 2. Disable
|
||||
print("2. Testing DISABLED...")
|
||||
cfg['preferences']['enable_telegram'] = False
|
||||
save_config(cfg)
|
||||
|
||||
# This should return early and NOT log "Telegram credentials missing" if implemented right.
|
||||
notifier.send_message("Test Message (Should NOT Send)")
|
||||
|
||||
print("Toggle logic executed without error.")
|
||||
|
||||
finally:
|
||||
# Restore
|
||||
save_config(original_config)
|
||||
print("Original config restored.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_telegram_toggle()
|
||||
1828
backend/tmp_master/kosdaq_code.mst
Normal file
1828
backend/tmp_master/kosdaq_code.mst
Normal file
File diff suppressed because it is too large
Load Diff
2482
backend/tmp_master/kospi_code.mst
Normal file
2482
backend/tmp_master/kospi_code.mst
Normal file
File diff suppressed because it is too large
Load Diff
214
backend/trader.py
Normal file
214
backend/trader.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import time
|
||||
import threading
|
||||
import logging
|
||||
from sqlalchemy.orm import Session
|
||||
from database import SessionLocal, TradeSetting, Order, Stock, AccountBalance, Holding
|
||||
import datetime
|
||||
from kis_api import kis
|
||||
from telegram_notifier import notifier
|
||||
|
||||
|
||||
logger = logging.getLogger("TRADER")
|
||||
|
||||
class TradingBot:
|
||||
def __init__(self):
|
||||
self.is_running = False
|
||||
self.thread = None
|
||||
self.holdings = {} # Local cache for holdings: {code: {qty: int, price: float}}
|
||||
|
||||
self.last_chart_update = 0
|
||||
|
||||
def refresh_assets(self):
|
||||
"""
|
||||
Fetch Balance and Holdings from KIS and save to DB
|
||||
"""
|
||||
logger.info("Syncing Assets to Database...")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 1. Domestic Balance
|
||||
balance = kis.get_balance(source="Automated_Sync")
|
||||
if balance and 'output2' in balance and balance['output2']:
|
||||
summary = balance['output2'][0]
|
||||
# Upsert AccountBalance
|
||||
fn_status = db.query(AccountBalance).first()
|
||||
if not fn_status:
|
||||
fn_status = AccountBalance()
|
||||
db.add(fn_status)
|
||||
|
||||
fn_status.total_eval = float(summary['tot_evlu_amt'])
|
||||
fn_status.deposit = float(summary['dnca_tot_amt'])
|
||||
fn_status.total_profit = float(summary['evlu_pfls_smtl_amt'])
|
||||
fn_status.updated_at = datetime.datetime.now()
|
||||
|
||||
# 2. Holdings (Domestic)
|
||||
# Clear existing DOMESTIC
|
||||
db.query(Holding).filter(Holding.market == 'DOMESTIC').delete()
|
||||
|
||||
if balance and 'output1' in balance:
|
||||
self.holdings = {} # Keep memory cache for trading logic
|
||||
for item in balance['output1']:
|
||||
code = item['pdno']
|
||||
qty = int(item['hldg_qty'])
|
||||
if qty > 0:
|
||||
buy_price = float(item['pchs_avg_pric'])
|
||||
current_price = float(item['prpr'])
|
||||
profit_rate = float(item['evlu_pfls_rt'])
|
||||
|
||||
# Save to DB
|
||||
db.add(Holding(
|
||||
code=code,
|
||||
name=item['prdt_name'],
|
||||
quantity=qty,
|
||||
price=buy_price,
|
||||
current_price=current_price,
|
||||
profit_rate=profit_rate,
|
||||
market="DOMESTIC"
|
||||
))
|
||||
|
||||
# Memory Cache for Trade Logic
|
||||
self.holdings[code] = {'qty': qty, 'price': buy_price}
|
||||
|
||||
# 3. Overseas Balance (NASD default)
|
||||
# TODO: Multi-market support if needed
|
||||
overseas = kis.get_overseas_balance(exchange="NASD")
|
||||
|
||||
# Clear existing NASD
|
||||
db.query(Holding).filter(Holding.market == 'NASD').delete()
|
||||
|
||||
if overseas and 'output1' in overseas:
|
||||
for item in overseas['output1']:
|
||||
qty = float(item['ovrs_cblc_qty']) # Overseas can be fractional? KIS is usually int but check.
|
||||
if qty > 0:
|
||||
code = item['ovrs_pdno']
|
||||
# name = item.get('ovrs_item_name') or item.get('prdt_name')
|
||||
# KIS overseas output keys vary.
|
||||
|
||||
db.add(Holding(
|
||||
code=code,
|
||||
name=item.get('ovrs_item_name', code),
|
||||
quantity=int(qty),
|
||||
price=float(item.get('frcr_pchs_amt1', 0)), # Avg Price? Check API
|
||||
current_price=float(item.get('now_pric2', 0)),
|
||||
profit_rate=float(item.get('evlu_pfls_rt', 0)),
|
||||
market="NASD"
|
||||
))
|
||||
|
||||
db.commit()
|
||||
logger.info("Assets Synced Successfully.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sync assets: {e}")
|
||||
db.rollback()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def start(self):
|
||||
if self.is_running:
|
||||
return
|
||||
self.is_running = True
|
||||
self.refresh_assets() # Fetch on start
|
||||
self.thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self.thread.start()
|
||||
logger.info("Trading Bot Started")
|
||||
|
||||
def stop(self):
|
||||
self.is_running = False
|
||||
if self.thread:
|
||||
self.thread.join()
|
||||
logger.info("Trading Bot Stopped")
|
||||
|
||||
def _run_loop(self):
|
||||
while self.is_running:
|
||||
try:
|
||||
self._process_cycle()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in trading loop: {e}")
|
||||
|
||||
# Sleep 1 second to avoid hammering
|
||||
time.sleep(1)
|
||||
|
||||
def _process_cycle(self):
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Get active trade settings
|
||||
settings = db.query(TradeSetting).filter(TradeSetting.is_active == True).all()
|
||||
|
||||
for setting in settings:
|
||||
self._check_and_trade(db, setting)
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _check_and_trade(self, db: Session, setting: TradeSetting):
|
||||
code = setting.code
|
||||
|
||||
# Get Current Price
|
||||
# Optimization: Ideally read from a shared cache from WebSocket
|
||||
# For now, we still poll price or should use WS logic?
|
||||
# User said "Websocket... automatic decision".
|
||||
# But trader.py is isolated.
|
||||
# For simplicity in this step (removing balance poll), we keep price fetch but remove balance poll.
|
||||
price_data = kis.get_current_price(code)
|
||||
if not price_data:
|
||||
return
|
||||
|
||||
current_price = float(price_data.get('stck_prpr', 0))
|
||||
if current_price == 0:
|
||||
return
|
||||
|
||||
# Check holdings from Cache
|
||||
if code not in self.holdings:
|
||||
return # No holdings, nothing to sell (if logic is Sell)
|
||||
|
||||
holding = self.holdings[code]
|
||||
holding_qty = holding['qty']
|
||||
|
||||
# SELL Logic
|
||||
if holding_qty > 0:
|
||||
# Stop Loss
|
||||
if setting.stop_loss_price and current_price <= setting.stop_loss_price:
|
||||
logger.info(f"Stop Loss Triggered for {code}. Price: {current_price}, SL: {setting.stop_loss_price}")
|
||||
self._place_order(db, code, 'sell', holding_qty, 0) # 0 means Market Price
|
||||
return
|
||||
|
||||
# Target Profit
|
||||
if setting.target_price and current_price >= setting.target_price:
|
||||
logger.info(f"Target Price Triggered for {code}. Price: {current_price}, TP: {setting.target_price}")
|
||||
self._place_order(db, code, 'sell', holding_qty, 0)
|
||||
return
|
||||
|
||||
def _place_order(self, db: Session, code: str, type: str, qty: int, price: int):
|
||||
logger.info(f"Placing Order: {code} {type} {qty} @ {price}")
|
||||
res = kis.place_order(code, type, qty, price)
|
||||
|
||||
status = "FAILED"
|
||||
order_id = ""
|
||||
if res and res.get('rt_cd') == '0':
|
||||
status = "PENDING"
|
||||
order_id = res.get('output', {}).get('ODNO', '')
|
||||
logger.info(f"Order Success: {order_id}")
|
||||
notifier.send_message(f"🔔 주문 전송 완료\n[{type.upper()}] {code}\n수량: {qty}\n가격: {price if price > 0 else '시장가'}")
|
||||
|
||||
# Optimistic Update or Refresh?
|
||||
# User said "If execution happens, update list".
|
||||
# We should schedule a refresh.
|
||||
time.sleep(1) # Wait for execution
|
||||
self.refresh_assets()
|
||||
|
||||
else:
|
||||
logger.error(f"Order Failed: {res}")
|
||||
notifier.send_message(f"⚠️ 주문 실패\n[{type.upper()}] {code}\n이유: {res}")
|
||||
|
||||
# Record to DB
|
||||
new_order = Order(
|
||||
code=code,
|
||||
order_id=order_id,
|
||||
type=type.upper(),
|
||||
price=price,
|
||||
quantity=qty,
|
||||
status=status
|
||||
)
|
||||
db.add(new_order)
|
||||
db.commit()
|
||||
|
||||
trader = TradingBot()
|
||||
182
backend/websocket_manager.py
Normal file
182
backend/websocket_manager.py
Normal file
@@ -0,0 +1,182 @@
|
||||
|
||||
import asyncio
|
||||
import websockets
|
||||
import json
|
||||
import logging
|
||||
import datetime
|
||||
from typing import List, Set
|
||||
from database import SessionLocal, Stock, StockPrice
|
||||
from kis_api import kis
|
||||
|
||||
logger = logging.getLogger("WEBSOCKET")
|
||||
|
||||
class KisWebSocketManager:
|
||||
def __init__(self):
|
||||
self.active_frontend_connections: List[any] = []
|
||||
self.subscribed_codes: Set[str] = set()
|
||||
self.running = False
|
||||
self.approval_key = None
|
||||
self.msg_queue = asyncio.Queue() # For outgoing subscription requests
|
||||
|
||||
# KIS Environment
|
||||
if kis.is_paper:
|
||||
self.url = "ws://ops.koreainvestment.com:31000"
|
||||
else:
|
||||
self.url = "ws://ops.koreainvestment.com:21000"
|
||||
|
||||
# ... (connect/disconnect/broadcast remains same)
|
||||
|
||||
# ... (connect/disconnect/broadcast remains same)
|
||||
|
||||
# ... (_handle_realtime_data remains same)
|
||||
|
||||
async def subscribe_stock(self, code):
|
||||
if code in self.subscribed_codes:
|
||||
return
|
||||
self.subscribed_codes.add(code)
|
||||
await self.msg_queue.put(code)
|
||||
logger.info(f"Queued Subscription for {code}")
|
||||
|
||||
async def connect_frontend(self, websocket):
|
||||
await websocket.accept()
|
||||
self.active_frontend_connections.append(websocket)
|
||||
logger.info(f"Frontend Client Connected. Total: {len(self.active_frontend_connections)}")
|
||||
|
||||
def disconnect_frontend(self, websocket):
|
||||
if websocket in self.active_frontend_connections:
|
||||
self.active_frontend_connections.remove(websocket)
|
||||
logger.info("Frontend Client Disconnected")
|
||||
|
||||
async def broadcast_to_frontend(self, message: dict):
|
||||
# Broadcast to all connected frontend clients
|
||||
for connection in self.active_frontend_connections:
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Broadcast error: {e}")
|
||||
self.disconnect_frontend(connection)
|
||||
|
||||
async def start_kis_socket(self):
|
||||
self.running = True
|
||||
logger.info(f"Starting KIS WebSocket Service... Target: {self.url}")
|
||||
|
||||
while self.running:
|
||||
# 1. Ensure Approval Key
|
||||
if not self.approval_key:
|
||||
self.approval_key = kis.get_websocket_key()
|
||||
if not self.approval_key:
|
||||
logger.error("Failed to get WebSocket Approval Key. Retrying in 10s...")
|
||||
await asyncio.sleep(10)
|
||||
continue
|
||||
logger.info(f"Got WS Key: {self.approval_key[:10]}...")
|
||||
|
||||
# 2. Connect
|
||||
try:
|
||||
# KIS doesn't use standard ping frames often, handle manually or disable auto-ping
|
||||
async with websockets.connect(self.url, ping_interval=None, open_timeout=20) as ws:
|
||||
logger.info("Connected to KIS WebSocket Server")
|
||||
|
||||
# Process initial subscriptions
|
||||
for code in self.subscribed_codes:
|
||||
await self._send_subscription(ws, code)
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# 1. Check for incoming data
|
||||
msg = await asyncio.wait_for(ws.recv(), timeout=0.1)
|
||||
|
||||
# PING/PONG (String starting with 0 or 1 usually means data)
|
||||
if msg[0] in ['0', '1']:
|
||||
await self._handle_realtime_data(msg)
|
||||
else:
|
||||
# JSON Message (System, PINGPONG)
|
||||
try:
|
||||
data = json.loads(msg)
|
||||
if data.get('header', {}).get('tr_id') == 'PINGPONG':
|
||||
await ws.send(msg) # Echo back
|
||||
continue
|
||||
except:
|
||||
pass
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
except websockets.ConnectionClosed:
|
||||
logger.warning("KIS WS Closed. Reconnecting...")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WS Connection Error: {e}")
|
||||
# If auth failed (maybe expired key?), clear key to force refresh
|
||||
# simplified check: if "Approval key" error in exception message?
|
||||
# For now just retry.
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def _handle_realtime_data(self, msg: str):
|
||||
# Format: 0|TR_ID|DATA_CNT|Code^Time^Price...
|
||||
try:
|
||||
parts = msg.split('|')
|
||||
if len(parts) < 4: return
|
||||
|
||||
tr_id = parts[1]
|
||||
data_part = parts[3]
|
||||
|
||||
if tr_id == "H0STCNT0": # Domestic Stock Price
|
||||
# Data format: Code^Time^CurrentPrice^Sign^Change...
|
||||
# Actually, data_part is delimiter separated.
|
||||
values = data_part.split('^')
|
||||
code = values[0]
|
||||
price = values[2]
|
||||
change = values[4]
|
||||
rate = values[5]
|
||||
|
||||
# Broadcast
|
||||
payload = {
|
||||
"type": "PRICE",
|
||||
"code": code,
|
||||
"price": price,
|
||||
"change": change,
|
||||
"rate": rate,
|
||||
"timestamp": datetime.datetime.now().isoformat()
|
||||
}
|
||||
await self.broadcast_to_frontend(payload)
|
||||
|
||||
# Update DB (Optional? Too frequent writes maybe bad)
|
||||
# Let's save only significant updates or throttle?
|
||||
# For now just log/broadcast.
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Data Parse Error: {e} | Msg: {msg[:50]}")
|
||||
|
||||
async def subscribe_stock(self, code):
|
||||
if code in self.subscribed_codes:
|
||||
return
|
||||
self.subscribed_codes.add(code)
|
||||
# If socket is active, send subscription (Implementation complexity: need access to active 'ws' object)
|
||||
# Will handle by restarting connection or using a queue?
|
||||
# Better: just set it in set, and the main loop will pick it up on reconnect,
|
||||
# BUT for immediate sub, we need a way to signal the running loop.
|
||||
# For MVP, let's assume we subscribe on startup or bulk.
|
||||
# Real-time dynamic sub needs a queue.
|
||||
logger.info(f"Subscribed to {code} (Pending next reconnect/sweep)")
|
||||
|
||||
async def _send_subscription(self, ws, code):
|
||||
# Domestic Stock Realtime Price: H0STCNT0
|
||||
body = {
|
||||
"header": {
|
||||
"approval_key": self.approval_key,
|
||||
"custtype": "P",
|
||||
"tr_type": "1", # 1: Register, 2: Unregister
|
||||
"content-type": "utf-8"
|
||||
},
|
||||
"body": {
|
||||
"input": {
|
||||
"tr_id": "H0STCNT0",
|
||||
"tr_key": code
|
||||
}
|
||||
}
|
||||
}
|
||||
await ws.send(json.dumps(body))
|
||||
logger.info(f"Sent Subscription Request for {code}")
|
||||
|
||||
ws_manager = KisWebSocketManager()
|
||||
Reference in New Issue
Block a user