zerishdorelser commited on
Commit
c46d8ad
·
verified ·
1 Parent(s): f852217

Upload 6 files

Browse files
Files changed (6) hide show
  1. .gitignore +18 -0
  2. app.py +123 -0
  3. functions.py +323 -0
  4. type2.py +86 -0
  5. type3.py +70 -0
  6. type4.py +133 -0
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Virtual environment
2
+ venv/
3
+ .venv/
4
+ ENV/
5
+
6
+ # Python cache/compiled files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+
11
+ # IDE-specific files
12
+ .vscode/
13
+ .idea/
14
+ *.swp
15
+ *.swo
16
+
17
+ # Environment variables
18
+ .env
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import pipeline
3
+ from transformers import DetrImageProcessor, DetrForObjectDetection
4
+ from transformers import CLIPProcessor, CLIPModel
5
+ from transformers import BlipProcessor, BlipForQuestionAnswering
6
+ #from transformers import YolosImageProcessor, YolosForObjectDetection
7
+ from PIL import Image
8
+ from functions import *
9
+ import io
10
+
11
+
12
+ #load models
13
+ @st.cache_resource
14
+ def load_models():
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
17
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50",revision="no_timm")
18
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
19
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
20
+ sales_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
21
+ sales_model = BlipForQuestionAnswering.from_pretrained(
22
+ "Salesforce/blip-vqa-base",
23
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
24
+ ).to(device)
25
+
26
+ return {
27
+ "detector": model,
28
+ "processor": processor,
29
+ "clip": clip_model,
30
+ "clip process": clip_processor,
31
+ #"t5 token": t5_tokenizer,
32
+ #"t5": t5_model,
33
+ 'story_teller': pipeline("text-generation", model="nickypro/tinyllama-15M"),
34
+ "sales process": sales_processor,
35
+ "sales model": sales_model,
36
+ "device": device
37
+ }
38
+
39
+
40
+
41
+ def main():
42
+ st.header("📱 Nano AI Image Analyzer")
43
+
44
+ uploaded_file= st.file_uploader("upload image")#, type=['.PNG','png','jpg','jpeg'])
45
+ models= load_models()
46
+ st.write('models loaded')
47
+
48
+ #im2=detect_objects(image_path=image, models= models)
49
+ #st.write(im2)
50
+ #st.write("done")
51
+ #annotated_image= draw_bounding_boxes(image, im2)
52
+ #st.image(annotated_image, caption="Detected Objects", use_container_width=True)
53
+
54
+ #buttons UI
55
+ if uploaded_file is not None:
56
+ image_bytes = uploaded_file.getvalue()
57
+ st.write("Filename:", uploaded_file.name)
58
+ image = Image.open(uploaded_file).convert('RGB')
59
+ st.image(image, caption="Uploaded Image", width=200) #use_container_width= False,
60
+
61
+ col1, col2, col3 = st.columns([0.33,0.33,0.33])
62
+
63
+ with col1:
64
+ detect= st.button("🔍 Detect Objects", key="btn1")
65
+ with col2:
66
+ describe= st.button("📝 Describe Image", key="btn2")
67
+ with col3:
68
+ story= st.button("📖 Generate Story", key="btn3",
69
+ help="story is generated based on caption")
70
+
71
+
72
+ if detect:
73
+ with st.spinner("Detecting objects..."):
74
+ try:
75
+ detections = detect_objects(image.copy(), models)
76
+ annotated_image= draw_bounding_boxes(image, detections)
77
+ st.image(annotated_image, caption="Detected Objects", use_column_width=True)
78
+ show_detection_table(detections)
79
+ except:
80
+ st.write("some error!! try another image")
81
+
82
+ elif describe:
83
+ with st.spinner("trying to describe..."):
84
+ description= get_image_description(image.copy(),models)
85
+ st.write(description)
86
+
87
+ elif story:
88
+ #st.write('btn3 clicked')
89
+ with st.spinner("getting a story..."):
90
+ description= get_image_description(image.copy(),models)
91
+ story= generate_story(description, models)
92
+ st.write(story)
93
+
94
+ # Chat interface
95
+ if "messages" not in st.session_state:
96
+ st.session_state.messages = []
97
+
98
+ chat_container = st.container(height=400)
99
+ with chat_container:
100
+
101
+ for message in st.session_state.messages:
102
+ with st.chat_message(message["role"]):
103
+ st.markdown(message["content"])
104
+
105
+ if prompt := st.chat_input("Ask about the image"):
106
+ st.session_state.messages.append({"role": "user", "content": prompt})
107
+ with st.chat_message("user"):
108
+ st.markdown(prompt)
109
+
110
+ with st.chat_message("assistant"):
111
+ with st.spinner("Thinking..."):
112
+ response = answer_question(image,
113
+ prompt,
114
+ models["sales process"],
115
+ models["sales model"],
116
+ models["device"])
117
+ #response= "response sample"
118
+ st.markdown(response)
119
+ st.session_state.messages.append({"role": "assistant", "content": response})
120
+
121
+
122
+ if __name__ == "__main__":
123
+ main()
functions.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageDraw
2
+ from transformers import DetrImageProcessor, DetrForObjectDetection
3
+ import numpy as np
4
+ import torch
5
+ import pandas as pd
6
+ import streamlit as st
7
+ from pathlib import Path
8
+
9
+ def safe_image_open(uploaded_file):
10
+ try:
11
+ # Convert to lowercase and remove spaces
12
+ filename = Path(uploaded_file.name).stem.lower().replace(" ", "_") + ".png"
13
+ image = Image.open(uploaded_file).convert("RGB")
14
+ return image
15
+ except Exception as e:
16
+ st.error(f"Error loading image: {str(e)}")
17
+ return None
18
+
19
+ def QA(image, question, models):
20
+ inputs= models['sales process'](image, question, return_tensors= 'pt')
21
+ out = models['sales model'].generate(**inputs)
22
+ return out
23
+
24
+ def answer_question(image, question, processor, model, device):
25
+ inputs = processor(image, question, return_tensors="pt").to(device)
26
+ outputs = model.generate(**inputs, max_new_tokens=100)
27
+ return processor.decode(outputs[0], skip_special_tokens=True)
28
+
29
+ def generate_story(caption, models):
30
+ """Generate short story"""
31
+ #caption= "a beutiful landscape"
32
+ return models['story_teller'](
33
+ f"Write story about: {caption}",
34
+ max_length=500,
35
+ do_sample=True,
36
+ temperature=0.7
37
+ )[0]['generated_text']
38
+
39
+ def generate_story2(prompt, models):
40
+ input_text = f"Write a short story about {prompt}"
41
+ input_ids = models["t5 token"].encode(input_text, return_tensors="pt", max_length=64, truncation=True)
42
+ output_ids = models["t5"].generate(input_ids, max_length=512)
43
+ story = models["t5 token"].decode(output_ids[0], skip_special_tokens=True)
44
+ return story
45
+
46
+ def get_image_description(image_path, models):
47
+ image = image_path
48
+ text_inputs = ["a dog", " cat", "a man", "a woman", "a child", "gruop of friends",
49
+ "a scenic view", "a cityscape", "a forest", "a beach", "a mountain", "a group of people", "a car", "a bird",
50
+ "a beautiful landscape", "a couple in love", "an animal", "amazing space",
51
+ "incridible earth", "motion", "singularity", "anime", "emotions",
52
+ "sorrow", "joy"]
53
+
54
+ inputs = models["clip process"](text=text_inputs, images=image, return_tensors="pt", padding=True)
55
+ outputs = models["clip"](**inputs)
56
+ logits_per_image = outputs.logits_per_image
57
+ probs = logits_per_image.softmax(dim=1)
58
+ best = text_inputs[probs.argmax()]
59
+ return best
60
+
61
+ def show_detection_table(detection_text):
62
+ """
63
+ Convert detection text into a formatted Streamlit table
64
+
65
+ Args:
66
+ detection_text: String in format "[x1,y1,x2,y2] label score"
67
+
68
+ Returns:
69
+ Displays a Streamlit table with columns: Object Type, Box Coordinates, Score
70
+ """
71
+ # Parse each line into a list of dictionaries
72
+ detections = []
73
+ for line in detection_text.strip().split('\n'):
74
+ if not line:
75
+ continue
76
+
77
+ # Parse the components
78
+ bbox_part, label, score = line.rsplit(' ', 2)
79
+ bbox = bbox_part.strip('[]')
80
+
81
+ detections.append({
82
+ 'Object Type': label,
83
+ 'Box Coordinates': f"[{bbox}]",
84
+ 'Score': float(score)
85
+ })
86
+
87
+ # Convert to DataFrame
88
+ df = pd.DataFrame(detections)
89
+
90
+ # Format the score column
91
+ df['Score'] = df['Score'].map('{:.2f}'.format)
92
+
93
+ # Display in Streamlit with some styling
94
+ st.dataframe(
95
+ df,
96
+ column_config={
97
+ "Object Type": "Object Type",
98
+ "Box Coordinates": "Box [x1,y1,x2,y2]",
99
+ "Score": st.column_config.NumberColumn(
100
+ "Confidence",
101
+ format="%.2f",
102
+ )
103
+ },
104
+ hide_index=True,
105
+ use_container_width=True
106
+ )
107
+
108
+ def draw_bounding_boxes(image, detection_text):
109
+ """
110
+ Draw bounding boxes on image with different colors for people vs other objects
111
+
112
+ Args:
113
+ image: PIL Image object
114
+ detection_text: String in format "[x1,y1,x2,y2] label score"
115
+
116
+ Returns:
117
+ PIL Image with bounding boxes drawn
118
+ """
119
+ # Create a drawing context
120
+ draw = ImageDraw.Draw(image)
121
+
122
+ # Define colors
123
+ PERSON_COLOR = (255, 0, 0) # Red for people
124
+ CAR_COLOR = (255, 165, 0)
125
+ OTHER_COLOR = (0, 255, 0) # Green for other objects
126
+ TEXT_COLOR = (255, 255, 255) # White text
127
+
128
+ # Parse each detection line
129
+ for line in detection_text.strip().split('\n'):
130
+ if not line:
131
+ continue
132
+
133
+ # Parse the detection info
134
+ bbox_part, label, score = line.rsplit(' ', 2)
135
+ bbox = list(map(int, bbox_part.strip('[]').split(',')))
136
+ confidence = float(score)
137
+
138
+ # Determine box color
139
+ #box_color = PERSON_COLOR if label == 'person' else OTHER_COLOR
140
+ if label == "person":
141
+ box_color= PERSON_COLOR
142
+ elif label == "car":
143
+ box_color= CAR_COLOR
144
+ else:
145
+ box_color= OTHER_COLOR
146
+
147
+ # Draw bounding box
148
+ draw.rectangle(
149
+ [(bbox[0], bbox[1]), (bbox[2], bbox[3])],
150
+ outline=box_color,
151
+ width=3
152
+ )
153
+
154
+ # Draw label with confidence
155
+ label_text = f"{label} {confidence:.2f}"
156
+ text_position = (bbox[0], bbox[1] - 15)
157
+
158
+ # Draw text background
159
+ text_bbox = draw.textbbox(text_position, label_text)
160
+ draw.rectangle(
161
+ [(text_bbox[0]-2, text_bbox[1]-2), (text_bbox[2]+2, text_bbox[3]+2)],
162
+ fill=box_color
163
+ )
164
+
165
+ # Draw text
166
+ draw.text(
167
+ text_position,
168
+ label_text,
169
+ fill=TEXT_COLOR
170
+ )
171
+
172
+ return image
173
+
174
+ def detect_objects(image_path, models):
175
+ """
176
+ Detects objects in the provided image.
177
+
178
+ Args:
179
+ image_path (str): The path to the image file.
180
+
181
+ Returns:
182
+ str: A string with all the detected objects. Each object as '[x1, x2, y1, y2, class_name, confindence_score]'.
183
+ """
184
+ image = image_path
185
+
186
+ #processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
187
+ #model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
188
+ processor= models['processor']
189
+ model= models['detector']
190
+
191
+ inputs = processor(images=image, return_tensors="pt")
192
+ outputs = model(**inputs)
193
+
194
+ # convert outputs (bounding boxes and class logits) to COCO API
195
+ # let's only keep detections with score > 0.9
196
+ target_sizes = torch.tensor([image.size[::-1]])
197
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
198
+
199
+ detections = ""
200
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
201
+ detections += '[{}, {}, {}, {}]'.format(int(box[0]), int(box[1]), int(box[2]), int(box[3]))
202
+ detections += ' {}'.format(model.config.id2label[int(label)])
203
+ detections += ' {}\n'.format(float(score))
204
+
205
+ return detections
206
+
207
+ def detect_objects4(image, models):
208
+ processor= models['processor']
209
+ model= models['detector']
210
+ inputs = processor(images=image, return_tensors="pt")
211
+ outputs = model(**inputs)
212
+ target_sizes = torch.tensor([image.size[::-1]])
213
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
214
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
215
+ box = [round(i, 2) for i in box.tolist()]
216
+ print(
217
+ f"Detected {model.config.id2label[label.item()]} with confidence "
218
+ f"{round(score.item(), 3)} at location {box}"
219
+ )
220
+
221
+ def detect_objects3(image, models, threshold=0.7):
222
+ """Object detection with bounding boxes using DETR"""
223
+ if not isinstance(image, Image.Image):
224
+ image = Image.open(image)
225
+
226
+ processor = models['processor']
227
+ model = models['detector']
228
+
229
+ # Preprocess image
230
+ inputs = processor(images=image, return_tensors="pt")
231
+
232
+ # Run model
233
+ outputs = model(**inputs)
234
+
235
+ # Get original image size (height, width)
236
+ target_size = torch.tensor([image.size[::-1]])
237
+
238
+ # Post-process results
239
+ results = processor.post_process_object_detection(outputs, target_sizes=target_size, threshold=threshold)[0]
240
+
241
+ # Draw results
242
+ draw = ImageDraw.Draw(image)
243
+ formatted_results = []
244
+
245
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
246
+ box = box.tolist()
247
+ label_text = model.config.id2label[label.item()]
248
+ score_val = score.item()
249
+
250
+ # Draw box
251
+ draw.rectangle(
252
+ [(box[0], box[1]), (box[2], box[3])],
253
+ outline="red",
254
+ width=3
255
+ )
256
+ draw.text(
257
+ (box[0], box[1] - 10),
258
+ f"{label_text} ({score_val:.2f})",
259
+ fill="red"
260
+ )
261
+
262
+ formatted_results.append({
263
+ "label": label_text,
264
+ "score": score_val,
265
+ "box": {
266
+ "xmin": box[0],
267
+ "ymin": box[1],
268
+ "xmax": box[2],
269
+ "ymax": box[3]
270
+ }
271
+ })
272
+
273
+ return image, formatted_results
274
+
275
+
276
+ def detect_objects2(image, models):
277
+ """Function 1: Object detection with bounding boxes"""
278
+ results = models['detector'](image)
279
+
280
+ # Draw bounding boxes
281
+ draw = ImageDraw.Draw(image)
282
+ for result in results:
283
+ box = result['box']
284
+ draw.rectangle(
285
+ [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
286
+ outline="red",
287
+ width=3
288
+ )
289
+ draw.text(
290
+ (box['xmin'], box['ymin'] - 10),
291
+ f"{result['label']} ({result['score']:.2f})",
292
+ fill="red"
293
+ )
294
+ return image, results
295
+
296
+
297
+ """@st.cache_resource
298
+ def load_light_models():
299
+ #Load lighter version of models with proper DETR handling
300
+ models = {}
301
+
302
+ # Load DETR components separately
303
+ with st.spinner("Loading object detection model..."):
304
+ models['detr_processor'] = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
305
+ models['detr_model'] = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
306
+
307
+ # Use pipeline for captioning
308
+ with st.spinner("Loading captioning model..."):
309
+ models['captioner'] = pipeline(
310
+ "image-to-text",
311
+ model="Salesforce/blip-image-captioning-base"
312
+ )
313
+
314
+ return models"""
315
+
316
+ """@st.cache_resource
317
+ def load_models():
318
+ return {
319
+ # Using tiny models for faster loading
320
+ 'detector': pipeline("object-detection", model="hustvl/yolos-tiny")
321
+ #'captioner': pipeline("image-to-text", model="Salesforce/blip-image-captioning-base"),
322
+ #'story_teller': pipeline("text-generation", model="gpt2")
323
+ }"""
type2.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image, ImageDraw
3
+ from transformers import pipeline
4
+
5
+ # Tiny models only
6
+ @st.cache_resource
7
+ def load_models():
8
+ return {
9
+ # Tiny object classifier (5MB)
10
+ #'detector': pipeline("image-classification", model="google/mobilenet_v1.0_224"),
11
+
12
+ # Micro captioning model (45MB)
13
+ #'captioner': pipeline("image-to-text", model="bipin/image-caption-generator"),
14
+
15
+ # Nano story generator (33MB)
16
+ 'story_teller': pipeline("text-generation", model="sshleifer/tiny-gpt2")
17
+ }
18
+
19
+ def analyze_image(image, models):
20
+ """Combined analysis to minimize model loads"""
21
+ results = {}
22
+
23
+ # Object classification (not detection)
24
+ with st.spinner("Identifying contents..."):
25
+ results['objects'] = models['detector'](image)
26
+
27
+ # Image captioning
28
+ with st.spinner("Generating caption..."):
29
+ results['caption'] = models['captioner'](image)[0]['generated_text']
30
+
31
+ return results
32
+
33
+ def generate_story(caption, models):
34
+ """Generate short story"""
35
+ return models['story_teller'](
36
+ f"Write a 3-sentence story about: {caption}",
37
+ max_length=100,
38
+ do_sample=True,
39
+ temperature=0.7
40
+ )[0]['generated_text']
41
+
42
+ def main():
43
+ st.title("📱 Nano AI Image Analyzer")
44
+
45
+ uploaded_file = st.file_uploader("Choose image...", type=["jpg", "png"])
46
+
47
+ if uploaded_file:
48
+ image = Image.open(uploaded_file).convert("RGB")
49
+ st.image(image, use_column_width=True)
50
+
51
+ models = load_models()
52
+ analysis = None
53
+
54
+ col1, col2, col3 = st.columns(3)
55
+
56
+ with col1:
57
+ if st.button("🔍 Analyze", key="analyze"):
58
+ analysis = analyze_image(image, models)
59
+ st.session_state.analysis = analysis
60
+
61
+ st.subheader("Main Objects")
62
+ for obj in analysis['objects'][:3]:
63
+ st.write(f"- {obj['label']} ({obj['score']:.0%})")
64
+
65
+ with col2:
66
+ if st.button("📝 Describe", key="describe"):
67
+ if 'analysis' not in st.session_state:
68
+ st.warning("Analyze first!")
69
+ else:
70
+ st.subheader("Caption")
71
+ st.write(st.session_state.analysis['caption'])
72
+
73
+ with col3:
74
+ if st.button("📖 Mini Story", key="story"):
75
+ if 'analysis' not in st.session_state:
76
+ st.warning("Analyze first!")
77
+ else:
78
+ story = generate_story(
79
+ st.session_state.analysis['caption'],
80
+ models
81
+ )
82
+ st.subheader("Short Story")
83
+ st.write(story)
84
+
85
+ if __name__ == "__main__":
86
+ main()
type3.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ from transformers import BlipProcessor, Blip2ForConditionalGeneration,BlipForQuestionAnswering
4
+ import torch
5
+
6
+ @st.cache_resource
7
+ def load_blip_model():
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
10
+ model = BlipForQuestionAnswering.from_pretrained(
11
+ "Salesforce/blip-vqa-base",
12
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
13
+ ).to(device)
14
+ return processor, model, device
15
+
16
+ def answer_question(image, question, processor, model, device):
17
+ inputs = processor(image, question, return_tensors="pt").to(device)
18
+ outputs = model.generate(**inputs, max_new_tokens=100)
19
+ return processor.decode(outputs[0], skip_special_tokens=True)
20
+
21
+ # Streamlit App
22
+ def main():
23
+ st.title("Image Chat Assistant")
24
+
25
+ # Load model
26
+ processor, model, device = load_blip_model()
27
+
28
+ # Image upload
29
+ uploaded_file = st.file_uploader("Upload image", type=["jpg", "png", "jpeg"])
30
+
31
+
32
+ if uploaded_file:
33
+ image = Image.open(uploaded_file)
34
+ st.image(image, use_column_width=True)
35
+
36
+ col1, col2, col3 = st.columns([0.33,0.33,0.33])
37
+
38
+ with col1:
39
+ detect= st.button("🔍 Detect Objects", key="btn1")
40
+
41
+ with col2:
42
+ describe= st.button("📝 Describe Image", key="btn2")
43
+ with col3:
44
+ story= st.button("📖 Generate Story", key="btn3")
45
+
46
+ # Chat interface
47
+ if "messages" not in st.session_state:
48
+ st.session_state.messages = []
49
+
50
+ chat_container = st.container(height=400)
51
+ with chat_container:
52
+
53
+ for message in st.session_state.messages:
54
+ with st.chat_message(message["role"]):
55
+ st.markdown(message["content"])
56
+
57
+ if prompt := st.chat_input("Ask about the image"):
58
+ st.session_state.messages.append({"role": "user", "content": prompt})
59
+ with st.chat_message("user"):
60
+ st.markdown(prompt)
61
+
62
+ with st.chat_message("assistant"):
63
+ with st.spinner("Thinking..."):
64
+ response = answer_question(image, prompt, processor, model, device)
65
+ #response= "response sample"
66
+ st.markdown(response)
67
+ st.session_state.messages.append({"role": "assistant", "content": response})
68
+
69
+ if __name__ == "__main__":
70
+ main()
type4.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import DetrImageProcessor, DetrForObjectDetection
3
+ from PIL import Image, ImageDraw
4
+ import torch
5
+ import re
6
+
7
+ @st.cache_resource
8
+ def load_detection_model():
9
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
10
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
11
+ return processor, model
12
+
13
+ def parse_detection_text(detection_text):
14
+ """Robust parsing of detection text with error handling"""
15
+ detections = []
16
+ pattern = r'\[([\d\s,]+)\]\s+([a-zA-Z\s]+)\s+([\d.]+)'
17
+
18
+ for line in detection_text.split('\n'):
19
+ if not line.strip():
20
+ continue
21
+
22
+ try:
23
+ match = re.match(pattern, line)
24
+ if match:
25
+ coords = [int(x.strip()) for x in match.group(1).split(',')]
26
+ label = match.group(2).strip()
27
+ score = float(match.group(3))
28
+
29
+ if len(coords) == 4:
30
+ detections.append({
31
+ 'box': {'xmin': coords[0], 'ymin': coords[1],
32
+ 'xmax': coords[2], 'ymax': coords[3]},
33
+ 'label': label,
34
+ 'score': score
35
+ })
36
+ except (ValueError, AttributeError) as e:
37
+ st.warning(f"Skipping malformed detection line: {line}")
38
+ continue
39
+
40
+ return detections
41
+
42
+ def detect_objects(image, processor, model):
43
+ """Run DETR object detection with proper error handling"""
44
+ try:
45
+ inputs = processor(images=image, return_tensors="pt")
46
+ outputs = model(**inputs)
47
+
48
+ target_sizes = torch.tensor([image.size[::-1]])
49
+ results = processor.post_process_object_detection(
50
+ outputs,
51
+ target_sizes=target_sizes,
52
+ threshold=0.7
53
+ )[0]
54
+
55
+ detection_text = ""
56
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
57
+ detection_text += f"[{int(box[0])}, {int(box[1])}, {int(box[2])}, {int(box[3])}] " \
58
+ f"{model.config.id2label[label.item()]} {score.item()}\n"
59
+
60
+ return detection_text, results
61
+
62
+ except Exception as e:
63
+ st.error(f"Detection failed: {str(e)}")
64
+ return "", None
65
+
66
+ def draw_boxes(image, detections):
67
+ """Draw bounding boxes with different colors for different classes"""
68
+ draw = ImageDraw.Draw(image)
69
+ color_map = {
70
+ 'person': 'red',
71
+ 'cell phone': 'blue',
72
+ 'default': 'green'
73
+ }
74
+
75
+ for det in detections:
76
+ box = det['box']
77
+ label = det['label']
78
+ color = color_map.get(label.lower(), color_map['default'])
79
+
80
+ draw.rectangle(
81
+ [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
82
+ outline=color,
83
+ width=3
84
+ )
85
+ draw.text(
86
+ (box['xmin'], box['ymin'] - 15),
87
+ f"{label} ({det['score']:.2f})",
88
+ fill=color
89
+ )
90
+ return image
91
+
92
+ def main():
93
+ st.title("Object Detection with DETR")
94
+ processor, model = load_detection_model()
95
+
96
+ uploaded_file = st.file_uploader("Upload image", type=["jpg", "png", "jpeg"])
97
+
98
+ if uploaded_file:
99
+ image = Image.open(uploaded_file)
100
+ st.image(image, caption="Original Image", use_column_width=True)
101
+
102
+ if st.button("Detect Objects"):
103
+ with st.spinner("Detecting objects..."):
104
+ detection_text, results = detect_objects(image, processor, model)
105
+
106
+ if detection_text:
107
+ st.subheader("Detection Results")
108
+
109
+ # Show raw detections
110
+ with st.expander("Raw Detection Output"):
111
+ st.text(detection_text)
112
+
113
+ # Show parsed results
114
+ detections = parse_detection_text(detection_text)
115
+ if detections:
116
+ annotated_image = draw_boxes(image.copy(), detections)
117
+ st.image(annotated_image, caption="Detected Objects", use_column_width=True)
118
+
119
+ # Display in table
120
+ st.subheader("Detected Objects")
121
+ st.table([
122
+ {
123
+ "Object": d["label"],
124
+ "Confidence": f"{d['score']:.2%}",
125
+ "Position": f"({d['box']['xmin']}, {d['box']['ymin']}) to ({d['box']['xmax']}, {d['box']['ymax']})"
126
+ }
127
+ for d in detections
128
+ ])
129
+ else:
130
+ st.warning("No valid detections found")
131
+
132
+ if __name__ == "__main__":
133
+ main()