Files
ai_stock/data_writer.py
2025-12-08 15:30:19 +08:00

493 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
CSV 数据落地模块(基于 DATA_SCHEMA.md 的简化实现)
- symbols.csv
- bars_1m.csv
- signals.csv
说明:
- 不做真正的 UpsertCSV 不擅长),通过读取现有行建立内存索引,避免重复写入关键键。
- 比率字段(如涨跌幅)采用小数存储,例如 4.02% 存 0.0402。
"""
import csv
import os
from datetime import datetime, timezone
from typing import Iterable, Dict, Any, List, Tuple
from utils_id import stable_symbol_id
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
SYMBOLS_CSV = os.path.join(DATA_DIR, "symbols.csv")
BARS_1M_CSV = os.path.join(DATA_DIR, "bars_1m.csv")
SIGNALS_CSV = os.path.join(DATA_DIR, "signals.csv")
FEATURES_1M_CSV = os.path.join(DATA_DIR, "features_1m.csv")
ETL_RUNS_CSV = os.path.join(DATA_DIR, "etl_runs.csv")
PREMARKET_BARS_CSV = os.path.join(DATA_DIR, "premarket_bars.csv")
PREMARKET_SIGNALS_CSV = os.path.join(DATA_DIR, "premarket_signals.csv")
# 确保目录存在
os.makedirs(DATA_DIR, exist_ok=True)
def _utc_now_iso() -> str:
return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
def _floor_minute(dt: datetime) -> datetime:
return dt.replace(second=0, microsecond=0, tzinfo=timezone.utc)
# ---------- symbols.csv ----------
_SYMBOLS_HEADER = [
"id","symbol","name","exchange","currency",
"tick_size","lot_size","sector","industry",
"is_active","first_seen_utc","last_seen_utc"
]
def write_symbols(stocks: Iterable[Dict[str, Any]]) -> Dict[str, int]:
"""将股票基础信息写入 symbols.csv并返回 symbol->symbol_id 映射。
stocks: 需包含 keys: symbol, name, exchange, currency
"""
existing: Dict[Tuple[str,str], Dict[str, str]] = {}
if os.path.exists(SYMBOLS_CSV):
with open(SYMBOLS_CSV, "r", encoding="utf-8-sig") as f:
reader = csv.DictReader(f)
for row in reader:
existing[(row["symbol"], row["exchange"])] = row
now = _utc_now_iso()
# 生成/更新
for s in stocks:
symbol = s.get("symbol")
name = s.get("name")
exchange = (s.get("exchange") or "US").upper()
currency = (s.get("currency") or "USD").upper()
key = (symbol, exchange)
if key not in existing:
sid = stable_symbol_id(symbol, exchange)
existing[key] = {
"id": str(sid),
"symbol": symbol,
"name": name or "",
"exchange": exchange,
"currency": currency,
"tick_size": "",
"lot_size": "",
"sector": "",
"industry": "",
"is_active": "1",
"first_seen_utc": now,
"last_seen_utc": now,
}
else:
existing[key]["last_seen_utc"] = now
# 写回
with open(SYMBOLS_CSV, "w", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=_SYMBOLS_HEADER)
writer.writeheader()
for row in existing.values():
writer.writerow(row)
# 返回映射
return {k[0]: int(v["id"]) for k, v in existing.items() if k[0] == v["symbol"]}
# ---------- bars_1m.csv ----------
_BARS_1M_HEADER = [
"symbol_id","symbol","ts_utc","open","high","low","close",
"volume","vwap","trades_count","source","session"
]
def _upgrade_bars_file_if_needed():
"""如果历史 bars_1m.csv 缺少 session 列,进行一次升级重写,补 session='regular'"""
if not os.path.exists(BARS_1M_CSV):
return
try:
with open(BARS_1M_CSV, 'r', encoding='utf-8-sig') as f:
reader = csv.reader(f)
rows = list(reader)
if not rows:
return
header = rows[0]
if 'session' in header:
return # 已升级
# 构造新文件内容
old_header = header
# 建立列索引映射
idx_map = {col: i for i, col in enumerate(old_header)}
new_rows = []
new_rows.append(_BARS_1M_HEADER)
for r in rows[1:]:
if not r:
continue
# 依据旧列生成新行
new_line = [
r[idx_map.get('symbol_id','')],
r[idx_map.get('symbol','')],
r[idx_map.get('ts_utc','')],
r[idx_map.get('open','')],
r[idx_map.get('high','')],
r[idx_map.get('low','')],
r[idx_map.get('close','')],
r[idx_map.get('volume','')],
r[idx_map.get('vwap','')],
r[idx_map.get('trades_count','')],
r[idx_map.get('source','')],
'regular'
]
new_rows.append(new_line)
# 写回升级
with open(BARS_1M_CSV, 'w', newline='', encoding='utf-8-sig') as f:
writer = csv.writer(f)
writer.writerows(new_rows)
except Exception as e:
print(f"⚠️ bars_1m.csv 升级失败: {e}")
def append_bars_1m(stocks: Iterable[Dict[str, Any]], symbol_id_map: Dict[str, int], source: str = "eastmoney") -> List[Dict[str, Any]]:
"""将当前快照近似为 1 分钟线写入 bars_1m.csv。
由于只有快照open/high/low/close 统一使用 current_pricevolume/vwap/trades_count 为空。
"""
now = _floor_minute(datetime.now(timezone.utc)).strftime("%Y-%m-%dT%H:%M:%SZ")
rows: List[Dict[str, Any]] = []
_upgrade_bars_file_if_needed()
for s in stocks:
symbol = s.get("symbol")
price = s.get("eastmoney_price") or s.get("current_price")
if price is None:
continue
sid = symbol_id_map.get(symbol) or stable_symbol_id(symbol)
rows.append({
"symbol_id": sid,
"symbol": symbol,
"ts_utc": now,
"open": price,
"high": price,
"low": price,
"close": price,
"volume": "",
"vwap": "",
"trades_count": "",
"source": source,
"session": "regular",
})
# 追加写
file_exists = os.path.exists(BARS_1M_CSV)
with open(BARS_1M_CSV, "a", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=_BARS_1M_HEADER)
if not file_exists:
writer.writeheader()
for r in rows:
writer.writerow(r)
return rows
def append_bars_session(stocks: Iterable[Dict[str, Any]], symbol_id_map: Dict[str, int], source: str = "futu", session: str = "pre") -> List[Dict[str, Any]]:
"""写入特定交易时段的快照(如盘前/盘后),与常规 bars 共存,通过 session 区分。"""
_upgrade_bars_file_if_needed()
now = _floor_minute(datetime.now(timezone.utc)).strftime("%Y-%m-%dT%H:%M:%SZ")
rows: List[Dict[str, Any]] = []
for s in stocks:
symbol = s.get("symbol")
price = s.get("premarket_price") or s.get("after_hours_price") or s.get("futu_before_open_price")
if price in (None, ""):
continue
try:
price_f = float(price)
except Exception:
continue
sid = symbol_id_map.get(symbol) or stable_symbol_id(symbol)
rows.append({
"symbol_id": sid,
"symbol": symbol,
"ts_utc": now,
"open": price_f,
"high": price_f,
"low": price_f,
"close": price_f,
"volume": "",
"vwap": "",
"trades_count": "",
"source": source,
"session": session,
})
file_exists = os.path.exists(BARS_1M_CSV)
with open(BARS_1M_CSV, "a", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=_BARS_1M_HEADER)
if not file_exists:
writer.writeheader()
for r in rows:
writer.writerow(r)
return rows
# ---------- premarket 专用快照与信号 ----------
_PREMARKET_BARS_HEADER = [
'symbol_id','symbol','name','ts_utc','ts_et','price','change','change_ratio','volume','source','session','raw_file'
]
_PREMARKET_SIGNALS_HEADER = [
'id','symbol_id','symbol','generated_at_utc','generated_at_et','signal_type','direction','score','reason','params_json','model_name','version','expires_at_utc'
]
def append_premarket_bars(rows: List[Dict[str, Any]], symbol_id_map: Dict[str, int], source: str = 'futu') -> None:
"""将盘前抓取行写入 premarket_bars.csv。
rows: 需包含 symbol,name,premarket_price,premarket_change,premarket_change_ratio(原始百分比或小数字符串), ts(ET字符串 HH:MM)
"""
if not rows:
return
file_exists = os.path.exists(PREMARKET_BARS_CSV)
now_utc = datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ')
# ET 时间字符串(便于人工查看)
try:
from zoneinfo import ZoneInfo
ts_et_full = datetime.now(ZoneInfo('America/New_York')).strftime('%Y-%m-%dT%H:%M:%S')
except Exception:
ts_et_full = ''
with open(PREMARKET_BARS_CSV, 'a', newline='', encoding='utf-8-sig') as f:
writer = csv.DictWriter(f, fieldnames=_PREMARKET_BARS_HEADER)
if not file_exists:
writer.writeheader()
for r in rows:
symbol = r.get('symbol')
if not symbol:
continue
price = r.get('premarket_price')
if price in (None,'','-'):
continue
try:
price_f = float(price)
except Exception:
continue
# ratio 原始可能是 "3.21%" / "-3.21%" / "0.0321" / ""
ratio_raw = r.get('premarket_change_ratio')
ratio_val = 0.0
if ratio_raw not in (None,''):
txt = str(ratio_raw).strip()
try:
if txt.endswith('%'):
ratio_val = float(txt.replace('%',''))/100.0
else:
# 若原始是小数形式(0.0321)或绝对值>1的百分值(3.21),都兼容
num = float(txt)
ratio_val = num/100.0 if abs(num) > 1 and abs(num) >= 2 else num # 粗略判断
except Exception:
ratio_val = 0.0
sid = symbol_id_map.get(symbol) or stable_symbol_id(symbol)
writer.writerow({
'symbol_id': sid,
'symbol': symbol,
'name': r.get('name',''),
'ts_utc': now_utc,
'ts_et': ts_et_full,
'price': price_f,
'change': r.get('premarket_change',''),
'change_ratio': ratio_val,
'volume': '',
'source': source,
'session': 'pre',
'raw_file': '',
})
def append_premarket_signals(signals: List[Dict[str, Any]], symbol_id_map: Dict[str, int]) -> None:
"""写入盘前信号到 premarket_signals.csv。
signals: 需包含 symbol,direction(BUY/SELL),reason,params_json(可选)
"""
if not signals:
return
file_exists = os.path.exists(PREMARKET_SIGNALS_CSV)
model_name, version = _def_model
now_utc = datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ')
try:
from zoneinfo import ZoneInfo
now_et = datetime.now(ZoneInfo('America/New_York')).strftime('%Y-%m-%dT%H:%M:%S')
except Exception:
now_et = ''
# 简单去重: 同 symbol+direction+当前UTC秒 不重复
seen = set()
if file_exists:
with open(PREMARKET_SIGNALS_CSV,'r',encoding='utf-8-sig') as f:
reader = csv.DictReader(f)
for row in reader:
seen.add((row['symbol'],row['direction'],row['generated_at_utc']))
with open(PREMARKET_SIGNALS_CSV,'a',newline='',encoding='utf-8-sig') as f:
writer = csv.DictWriter(f, fieldnames=_PREMARKET_SIGNALS_HEADER)
if not file_exists:
writer.writeheader()
for sig in signals:
symbol = sig.get('symbol')
direction = sig.get('direction')
if not symbol or not direction:
continue
key = (symbol,direction,now_utc)
if key in seen:
continue
sid = symbol_id_map.get(symbol) or stable_symbol_id(symbol)
params_obj = sig.get('params') or {}
writer.writerow({
'id': f'{sid}-{now_utc}',
'symbol_id': sid,
'symbol': symbol,
'generated_at_utc': now_utc,
'generated_at_et': now_et,
'signal_type': sig.get('signal_type','premarket_alert'),
'direction': direction,
'score': sig.get('score',''),
'reason': sig.get('reason',''),
'params_json': json.dumps(params_obj, ensure_ascii=False),
'model_name': model_name,
'version': version,
'expires_at_utc': '',
})
# ---------- signals.csv ----------
_SIGNALS_HEADER = [
"id","symbol_id","symbol","generated_at_utc",
"signal_type","direction","score","horizon",
"params_json","model_name","version","expires_at_utc"
]
_def_model = ("rule_threshold", "v1")
import json
def append_signals(signals: Iterable[Dict[str, Any]], symbol_id_map: Dict[str, int]) -> None:
"""将策略信号写入 signals.csv使用时间+symbol 做近似去重。
输入信号应包含symbol, type(BUY/SELL), reason/score 可选。
"""
file_exists = os.path.exists(SIGNALS_CSV)
seen_keys = set()
if file_exists:
with open(SIGNALS_CSV, "r", encoding="utf-8-sig") as f:
reader = csv.DictReader(f)
for row in reader:
seen_keys.add((row["symbol"], row["generated_at_utc"], row.get("direction")))
model_name, version = _def_model
with open(SIGNALS_CSV, "a", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=_SIGNALS_HEADER)
if not file_exists:
writer.writeheader()
for sig in signals:
symbol = sig.get("symbol")
direction = sig.get("type") or sig.get("direction")
gen_at = sig.get('generated_at_utc') or _utc_now_iso()
key = (symbol, gen_at, direction)
if key in seen_keys:
continue
sid = symbol_id_map.get(symbol) or stable_symbol_id(symbol)
writer.writerow({
"id": f"{sid}-{gen_at}",
"symbol_id": sid,
"symbol": symbol,
"generated_at_utc": gen_at,
"signal_type": "momentum",
"direction": direction,
"score": sig.get("confidence", ""),
"horizon": "intraday",
"params_json": json.dumps({"reason": sig.get("reason", "")}, ensure_ascii=False),
"model_name": model_name,
"version": version,
"expires_at_utc": "",
})
# ---------- features_1m.csv ----------
_FEATURES_1M_HEADER = [
'symbol_id','symbol','ts_utc','price','return_1m','ma_5','ma_15','vol_15'
]
def _load_existing_prices() -> Dict[str, List[Tuple[str, float]]]:
data: Dict[str, List[Tuple[str, float]]] = {}
if not os.path.exists(BARS_1M_CSV):
return data
with open(BARS_1M_CSV, 'r', encoding='utf-8-sig') as f:
reader = csv.DictReader(f)
for row in reader:
symbol = row['symbol']
ts = row['ts_utc']
try:
price = float(row['close'])
except Exception:
continue
data.setdefault(symbol, []).append((ts, price))
# 保证按时间排序CSV 追加已有序,但防御性处理)
for sym in data:
data[sym].sort(key=lambda x: x[0])
return data
def append_features_1m(new_bar_rows: List[Dict[str, Any]]) -> None:
if not new_bar_rows:
return
price_history = _load_existing_prices()
feature_rows: List[Dict[str, Any]] = []
# 按新增行计算特征
for r in new_bar_rows:
symbol = r['symbol']
sid = r['symbol_id']
ts = r['ts_utc']
try:
price = float(r['close'])
except Exception:
continue
series = price_history.get(symbol, [])
# 找到当前索引位置
# 防御series 已包含当前行,因为新行已追加;若未包含则添加再计算
if not series or series[-1][0] != ts:
series.append((ts, price))
idx = len(series) - 1
# return_1m
ret_1m = 0.0
if idx >= 1:
prev_price = series[idx-1][1]
if prev_price != 0:
ret_1m = (price / prev_price) - 1
# ma_5
window5 = [p for _, p in series[max(0, idx-4):idx+1]]
ma_5 = sum(window5)/len(window5) if window5 else price
# ma_15
window15 = [p for _, p in series[max(0, idx-14):idx+1]]
ma_15 = sum(window15)/len(window15) if window15 else price
# vol_15 = 标准差
vol_15 = 0.0
if len(window15) > 1:
avg15 = ma_15
var = sum((p-avg15)**2 for p in window15)/ (len(window15)-1)
vol_15 = var**0.5
feature_rows.append({
'symbol_id': sid,
'symbol': symbol,
'ts_utc': ts,
'price': price,
'return_1m': ret_1m,
'ma_5': ma_5,
'ma_15': ma_15,
'vol_15': vol_15,
})
file_exists = os.path.exists(FEATURES_1M_CSV)
with open(FEATURES_1M_CSV, 'a', newline='', encoding='utf-8-sig') as f:
writer = csv.DictWriter(f, fieldnames=_FEATURES_1M_HEADER)
if not file_exists:
writer.writeheader()
for fr in feature_rows:
writer.writerow(fr)
# ---------- etl_runs.csv ----------
_ETL_RUNS_HEADER = [
'run_ts_utc','loop','fetched_count','signal_count','duration_seconds','errors'
]
def append_etl_run(loop: int, fetched: int, signals: int, duration: float, errors: int = 0) -> None:
file_exists = os.path.exists(ETL_RUNS_CSV)
with open(ETL_RUNS_CSV, 'a', newline='', encoding='utf-8-sig') as f:
writer = csv.DictWriter(f, fieldnames=_ETL_RUNS_HEADER)
if not file_exists:
writer.writeheader()
writer.writerow({
'run_ts_utc': _utc_now_iso(),
'loop': loop,
'fetched_count': fetched,
'signal_count': signals,
'duration_seconds': f'{duration:.3f}',
'errors': errors,
})