SamuelJaja commited on
Commit
60ae325
·
verified ·
1 Parent(s): 722a723

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -26
app.py CHANGED
@@ -1,10 +1,15 @@
1
  import os
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from peft import PeftModel
5
  from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
7
 
 
 
 
 
8
  # Initialize FastAPI
9
  app = FastAPI()
10
 
@@ -17,43 +22,90 @@ HF_TOKEN = os.getenv("HF_TOKEN")
17
  if HF_TOKEN is None:
18
  raise ValueError("Hugging Face token not found. Please set the HF_TOKEN environment variable.")
19
 
20
- # Define the offload directory
21
- OFFLOAD_DIR = "/app/offload"
22
- os.makedirs(OFFLOAD_DIR, exist_ok=True)
23
-
24
- # Load tokenizer with authentication
25
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=HF_TOKEN, cache_dir=OFFLOAD_DIR)
26
-
27
- # Load base model with offloading
28
- model = AutoModelForCausalLM.from_pretrained(
29
- BASE_MODEL,
30
- torch_dtype=torch.float16,
31
- device_map="auto",
32
- token=HF_TOKEN,
33
- cache_dir=OFFLOAD_DIR,
34
- offload_folder=OFFLOAD_DIR # Specify the offload directory
35
- )
36
-
37
- # Load fine-tuned weights
38
- model = PeftModel.from_pretrained(
39
- model,
40
- FINETUNED_MODEL,
41
- token=HF_TOKEN,
42
- cache_dir=OFFLOAD_DIR,
43
- offload_folder=OFFLOAD_DIR # Ensure offloading is consistent
44
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # Define request body
47
  class Query(BaseModel):
48
  text: str
 
 
49
 
50
  # Define the text generation endpoint
51
  @app.post("/generate")
52
  async def generate_text(query: Query):
53
  try:
 
 
54
  inputs = tokenizer(query.text, return_tensors="pt").to(model.device)
55
- output = model.generate(**inputs, max_new_tokens=200)
 
 
 
 
 
 
 
 
56
  response_text = tokenizer.decode(output[0], skip_special_tokens=True)
 
 
57
  return {"response": response_text}
58
  except Exception as e:
 
59
  raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import logging
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  from peft import PeftModel
6
  from fastapi import FastAPI, HTTPException
7
  from pydantic import BaseModel
8
 
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
11
+ logger = logging.getLogger(__name__)
12
+
13
  # Initialize FastAPI
14
  app = FastAPI()
15
 
 
22
  if HF_TOKEN is None:
23
  raise ValueError("Hugging Face token not found. Please set the HF_TOKEN environment variable.")
24
 
25
+ # Define the cache directory
26
+ CACHE_DIR = "/app/cache"
27
+ os.makedirs(CACHE_DIR, exist_ok=True)
28
+
29
+ try:
30
+ # Load tokenizer with authentication
31
+ logger.info(f"Loading tokenizer from {BASE_MODEL}")
32
+ tokenizer = AutoTokenizer.from_pretrained(
33
+ BASE_MODEL,
34
+ token=HF_TOKEN,
35
+ cache_dir=CACHE_DIR
36
+ )
37
+
38
+ # Load base model with simplified configuration
39
+ logger.info(f"Loading base model from {BASE_MODEL}")
40
+ model = AutoModelForCausalLM.from_pretrained(
41
+ BASE_MODEL,
42
+ torch_dtype=torch.float16,
43
+ device_map="auto",
44
+ token=HF_TOKEN,
45
+ cache_dir=CACHE_DIR,
46
+ trust_remote_code=True
47
+ )
48
+
49
+ # Load fine-tuned adapter with simplified approach
50
+ logger.info(f"Loading adapter from {FINETUNED_MODEL}")
51
+ adapter_model = PeftModel.from_pretrained(
52
+ model,
53
+ FINETUNED_MODEL,
54
+ token=HF_TOKEN,
55
+ device_map="auto",
56
+ torch_dtype=torch.float16,
57
+ is_trainable=False # Set to False for inference
58
+ )
59
+
60
+ # Merge adapter weights with base model for better performance (optional)
61
+ logger.info("Merging adapter weights with base model")
62
+ model = adapter_model.merge_and_unload()
63
+
64
+ logger.info("Model loading completed successfully")
65
+ except Exception as e:
66
+ logger.error(f"Error loading model: {str(e)}")
67
+ raise
68
 
69
  # Define request body
70
  class Query(BaseModel):
71
  text: str
72
+ max_tokens: int = 200
73
+ temperature: float = 0.7
74
 
75
  # Define the text generation endpoint
76
  @app.post("/generate")
77
  async def generate_text(query: Query):
78
  try:
79
+ logger.info(f"Generating text for input: {query.text[:50]}...")
80
+
81
  inputs = tokenizer(query.text, return_tensors="pt").to(model.device)
82
+
83
+ with torch.no_grad():
84
+ output = model.generate(
85
+ **inputs,
86
+ max_new_tokens=query.max_tokens,
87
+ temperature=query.temperature,
88
+ do_sample=True if query.temperature > 0 else False
89
+ )
90
+
91
  response_text = tokenizer.decode(output[0], skip_special_tokens=True)
92
+ logger.info("Text generation successful")
93
+
94
  return {"response": response_text}
95
  except Exception as e:
96
+ logger.error(f"Error in text generation: {str(e)}")
97
  raise HTTPException(status_code=500, detail=str(e))
98
+
99
+ # Health check endpoint
100
+ @app.get("/health")
101
+ async def health_check():
102
+ return {"status": "healthy"}
103
+
104
+ # Model info endpoint
105
+ @app.get("/info")
106
+ async def model_info():
107
+ return {
108
+ "base_model": BASE_MODEL,
109
+ "adapter_model": FINETUNED_MODEL,
110
+ "device": str(model.device)
111
+ }