Spaces:
Sleeping
Sleeping
# 1. app/model/config.py | |
from dataclasses import dataclass | |
from typing import List, Optional | |
import torch | |
class ModelConfig: | |
model_name: str = "gpt2" | |
max_length: int = 128 | |
batch_size: int = 16 | |
learning_rate: float = 2e-5 | |
num_train_epochs: int = 3 | |
languages: List[str] = ("YORUBA", "IGBO", "HAUSA") | |
device: str = "cuda" if torch.cuda.is_available() else "cpu" | |
output_dir: str = "outputs" | |
# app/model/model.py | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
from .config import ModelConfig | |
class NigerianLanguageModel: | |
def __init__(self, config: ModelConfig): | |
self.config = config | |
self.setup_model() | |
def setup_model(self): | |
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name) | |
self.model = AutoModelForCausalLM.from_pretrained(self.config.model_name) | |
self._setup_special_tokens() | |
self.model.to(self.config.device) | |
def _setup_special_tokens(self): | |
special_tokens = { | |
"additional_special_tokens": [f"[{lang}]" for lang in self.config.languages] | |
} | |
self.tokenizer.add_special_tokens(special_tokens) | |
self.model.resize_token_embeddings(len(self.tokenizer)) |