mjarrett
updated for 8B model
29969bf
import logging
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
import torch
import tarfile
from huggingface_hub import HfApi
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Debug environment variables
logger.info("Environment variables: %s", {k: "****" if "TOKEN" in k or k == "granite" else v for k, v in os.environ.items()})
model_path = "ibm-granite/granite-3.3-8b-instruct"
dataset_path = "mycholpath/ascii-json"
output_dir = "/app/granite-8b-finetuned-ascii"
output_tarball = "/app/granite-8b-finetuned-ascii.tar.gz"
model_repo = "mycholpath/granite-8b-finetuned-ascii"
artifact_repo = "mycholpath/granite-finetuned-artifacts"
# Get HF token from granite environment variable
granite_var = os.getenv("granite")
if not granite_var or not granite_var.startswith("HF_TOKEN="):
logger.error("granite environment variable is not set or invalid. Expected format: HF_TOKEN=<token>.")
raise ValueError("granite environment variable is not set or invalid. Please set it in HF Space settings.")
hf_token = granite_var.replace("HF_TOKEN=", "")
logger.info("HF_TOKEN extracted from granite (value hidden for security)")
logging.info("Loading tokenizer...")
try:
tokenizer = AutoTokenizer.from_pretrained(
model_path, token=hf_token, cache_dir="/tmp/hf_cache", trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'right'
except Exception as e:
logger.error(f"Failed to load tokenizer: {str(e)}")
raise
logging.info("Loading model...")
try:
model = AutoModelForCausalLM.from_pretrained(
model_path,
token=hf_token,
torch_dtype=torch.float16,
device_map="auto",
cache_dir="/tmp/hf_cache",
trust_remote_code=True
)
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
raise
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
logging.info("Preparing to load private dataset...")
logger.info("Using HF_TOKEN from granite for private dataset authentication")
try:
dataset = load_dataset(dataset_path, split="train", token=hf_token)
logger.info(f"Dataset loaded successfully: {len(dataset)} examples")
except Exception as e:
logger.error(f"Failed to load dataset: {str(e)}")
raise
def formatting_prompts_func(example):
formatted = f"{example['prompt']}\n{example['completion']}"
return [formatted]
# Use SFTConfig for training arguments
sft_config = SFTConfig(
output_dir=output_dir,
num_train_epochs=5,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
weight_decay=0.01,
eval_strategy="no",
save_steps=50,
logging_steps=10,
fp16=True,
max_grad_norm=0.3,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
max_seq_length=768,
dataset_text_field=None,
packing=False
)
logging.info("Starting training...")
try:
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
eval_dataset=None,
formatting_func=formatting_prompts_func,
args=sft_config
)
except Exception as e:
logger.error(f"Failed to initialize SFTTrainer: {str(e)}")
raise
trainer.train()
logging.info("Saving fine-tuned model...")
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
# Create tarball for local retrieval
try:
with tarfile.open(output_tarball, "w:gz") as tar:
tar.add(output_dir, arcname=os.path.basename(output_dir))
logger.info(f"Model tarball created: {output_tarball}")
except Exception as e:
logger.error(f"Failed to create model tarball: {str(e)}")
raise
# Upload model to HF Hub
try:
api = HfApi()
logger.info(f"Creating model repository: {model_repo}")
api.create_repo(
repo_id=model_repo,
repo_type="model",
token=hf_token,
private=True,
exist_ok=True
)
logger.info(f"Uploading model to {model_repo}")
api.upload_folder(
folder_path=output_dir,
repo_id=model_repo,
repo_type="model",
token=hf_token,
create_pr=False
)
logger.info(f"Fine-tuned model uploaded to {model_repo}")
except Exception as e:
logger.error(f"Failed to upload model to HF Hub: {str(e)}")
logger.warning("Continuing to tarball upload despite model upload failure")
# Upload tarball to HF Hub dataset repository
try:
api = HfApi()
logger.info(f"Creating dataset repository: {artifact_repo}")
api.create_repo(
repo_id=artifact_repo,
repo_type="dataset",
token=hf_token,
private=True,
exist_ok=True
)
logger.info(f"Uploading tarball to {artifact_repo}")
api.upload_file(
path_or_fileobj=output_tarball,
path_in_repo="granite-8b-finetuned-ascii.tar.gz",
repo_id=artifact_repo,
repo_type="dataset"
token=hf_token
)
logger.info(f"Tarball uploaded to {artifact_repo}/granite-8b-finetuned-ascii.tar.gz")
except Exception as e:
logger.error(f"Failed to upload tarball to HF Hub: {str(e)}")
raise