mjarrett commited on
Commit
29969bf
·
1 Parent(s): 76ec1cb

updated for 8B model

Browse files
Files changed (6) hide show
  1. Dockerfile +21 -0
  2. README.md +3 -3
  3. app.py +64 -0
  4. finetune.py +176 -0
  5. handler.py +47 -0
  6. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ FROM python:3.10-slim
3
+
4
+ # Create non-root user
5
+ RUN useradd -m -u 1000 user
6
+ USER user
7
+ ENV PATH="/home/user/.local/bin:$PATH"
8
+
9
+ WORKDIR /app
10
+
11
+ # Install dependencies
12
+ COPY --chown=user requirements.txt requirements.txt
13
+ RUN pip install --no-cache-dir --upgrade pip && \
14
+ pip install --no-cache-dir -r requirements.txt
15
+
16
+ # Copy scripts
17
+ COPY --chown=user finetune.py /app/finetune.py
18
+ COPY --chown=user app.py /app/app.py
19
+
20
+ # Run finetune and start API
21
+ CMD ["bash", "-c", "python finetune.py && uvicorn app:app --host 0.0.0.0 --port 7860"]
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Granite 3.1 8b Instruct Ascii
3
- emoji: 👀
4
  colorFrom: yellow
5
- colorTo: green
6
  sdk: docker
7
  pinned: false
8
  license: apache-2.0
 
1
  ---
2
+ title: Granite 2b Finetuning
3
+ emoji: 🌖
4
  colorFrom: yellow
5
+ colorTo: gray
6
  sdk: docker
7
  pinned: false
8
  license: apache-2.0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ import logging
