Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, BackgroundTasks | |
import uuid | |
import os | |
from huggingface_hub import snapshot_download | |
from flux_train import build_job | |
from toolkit.job import run_job | |
app = FastAPI() | |
REPO_ID = "rahul7star/ohamlab" | |
FOLDER_IN_REPO = "filter-demo/upload_20250708_041329_9c5c81" | |
CONCEPT_SENTENCE = "ohamlab style" | |
LORA_NAME = "ohami_filter_autorun" | |
HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
status = {"running": False, "last_job": None, "error": None} | |
def run_lora_training(): | |
try: | |
status.update({"running": True, "error": None}) | |
local_dir = f"/tmp/{LORA_NAME}-{uuid.uuid4()}" | |
snapshot_download( | |
repo_id=REPO_ID, | |
repo_type="dataset", | |
allow_patterns=[f"{FOLDER_IN_REPO}/*"], | |
local_dir=local_dir, | |
local_dir_use_symlinks=False | |
) | |
training_path = os.path.join(local_dir, FOLDER_IN_REPO) | |
job = build_job(CONCEPT_SENTENCE, training_path, LORA_NAME) | |
run_job(job) | |
status.update({"running": False, "last_job": job}) | |
except Exception as e: | |
status.update({"running": False, "error": str(e)}) | |
def root(): | |
return {"message": "LoRA training FastAPI is live."} | |
def get_status(): | |
return status | |
def start_training(background_tasks: BackgroundTasks): | |
if status["running"]: | |
return {"message": "A training job is already running."} | |
background_tasks.add_task(run_lora_training) | |
return {"message": "Training started in background."} | |