AC2513 commited on
Commit
3a8e58d
·
1 Parent(s): 8808cc1

added media size limit

Browse files
Files changed (1) hide show
  1. app.py +34 -2
app.py CHANGED
@@ -24,6 +24,9 @@ load_dotenv(dotenv_path)
24
  model_12_id = os.getenv("MODEL_12_ID", "google/gemma-3-12b-it")
25
  model_3n_id = os.getenv("MODEL_3N_ID", "google/gemma-3n-E4B-it")
26
 
 
 
 
27
  input_processor = Gemma3Processor.from_pretrained(model_12_id)
28
 
29
  model_12 = Gemma3ForConditionalGeneration.from_pretrained(
@@ -41,7 +44,26 @@ model_3n = Gemma3nForConditionalGeneration.from_pretrained(
41
  )
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def get_frames(video_path: str, max_images: int) -> list[tuple[Image.Image, float]]:
 
 
 
45
  frames: list[tuple[Image.Image, float]] = []
46
  capture = cv2.VideoCapture(video_path)
47
  if not capture.isOpened():
@@ -91,14 +113,24 @@ def process_user_input(message: dict, max_images: int) -> list[dict]:
91
  result_content = [{"type": "text", "text": message["text"]}]
92
 
93
  for file_path in message["files"]:
 
 
 
 
 
 
 
94
  if file_path.endswith((".mp4", ".mov")):
95
- result_content = [*result_content, *process_video(file_path, max_images)]
 
 
 
 
96
  else:
97
  result_content = [*result_content, {"type": "image", "url": file_path}]
98
 
99
  return result_content
100
 
101
-
102
  def process_history(history: list[dict]) -> list[dict]:
103
  messages = []
104
  content_buffer = []
 
24
  model_12_id = os.getenv("MODEL_12_ID", "google/gemma-3-12b-it")
25
  model_3n_id = os.getenv("MODEL_3N_ID", "google/gemma-3n-E4B-it")
26
 
27
+ MAX_VIDEO_SIZE = 100 * 1024 * 1024 # 100 MB
28
+ MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10 MB
29
+
30
  input_processor = Gemma3Processor.from_pretrained(model_12_id)
31
 
32
  model_12 = Gemma3ForConditionalGeneration.from_pretrained(
 
44
  )
45
 
46
 
47
+ def check_file_size(file_path: str) -> bool:
48
+ if not os.path.exists(file_path):
49
+ raise ValueError(f"File not found: {file_path}")
50
+
51
+ file_size = os.path.getsize(file_path)
52
+
53
+ if file_path.lower().endswith((".mp4", ".mov")):
54
+ if file_size > MAX_VIDEO_SIZE:
55
+ raise ValueError(f"Video file too large: {file_size / (1024*1024):.1f}MB. Maximum allowed: {MAX_VIDEO_SIZE / (1024*1024):.0f}MB")
56
+ else:
57
+ if file_size > MAX_IMAGE_SIZE:
58
+ raise ValueError(f"Image file too large: {file_size / (1024*1024):.1f}MB. Maximum allowed: {MAX_IMAGE_SIZE / (1024*1024):.0f}MB")
59
+
60
+ return True
61
+
62
+
63
  def get_frames(video_path: str, max_images: int) -> list[tuple[Image.Image, float]]:
64
+ # Check file size before processing
65
+ check_file_size(video_path)
66
+
67
  frames: list[tuple[Image.Image, float]] = []
68
  capture = cv2.VideoCapture(video_path)
69
  if not capture.isOpened():
 
113
  result_content = [{"type": "text", "text": message["text"]}]
114
 
115
  for file_path in message["files"]:
116
+ try:
117
+ check_file_size(file_path)
118
+ except ValueError as e:
119
+ logger.error(f"File size check failed: {e}")
120
+ result_content.append({"type": "text", "text": f"Error: {str(e)}"})
121
+ continue
122
+
123
  if file_path.endswith((".mp4", ".mov")):
124
+ try:
125
+ result_content = [*result_content, *process_video(file_path, max_images)]
126
+ except Exception as e:
127
+ logger.error(f"Video processing failed: {e}")
128
+ result_content.append({"type": "text", "text": f"Error processing video: {str(e)}"})
129
  else:
130
  result_content = [*result_content, {"type": "image", "url": file_path}]
131
 
132
  return result_content
133
 
 
134
  def process_history(history: list[dict]) -> list[dict]:
135
  messages = []
136
  content_buffer = []