Spaces:
Build error
Build error
import os | |
import json | |
import time | |
import random | |
from collections import defaultdict | |
from datetime import date, datetime, timedelta | |
import gradio as gr | |
import pandas as pd | |
import finnhub | |
from huggingface_hub import hf_hub_download, list_repo_files | |
from llama_cpp import Llama | |
from io import StringIO | |
import requests | |
from requests.adapters import HTTPAdapter | |
from urllib3.util.retry import Retry | |
import platform | |
# Suppress Google Cloud warnings | |
os.environ['GRPC_VERBOSITY'] = 'ERROR' | |
os.environ['GRPC_TRACE'] = '' | |
# Suppress other warnings | |
import warnings | |
warnings.filterwarnings('ignore', category=UserWarning) | |
warnings.filterwarnings('ignore', category=FutureWarning) | |
# ---------- CẤU HÌNH --------------------------------------------------------- | |
# Local GGUF model config (CPU-only HF Spaces ~16GB RAM) | |
# Default to a non-Qwen LLaMA-arch model to ensure compatibility | |
GGUF_REPO = os.getenv("GGUF_REPO", "QuantFactory/Meta-Llama-3.1-8B-Instruct-GGUF") | |
# Default to lighter quant to reduce RAM (can override via env) | |
GGUF_FILENAME = os.getenv("GGUF_FILENAME", "Meta-Llama-3.1-8B-Instruct.Q4_K_S.gguf") | |
N_CTX = int(os.getenv("LLAMA_N_CTX", "2048")) | |
N_THREADS = int(os.getenv("LLAMA_N_THREADS", str(os.cpu_count() or 4))) | |
N_BATCH = int(os.getenv("LLAMA_N_BATCH", "128")) | |
LLM_TEMPERATURE = float(os.getenv("LLAMA_TEMPERATURE", "0.2")) | |
# KV-cache quantization override | |
LLAMA_KV_TYPE_K = os.getenv("LLAMA_KV_TYPE_K", "q5_0") | |
LLAMA_KV_TYPE_V = os.getenv("LLAMA_KV_TYPE_V", "q4_0") | |
# Optional: Use pre-mounted local GGUF path to avoid any downloads | |
GGUF_LOCAL_PATH = os.getenv("GGUF_LOCAL_PATH", "").strip() or None | |
# Optional: Alternate non-Qwen repo fallback (e.g. a repo that contains LLaMA-arch GGUFs) | |
GGUF_REPO_ALT = os.getenv("GGUF_REPO_ALT", "").strip() or None | |
# RapidAPI Configuration | |
RAPIDAPI_HOST = "alpha-vantage.p.rapidapi.com" | |
# Load Finnhub API keys from single secret (multiple keys separated by newlines) | |
FINNHUB_KEYS_RAW = os.getenv("FINNHUB_KEYS", "") | |
if FINNHUB_KEYS_RAW: | |
FINNHUB_KEYS = [key.strip() for key in FINNHUB_KEYS_RAW.split('\n') if key.strip()] | |
else: | |
FINNHUB_KEYS = [] | |
# Load RapidAPI keys from single secret (multiple keys separated by newlines) | |
RAPIDAPI_KEYS_RAW = os.getenv("RAPIDAPI_KEYS", "") | |
if RAPIDAPI_KEYS_RAW: | |
RAPIDAPI_KEYS = [key.strip() for key in RAPIDAPI_KEYS_RAW.split('\n') if key.strip()] | |
else: | |
RAPIDAPI_KEYS = [] | |
# Placeholder for compatibility; no Google keys needed with local model | |
GOOGLE_API_KEYS = [] | |
# Filter out empty keys | |
FINNHUB_KEYS = [key for key in FINNHUB_KEYS if key.strip()] | |
GOOGLE_API_KEYS = [key for key in GOOGLE_API_KEYS if key.strip()] | |
# Validate that we have at least one key for each service | |
if not FINNHUB_KEYS: | |
print("⚠️ Warning: No Finnhub API keys found in secrets") | |
if not RAPIDAPI_KEYS: | |
print("⚠️ Warning: No RapidAPI keys found in secrets") | |
if not GOOGLE_API_KEYS: | |
print("⚠️ Warning: No Google API keys found in secrets") | |
# Chọn ngẫu nhiên một khóa API để bắt đầu (if available) | |
GOOGLE_API_KEY = random.choice(GOOGLE_API_KEYS) if GOOGLE_API_KEYS else None | |
print("=" * 50) | |
print("🚀 FinRobot Forecaster Starting Up...") | |
print("=" * 50) | |
if FINNHUB_KEYS: | |
print(f"📊 Finnhub API: {len(FINNHUB_KEYS)} keys loaded") | |
else: | |
print("📊 Finnhub API: Not configured") | |
if RAPIDAPI_KEYS: | |
print(f"📈 RapidAPI Alpha Vantage: {RAPIDAPI_HOST} ({len(RAPIDAPI_KEYS)} keys loaded)") | |
else: | |
print("📈 RapidAPI Alpha Vantage: Not configured") | |
print("🧠 Local LLM (llama.cpp) will be used: "+GGUF_REPO+"/"+GGUF_FILENAME) | |
print("✅ Application started successfully!") | |
print("=" * 50) | |
# Download GGUF model and initialize llama.cpp | |
_LLM = None | |
_TOKENS_PER_SECOND_INFO = None | |
def _resolve_and_download_gguf(repo_id: str, preferred_filename: str) -> str: | |
"""Resolve correct GGUF filename (case-sensitive) and download. | |
Strategy: | |
1) Try preferred filename directly | |
2) List repo files; pick case-insensitive match | |
3) Prefer files containing the same quant tag (e.g., Q5_K_M) ignoring case | |
4) Fallback to any .gguf in the repo | |
""" | |
# 0) If local path provided, use it directly | |
if GGUF_LOCAL_PATH and os.path.exists(GGUF_LOCAL_PATH): | |
print(f"➡️ Using local GGUF at {GGUF_LOCAL_PATH}") | |
return GGUF_LOCAL_PATH | |
# 1) Direct attempt | |
try: | |
return hf_hub_download(repo_id=repo_id, filename=preferred_filename, local_dir="/home/user/.cache/hf") | |
except Exception: | |
pass | |
# 2) List repo files | |
try: | |
files = list_repo_files(repo_id=repo_id, repo_type="model") | |
ggufs = [f for f in files if f.lower().endswith(".gguf")] | |
# Prefer non-Qwen models to avoid unsupported 'qwen3' architecture in some builds | |
ggufs_non_qwen = [f for f in ggufs if "qwen" not in f.lower()] | |
preferred_pool = ggufs_non_qwen or ggufs | |
if not ggufs: | |
raise RuntimeError("No .gguf files found in repo") | |
# Strong allowlist preference order (non-Qwen variants) | |
strong_order = [ | |
"Fin-o1-14B.Q5_K_S.gguf", | |
"Fin-o1-14B.Q6_K.gguf", | |
"Fin-o1-14B.Q4_K_S.gguf", | |
] | |
for fname in strong_order: | |
if fname in preferred_pool: | |
return hf_hub_download(repo_id=repo_id, filename=fname, local_dir="/home/user/.cache/hf") | |
# Case-insensitive exact match | |
lower_map = {f.lower(): f for f in preferred_pool} | |
pref_lower = preferred_filename.lower() | |
if pref_lower in lower_map: | |
return hf_hub_download(repo_id=repo_id, filename=lower_map[pref_lower], local_dir="/home/user/.cache/hf") | |
# Extract quant token from preferred, e.g., Q5_K_M or Q6_K | |
import re | |
m = re.search(r"q\d+[_a-z]*", pref_lower) | |
quant = m.group(0) if m else None | |
if quant: | |
# Find any file containing that quant token (case-insensitive) | |
candidates = [f for f in preferred_pool if quant in f.lower()] | |
# Prefer Fin-o1-14B prefix if multiple | |
candidates.sort(key=lambda s: (not s.startswith("Fin-o1-14B"), s)) | |
if candidates: | |
return hf_hub_download(repo_id=repo_id, filename=candidates[0], local_dir="/home/user/.cache/hf") | |
# 4) Fallback: first non-Qwen .gguf (alphabetical) | |
preferred_pool.sort() | |
return hf_hub_download(repo_id=repo_id, filename=preferred_pool[0], local_dir="/home/user/.cache/hf") | |
except Exception as e: | |
# As a final attempt, try alternate repo if provided | |
if GGUF_REPO_ALT: | |
try: | |
print(f"ℹ️ Trying alternate repo: {GGUF_REPO_ALT}") | |
files = list_repo_files(repo_id=GGUF_REPO_ALT, repo_type="model") | |
ggufs = [f for f in files if f.lower().endswith(".gguf") and "qwen" not in f.lower()] | |
ggufs.sort() | |
if ggufs: | |
return hf_hub_download(repo_id=GGUF_REPO_ALT, filename=ggufs[0], local_dir="/home/user/.cache/hf") | |
except Exception as ee: | |
raise ee | |
raise e | |
try: | |
print("⬇️ Downloading GGUF model from Hugging Face Hub if not cached...") | |
gguf_path = _resolve_and_download_gguf(GGUF_REPO, GGUF_FILENAME) | |
print(f"✅ Model file ready: {gguf_path}") | |
print("🚀 Initializing llama.cpp (CPU)") | |
_LLM = Llama( | |
model_path=gguf_path, | |
n_ctx=N_CTX, | |
n_threads=N_THREADS, | |
n_batch=N_BATCH, | |
use_mlock=False, | |
use_mmap=True, | |
logits_all=False, | |
kv_overrides={"type_k": "q5_0", "type_v": "q4_0"}, | |
) | |
print("✅ Llama initialized") | |
except Exception as e: | |
print(f"❌ Failed to initialize local LLM: {e}") | |
_LLM = None | |
# Cấu hình Finnhub client (if keys available) | |
if FINNHUB_KEYS: | |
# Configure with first key for initial setup | |
finnhub_client = finnhub.Client(api_key=FINNHUB_KEYS[0]) | |
print(f"✅ Finnhub configured with {len(FINNHUB_KEYS)} keys") | |
else: | |
finnhub_client = None | |
print("⚠️ Finnhub not configured - will use mock news data") | |
# Tạo session với retry strategy cho requests | |
def create_session(): | |
session = requests.Session() | |
retry_strategy = Retry( | |
total=3, | |
backoff_factor=1, | |
status_forcelist=[429, 500, 502, 503, 504], | |
) | |
adapter = HTTPAdapter(max_retries=retry_strategy) | |
session.mount("http://", adapter) | |
session.mount("https://", adapter) | |
return session | |
# Tạo session global | |
requests_session = create_session() | |
SYSTEM_PROMPT = ( | |
"You are a seasoned stock-market analyst. " | |
"Given recent company news and optional basic financials, " | |
"return:\n" | |
"[Positive Developments] – 2-4 bullets\n" | |
"[Potential Concerns] – 2-4 bullets\n" | |
"[Prediction & Analysis] – a one-week price outlook with rationale." | |
) | |
# ---------- UTILITY HELPERS ---------------------------------------- | |
def today() -> str: | |
return date.today().strftime("%Y-%m-%d") | |
def n_weeks_before(date_string: str, n: int) -> str: | |
return (datetime.strptime(date_string, "%Y-%m-%d") - | |
timedelta(days=7 * n)).strftime("%Y-%m-%d") | |
# ---------- DATA FETCHING -------------------------------------------------- | |
def get_stock_data(symbol: str, steps: list[str]) -> pd.DataFrame: | |
# Thử tất cả RapidAPI Alpha Vantage keys | |
for rapidapi_key in RAPIDAPI_KEYS: | |
try: | |
print(f"📈 Fetching stock data for {symbol} via RapidAPI (key: {rapidapi_key[:8]}...)") | |
# RapidAPI Alpha Vantage endpoint | |
url = f"https://{RAPIDAPI_HOST}/query" | |
headers = { | |
"X-RapidAPI-Host": RAPIDAPI_HOST, | |
"X-RapidAPI-Key": rapidapi_key | |
} | |
params = { | |
"function": "TIME_SERIES_DAILY", | |
"symbol": symbol, | |
"outputsize": "full", | |
"datatype": "csv" | |
} | |
# Thử lại 3 lần với RapidAPI key hiện tại | |
for attempt in range(3): | |
try: | |
resp = requests_session.get(url, headers=headers, params=params, timeout=30) | |
if not resp.ok: | |
print(f"RapidAPI HTTP error {resp.status_code} with key {rapidapi_key[:8]}..., attempt {attempt + 1}") | |
time.sleep(2 ** attempt) | |
continue | |
text = resp.text.strip() | |
if text.startswith("{"): | |
info = resp.json() | |
msg = info.get("Note") or info.get("Error Message") or info.get("Information") or str(info) | |
if "rate limit" in msg.lower() or "quota" in msg.lower(): | |
print(f"RapidAPI rate limit hit with key {rapidapi_key[:8]}..., trying next key") | |
break # Thử key tiếp theo | |
raise RuntimeError(f"RapidAPI Alpha Vantage Error: {msg}") | |
# Parse CSV data | |
df = pd.read_csv(StringIO(text)) | |
date_col = "timestamp" if "timestamp" in df.columns else df.columns[0] | |
df[date_col] = pd.to_datetime(df[date_col]) | |
df = df.sort_values(date_col).set_index(date_col) | |
data = {"Start Date": [], "End Date": [], "Start Price": [], "End Price": []} | |
for i in range(len(steps) - 1): | |
s_date = pd.to_datetime(steps[i]) | |
e_date = pd.to_datetime(steps[i+1]) | |
seg = df.loc[s_date:e_date] | |
if seg.empty: | |
raise RuntimeError( | |
f"RapidAPI Alpha Vantage cannot get {symbol} data for {steps[i]} – {steps[i+1]}" | |
) | |
data["Start Date"].append(seg.index[0]) | |
data["Start Price"].append(seg["close"].iloc[0]) | |
data["End Date"].append(seg.index[-1]) | |
data["End Price"].append(seg["close"].iloc[-1]) | |
time.sleep(1) # RapidAPI has higher limits | |
print(f"✅ Successfully retrieved {symbol} data via RapidAPI (key: {rapidapi_key[:8]}...)") | |
return pd.DataFrame(data) | |
except requests.exceptions.Timeout: | |
print(f"RapidAPI timeout with key {rapidapi_key[:8]}..., attempt {attempt + 1}") | |
if attempt < 2: | |
time.sleep(5 * (attempt + 1)) | |
continue | |
else: | |
break | |
except requests.exceptions.RequestException as e: | |
print(f"RapidAPI request error with key {rapidapi_key[:8]}..., attempt {attempt + 1}: {e}") | |
if attempt < 2: | |
time.sleep(3) | |
continue | |
else: | |
break | |
except Exception as e: | |
print(f"RapidAPI Alpha Vantage failed with key {rapidapi_key[:8]}...: {e}") | |
continue # Thử key tiếp theo | |
# Fallback: Tạo mock data nếu tất cả RapidAPI keys đều fail | |
print("⚠️ All RapidAPI keys failed, using mock data for demonstration...") | |
return create_mock_stock_data(symbol, steps) | |
def create_mock_stock_data(symbol: str, steps: list[str]) -> pd.DataFrame: | |
"""Tạo mock data để demo khi API không hoạt động""" | |
import numpy as np | |
data = {"Start Date": [], "End Date": [], "Start Price": [], "End Price": []} | |
# Giá cơ bản khác nhau cho các symbol khác nhau | |
base_prices = { | |
"AAPL": 180.0, "MSFT": 350.0, "GOOGL": 140.0, | |
"TSLA": 200.0, "NVDA": 450.0, "AMZN": 150.0 | |
} | |
base_price = base_prices.get(symbol.upper(), 150.0) | |
for i in range(len(steps) - 1): | |
s_date = pd.to_datetime(steps[i]) | |
e_date = pd.to_datetime(steps[i+1]) | |
# Tạo giá ngẫu nhiên với xu hướng tăng nhẹ | |
start_price = base_price + np.random.normal(0, 5) | |
end_price = start_price + np.random.normal(2, 8) # Xu hướng tăng nhẹ | |
data["Start Date"].append(s_date) | |
data["Start Price"].append(round(start_price, 2)) | |
data["End Date"].append(e_date) | |
data["End Price"].append(round(end_price, 2)) | |
base_price = end_price # Cập nhật giá cơ bản cho tuần tiếp theo | |
return pd.DataFrame(data) | |
def current_basics(symbol: str, curday: str) -> dict: | |
# Check if Finnhub is configured | |
if not FINNHUB_KEYS: | |
print(f"⚠️ Finnhub not configured, skipping financial basics for {symbol}") | |
return {} | |
# Thử với tất cả các Finnhub API keys | |
for api_key in FINNHUB_KEYS: | |
try: | |
client = finnhub.Client(api_key=api_key) | |
# Thêm timeout cho Finnhub client | |
raw = client.company_basic_financials(symbol, "all") | |
if not raw["series"]: | |
continue | |
merged = defaultdict(dict) | |
for metric, vals in raw["series"]["quarterly"].items(): | |
for v in vals: | |
merged[v["period"]][metric] = v["v"] | |
latest = max((p for p in merged if p <= curday), default=None) | |
if latest is None: | |
continue | |
d = dict(merged[latest]) | |
d["period"] = latest | |
return d | |
except Exception as e: | |
print(f"Error getting basics for {symbol} with key {api_key[:8]}...: {e}") | |
time.sleep(2) # Thêm delay trước khi thử key tiếp theo | |
continue | |
return {} | |
def attach_news(symbol: str, df: pd.DataFrame) -> pd.DataFrame: | |
news_col = [] | |
for _, row in df.iterrows(): | |
start = row["Start Date"].strftime("%Y-%m-%d") | |
end = row["End Date"].strftime("%Y-%m-%d") | |
time.sleep(2) # Tăng delay để tránh rate limit | |
# Check if Finnhub is configured | |
if not FINNHUB_KEYS: | |
print(f"⚠️ Finnhub not configured, using mock news for {symbol}") | |
news_data = create_mock_news(symbol, start, end) | |
news_col.append(json.dumps(news_data)) | |
continue | |
# Thử với tất cả các Finnhub API keys | |
news_data = [] | |
for api_key in FINNHUB_KEYS: | |
try: | |
client = finnhub.Client(api_key=api_key) | |
weekly = client.company_news(symbol, _from=start, to=end) | |
weekly_fmt = [ | |
{ | |
"date" : datetime.fromtimestamp(n["datetime"]).strftime("%Y%m%d%H%M%S"), | |
"headline": n["headline"], | |
"summary" : n["summary"], | |
} | |
for n in weekly | |
] | |
weekly_fmt.sort(key=lambda x: x["date"]) | |
news_data = weekly_fmt | |
break # Thành công, thoát khỏi loop | |
except Exception as e: | |
print(f"Error with Finnhub key {api_key[:8]}... for {symbol} from {start} to {end}: {e}") | |
time.sleep(3) # Thêm delay trước khi thử key tiếp theo | |
continue | |
# Nếu không có news data, tạo mock news | |
if not news_data: | |
news_data = create_mock_news(symbol, start, end) | |
news_col.append(json.dumps(news_data)) | |
df["News"] = news_col | |
return df | |
def create_mock_news(symbol: str, start: str, end: str) -> list: | |
"""Tạo mock news data khi API không hoạt động""" | |
mock_news = [ | |
{ | |
"date": f"{start}120000", | |
"headline": f"{symbol} Shows Strong Performance in Recent Trading", | |
"summary": f"Company {symbol} has demonstrated resilience in the current market conditions with positive investor sentiment." | |
}, | |
{ | |
"date": f"{end}090000", | |
"headline": f"Analysts Maintain Positive Outlook for {symbol}", | |
"summary": f"Financial analysts continue to recommend {symbol} based on strong fundamentals and growth prospects." | |
} | |
] | |
return mock_news | |
# ---------- PROMPT CONSTRUCTION ------------------------------------------- | |
def sample_news(news: list[str], k: int = 5) -> list[str]: | |
if len(news) <= k: | |
return news | |
return [news[i] for i in sorted(random.sample(range(len(news)), k))] | |
def make_prompt(symbol: str, df: pd.DataFrame, curday: str, use_basics=False) -> str: | |
# Thử với tất cả các Finnhub API keys để lấy company profile | |
company_blurb = f"[Company Introduction]:\n{symbol} is a publicly traded company.\n" | |
if FINNHUB_KEYS: | |
for api_key in FINNHUB_KEYS: | |
try: | |
client = finnhub.Client(api_key=api_key) | |
prof = client.company_profile2(symbol=symbol) | |
company_blurb = ( | |
f"[Company Introduction]:\n{prof['name']} operates in the " | |
f"{prof['finnhubIndustry']} sector ({prof['country']}). " | |
f"Founded {prof['ipo']}, market cap {prof['marketCapitalization']:.1f} " | |
f"{prof['currency']}; ticker {symbol} on {prof['exchange']}.\n" | |
) | |
break # Thành công, thoát khỏi loop | |
except Exception as e: | |
print(f"Error getting company profile for {symbol} with key {api_key[:8]}...: {e}") | |
time.sleep(2) # Thêm delay trước khi thử key tiếp theo | |
continue | |
else: | |
print(f"⚠️ Finnhub not configured, using basic company info for {symbol}") | |
# Past weeks block | |
past_block = "" | |
for _, row in df.iterrows(): | |
term = "increased" if row["End Price"] > row["Start Price"] else "decreased" | |
head = (f"From {row['Start Date']:%Y-%m-%d} to {row['End Date']:%Y-%m-%d}, " | |
f"{symbol}'s stock price {term} from " | |
f"{row['Start Price']:.2f} to {row['End Price']:.2f}.") | |
news_items = json.loads(row["News"]) | |
summaries = [ | |
f"[Headline] {n['headline']}\n[Summary] {n['summary']}\n" | |
for n in news_items | |
if not n["summary"].startswith("Looking for stock market analysis") | |
] | |
past_block += "\n" + head + "\n" + "".join(sample_news(summaries, 5)) | |
# Optional basic financials | |
if use_basics: | |
basics = current_basics(symbol, curday) | |
if basics: | |
basics_txt = "\n".join(f"{k}: {v}" for k, v in basics.items() if k != "period") | |
basics_block = (f"\n[Basic Financials] (reported {basics['period']}):\n{basics_txt}\n") | |
else: | |
basics_block = "\n[Basic Financials]: not available\n" | |
else: | |
basics_block = "\n[Basic Financials]: not requested\n" | |
horizon = f"{curday} to {n_weeks_before(curday, -1)}" | |
final_user_msg = ( | |
company_blurb | |
+ past_block | |
+ basics_block | |
+ f"\nBased on all information before {curday}, analyse positive " | |
"developments and potential concerns for {symbol}, then predict its " | |
f"price movement for next week ({horizon})." | |
) | |
return final_user_msg | |
# ---------- LLM CALL ------------------------------------------------------- | |
def chat_completion(prompt: str, | |
model: str = "local-llama-cpp", | |
temperature: float = LLM_TEMPERATURE, | |
stream: bool = False, | |
symbol: str = "STOCK") -> str: | |
if _LLM is None: | |
print(f"⚠️ Local LLM not available, using mock response for {symbol}") | |
return create_mock_ai_response(symbol) | |
# Build a simple chat-style prompt for Qwen-based SFT | |
# Qwen-style chat can work with a plain system + user concatenation for inference | |
full_prompt = f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" | |
try: | |
if stream: | |
out_text = [] | |
for tok in _LLM( | |
full_prompt, | |
max_tokens=1024, | |
temperature=temperature, | |
top_p=0.9, | |
repeat_penalty=1.1, | |
stop=["<|im_end|>", "</s>", "<|endoftext|>"], | |
stream=True, | |
): | |
delta = tok.get("choices", [{}])[0].get("text", "") | |
if delta: | |
print(delta, end="", flush=True) | |
out_text.append(delta) | |
print() | |
return "".join(out_text) | |
else: | |
res = _LLM( | |
full_prompt, | |
max_tokens=1024, | |
temperature=temperature, | |
top_p=0.9, | |
repeat_penalty=1.1, | |
stop=["<|im_end|>", "</s>", "<|endoftext|>"] | |
) | |
return res["choices"][0]["text"].strip() | |
except Exception as e: | |
print(f"❌ LLM inference error: {e}") | |
return create_mock_ai_response(symbol) | |
def create_mock_ai_response(symbol: str) -> str: | |
"""Tạo mock AI response khi Google API không hoạt động""" | |
return f""" | |
[Positive Developments] | |
• Strong market position and brand recognition for {symbol} | |
• Recent quarterly earnings showing growth potential | |
• Positive analyst sentiment and institutional investor interest | |
• Technological innovation and market expansion opportunities | |
[Potential Concerns] | |
• Market volatility and economic uncertainty | |
• Competitive pressures in the industry | |
• Regulatory changes that may impact operations | |
• Global economic factors affecting stock performance | |
[Prediction & Analysis] | |
Based on the current market conditions and company fundamentals, {symbol} is expected to show moderate growth over the next week. The stock may experience some volatility but should maintain an upward trend with a potential price increase of 2-5%. This prediction is based on current market sentiment and technical analysis patterns. | |
Note: This is a demonstration response using mock data. For real investment decisions, please consult with qualified financial professionals. | |
""" | |
# ---------- DEBUG / DIAGNOSTICS ----------------------------------------- | |
def _safe_version(mod_name: str) -> str: | |
try: | |
mod = __import__(mod_name) | |
ver = getattr(mod, "__version__", None) | |
if ver is None and hasattr(mod, "version"): | |
try: | |
ver = mod.version.__version__ # type: ignore[attr-defined] | |
except Exception: | |
ver = None | |
return str(ver) if ver is not None else "unknown" | |
except Exception: | |
return "not installed" | |
def collect_debug_info() -> dict: | |
info = {} | |
# Model / app | |
info["model_repo"] = GGUF_REPO | |
info["model_filename"] = GGUF_FILENAME | |
info["llm_initialized"] = _LLM is not None | |
info["llama_n_ctx"] = N_CTX | |
info["llama_n_threads"] = N_THREADS | |
info["llama_n_batch"] = N_BATCH | |
# Runtime | |
info["python_version"] = platform.python_version() | |
info["platform"] = platform.platform() | |
info["machine"] = platform.machine() | |
info["processor"] = platform.processor() | |
# Libraries | |
info["libraries"] = { | |
"gradio": _safe_version("gradio"), | |
"pandas": _safe_version("pandas"), | |
"requests": _safe_version("requests"), | |
"finnhub": _safe_version("finnhub"), | |
"huggingface_hub": _safe_version("huggingface_hub"), | |
"llama_cpp": _safe_version("llama_cpp"), | |
"torch": _safe_version("torch"), | |
} | |
# Torch details (if available) | |
try: | |
import torch # type: ignore | |
cuda_available = bool(getattr(torch.cuda, "is_available", lambda: False)()) | |
cuda_count = int(getattr(torch.cuda, "device_count", lambda: 0)()) | |
devices = [] | |
if cuda_available and cuda_count > 0: | |
for i in range(cuda_count): | |
dev = {"index": i} | |
try: | |
dev["name"] = torch.cuda.get_device_name(i) | |
except Exception: | |
dev["name"] = "unknown" | |
try: | |
props = torch.cuda.get_device_properties(i) | |
dev["total_mem_gb"] = round(getattr(props, "total_memory", 0) / (1024**3), 2) | |
dev["multi_processor_count"] = getattr(props, "multi_processor_count", None) | |
dev["major"] = getattr(props, "major", None) | |
dev["minor"] = getattr(props, "minor", None) | |
except Exception: | |
pass | |
try: | |
# These require a context; guard individually | |
dev["mem_reserved_gb"] = round(torch.cuda.memory_reserved(i) / (1024**3), 3) | |
dev["mem_allocated_gb"] = round(torch.cuda.memory_allocated(i) / (1024**3), 3) | |
except Exception: | |
pass | |
devices.append(dev) | |
info["torch"] = { | |
"version": getattr(torch, "__version__", "unknown"), | |
"cuda_available": cuda_available, | |
"cuda_device_count": cuda_count, | |
"devices": devices, | |
} | |
except Exception: | |
info["torch"] = {"available": False} | |
# CPU / RAM (prefer psutil) | |
try: | |
import psutil # type: ignore | |
vm = psutil.virtual_memory() | |
info["system"] = { | |
"cpu_percent": psutil.cpu_percent(interval=0.4), | |
"ram_total_gb": round(vm.total / (1024**3), 2), | |
"ram_used_gb": round((vm.total - vm.available) / (1024**3), 2), | |
"ram_percent": vm.percent, | |
} | |
except Exception: | |
info["system"] = {"cpu_percent": "n/a", "ram_percent": "n/a"} | |
# API keys availability (counts only) | |
info["api_keys"] = { | |
"finnhub_keys_count": len(FINNHUB_KEYS), | |
"rapidapi_keys_count": len(RAPIDAPI_KEYS), | |
} | |
return info | |
# ---------- MAIN PREDICTION FUNCTION ----------------------------------------- | |
def predict(symbol: str = "AAPL", | |
curday: str = today(), | |
n_weeks: int = 3, | |
use_basics: bool = False, | |
stream: bool = False) -> tuple[str, str]: | |
try: | |
steps = [n_weeks_before(curday, n) for n in range(n_weeks + 1)][::-1] | |
df = get_stock_data(symbol, steps) | |
df = attach_news(symbol, df) | |
prompt_info = make_prompt(symbol, df, curday, use_basics) | |
answer = chat_completion(prompt_info, stream=stream, symbol=symbol) | |
return prompt_info, answer | |
except Exception as e: | |
error_msg = f"Error in prediction: {str(e)}" | |
print(f"Prediction error: {e}") # Log the error for debugging | |
return error_msg, error_msg | |
# ---------- HUGGINGFACE SPACES INTERFACE ----------------------------------------- | |
def hf_predict(symbol, n_weeks, use_basics): | |
# 1. get curday | |
curday = date.today().strftime("%Y-%m-%d") | |
# 2. call predict | |
prompt, answer = predict( | |
symbol=symbol.upper(), | |
curday=curday, | |
n_weeks=int(n_weeks), | |
use_basics=bool(use_basics), | |
stream=False | |
) | |
return prompt, answer | |
# ---------- GRADIO INTERFACE ----------------------------------------- | |
def create_interface(): | |
with gr.Blocks( | |
title="FinRobot Forecaster", | |
theme=gr.themes.Soft(), | |
css=""" | |
.gradio-container { | |
max-width: 1200px !important; | |
margin: auto !important; | |
} | |
#model_prompt_textbox textarea { | |
overflow-y: auto !important; | |
max-height: none !important; | |
min-height: 400px !important; | |
resize: vertical !important; | |
white-space: pre-wrap !important; | |
word-wrap: break-word !important; | |
height: auto !important; | |
} | |
#model_prompt_textbox { | |
height: auto !important; | |
} | |
#analysis_results_textbox textarea { | |
overflow-y: auto !important; | |
max-height: none !important; | |
min-height: 400px !important; | |
resize: vertical !important; | |
white-space: pre-wrap !important; | |
word-wrap: break-word !important; | |
height: auto !important; | |
} | |
#analysis_results_textbox { | |
height: auto !important; | |
} | |
.textarea textarea { | |
overflow-y: auto !important; | |
max-height: 500px !important; | |
resize: vertical !important; | |
} | |
.textarea { | |
height: auto !important; | |
min-height: 300px !important; | |
} | |
.gradio-textbox { | |
height: auto !important; | |
max-height: none !important; | |
} | |
.gradio-textbox textarea { | |
height: auto !important; | |
max-height: none !important; | |
overflow-y: auto !important; | |
} | |
""" | |
) as demo: | |
gr.Markdown(""" | |
# 🤖 FinRobot Forecaster | |
**AI-powered stock market analysis and prediction using advanced language models** | |
This application analyzes stock market data, company news, and financial metrics to provide comprehensive market insights and predictions. | |
⚠️ **Note**: Free API keys have daily rate limits. If you encounter errors, the app will use mock data for demonstration purposes. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
symbol = gr.Textbox( | |
label="Stock Symbol", | |
value="AAPL", | |
placeholder="Enter stock symbol (e.g., AAPL, MSFT, GOOGL)", | |
info="Enter the ticker symbol of the stock you want to analyze" | |
) | |
n_weeks = gr.Slider( | |
1, 6, | |
value=3, | |
step=1, | |
label="Historical Weeks to Analyze", | |
info="Number of weeks of historical data to include in analysis" | |
) | |
use_basics = gr.Checkbox( | |
label="Include Basic Financials", | |
value=True, | |
info="Include basic financial metrics in the analysis" | |
) | |
btn = gr.Button( | |
"🚀 Run Analysis", | |
variant="primary" | |
) | |
with gr.Column(scale=2): | |
with gr.Tabs(): | |
with gr.Tab("📊 Analysis Results"): | |
gr.Markdown("**AI Analysis & Prediction**") | |
output_answer = gr.Textbox( | |
label="", | |
lines=40, | |
show_copy_button=True, | |
interactive=False, | |
placeholder="AI analysis and predictions will appear here...", | |
container=True, | |
scale=1, | |
elem_id="analysis_results_textbox" | |
) | |
with gr.Tab("🔍 Model Prompt"): | |
gr.Markdown("**Generated Prompt**") | |
output_prompt = gr.Textbox( | |
label="", | |
lines=40, | |
show_copy_button=True, | |
interactive=False, | |
placeholder="Generated prompt will appear here...", | |
container=True, | |
scale=1, | |
elem_id="model_prompt_textbox" | |
) | |
with gr.Tab("🛠️ Debug Info"): | |
gr.Markdown("**Runtime Diagnostics**") | |
debug_json = gr.JSON(label="Debug Data", value=None) | |
refresh_btn = gr.Button("🔄 Refresh Debug Info") | |
# Examples | |
gr.Examples( | |
examples=[ | |
["AAPL", 3, False], | |
["MSFT", 4, True], | |
["GOOGL", 2, False], | |
["TSLA", 5, True], | |
["NVDA", 3, True] | |
], | |
inputs=[symbol, n_weeks, use_basics], | |
label="💡 Try these examples" | |
) | |
# Event handlers | |
btn.click( | |
fn=hf_predict, | |
inputs=[symbol, n_weeks, use_basics], | |
outputs=[output_prompt, output_answer], | |
show_progress=True | |
) | |
# Debug tab handlers | |
def _collect_debug_info_wrapper(): | |
try: | |
return collect_debug_info() | |
except Exception as e: | |
return {"error": str(e)} | |
refresh_btn.click( | |
fn=_collect_debug_info_wrapper, | |
inputs=[], | |
outputs=[debug_json], | |
show_progress=False | |
) | |
# Populate on load | |
demo.load( | |
fn=_collect_debug_info_wrapper, | |
inputs=None, | |
outputs=[debug_json] | |
) | |
# Footer | |
gr.Markdown(""" | |
--- | |
**Disclaimer**: This application is for educational and research purposes only. | |
The predictions and analysis provided should not be considered as financial advice. | |
Always consult with qualified financial professionals before making investment decisions. | |
""") | |
return demo | |
# ---------- MAIN EXECUTION ----------------------------------------- | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
show_error=True, | |
debug=False, | |
quiet=True | |
) | |