|
|
|
""" |
|
Author : Fabien FURFARO |
|
""" |
|
|
|
import logging |
|
import os |
|
import re |
|
from typing import Any, Dict, List, Optional, Union |
|
from jinja2 import Environment, FileSystemLoader |
|
|
|
import torch |
|
from transformers import AutoConfig, PretrainedConfig |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def convert_sets_to_lists(obj): |
|
"""Convert sets to list for LoRA serialized config""" |
|
if isinstance(obj, set): |
|
return list(obj) |
|
if isinstance(obj, dict): |
|
return {k: convert_sets_to_lists(v) for k, v in obj.items()} |
|
if isinstance(obj, (list, tuple)): |
|
return [convert_sets_to_lists(x) for x in obj] |
|
return obj |
|
|
|
|
|
class TpttConfig(PretrainedConfig): |
|
""" |
|
Configuration class for the TPTT model. |
|
This class merges the backbone config (e.g., Llama) with custom TPTT parameters, |
|
""" |
|
|
|
model_type = "tptt" |
|
auto_map = { |
|
"AutoModelForCausalLM": "modeling_tptt.TpttModel", |
|
"AutoConfig": "configuration_tptt.TpttConfig", |
|
} |
|
architectures = ["TpttModel"] |
|
|
|
RECURRENT_MODES = { |
|
"delta_rule": { |
|
"order": 1, |
|
"gate_type": "k", |
|
"linear": True, |
|
"trick": "derivative", |
|
}, |
|
"delta_rule_v": { |
|
"order": 1, |
|
"gate_type": "v", |
|
"linear": True, |
|
"trick": "derivative", |
|
}, |
|
"delta_rule_kv": { |
|
"order": 1, |
|
"gate_type": "kv", |
|
"linear": True, |
|
"trick": "derivative", |
|
}, |
|
"delta_rule_gelu": { |
|
"order": 1, |
|
"gate_type": "k", |
|
"linear": False, |
|
"trick": "derivative", |
|
}, |
|
"delta_product": { |
|
"order": 2, |
|
"gate_type": "k", |
|
"linear": True, |
|
"trick": "derivative", |
|
}, |
|
"delta_product_r": { |
|
"order": 2, |
|
"gate_type": "k", |
|
"linear": True, |
|
"trick": "rotative", |
|
}, |
|
"delta_product_c": { |
|
"order": 2, |
|
"gate_type": "k", |
|
"linear": True, |
|
"trick": "combined", |
|
}, |
|
} |
|
|
|
def __init__( |
|
self, |
|
base_model_config: Optional[Union[dict, PretrainedConfig]] = None, |
|
base_model_name: str = "meta-llama/Llama-3.2-1B", |
|
base_model_subfolder: Optional = None, |
|
name_or_path: Optional[str] = None, |
|
target_modules_names: Optional[List[str]] = None, |
|
operator_mode: str = "delta_rule", |
|
max_self_attn_length: Optional[ |
|
int |
|
] = None, |
|
base_scale_attn: bool = False, |
|
mag_weight: float = 0.5, |
|
cross_gate: bool = False, |
|
max_chunk_size: int = 64, |
|
linear_precision: Union[str, torch.dtype] = "float32", |
|
lora_config: Optional[dict] = None, |
|
padding_side: Optional[str] = None, |
|
bidirectional: bool = False, |
|
pooling_config: Optional[Dict[str, Any]] = None, |
|
**kwargs, |
|
): |
|
|
|
if base_model_config is not None: |
|
if isinstance(base_model_config, PretrainedConfig): |
|
base_model_config = base_model_config.to_dict() |
|
else: |
|
|
|
base_model_config = AutoConfig.from_pretrained( |
|
base_model_name, **kwargs |
|
).to_dict() |
|
|
|
for k, v in base_model_config.items(): |
|
setattr(self, k, v) |
|
|
|
self.base_model_name = base_model_name |
|
self.base_model_subfolder = base_model_subfolder |
|
|
|
if name_or_path is not None: |
|
self._name_or_path = name_or_path |
|
else: |
|
if "/" in base_model_name: |
|
self._name_or_path = "Titans-" + base_model_name.split("/", 1)[1] |
|
else: |
|
self._name_or_path = "Titans-" + base_model_name |
|
|
|
self.target_modules_names = target_modules_names or [ |
|
"attn", |
|
"self_attn", |
|
"attention", |
|
] |
|
self.operator_mode = operator_mode |
|
self.base_scale_attn = base_scale_attn |
|
self.mag_weight = mag_weight |
|
self.cross_gate = cross_gate |
|
self.max_chunk_size = max_chunk_size |
|
self.max_self_attn_length = max_self_attn_length |
|
if isinstance(linear_precision, torch.dtype): |
|
linear_precision = str(linear_precision).replace("torch.", "") |
|
self.linear_precision = linear_precision |
|
|
|
self.lora_config = lora_config |
|
if lora_config is not None: |
|
if hasattr(self.lora_config.get("peft_type"), "value"): |
|
self.lora_config["peft_type"] = self.lora_config["peft_type"].value |
|
self.lora_config = convert_sets_to_lists(self.lora_config) |
|
|
|
self.padding_side = padding_side |
|
self.bidirectional = bidirectional |
|
if self.bidirectional: |
|
print("Bidirectional is enabled, need to be uncausal and unpadded.") |
|
self.pooling_config = pooling_config |
|
|
|
super().__init__(**kwargs) |
|
|
|
self.model_type = self.__class__.model_type |
|
self.auto_map = self.__class__.auto_map |
|
self.architectures = self.__class__.architectures |
|
|
|
if self.padding_side is None: |
|
self.padding_side = "right" |
|
logger.info("Warning: padding_side is None, defaulting to 'right'.") |
|
|
|
if operator_mode not in self.__class__.RECURRENT_MODES: |
|
self.recurrent_config = parse_mode_name(operator_mode) |
|
else: |
|
self.recurrent_config = self.__class__.RECURRENT_MODES[operator_mode] |
|
logger.info("Using recurrent mode: %s", get_mode_name(**self.recurrent_config)) |
|
|
|
|
|
TpttConfig.register_for_auto_class() |
|
|
|
|
|
def parse_mode_name(name: str) -> dict: |
|
"""Parse mode to recurrent config""" |
|
if name.startswith("delta_product"): |
|
parts = name.split("_") |
|
|
|
base_len = 2 |
|
order = 2 |
|
gate_type = "k" |
|
linear = True |
|
trick = "derivative" |
|
|
|
idx = base_len |
|
|
|
if len(parts) > idx and parts[idx].isdigit(): |
|
order = int(parts[idx]) |
|
idx += 1 |
|
|
|
remaining = parts[idx:] |
|
|
|
if remaining and remaining[-1] in ("r", "c"): |
|
trick = {"r": "rotative", "c": "combined"}[remaining[-1]] |
|
remaining = remaining[:-1] |
|
|
|
if remaining and remaining[-1] == "gelu": |
|
linear = False |
|
remaining = remaining[:-1] |
|
|
|
if remaining: |
|
gate_type = "_".join(remaining) |
|
return { |
|
"order": order, |
|
"gate_type": gate_type, |
|
"linear": linear, |
|
"trick": trick, |
|
} |
|
|
|
|
|
m = re.match(r"^delta_rule(?:_(kv|v|k))?(_gelu)?$", name) |
|
if m: |
|
return { |
|
"order": 1, |
|
"gate_type": m.group(1) if m.group(1) else "k", |
|
"linear": not bool(m.group(2)), |
|
"trick": "derivative", |
|
} |
|
raise ValueError(f"Unknown mode: {name}") |
|
|
|
|
|
def get_mode_name( |
|
order: int = 1, gate_type: str = "k", linear: bool = True, trick: str = "derivative" |
|
) -> str: |
|
"""Get recurrent mode name from parameter""" |
|
base = ( |
|
"delta_rule" |
|
if order == 1 |
|
else ("delta_product" if order == 2 else f"delta_product_{order}") |
|
) |
|
parts = [] |
|
if gate_type != "k": |
|
parts.append(gate_type) |
|
if not linear: |
|
parts.append("gelu") |
|
if order >= 2 and trick != "derivative": |
|
parts.append({"rotative": "r", "combined": "c"}.get(trick, trick)) |
|
return base + (("_" + "_".join(parts)) if parts else "") |
|
|
|
|
|
def render_template(template_path: str, variables: dict) -> str: |
|
"""Load and render a Jinja2 template from any file path.""" |
|
env = Environment(loader=FileSystemLoader(os.path.dirname(template_path))) |
|
template = env.get_template(os.path.basename(template_path)) |
|
return template.render(**variables) |
|
|
|
|
|
def write_model_card(output_path: str, content: str): |
|
"""Write the generated content into README.md.""" |
|
os.makedirs(output_path, exist_ok=True) |
|
readme_path = os.path.join(output_path, "README.md") |
|
with open(readme_path, "w", encoding="utf-8") as f: |
|
f.write(content) |
|
|
|
|
|
def generate_model_card( |
|
output_path: str, |
|
config: Union[dict, object], |
|
template: Optional[ |
|
str |
|
], |
|
extra_variables: Optional[Dict] = None, |
|
): |
|
""" |
|
Generate a README.md file from a Jinja2 template and a configuration. |
|
|
|
- template can be either: |
|
* a full path to a template file |
|
* a short name (e.g., "model_card") -> will be looked up inside default_templates_dir |
|
""" |
|
if template is None: |
|
template = "model_card_template" |
|
|
|
if os.path.exists(template): |
|
template_path = template |
|
else: |
|
default_templates_dir = os.path.join(os.path.dirname(__file__), "templates") |
|
template_path = os.path.join(default_templates_dir, f"{template}.md") |
|
|
|
if not os.path.exists(template_path): |
|
raise FileNotFoundError(f"Template not found: {template_path}") |
|
|
|
variables = { |
|
"model_id": os.path.basename(output_path), |
|
"config": config, |
|
} |
|
if extra_variables: |
|
variables.update(extra_variables) |
|
|
|
content = render_template(template_path, variables) |
|
write_model_card(output_path, content) |
|
|