Gabriel Okiri
Initial commit
4bb9d41
raw
history blame
405 Bytes
from dataclasses import dataclass
from typing import List, Optional
import torch
@dataclass
class ModelConfig:
model_name: str = "gpt2"
max_length: int = 128
batch_size: int = 16
learning_rate: float = 2e-5
num_train_epochs: int = 3
languages: List[str] = ("YORUBA", "IGBO", "HAUSA")
device: str = "cuda" if torch.cuda.is_available() else "cpu"
output_dir: str = "outputs"