Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Correct gptq approach [wip]
Browse files
backend/app/utils/model_validation.py
CHANGED
@@ -17,6 +17,8 @@ class ModelValidator:
|
|
17 |
self.token = HF_TOKEN
|
18 |
self.api = HfApi(token=self.token)
|
19 |
self.headers = {"Authorization": f"Bearer {self.token}"} if self.token else {}
|
|
|
|
|
20 |
|
21 |
async def check_model_card(self, model_id: str) -> Tuple[bool, str, Optional[Dict[str, Any]]]:
|
22 |
"""Check if model has a valid model card"""
|
@@ -86,48 +88,135 @@ class ModelValidator:
|
|
86 |
base_model: str,
|
87 |
revision: str
|
88 |
) -> Tuple[Optional[float], Optional[str]]:
|
89 |
-
"""Get model size in billions of parameters"""
|
90 |
try:
|
91 |
-
logger.info(LogFormatter.info(f"Checking model size for {model_info.modelId}"))
|
92 |
-
|
93 |
# Check if model is adapter
|
94 |
-
is_adapter = any(
|
|
|
|
|
|
|
|
|
95 |
|
96 |
-
#
|
97 |
model_size = None
|
98 |
-
|
99 |
if is_adapter and base_model:
|
100 |
-
# For adapters,
|
101 |
-
adapter_meta = await self.get_safetensors_metadata(
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
if adapter_meta and base_meta:
|
105 |
adapter_size = sum(adapter_meta.parameter_count.values())
|
106 |
base_size = sum(base_meta.parameter_count.values())
|
107 |
model_size = adapter_size + base_size
|
108 |
else:
|
109 |
-
# For regular models
|
110 |
-
meta = await self.get_safetensors_metadata(
|
|
|
|
|
|
|
111 |
if meta:
|
112 |
-
model_size = sum(meta.parameter_count.values())
|
113 |
|
114 |
if model_size is None:
|
115 |
-
# If model size could not be determined, return an error
|
116 |
return None, "Model size could not be determined"
|
117 |
|
118 |
-
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
model_size = round(size_factor * model_size, 3)
|
122 |
|
123 |
-
logger.info(LogFormatter.success(f"Model size: {model_size}B parameters"))
|
124 |
return model_size, None
|
125 |
|
126 |
except Exception as e:
|
127 |
-
logger.error(LogFormatter.error(f"Error while determining model size: {e}"))
|
128 |
return None, str(e)
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
async def check_chat_template(
|
132 |
self,
|
133 |
model_id: str,
|
|
|
17 |
self.token = HF_TOKEN
|
18 |
self.api = HfApi(token=self.token)
|
19 |
self.headers = {"Authorization": f"Bearer {self.token}"} if self.token else {}
|
20 |
+
self.logger = logger
|
21 |
+
self.config_cache = {}
|
22 |
|
23 |
async def check_model_card(self, model_id: str) -> Tuple[bool, str, Optional[Dict[str, Any]]]:
|
24 |
"""Check if model has a valid model card"""
|
|
|
88 |
base_model: str,
|
89 |
revision: str
|
90 |
) -> Tuple[Optional[float], Optional[str]]:
|
|
|
91 |
try:
|
92 |
+
self.logger.info(LogFormatter.info(f"Checking model size for {model_info.modelId}"))
|
93 |
+
|
94 |
# Check if model is adapter
|
95 |
+
is_adapter = any(
|
96 |
+
s.rfilename == "adapter_config.json"
|
97 |
+
for s in model_info.siblings
|
98 |
+
if hasattr(s, 'rfilename')
|
99 |
+
)
|
100 |
|
101 |
+
# Get model size from safetensors
|
102 |
model_size = None
|
|
|
103 |
if is_adapter and base_model:
|
104 |
+
# For adapters, combine adapter and base model sizes
|
105 |
+
adapter_meta = await self.get_safetensors_metadata(
|
106 |
+
model_info.id,
|
107 |
+
is_adapter=True,
|
108 |
+
revision=revision
|
109 |
+
)
|
110 |
+
base_meta = await self.get_safetensors_metadata(
|
111 |
+
base_model,
|
112 |
+
revision="main"
|
113 |
+
)
|
114 |
+
|
115 |
if adapter_meta and base_meta:
|
116 |
adapter_size = sum(adapter_meta.parameter_count.values())
|
117 |
base_size = sum(base_meta.parameter_count.values())
|
118 |
model_size = adapter_size + base_size
|
119 |
else:
|
120 |
+
# For regular models
|
121 |
+
meta = await self.get_safetensors_metadata(
|
122 |
+
model_info.id,
|
123 |
+
revision=revision
|
124 |
+
)
|
125 |
if meta:
|
126 |
+
model_size = sum(meta.parameter_count.values())
|
127 |
|
128 |
if model_size is None:
|
|
|
129 |
return None, "Model size could not be determined"
|
130 |
|
131 |
+
if model_size <= 0:
|
132 |
+
return None, "Invalid model size: must be positive"
|
133 |
+
|
134 |
+
# Only proceed with GPTQ adjustments if necessary
|
135 |
+
if precision == "GPTQ" or "gptq" in model_info.id.lower():
|
136 |
+
precision_bits = await self._get_precision_bits(
|
137 |
+
model_info.id,
|
138 |
+
revision
|
139 |
+
)
|
140 |
+
|
141 |
+
if precision_bits is None:
|
142 |
+
return None, "Failed to determine precision bits"
|
143 |
+
|
144 |
+
# FIXED: We should divide by the size factor since quantization reduces size
|
145 |
+
size_factor = precision_bits / 32 # For 2-bit this is 2/32 = 1/16
|
146 |
+
self.logger.info(LogFormatter.info(
|
147 |
+
f"Applying quantization factor: {size_factor}x (bits={precision_bits})"
|
148 |
+
))
|
149 |
+
else:
|
150 |
+
size_factor = 1
|
151 |
+
|
152 |
+
# Convert to billions and apply quantization factor
|
153 |
+
model_size = model_size / 1e9 # Convert to billions
|
154 |
model_size = round(size_factor * model_size, 3)
|
155 |
|
156 |
+
self.logger.info(LogFormatter.success(f"Model size: {model_size}B parameters"))
|
157 |
return model_size, None
|
158 |
|
159 |
except Exception as e:
|
160 |
+
self.logger.error(LogFormatter.error(f"Error while determining model size: {e}"))
|
161 |
return None, str(e)
|
162 |
|
163 |
+
async def _get_precision_bits(
|
164 |
+
self,
|
165 |
+
model_id: str,
|
166 |
+
revision: str
|
167 |
+
) -> Optional[int]:
|
168 |
+
"""Get the precision bits from config.json, with caching."""
|
169 |
+
# Check cache first
|
170 |
+
cache_key = f"{model_id}_{revision}"
|
171 |
+
if cache_key in self.config_cache:
|
172 |
+
config_data = self.config_cache[cache_key]
|
173 |
+
else:
|
174 |
+
# Fetch config.json
|
175 |
+
config_url = f"https://huggingface.co/{model_id}/raw/{revision}/config.json"
|
176 |
+
try:
|
177 |
+
async with aiohttp.ClientSession() as session:
|
178 |
+
async with session.get(config_url, headers=self.headers) as response:
|
179 |
+
if response.status != 200:
|
180 |
+
self.logger.warning(LogFormatter.warning(
|
181 |
+
f"Failed to fetch config.json from {config_url}. Defaulting to 4 bits for GPTQ."
|
182 |
+
))
|
183 |
+
return 4
|
184 |
+
|
185 |
+
# Try to parse response as JSON regardless of content type
|
186 |
+
try:
|
187 |
+
text = await response.text()
|
188 |
+
config_data = json.loads(text)
|
189 |
+
self.config_cache[cache_key] = config_data
|
190 |
+
except json.JSONDecodeError:
|
191 |
+
self.logger.warning(LogFormatter.warning(
|
192 |
+
f"Failed to parse config.json from {config_url}. Defaulting to 4 bits for GPTQ."
|
193 |
+
))
|
194 |
+
return 4
|
195 |
+
|
196 |
+
except Exception as e:
|
197 |
+
self.logger.error(LogFormatter.error(
|
198 |
+
f"Error fetching config.json: {e}. Defaulting to 4 bits."
|
199 |
+
))
|
200 |
+
return 4
|
201 |
+
|
202 |
+
# Get precision bits from config
|
203 |
+
try:
|
204 |
+
precision_bits = config_data.get("quantization_config", {}).get("bits", 4)
|
205 |
|
206 |
+
# Validate precision bits
|
207 |
+
if precision_bits not in [2, 3, 4, 8]:
|
208 |
+
self.logger.error(LogFormatter.error(
|
209 |
+
f"Unsupported precision_bits: {precision_bits}"
|
210 |
+
))
|
211 |
+
return None
|
212 |
+
|
213 |
+
return precision_bits
|
214 |
+
except Exception as e:
|
215 |
+
self.logger.error(LogFormatter.error(
|
216 |
+
f"Error extracting precision bits from config: {e}. Defaulting to 4 bits."
|
217 |
+
))
|
218 |
+
return 4
|
219 |
+
|
220 |
async def check_chat_template(
|
221 |
self,
|
222 |
model_id: str,
|