Goekdeniz-Guelmez commited on
Commit
315b458
·
verified ·
1 Parent(s): 39de75b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +42 -4
README.md CHANGED
@@ -1,10 +1,48 @@
1
  ---
2
  title: README
3
- emoji: 🏆
4
- colorFrom: red
5
- colorTo: pink
6
  sdk: static
7
  pinned: false
8
  ---
9
 
10
- Edit this `README.md` markdown file to author your organization card.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: README
3
+ emoji:
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: static
7
  pinned: false
8
  ---
9
 
10
+ # MLX KAN
11
+
12
+ A community org for model weights compatible with `mlx-kan` powered by MLX.
13
+
14
+ GitHub link: https://github.com/Goekdeniz-Guelmez/mlx-kan
15
+
16
+ These are weights converted from pytorch and ready to be used.
17
+
18
+ ## How to install
19
+
20
+ ```
21
+ pip install mlx-kan
22
+ ```
23
+
24
+ ## Models
25
+
26
+ To load a model with pre-trained weights or create one from scratch:
27
+ ```python
28
+ from mlx_kan.kan import KAN
29
+
30
+ # Initialize and use KAN
31
+ kan_model = KAN([in_features * out_features] + [hidden_dim] * (num_layers - 1) + [num_classes])
32
+
33
+ def train(model, train_set, train_labels, num_epochs=100):
34
+ optimizer = optim.AdamW(learning_rate=0.0004, weight_decay=0.003) # Initialize a new optimizer for each model
35
+ loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
36
+
37
+ # For 1 step
38
+ loss, grads = loss_and_grad_fn(model, train_set, train_labels)
39
+ optimizer.update(model, grads)
40
+ mx.eval(model.parameters(), optimizer.state)
41
+ avg_loss = total_loss += loss.item()
42
+
43
+ # Update grid points here
44
+ for name, layer in model.__dict__.items():
45
+ if isinstance(layer, KANLinear):
46
+ with mx.no_grad():
47
+ layer.update_grid(train_set)
48
+ ```