alozowski HF staff commited on
Commit
fb82c68
·
1 Parent(s): 04e1112

Correct gptq approach [wip]

Browse files
Files changed (1) hide show
  1. backend/app/utils/model_validation.py +108 -19
backend/app/utils/model_validation.py CHANGED
@@ -17,6 +17,8 @@ class ModelValidator:
17
  self.token = HF_TOKEN
18
  self.api = HfApi(token=self.token)
19
  self.headers = {"Authorization": f"Bearer {self.token}"} if self.token else {}
 
 
20
 
21
  async def check_model_card(self, model_id: str) -> Tuple[bool, str, Optional[Dict[str, Any]]]:
22
  """Check if model has a valid model card"""
@@ -86,48 +88,135 @@ class ModelValidator:
86
  base_model: str,
87
  revision: str
88
  ) -> Tuple[Optional[float], Optional[str]]:
89
- """Get model size in billions of parameters"""
90
  try:
91
- logger.info(LogFormatter.info(f"Checking model size for {model_info.modelId}"))
92
-
93
  # Check if model is adapter
94
- is_adapter = any(s.rfilename == "adapter_config.json" for s in model_info.siblings if hasattr(s, 'rfilename'))
 
 
 
 
95
 
96
- # Try to get size from safetensors first
97
  model_size = None
98
-
99
  if is_adapter and base_model:
100
- # For adapters, we need both adapter and base model sizes
101
- adapter_meta = await self.get_safetensors_metadata(model_info.id, is_adapter=True, revision=revision)
102
- base_meta = await self.get_safetensors_metadata(base_model, revision="main")
103
-
 
 
 
 
 
 
 
104
  if adapter_meta and base_meta:
105
  adapter_size = sum(adapter_meta.parameter_count.values())
106
  base_size = sum(base_meta.parameter_count.values())
107
  model_size = adapter_size + base_size
108
  else:
109
- # For regular models, just get the model size
110
- meta = await self.get_safetensors_metadata(model_info.id, revision=revision)
 
 
 
111
  if meta:
112
- model_size = sum(meta.parameter_count.values()) # total params
113
 
114
  if model_size is None:
115
- # If model size could not be determined, return an error
116
  return None, "Model size could not be determined"
117
 
118
- # Adjust size for GPTQ models
119
- size_factor = 8 if (precision == "GPTQ" or "gptq" in model_info.id.lower()) else 1
120
- model_size = model_size / 1e9 # Convert to billions, assuming float16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  model_size = round(size_factor * model_size, 3)
122
 
123
- logger.info(LogFormatter.success(f"Model size: {model_size}B parameters"))
124
  return model_size, None
125
 
126
  except Exception as e:
127
- logger.error(LogFormatter.error(f"Error while determining model size: {e}"))
128
  return None, str(e)
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  async def check_chat_template(
132
  self,
133
  model_id: str,
 
17
  self.token = HF_TOKEN
18
  self.api = HfApi(token=self.token)
19
  self.headers = {"Authorization": f"Bearer {self.token}"} if self.token else {}
20
+ self.logger = logger
21
+ self.config_cache = {}
22
 
23
  async def check_model_card(self, model_id: str) -> Tuple[bool, str, Optional[Dict[str, Any]]]:
24
  """Check if model has a valid model card"""
 
88
  base_model: str,
89
  revision: str
90
  ) -> Tuple[Optional[float], Optional[str]]:
 
91
  try:
92
+ self.logger.info(LogFormatter.info(f"Checking model size for {model_info.modelId}"))
93
+
94
  # Check if model is adapter
95
+ is_adapter = any(
96
+ s.rfilename == "adapter_config.json"
97
+ for s in model_info.siblings
98
+ if hasattr(s, 'rfilename')
99
+ )
100
 
101
+ # Get model size from safetensors
102
  model_size = None
 
103
  if is_adapter and base_model:
104
+ # For adapters, combine adapter and base model sizes
105
+ adapter_meta = await self.get_safetensors_metadata(
106
+ model_info.id,
107
+ is_adapter=True,
108
+ revision=revision
109
+ )
110
+ base_meta = await self.get_safetensors_metadata(
111
+ base_model,
112
+ revision="main"
113
+ )
114
+
115
  if adapter_meta and base_meta:
