|
--- |
|
title: README |
|
emoji: ⚡ |
|
colorFrom: blue |
|
colorTo: green |
|
sdk: static |
|
pinned: false |
|
--- |
|
|
|
# MLX KAN |
|
|
|
A community org for model weights compatible with `mlx-kan` powered by MLX. |
|
|
|
GitHub link: https://github.com/Goekdeniz-Guelmez/mlx-kan |
|
|
|
These are weights converted from pytorch and ready to be used. |
|
|
|
## How to install |
|
|
|
``` |
|
pip install mlx-kan |
|
``` |
|
|
|
## Models |
|
|
|
To load a model with pre-trained weights or create one from scratch: |
|
```python |
|
from mlx_kan.kan import KAN |
|
|
|
# Initialize and use KAN |
|
kan_model = KAN([in_features * out_features] + [hidden_dim] * (num_layers - 1) + [num_classes]) |
|
|
|
def train(model, train_set, train_labels, num_epochs=100): |
|
optimizer = optim.AdamW(learning_rate=0.0004, weight_decay=0.003) # Initialize a new optimizer for each model |
|
loss_and_grad_fn = nn.value_and_grad(model, loss_fn) |
|
|
|
# For 1 step |
|
loss, grads = loss_and_grad_fn(model, train_set, train_labels) |
|
optimizer.update(model, grads) |
|
mx.eval(model.parameters(), optimizer.state) |
|
avg_loss = total_loss += loss.item() |
|
|
|
# Update grid points here |
|
for name, layer in model.__dict__.items(): |
|
if isinstance(layer, KANLinear): |
|
with mx.no_grad(): |
|
layer.update_grid(train_set) |
|
``` |