dnnsdunca commited on
Commit
9cb6f8b
·
verified ·
1 Parent(s): d8b629c

training dataset

Browse files

model training data

Files changed (1) hide show
  1. training +21 -0
training ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tensorflow.keras.models import Model
2
+ from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
3
+ from tensorflow.keras.optimizers import Adam
4
+
5
+ # Assuming base_model is your pre-loaded Mixtral 8x7b model
6
+ base_model.trainable = False # Freeze the layers
7
+
8
+ # Add custom layers
9
+ x = base_model.output
10
+ x = GlobalAveragePooling2D()(x)
11
+ x = Dense(1024, activation='relu')(x)
12
+ predictions = Dense(num_classes, activation='softmax')(x)
13
+
14
+ # This is the model we will train
15
+ model = Model(inputs=base_model.input, outputs=predictions)
16
+
17
+ # Compile the model
18
+ model.compile(optimizer=Adam(lr=0.001), loss='categorical_crossentropy', metrics=['accuracy'])
19
+
20
+ # Train the model
21
+ model.fit(train_dataset, validation_data=val_dataset, epochs=10)