Commit
·
609acff
1
Parent(s):
2279ee0
force jax backend
Browse files- app.py +2 -0
- main.py +3 -0
- requirements.txt +2 -1
app.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import os
|
2 |
import time
|
3 |
|
|
|
|
|
4 |
import gradio as gr
|
5 |
import jax
|
6 |
import numpy as np
|
|
|
1 |
import os
|
2 |
import time
|
3 |
|
4 |
+
os.environ["KERAS_BACKEND"] = "jax"
|
5 |
+
|
6 |
import gradio as gr
|
7 |
import jax
|
8 |
import numpy as np
|
main.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
import copy
|
|
|
2 |
from pathlib import Path
|
3 |
|
|
|
|
|
4 |
import jax
|
5 |
import keras
|
6 |
import matplotlib.pyplot as plt
|
|
|
1 |
import copy
|
2 |
+
import os
|
3 |
from pathlib import Path
|
4 |
|
5 |
+
os.environ["KERAS_BACKEND"] = "jax"
|
6 |
+
|
7 |
import jax
|
8 |
import keras
|
9 |
import matplotlib.pyplot as plt
|
requirements.txt
CHANGED
@@ -3,4 +3,5 @@ jax[cuda12]
|
|
3 |
tyro
|
4 |
optuna
|
5 |
gradio
|
6 |
-
spaces
|
|
|
|
3 |
tyro
|
4 |
optuna
|
5 |
gradio
|
6 |
+
spaces
|
7 |
+
keras==3.11
|