rahul7star's picture
Create app.py
9a464d1 verified
raw
history blame
1.55 kB
from fastapi import FastAPI, BackgroundTasks
import uuid
import os
from huggingface_hub import snapshot_download
from flux_train import build_job
from toolkit.job import run_job
app = FastAPI()
REPO_ID = "rahul7star/ohamlab"
FOLDER_IN_REPO = "filter-demo/upload_20250708_041329_9c5c81"
CONCEPT_SENTENCE = "ohamlab style"
LORA_NAME = "ohami_filter_autorun"
HF_TOKEN = os.environ.get("HF_TOKEN", "")
status = {"running": False, "last_job": None, "error": None}
def run_lora_training():
try:
status.update({"running": True, "error": None})
local_dir = f"/tmp/{LORA_NAME}-{uuid.uuid4()}"
snapshot_download(
repo_id=REPO_ID,
repo_type="dataset",
allow_patterns=[f"{FOLDER_IN_REPO}/*"],
local_dir=local_dir,
local_dir_use_symlinks=False
)
training_path = os.path.join(local_dir, FOLDER_IN_REPO)
job = build_job(CONCEPT_SENTENCE, training_path, LORA_NAME)
run_job(job)
status.update({"running": False, "last_job": job})
except Exception as e:
status.update({"running": False, "error": str(e)})
@app.get("/")
def root():
return {"message": "LoRA training FastAPI is live."}
@app.get("/status")
def get_status():
return status
@app.post("/train")
def start_training(background_tasks: BackgroundTasks):
if status["running"]:
return {"message": "A training job is already running."}
background_tasks.add_task(run_lora_training)
return {"message": "Training started in background."}