tristan-deep commited on
Commit
609acff
·
1 Parent(s): 2279ee0

force jax backend

Browse files
Files changed (3) hide show
  1. app.py +2 -0
  2. main.py +3 -0
  3. requirements.txt +2 -1
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
2
  import time
3
 
 
 
4
  import gradio as gr
5
  import jax
6
  import numpy as np
 
1
  import os
2
  import time
3
 
4
+ os.environ["KERAS_BACKEND"] = "jax"
5
+
6
  import gradio as gr
7
  import jax
8
  import numpy as np
main.py CHANGED
@@ -1,6 +1,9 @@
1
  import copy
 
2
  from pathlib import Path
3
 
 
 
4
  import jax
5
  import keras
6
  import matplotlib.pyplot as plt
 
1
  import copy
2
+ import os
3
  from pathlib import Path
4
 
5
+ os.environ["KERAS_BACKEND"] = "jax"
6
+
7
  import jax
8
  import keras
9
  import matplotlib.pyplot as plt
requirements.txt CHANGED
@@ -3,4 +3,5 @@ jax[cuda12]
3
  tyro
4
  optuna
5
  gradio
6
- spaces
 
 
3
  tyro
4
  optuna
5
  gradio
6
+ spaces
7
+ keras==3.11