myyim commited on
Commit
c438216
·
verified ·
1 Parent(s): e632118

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +82 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import (
3
+ PaliGemmaProcessor,
4
+ PaliGemmaForConditionalGeneration,
5
+ )
6
+ import streamlit as st
7
+ from PIL import Image
8
+ from transformers.image_utils import load_image
9
+ import os
10
+
11
+ # write access token in secrets
12
+ token = os.environ.get('HF_TOKEN')
13
+
14
+ # paligemma model
15
+ model_id = "google/paligemma2-3b-pt-896"
16
+
17
+ @st.cache_resource
18
+ def model_setup(model_id):
19
+ model = PaliGemmaForConditionalGeneration.from_pretrained(model_id,torch_dtype=torch.bfloat16,device_map="auto",token=token).eval()
20
+ processor = PaliGemmaProcessor.from_pretrained(model_id,token=token)
21
+ return model,processor
22
+
23
+ def runModel(prompt,image):
24
+ model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(model.device)
25
+ input_len = model_inputs["input_ids"].shape[-1]
26
+ with torch.inference_mode():
27
+ generation = model.generate(**model_inputs, max_new_tokens=1000, do_sample=False)
28
+ generation = generation[0][input_len:]
29
+ return processor.decode(generation, skip_special_tokens=True)
30
+
31
+ def initialize():
32
+ # initialize chat history
33
+ st.session_state.messages = []
34
+
35
+ ### load model
36
+ model,processor = model_setup(model_id)
37
+
38
+ ### upload a file
39
+ uploaded_file = st.file_uploader("Choose an image",on_change=initialize)
40
+
41
+ if uploaded_file:
42
+ st.image(uploaded_file)
43
+ image = Image.open(uploaded_file).convert("RGB")
44
+
45
+ # tasks
46
+ task = st.radio(
47
+ "Task",
48
+ tuple(['Caption','OCR','Segment','Enter your prompt']),
49
+ horizontal=True)
50
+
51
+ # display chat messages from history on app rerun
52
+ for message in st.session_state.messages:
53
+ with st.chat_message(message["role"]):
54
+ st.markdown(message["content"])
55
+
56
+ if task == 'Enter your prompt':
57
+ if prompt := st.chat_input("Type here!",key="question"):
58
+ # display user message in chat message container
59
+ with st.chat_message("user"):
60
+ st.markdown(prompt)
61
+ # Add user message to chat history
62
+ st.session_state.messages.append({"role": "user", "content": prompt})
63
+ # run the VLM
64
+ response = runModel(prompt,image)
65
+ # display assistant response in chat message container
66
+ with st.chat_message("assistant"):
67
+ st.markdown(response)
68
+ # Add assistant response to chat history
69
+ st.session_state.messages.append({"role": "assistant", "content": response})
70
+ else:
71
+ # display user message in chat message container
72
+ with st.chat_message("user"):
73
+ st.markdown(task)
74
+ # Add user message to chat history
75
+ st.session_state.messages.append({"role": "user", "content": task})
76
+ # run the VLM
77
+ response = runModel(task,image)
78
+ # display assistant response in chat message container
79
+ with st.chat_message("assistant"):
80
+ st.markdown(response)
81
+ # Add assistant response to chat history
82
+ st.session_state.messages.append({"role": "assistant", "content": response})
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ accelerate
4
+ pillow