PAVULURI KIRAN commited on
Commit
b81dd02
·
1 Parent(s): f5dfe77

Updated FastAPI app and requirements

Browse files
Files changed (3) hide show
  1. app.py +14 -4
  2. requirement.txt +0 -6
  3. requirements.txt +1 -0
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI, File, UploadFile
2
  import torch
3
- from transformers import AutoProcessor, LlavaForConditionalGeneration
4
  from PIL import Image
5
  import io
6
  import base64
@@ -8,17 +8,27 @@ import base64
8
  # Initialize FastAPI app
9
  app = FastAPI()
10
 
11
- # Load the model and processor from Hugging Face
12
  model_name = "mervinpraison/Llama-3.2-11B-Vision-Radiology-mini"
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
 
 
 
 
 
 
 
 
15
  processor = AutoProcessor.from_pretrained(model_name)
16
- model = LlavaForConditionalGeneration.from_pretrained(model_name).to(device)
 
 
17
 
18
  @app.post("/predict/")
19
  async def predict(file: UploadFile = File(...)):
20
  try:
21
- # Read image
22
  image_bytes = await file.read()
23
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
24
 
 
1
  from fastapi import FastAPI, File, UploadFile
2
  import torch
3
+ from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
4
  from PIL import Image
5
  import io
6
  import base64
 
8
  # Initialize FastAPI app
9
  app = FastAPI()
10
 
11
+ # Model details
12
  model_name = "mervinpraison/Llama-3.2-11B-Vision-Radiology-mini"
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
+ # Configure 4-bit quantization using bitsandbytes
16
+ quantization_config = BitsAndBytesConfig(
17
+ load_in_4bit=True, # Use 4-bit quantization (set False for 8-bit)
18
+ bnb_4bit_compute_dtype=torch.float16, # Use float16 for faster computation
19
+ bnb_4bit_use_double_quant=True, # Double quantization for memory efficiency
20
+ )
21
+
22
+ # Load processor and quantized model
23
  processor = AutoProcessor.from_pretrained(model_name)
24
+ model = AutoModelForImageTextToText.from_pretrained(
25
+ model_name, quantization_config=quantization_config
26
+ ).to(device)
27
 
28
  @app.post("/predict/")
29
  async def predict(file: UploadFile = File(...)):
30
  try:
31
+ # Read and process the image
32
  image_bytes = await file.read()
33
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
34
 
requirement.txt DELETED
@@ -1,6 +0,0 @@
1
- fastapi
2
- uvicorn
3
- torch
4
- transformers
5
- pillow
6
- python-multipart
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -4,3 +4,4 @@ torch
4
  transformers
5
  pillow
6
  python-multipart
 
 
4
  transformers
5
  pillow
6
  python-multipart
7
+ bitsandbytes