116
  adapter_size = sum(adapter_meta.parameter_count.values())
117
  base_size = sum(base_meta.parameter_count.values())
118
  model_size = adapter_size + base_size
119
  else:
120
+ # For regular models
121
+ meta = await self.get_safetensors_metadata(
122
+ model_info.id,
123
+ revision=revision
124
+ )
125
  if meta:
126
+ model_size = sum(meta.parameter_count.values())
127
 
128
  if model_size is None:
 
129
  return None, "Model size could not be determined"
130
 
131
+ if model_size <= 0:
132
+ return None, "Invalid model size: must be positive"
133
+
134
+ # Only proceed with GPTQ adjustments if necessary
135
+ if precision == "GPTQ" or "gptq" in model_info.id.lower():
136
+ precision_bits = await self._get_precision_bits(
137
+ model_info.id,
138
+ revision
139
+ )
140
+
141
+ if precision_bits is None:
142
+ return None, "Failed to determine precision bits"
143
+
144
+ # FIXED: We should divide by the size factor since quantization reduces size
145
+ size_factor = precision_bits / 32 # For 2-bit this is 2/32 = 1/16
146
+ self.logger.info(LogFormatter.info(
147
+ f"Applying quantization factor: {size_factor}x (bits={precision_bits})"
148
+ ))
149
+ else:
150
+ size_factor = 1
151
+
152
+ # Convert to billions and apply quantization factor
153
+ model_size = model_size / 1e9 # Convert to billions
154
  model_size = round(size_factor * model_size, 3)
155
 
156
+ self.logger.info(LogFormatter.success(f"Model size: {model_size}B parameters"))
157
  return model_size, None
158
 
159
  except Exception as e:
160
+ self.logger.error(LogFormatter.error(f"Error while determining model size: {e}"))
161
  return None, str(e)
162
 
163
+ async def _get_precision_bits(
164
+ self,
165
+ model_id: str,
166
+ revision: str
167
+ ) -> Optional[int]:
168
+ """Get the precision bits from config.json, with caching."""
169
+ # Check cache first
170
+ cache_key = f"{model_id}_{revision}"
171
+ if cache_key in self.config_cache:
172
+ config_data = self.config_cache[cache_key]
173
+ else:
174
+ # Fetch config.json
175
+ config_url = f"https://huggingface.co/{model_id}/raw/{revision}/config.json"
176
+ try:
177
+ async with aiohttp.ClientSession() as session:
178
+ async with session.get(config_url, headers=self.headers) as response:
179
+ if response.status != 200:
180
+ self.logger.warning(LogFormatter.warning(
181
+ f"Failed to fetch config.json from {config_url}. Defaulting to 4 bits for GPTQ."
182
+ ))
183
+ return 4
184
+
185
+ # Try to parse response as JSON regardless of content type
186
+ try:
187
+ text = await response.text()
188
+ config_data = json.loads(text)
189
+ self.config_cache[cache_key] = config_data
190
+ except json.JSONDecodeError:
191
+ self.logger.warning(LogFormatter.warning(
192
+ f"Failed to parse config.json from {config_url}. Defaulting to 4 bits for GPTQ."
193
+ ))
194
+ return 4
195
+
196
+ except Exception as e:
197
+ self.logger.error(LogFormatter.error(
198
+ f"Error fetching config.json: {e}. Defaulting to 4 bits."
199
+ ))
200
+ return 4
201
+
202
+ # Get precision bits from config
203
+ try:
204
+ precision_bits = config_data.get("quantization_config", {}).get("bits", 4)
205
 
206
+ # Validate precision bits
207
+ if precision_bits not in [2, 3, 4, 8]:
208
+ self.logger.error(LogFormatter.error(
209
+ f"Unsupported precision_bits: {precision_bits}"
210
+ ))
211
+ return None
212
+
213
+ return precision_bits
214
+ except Exception as e:
215
+ self.logger.error(LogFormatter.error(
216
+ f"Error extracting precision bits from config: {e}. Defaulting to 4 bits."
217
+ ))
218
+ return 4
219
+
220
  async def check_chat_template(
221
  self,
222
  model_id: str,