init commit

This commit is contained in:
徐微
2025-12-08 15:30:19 +08:00
commit 09193a2288
39 changed files with 16688 additions and 0 deletions

492
data_writer.py Normal file
View File

@@ -0,0 +1,492 @@
# -*- coding: utf-8 -*-
"""
CSV 数据落地模块(基于 DATA_SCHEMA.md 的简化实现)
- symbols.csv
- bars_1m.csv
- signals.csv
说明:
- 不做真正的 UpsertCSV 不擅长),通过读取现有行建立内存索引,避免重复写入关键键。
- 比率字段(如涨跌幅)采用小数存储,例如 4.02% 存 0.0402。
"""
import csv
import os
from datetime import datetime, timezone
from typing import Iterable, Dict, Any, List, Tuple
from utils_id import stable_symbol_id
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
SYMBOLS_CSV = os.path.join(DATA_DIR, "symbols.csv")
BARS_1M_CSV = os.path.join(DATA_DIR, "bars_1m.csv")
SIGNALS_CSV = os.path.join(DATA_DIR, "signals.csv")
FEATURES_1M_CSV = os.path.join(DATA_DIR, "features_1m.csv")
ETL_RUNS_CSV = os.path.join(DATA_DIR, "etl_runs.csv")
PREMARKET_BARS_CSV = os.path.join(DATA_DIR, "premarket_bars.csv")
PREMARKET_SIGNALS_CSV = os.path.join(DATA_DIR, "premarket_signals.csv")
# 确保目录存在
os.makedirs(DATA_DIR, exist_ok=True)
def _utc_now_iso() -> str:
return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
def _floor_minute(dt: datetime) -> datetime:
return dt.replace(second=0, microsecond=0, tzinfo=timezone.utc)
# ---------- symbols.csv ----------
_SYMBOLS_HEADER = [
"id","symbol","name","exchange","currency",
"tick_size","lot_size","sector","industry",
"is_active","first_seen_utc","last_seen_utc"
]
def write_symbols(stocks: Iterable[Dict[str, Any]]) -> Dict[str, int]:
"""将股票基础信息写入 symbols.csv并返回 symbol->symbol_id 映射。
stocks: 需包含 keys: symbol, name, exchange, currency
"""
existing: Dict[Tuple[str,str], Dict[str, str]] = {}
if os.path.exists(SYMBOLS_CSV):
with open(SYMBOLS_CSV, "r", encoding="utf-8-sig") as f:
reader = csv.DictReader(f)
for row in reader:
existing[(row["symbol"], row["exchange"])] = row
now = _utc_now_iso()
# 生成/更新
for s in stocks:
symbol = s.get("symbol")
name = s.get("name")
exchange = (s.get("exchange") or "US").upper()
currency = (s.get("currency") or "USD").upper()
key = (symbol, exchange)
if key not in existing:
sid = stable_symbol_id(symbol, exchange)
existing[key] = {
"id": str(sid),
"symbol": symbol,
"name": name or "",
"exchange": exchange,
"currency": currency,
"tick_size": "",
"lot_size": "",
"sector": "",
"industry": "",
"is_active": "1",
"first_seen_utc": now,
"last_seen_utc": now,
}
else:
existing[key]["last_seen_utc"] = now
# 写回
with open(SYMBOLS_CSV, "w", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=_SYMBOLS_HEADER)
writer.writeheader()
for row in existing.values():
writer.writerow(row)
# 返回映射
return {k[0]: int(v["id"]) for k, v in existing.items() if k[0] == v["symbol"]}
# ---------- bars_1m.csv ----------
_BARS_1M_HEADER = [
"symbol_id","symbol","ts_utc","open","high","low","close",
"volume","vwap","trades_count","source","session"
]
def _upgrade_bars_file_if_needed():
"""如果历史 bars_1m.csv 缺少 session 列,进行一次升级重写,补 session='regular'"""
if not os.path.exists(BARS_1M_CSV):
return
try:
with open(BARS_1M_CSV, 'r', encoding='utf-8-sig') as f:
reader = csv.reader(f)
rows = list(reader)
if not rows:
return
header = rows[0]
if 'session' in header:
return # 已升级
# 构造新文件内容
old_header = header
# 建立列索引映射
idx_map = {col: i for i, col in enumerate(old_header)}
new_rows = []
new_rows.append(_BARS_1M_HEADER)
for r in rows[1:]:
if not r:
continue
# 依据旧列生成新行
new_line = [
r[idx_map.get('symbol_id','')],
r[idx_map.get('symbol','')],
r[idx_map.get('ts_utc','')],
r[idx_map.get('open','')],
r[idx_map.get('high','')],
r[idx_map.get('low','')],
r[idx_map.get('close','')],
r[idx_map.get('volume','')],
r[idx_map.get('vwap','')],
r[idx_map.get('trades_count','')],
r[idx_map.get('source','')],
'regular'
]
new_rows.append(new_line)
# 写回升级
with open(BARS_1M_CSV, 'w', newline='', encoding='utf-8-sig') as f:
writer = csv.writer(f)
writer.writerows(new_rows)
except Exception as e:
print(f"⚠️ bars_1m.csv 升级失败: {e}")
def append_bars_1m(stocks: Iterable[Dict[str, Any]], symbol_id_map: Dict[str, int], source: str = "eastmoney") -> List[Dict[str, Any]]:
"""将当前快照近似为 1 分钟线写入 bars_1m.csv。
由于只有快照open/high/low/close 统一使用 current_pricevolume/vwap/trades_count 为空。
"""
now = _floor_minute(datetime.now(timezone.utc)).strftime("%Y-%m-%dT%H:%M:%SZ")
rows: List[Dict[str, Any]] = []
_upgrade_bars_file_if_needed()
for s in stocks:
symbol = s.get("symbol")
price = s.get("eastmoney_price") or s.get("current_price")
if price is None:
continue
sid = symbol_id_map.get(symbol) or stable_symbol_id(symbol)
rows.append({
"symbol_id": sid,
"symbol": symbol,
"ts_utc": now,
"open": price,
"high": price,
"low": price,
"close": price,
"volume": "",
"vwap": "",
"trades_count": "",
"source": source,
"session": "regular",
})
# 追加写
file_exists = os.path.exists(BARS_1M_CSV)
with open(BARS_1M_CSV, "a", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=_BARS_1M_HEADER)
if not file_exists:
writer.writeheader()
for r in rows:
writer.writerow(r)
return rows
def append_bars_session(stocks: Iterable[Dict[str, Any]], symbol_id_map: Dict[str, int], source: str = "futu", session: str = "pre") -> List[Dict[str, Any]]:
"""写入特定交易时段的快照(如盘前/盘后),与常规 bars 共存,通过 session 区分。"""
_upgrade_bars_file_if_needed()
now = _floor_minute(datetime.now(timezone.utc)).strftime("%Y-%m-%dT%H:%M:%SZ")
rows: List[Dict[str, Any]] = []
for s in stocks:
symbol = s.get("symbol")
price = s.get("premarket_price") or s.get("after_hours_price") or s.get("futu_before_open_price")
if price in (None, ""):
continue
try:
price_f = float(price)
except Exception:
continue
sid = symbol_id_map.get(symbol) or stable_symbol_id(symbol)
rows.append({
"symbol_id": sid,
"symbol": symbol,
"ts_utc": now,
"open": price_f,
"high": price_f,
"low": price_f,
"close": price_f,
"volume": "",
"vwap": "",
"trades_count": "",
"source": source,
"session": session,
})
file_exists = os.path.exists(BARS_1M_CSV)
with open(BARS_1M_CSV, "a", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=_BARS_1M_HEADER)
if not file_exists:
writer.writeheader()
for r in rows:
writer.writerow(r)
return rows
# ---------- premarket 专用快照与信号 ----------
_PREMARKET_BARS_HEADER = [
'symbol_id','symbol','name','ts_utc','ts_et','price','change','change_ratio','volume','source','session','raw_file'
]
_PREMARKET_SIGNALS_HEADER = [
'id','symbol_id','symbol','generated_at_utc','generated_at_et','signal_type','direction','score','reason','params_json','model_name','version','expires_at_utc'
]
def append_premarket_bars(rows: List[Dict[str, Any]], symbol_id_map: Dict[str, int], source: str = 'futu') -> None:
"""将盘前抓取行写入 premarket_bars.csv。
rows: 需包含 symbol,name,premarket_price,premarket_change,premarket_change_ratio(原始百分比或小数字符串), ts(ET字符串 HH:MM)
"""
if not rows:
return
file_exists = os.path.exists(PREMARKET_BARS_CSV)
now_utc = datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ')
# ET 时间字符串(便于人工查看)
try:
from zoneinfo import ZoneInfo
ts_et_full = datetime.now(ZoneInfo('America/New_York')).strftime('%Y-%m-%dT%H:%M:%S')
except Exception:
ts_et_full = ''
with open(PREMARKET_BARS_CSV, 'a', newline='', encoding='utf-8-sig') as f:
writer = csv.DictWriter(f, fieldnames=_PREMARKET_BARS_HEADER)
if not file_exists:
writer.writeheader()
for r in rows:
symbol = r.get('symbol')
if not symbol:
continue
price = r.get('premarket_price')
if price in (None,'','-'):
continue
try:
price_f = float(price)
except Exception:
continue
# ratio 原始可能是 "3.21%" / "-3.21%" / "0.0321" / ""
ratio_raw = r.get('premarket_change_ratio')
ratio_val = 0.0
if ratio_raw not in (None,''):
txt = str(ratio_raw).strip()
try:
if txt.endswith('%'):
ratio_val = float(txt.replace('%',''))/100.0
else:
# 若原始是小数形式(0.0321)或绝对值>1的百分值(3.21),都兼容
num = float(txt)
ratio_val = num/100.0 if abs(num) > 1 and abs(num) >= 2 else num # 粗略判断
except Exception:
ratio_val = 0.0
sid = symbol_id_map.get(symbol) or stable_symbol_id(symbol)
writer.writerow({
'symbol_id': sid,
'symbol': symbol,
'name': r.get('name',''),
'ts_utc': now_utc,
'ts_et': ts_et_full,
'price': price_f,
'change': r.get('premarket_change',''),
'change_ratio': ratio_val,
'volume': '',
'source': source,
'session': 'pre',
'raw_file': '',
})
def append_premarket_signals(signals: List[Dict[str, Any]], symbol_id_map: Dict[str, int]) -> None:
"""写入盘前信号到 premarket_signals.csv。
signals: 需包含 symbol,direction(BUY/SELL),reason,params_json(可选)
"""
if not signals:
return
file_exists = os.path.exists(PREMARKET_SIGNALS_CSV)
model_name, version = _def_model
now_utc = datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ')
try:
from zoneinfo import ZoneInfo
now_et = datetime.now(ZoneInfo('America/New_York')).strftime('%Y-%m-%dT%H:%M:%S')
except Exception:
now_et = ''
# 简单去重: 同 symbol+direction+当前UTC秒 不重复
seen = set()
if file_exists:
with open(PREMARKET_SIGNALS_CSV,'r',encoding='utf-8-sig') as f:
reader = csv.DictReader(f)
for row in reader:
seen.add((row['symbol'],row['direction'],row['generated_at_utc']))
with open(PREMARKET_SIGNALS_CSV,'a',newline='',encoding='utf-8-sig') as f:
writer = csv.DictWriter(f, fieldnames=_PREMARKET_SIGNALS_HEADER)
if not file_exists:
writer.writeheader()
for sig in signals:
symbol = sig.get('symbol')
direction = sig.get('direction')
if not symbol or not direction:
continue
key = (symbol,direction,now_utc)
if key in seen:
continue
sid = symbol_id_map.get(symbol) or stable_symbol_id(symbol)
params_obj = sig.get('params') or {}
writer.writerow({
'id': f'{sid}-{now_utc}',
'symbol_id': sid,
'symbol': symbol,
'generated_at_utc': now_utc,
'generated_at_et': now_et,
'signal_type': sig.get('signal_type','premarket_alert'),
'direction': direction,
'score': sig.get('score',''),
'reason': sig.get('reason',''),
'params_json': json.dumps(params_obj, ensure_ascii=False),
'model_name': model_name,
'version': version,
'expires_at_utc': '',
})
# ---------- signals.csv ----------
_SIGNALS_HEADER = [
"id","symbol_id","symbol","generated_at_utc",
"signal_type","direction","score","horizon",
"params_json","model_name","version","expires_at_utc"
]
_def_model = ("rule_threshold", "v1")
import json
def append_signals(signals: Iterable[Dict[str, Any]], symbol_id_map: Dict[str, int]) -> None:
"""将策略信号写入 signals.csv使用时间+symbol 做近似去重。
输入信号应包含symbol, type(BUY/SELL), reason/score 可选。
"""
file_exists = os.path.exists(SIGNALS_CSV)
seen_keys = set()
if file_exists:
with open(SIGNALS_CSV, "r", encoding="utf-8-sig") as f:
reader = csv.DictReader(f)
for row in reader:
seen_keys.add((row["symbol"], row["generated_at_utc"], row.get("direction")))
model_name, version = _def_model
with open(SIGNALS_CSV, "a", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=_SIGNALS_HEADER)
if not file_exists:
writer.writeheader()
for sig in signals:
symbol = sig.get("symbol")
direction = sig.get("type") or sig.get("direction")
gen_at = sig.get('generated_at_utc') or _utc_now_iso()
key = (symbol, gen_at, direction)
if key in seen_keys:
continue
sid = symbol_id_map.get(symbol) or stable_symbol_id(symbol)
writer.writerow({
"id": f"{sid}-{gen_at}",
"symbol_id": sid,
"symbol": symbol,
"generated_at_utc": gen_at,
"signal_type": "momentum",
"direction": direction,
"score": sig.get("confidence", ""),
"horizon": "intraday",
"params_json": json.dumps({"reason": sig.get("reason", "")}, ensure_ascii=False),
"model_name": model_name,
"version": version,
"expires_at_utc": "",
})
# ---------- features_1m.csv ----------
_FEATURES_1M_HEADER = [
'symbol_id','symbol','ts_utc','price','return_1m','ma_5','ma_15','vol_15'
]
def _load_existing_prices() -> Dict[str, List[Tuple[str, float]]]:
data: Dict[str, List[Tuple[str, float]]] = {}
if not os.path.exists(BARS_1M_CSV):
return data
with open(BARS_1M_CSV, 'r', encoding='utf-8-sig') as f:
reader = csv.DictReader(f)
for row in reader:
symbol = row['symbol']
ts = row['ts_utc']
try:
price = float(row['close'])
except Exception:
continue
data.setdefault(symbol, []).append((ts, price))
# 保证按时间排序CSV 追加已有序,但防御性处理)
for sym in data:
data[sym].sort(key=lambda x: x[0])
return data
def append_features_1m(new_bar_rows: List[Dict[str, Any]]) -> None:
if not new_bar_rows:
return
price_history = _load_existing_prices()
feature_rows: List[Dict[str, Any]] = []
# 按新增行计算特征
for r in new_bar_rows:
symbol = r['symbol']
sid = r['symbol_id']
ts = r['ts_utc']
try:
price = float(r['close'])
except Exception:
continue
series = price_history.get(symbol, [])
# 找到当前索引位置
# 防御series 已包含当前行,因为新行已追加;若未包含则添加再计算
if not series or series[-1][0] != ts:
series.append((ts, price))
idx = len(series) - 1
# return_1m
ret_1m = 0.0
if idx >= 1:
prev_price = series[idx-1][1]
if prev_price != 0:
ret_1m = (price / prev_price) - 1
# ma_5
window5 = [p for _, p in series[max(0, idx-4):idx+1]]
ma_5 = sum(window5)/len(window5) if window5 else price
# ma_15
window15 = [p for _, p in series[max(0, idx-14):idx+1]]
ma_15 = sum(window15)/len(window15) if window15 else price
# vol_15 = 标准差
vol_15 = 0.0
if len(window15) > 1:
avg15 = ma_15
var = sum((p-avg15)**2 for p in window15)/ (len(window15)-1)
vol_15 = var**0.5
feature_rows.append({
'symbol_id': sid,
'symbol': symbol,
'ts_utc': ts,
'price': price,
'return_1m': ret_1m,
'ma_5': ma_5,
'ma_15': ma_15,
'vol_15': vol_15,
})
file_exists = os.path.exists(FEATURES_1M_CSV)
with open(FEATURES_1M_CSV, 'a', newline='', encoding='utf-8-sig') as f:
writer = csv.DictWriter(f, fieldnames=_FEATURES_1M_HEADER)
if not file_exists:
writer.writeheader()
for fr in feature_rows:
writer.writerow(fr)
# ---------- etl_runs.csv ----------
_ETL_RUNS_HEADER = [
'run_ts_utc','loop','fetched_count','signal_count','duration_seconds','errors'
]
def append_etl_run(loop: int, fetched: int, signals: int, duration: float, errors: int = 0) -> None:
file_exists = os.path.exists(ETL_RUNS_CSV)
with open(ETL_RUNS_CSV, 'a', newline='', encoding='utf-8-sig') as f:
writer = csv.DictWriter(f, fieldnames=_ETL_RUNS_HEADER)
if not file_exists:
writer.writeheader()
writer.writerow({
'run_ts_utc': _utc_now_iso(),
'loop': loop,
'fetched_count': fetched,
'signal_count': signals,
'duration_seconds': f'{duration:.3f}',
'errors': errors,
})