rudranighosh commited on
Commit
c5b561d
·
verified ·
1 Parent(s): f933d7b

Upload app.py

Browse files
Files changed (1) hide show
  1. app/app.py +59 -0
app/app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """app
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/11wiIj_rvhSCb_ULZJmOMUhIwInf_QYiW
8
+ """
9
+
10
+ from fastapi import FastAPI
11
+ from pydantic import BaseModel
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM
13
+ import torch
14
+ from peft import PeftModel
15
+ import os
16
+
17
+ app = FastAPI()
18
+
19
+ # Load the base model first
20
+ base_model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
21
+ adapter_base_path = "./tinyllama-lora-finetuned"
22
+
23
+ # This assumes checkpoints are named like 'checkpoint-XXX'
24
+ checkpoints = [d for d in os.listdir(adapter_base_path) if os.path.isdir(os.path.join(adapter_base_path, d)) and d.startswith('checkpoint-')]
25
+ if not checkpoints:
26
+ raise FileNotFoundError(f"No checkpoints found in {adapter_base_path}")
27
+
28
+ # Sort checkpoints to find the latest one (based on checkpoint number)
29
+ latest_checkpoint = sorted(checkpoints, key=lambda x: int(x.split('-')[1]))[-1]
30
+ adapter_path = os.path.join(adapter_base_path, latest_checkpoint) # Point to the latest checkpoint directory
31
+
32
+ print(f"Loading adapter from: {adapter_path}") # Print the path being loaded
33
+
34
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name)
35
+ base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float32)
36
+
37
+ # Load the LoRA adapter weights onto the base model from the specific checkpoint path
38
+ model = PeftModel.from_pretrained(base_model, adapter_path)
39
+ model.eval()
40
+
41
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ model.to(device)
43
+
44
+ class Query(BaseModel):
45
+ question: str
46
+
47
+ @app.post("/generate/")
48
+ def generate_answer(query: Query):
49
+ input_text = f"### Question:\n{query.question}\n\n### Answer:\n"
50
+ inputs = tokenizer(input_text, return_tensors="pt").to(device)
51
+ outputs = model.generate(**inputs, max_length=512, num_return_sequences=1)
52
+ decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
53
+ answer_start_index = decoded_output.find("### Answer:")
54
+ if answer_start_index != -1:
55
+ answer = decoded_output[answer_start_index + len("### Answer:"):].strip()
56
+ else:
57
+ answer = decoded_output.strip()
58
+
59
+ return {"answer": answer}