5
+ from pydantic import BaseModel
6
+ import os
7
+ import tarfile
8
+
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Debug environment variables
13
+ logger.info("Environment variables: %s", {k: "****" if "TOKEN" in k or k == "granite" else v for k, v in os.environ.items()})
14
+
15
+ app = FastAPI()
16
+
17
+ model_tarball = "/app/granite-8b-finetuned-ascii.tar.gz"
18
+ model_path = "/app/granite-8b-finetuned-ascii"
19
+
20
+ # Extract tarball if model directory doesn't exist
21
+ if not os.path.exists(model_path):
22
+ logger.info(f"Extracting model tarball: {model_tarball}")
23
+ try:
24
+ with tarfile.open(model_tarball, "r:gz") as tar:
25
+ tar.extractall(path="/app")
26
+ logger.info("Model tarball extracted successfully")
27
+ except Exception as e:
28
+ logger.error(f"Failed to extract model tarball: {str(e)}")
29
+ raise HTTPException(status_code=500, detail=f"Model tarball extraction failed: {str(e)}")
30
+
31
+ try:
32
+ logger.info("Loading tokenizer and model")
33
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
34
+ tokenizer.padding_side = 'right'
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ model_path,
37
+ torch_dtype=torch.float16,
38
+ device_map="auto",
39
+ trust_remote_code=True
40
+ )
41
+ logger.info("Model and tokenizer loaded successfully")
42
+ except Exception as e:
43
+ logger.error(f"Failed to load model or tokenizer: {str(e)}")
44
+ raise HTTPException(status_code=500, detail=f"Model initialization failed: {str(e)}")
45
+
46
+ class EditRequest(BaseModel):
47
+ text: str
48
+
49
+ @app.get("/")
50
+ def greet_json():
51
+ return {"status": "Model is ready", "model": model_path}
52
+
53
+ @app.post("/generate")
54
+ async def generate(request: EditRequest):
55
+ try:
56
+ prompt = f"Edit this AsciiDoc sentence: {request.text}"
57
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
58
+ outputs = model.generate(**inputs, max_length=200)
59
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
60
+ logger.info(f"Generated response for prompt: {prompt}")
61
+ return {"response": response}
62
+ except Exception as e:
63
+ logger.error(f"Generation failed: {str(e)}")
64
+ raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
finetune.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from peft import LoraConfig, get_peft_model
5
+ from trl import SFTTrainer, SFTConfig
6
+ from datasets import load_dataset
7
+ import torch
8
+ import tarfile
9
+ from huggingface_hub import HfApi
10
+
11
+ logging.basicConfig(level=logging.DEBUG)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Debug environment variables
15
+ logger.info("Environment variables: %s", {k: "****" if "TOKEN" in k or k == "granite" else v for k, v in os.environ.items()})
16
+
17
+ model_path = "ibm-granite/granite-3.3-8b-instruct"
18
+ dataset_path = "mycholpath/ascii-json"
19
+ output_dir = "/app/granite-8b-finetuned-ascii"
20
+ output_tarball = "/app/granite-8b-finetuned-ascii.tar.gz"
21
+ model_repo = "mycholpath/granite-8b-finetuned-ascii"
22
+ artifact_repo = "mycholpath/granite-finetuned-artifacts"
23
+
24
+ # Get HF token from granite environment variable
25
+ granite_var = os.getenv("granite")
26
+ if not granite_var or not granite_var.startswith("HF_TOKEN="):
27
+ logger.error("granite environment variable is not set or invalid. Expected format: HF_TOKEN=<token>.")
28
+ raise ValueError("granite environment variable is not set or invalid. Please set it in HF Space settings.")
29
+ hf_token = granite_var.replace("HF_TOKEN=", "")
30
+ logger.info("HF_TOKEN extracted from granite (value hidden for security)")
31
+
32
+ logging.info("Loading tokenizer...")
33
+ try:
34
+ tokenizer = AutoTokenizer.from_pretrained(
35
+ model_path, token=hf_token, cache_dir="/tmp/hf_cache", trust_remote_code=True
36
+ )
37
+ tokenizer.pad_token = tokenizer.eos_token
38
+ tokenizer.padding_side = 'right'
39
+ except Exception as e:
40
+ logger.error(f"Failed to load tokenizer: {str(e)}")
41
+ raise
42
+
43
+ logging.info("Loading model...")
44
+ try:
45
+ model = AutoModelForCausalLM.from_pretrained(
46
+ model_path,
47
+ token=hf_token,
48
+ torch_dtype=torch.float16,
49
+ device_map="auto",
50
+ cache_dir="/tmp/hf_cache",
51
+ trust_remote_code=True
52
+ )
53
+ except Exception as e:
54
+ logger.error(f"Failed to load model: {str(e)}")
55
+ raise
56
+
57
+ lora_config = LoraConfig(
58
+ r=16,
59
+ lora_alpha=32,
60
+ target_modules=["q_proj", "v_proj"],
61
+ lora_dropout=0.05,
62
+ bias="none",
63
+ task_type="CAUSAL_LM"
64
+ )
65
+ model = get_peft_model(model, lora_config)
66
+
67
+ logging.info("Preparing to load private dataset...")
68
+ logger.info("Using HF_TOKEN from granite for private dataset authentication")
69
+ try:
70
+ dataset = load_dataset(dataset_path, split="train", token=hf_token)
71
+ logger.info(f"Dataset loaded successfully: {len(dataset)} examples")
72
+ except Exception as e:
73
+ logger.error(f"Failed to load dataset: {str(e)}")
74
+ raise
75
+
76
+ def formatting_prompts_func(example):
77
+ formatted = f"{example['prompt']}\n{example['completion']}"
78
+ return [formatted]
79
+
80
+ # Use SFTConfig for training arguments
81
+ sft_config = SFTConfig(
82
+ output_dir=output_dir,
83
+ num_train_epochs=5,
84
+ per_device_train_batch_size=4,
85
+ per_device_eval_batch_size=4,
86
+ gradient_accumulation_steps=4,
87
+ learning_rate=2e-4,
88
+ weight_decay=0.01,
89
+ eval_strategy="no",
90
+ save_steps=50,
91
+ logging_steps=10,
92
+ fp16=True,
93
+ max_grad_norm=0.3,
94
+ warmup_ratio=0.03,
95
+ lr_scheduler_type="cosine",
96
+ max_seq_length=768,
97
+ dataset_text_field=None,
98
+ packing=False
99
+ )
100
+
101
+ logging.info("Starting training...")
102
+ try:
103
+ trainer = SFTTrainer(
104
+ model=model,
105
+ tokenizer=tokenizer,
106
+ train_dataset=dataset,
107
+ eval_dataset=None,
108
+ formatting_func=formatting_prompts_func,
109
+ args=sft_config
110
+ )
111
+ except Exception as e:
112
+ logger.error(f"Failed to initialize SFTTrainer: {str(e)}")
113
+ raise
114
+
115
+ trainer.train()
116
+
117
+ logging.info("Saving fine-tuned model...")
118
+ trainer.save_model(output_dir)
119
+ tokenizer.save_pretrained(output_dir)
120
+
121
+ # Create tarball for local retrieval
122
+ try:
123
+ with tarfile.open(output_tarball, "w:gz") as tar:
124
+ tar.add(output_dir, arcname=os.path.basename(output_dir))
125
+ logger.info(f"Model tarball created: {output_tarball}")
126
+ except Exception as e:
127
+ logger.error(f"Failed to create model tarball: {str(e)}")
128
+ raise
129
+
130
+ # Upload model to HF Hub
131
+ try:
132
+ api = HfApi()
133
+ logger.info(f"Creating model repository: {model_repo}")
134
+ api.create_repo(
135
+ repo_id=model_repo,
136
+ repo_type="model",
137
+ token=hf_token,
138
+ private=True,
139
+ exist_ok=True
140
+ )
141
+ logger.info(f"Uploading model to {model_repo}")
142
+ api.upload_folder(
143
+ folder_path=output_dir,
144
+ repo_id=model_repo,
145
+ repo_type="model",
146
+ token=hf_token,
147
+ create_pr=False
148
+ )
149
+ logger.info(f"Fine-tuned model uploaded to {model_repo}")
150
+ except Exception as e:
151
+ logger.error(f"Failed to upload model to HF Hub: {str(e)}")
152
+ logger.warning("Continuing to tarball upload despite model upload failure")
153
+
154
+ # Upload tarball to HF Hub dataset repository
155
+ try:
156
+ api = HfApi()
157
+ logger.info(f"Creating dataset repository: {artifact_repo}")
158
+ api.create_repo(
159
+ repo_id=artifact_repo,
160
+ repo_type="dataset",
161
+ token=hf_token,
162
+ private=True,
163
+ exist_ok=True
164
+ )
165
+ logger.info(f"Uploading tarball to {artifact_repo}")
166
+ api.upload_file(
167
+ path_or_fileobj=output_tarball,
168
+ path_in_repo="granite-8b-finetuned-ascii.tar.gz",
169
+ repo_id=artifact_repo,
170
+ repo_type="dataset"
171
+ token=hf_token
172
+ )
173
+ logger.info(f"Tarball uploaded to {artifact_repo}/granite-8b-finetuned-ascii.tar.gz")
174
+ except Exception as e:
175
+ logger.error(f"Failed to upload tarball to HF Hub: {str(e)}")
176
+ raise
handler.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from transformers import pipeline
3
+
4
+ class EndpointHandler:
5
+ def __init__(self, path=""):
6
+ self.pipeline = pipeline("text-generation", model=path, device=0)
7
+
8
+ def __call__(self, data):
9
+ inputs = data.get("inputs", "")
10
+ style_guide = data.get("style_guide", "Apply general AsciiDoc best practices.")
11
+ max_tokens = data.get("max_tokens", 2048)
12
+
13
+ system_prompt = f"""
14
+ You are an expert technical editor specializing in AsciiDoc document correction. Your task is to analyze the provided AsciiDoc text and suggest corrections based on the following style guide:
15
+ {style_guide}
16
+
17
+ **Output Requirements**:
18
+ - Return corrections **only** in valid JSON format, enclosed in curly braces: {{"corrections": [...]}}.
19
+ - Each correction must include:
20
+ - "original_line": The exact line from the input text.
21
+ - "corrected_line": The corrected version of the line.
22
+ - "explanation": A brief reason for the correction.
23
+ - If no corrections are needed, return: {{"corrections": []}}.
24
+ - Ensure the JSON is complete, valid, and concise to avoid truncation.
25
+ - Do **not** include any text, comments, or explanations outside the JSON object.
26
+ - Do **not** include placeholder text like "<original AsciiDoc line>".
27
+ - Only correct lines with AsciiDoc syntax, style, or technical accuracy issues (e.g., missing punctuation, incorrect headers, malformed attributes like :gls_prefix:).
28
+
29
+ Analyze the following AsciiDoc lines and provide corrections in JSON format:
30
+ """
31
+ prompt = f"{system_prompt}\n{inputs}"
32
+
33
+ try:
34
+ response = self.pipeline(
35
+ prompt,
36
+ max_new_tokens=max_tokens,
37
+ temperature=0.3,
38
+ return_full_text=False
39
+ )[0]["generated_text"].strip()
40
+ json_start = response.find('{')
41
+ json_end = response.rfind('}') + 1
42
+ if json_start == -1 or json_end == -1:
43
+ return {"corrections": []}
44
+ correction_json = json.loads(response[json_start:json_end])
45
+ return correction_json
46
+ except Exception as e:
47
+ return {"corrections": []}
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ transformers==4.46.0
2
+ torch==2.4.1
3
+ datasets==3.0.1
4
+ peft==0.13.2
5
+ trl==0.11.4
6
+ accelerate==1.0.1
7
+ huggingface_hub==0.25.2
8
+ fastapi==0.115.2
9
+ uvicorn==0.32.0