--- title: README emoji: ⚡ colorFrom: blue colorTo: green sdk: static pinned: false --- # MLX KAN A community org for model weights compatible with `mlx-kan` powered by MLX. GitHub link: https://github.com/Goekdeniz-Guelmez/mlx-kan These are weights converted from pytorch and ready to be used. ## How to install ``` pip install mlx-kan ``` ## Models To load a model with pre-trained weights or create one from scratch: ```python from mlx_kan.kan import KAN # Initialize and use KAN kan_model = KAN([in_features * out_features] + [hidden_dim] * (num_layers - 1) + [num_classes]) def train(model, train_set, train_labels, num_epochs=100): optimizer = optim.AdamW(learning_rate=0.0004, weight_decay=0.003) # Initialize a new optimizer for each model loss_and_grad_fn = nn.value_and_grad(model, loss_fn) # For 1 step loss, grads = loss_and_grad_fn(model, train_set, train_labels) optimizer.update(model, grads) mx.eval(model.parameters(), optimizer.state) avg_loss = total_loss += loss.item() # Update grid points here for name, layer in model.__dict__.items(): if isinstance(layer, KANLinear): with mx.no_grad(): layer.update_grid(train_set) ```