Spaces:
Runtime error
Runtime error
import logging | |
import os | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from peft import LoraConfig, get_peft_model | |
from trl import SFTTrainer, SFTConfig | |
from datasets import load_dataset | |
import torch | |
import tarfile | |
from huggingface_hub import HfApi | |
logging.basicConfig(level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
# Debug environment variables | |
logger.info("Environment variables: %s", {k: "****" if "TOKEN" in k or k == "granite" else v for k, v in os.environ.items()}) | |
model_path = "ibm-granite/granite-3.3-8b-instruct" | |
dataset_path = "mycholpath/ascii-json" | |
output_dir = "/app/granite-8b-finetuned-ascii" | |
output_tarball = "/app/granite-8b-finetuned-ascii.tar.gz" | |
model_repo = "mycholpath/granite-8b-finetuned-ascii" | |
artifact_repo = "mycholpath/granite-finetuned-artifacts" | |
# Get HF token from granite environment variable | |
granite_var = os.getenv("granite") | |
if not granite_var or not granite_var.startswith("HF_TOKEN="): | |
logger.error("granite environment variable is not set or invalid. Expected format: HF_TOKEN=<token>.") | |
raise ValueError("granite environment variable is not set or invalid. Please set it in HF Space settings.") | |
hf_token = granite_var.replace("HF_TOKEN=", "") | |
logger.info("HF_TOKEN extracted from granite (value hidden for security)") | |
logging.info("Loading tokenizer...") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_path, token=hf_token, cache_dir="/tmp/hf_cache", trust_remote_code=True | |
) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.padding_side = 'right' | |
except Exception as e: | |
logger.error(f"Failed to load tokenizer: {str(e)}") | |
raise | |
logging.info("Loading model...") | |
try: | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
token=hf_token, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
cache_dir="/tmp/hf_cache", | |
trust_remote_code=True | |
) | |
except Exception as e: | |
logger.error(f"Failed to load model: {str(e)}") | |
raise | |
lora_config = LoraConfig( | |
r=16, | |
lora_alpha=32, | |
target_modules=["q_proj", "v_proj"], | |
lora_dropout=0.05, | |
bias="none", | |
task_type="CAUSAL_LM" | |
) | |
model = get_peft_model(model, lora_config) | |
logging.info("Preparing to load private dataset...") | |
logger.info("Using HF_TOKEN from granite for private dataset authentication") | |
try: | |
dataset = load_dataset(dataset_path, split="train", token=hf_token) | |
logger.info(f"Dataset loaded successfully: {len(dataset)} examples") | |
except Exception as e: | |
logger.error(f"Failed to load dataset: {str(e)}") | |
raise | |
def formatting_prompts_func(example): | |
formatted = f"{example['prompt']}\n{example['completion']}" | |
return [formatted] | |
# Use SFTConfig for training arguments | |
sft_config = SFTConfig( | |
output_dir=output_dir, | |
num_train_epochs=5, | |
per_device_train_batch_size=4, | |
per_device_eval_batch_size=4, | |
gradient_accumulation_steps=4, | |
learning_rate=2e-4, | |
weight_decay=0.01, | |
eval_strategy="no", | |
save_steps=50, | |
logging_steps=10, | |
fp16=True, | |
max_grad_norm=0.3, | |
warmup_ratio=0.03, | |
lr_scheduler_type="cosine", | |
max_seq_length=768, | |
dataset_text_field=None, | |
packing=False | |
) | |
logging.info("Starting training...") | |
try: | |
trainer = SFTTrainer( | |
model=model, | |
tokenizer=tokenizer, | |
train_dataset=dataset, | |
eval_dataset=None, | |
formatting_func=formatting_prompts_func, | |
args=sft_config | |
) | |
except Exception as e: | |
logger.error(f"Failed to initialize SFTTrainer: {str(e)}") | |
raise | |
trainer.train() | |
logging.info("Saving fine-tuned model...") | |
trainer.save_model(output_dir) | |
tokenizer.save_pretrained(output_dir) | |
# Create tarball for local retrieval | |
try: | |
with tarfile.open(output_tarball, "w:gz") as tar: | |
tar.add(output_dir, arcname=os.path.basename(output_dir)) | |
logger.info(f"Model tarball created: {output_tarball}") | |
except Exception as e: | |
logger.error(f"Failed to create model tarball: {str(e)}") | |
raise | |
# Upload model to HF Hub | |
try: | |
api = HfApi() | |
logger.info(f"Creating model repository: {model_repo}") | |
api.create_repo( | |
repo_id=model_repo, | |
repo_type="model", | |
token=hf_token, | |
private=True, | |
exist_ok=True | |
) | |
logger.info(f"Uploading model to {model_repo}") | |
api.upload_folder( | |
folder_path=output_dir, | |
repo_id=model_repo, | |
repo_type="model", | |
token=hf_token, | |
create_pr=False | |
) | |
logger.info(f"Fine-tuned model uploaded to {model_repo}") | |
except Exception as e: | |
logger.error(f"Failed to upload model to HF Hub: {str(e)}") | |
logger.warning("Continuing to tarball upload despite model upload failure") | |
# Upload tarball to HF Hub dataset repository | |
try: | |
api = HfApi() | |
logger.info(f"Creating dataset repository: {artifact_repo}") | |
api.create_repo( | |
repo_id=artifact_repo, | |
repo_type="dataset", | |
token=hf_token, | |
private=True, | |
exist_ok=True | |
) | |
logger.info(f"Uploading tarball to {artifact_repo}") | |
api.upload_file( | |
path_or_fileobj=output_tarball, | |
path_in_repo="granite-8b-finetuned-ascii.tar.gz", | |
repo_id=artifact_repo, | |
repo_type="dataset" | |
token=hf_token | |
) | |
logger.info(f"Tarball uploaded to {artifact_repo}/granite-8b-finetuned-ascii.tar.gz") | |
except Exception as e: | |
logger.error(f"Failed to upload tarball to HF Hub: {str(e)}") | |
raise |