ip_image_gen / loader.py
nsfwalex's picture
update
82635c8
raw
history blame
3.2 kB
import os
from huggingface_hub import hf_hub_download
def load_script(file_str: str):
"""
file_str: something like 'myorg/myrepo/mysubfolder/myscript.py'
This function downloads the file from the Hugging Face Hub into ./ (current directory).
"""
try:
# Split the path by "/"
parts = file_str.split("/")
if len(parts) < 3:
raise ValueError(
f"Invalid file specification '{file_str}'. "
f"Expected format: 'repo_id/[subfolder]/filename'"
)
# First two parts form the repo_id (e.g. 'myorg/myrepo')
repo_id = "/".join(parts[:2])
# Last part is the actual filename (e.g. 'myscript.py')
filename = parts[-1]
# Anything between the second and last parts is a subfolder path
subfolder = None
if len(parts) > 3:
subfolder = "/".join(parts[2:-1])
# Retrieve HF token from environment
hf_token = os.getenv("HF_TOKEN", None)
# Download the file into current directory "."
file_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder,
token=hf_token,
local_dir="." # Download into current directory
)
print(f"Downloaded {filename} from {repo_id} to {file_path}")
return file_path
except Exception as e:
print(f"Error downloading the script '{file_str}': {e}")
return None
def load_scripts():
"""
1. Get the path of the 'FILE_LIST' file from the environment variable FILE_LIST.
2. Download that file list using load_script().
3. Read its lines, and each line is another file to be downloaded using load_script().
4. After all lines are downloaded, execute the last file.
"""
file_list = os.getenv("FILE_LIST", "").strip()
if not file_list:
print("No FILE_LIST environment variable set. Nothing to download.")
return
# Step 1: Download the file list itself
file_list_path = load_script(file_list)
if not file_list_path or not os.path.exists(file_list_path):
print(f"Could not download or find file list: {file_list_path}")
return
# Step 2: Read each line in the downloaded file list
try:
with open(file_list_path, 'r') as f:
lines = [line.strip() for line in f if line.strip()]
except Exception as e:
print(f"Error reading file list: {e}")
return
# Step 3: Download each file from the lines
downloaded_files = []
for file_str in lines:
file_path = load_script(file_str)
if file_path:
downloaded_files.append(file_path)
# Step 4: Execute the last downloaded file
if downloaded_files:
last_file_path = downloaded_files[-1]
print(f"Executing the last downloaded script: {last_file_path}")
try:
with open(last_file_path, 'r') as f:
exec(f.read(), globals())
except Exception as e:
print(f"Error executing the last downloaded script: {e}")
# Run the load_scripts function
load_scripts()