Nirav-Madhani commited on
Commit
f8cb635
·
verified ·
1 Parent(s): 89208ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -35
app.py CHANGED
@@ -8,25 +8,20 @@ import os
8
  import io
9
  import base64
10
  from typing import List
11
- from fastapi.openapi.docs import get_swagger_ui_html
12
 
13
- # Set JAX to use CPU platform (adjust if GPU is needed)
14
  os.environ['JAX_PLATFORMS'] = 'cpu'
15
 
16
- # Load the model once globally
17
  model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")
18
 
19
  # Initialize FastAPI app
20
- app = FastAPI(
21
- title="Octo Model Inference API",
22
- docs_url="/" # Swagger UI at root
23
- )
24
 
25
- # Define request body model
26
  class InferenceRequest(BaseModel):
27
- image_base64: List[str] # List of base64-encoded images in time sequence
28
  task: str = "pick up the fork" # Default task
29
- window_size: int = 2 # Default window size, configurable
30
 
31
  # Health check endpoint
32
  @app.get("/health")
@@ -37,14 +32,7 @@ async def health_check():
37
  @app.post("/predict")
38
  async def predict(request: InferenceRequest, dataset_name: str = "bridge_dataset"):
39
  try:
40
- # Validate input
41
- if len(request.image_base64) < request.window_size:
42
- raise HTTPException(
43
- status_code=400,
44
- detail=f"At least {request.window_size} images required for the specified window size"
45
- )
46
-
47
- # Process images
48
  images = []
49
  for img_base64 in request.image_base64:
50
  if img_base64.startswith("data:image"):
@@ -54,11 +42,11 @@ async def predict(request: InferenceRequest, dataset_name: str = "bridge_dataset
54
  img = np.array(img)
55
  images.append(img)
56
 
57
- # Stack all images and add batch dimension
58
  img_array = np.stack(images)[np.newaxis, ...] # Shape: (1, T, 256, 256, 3)
59
  observation = {
60
  "image_primary": img_array,
61
- "timestep_pad_mask": np.full((1, len(images)), True, dtype=bool) # Shape: (1, T)
62
  }
63
 
64
  # Create task and predict actions
@@ -69,20 +57,8 @@ async def predict(request: InferenceRequest, dataset_name: str = "bridge_dataset
69
  unnormalization_statistics=model.dataset_statistics[dataset_name]["action"],
70
  rng=jax.random.PRNGKey(0)
71
  )
72
- actions = actions[0] # Remove batch dimension, Shape: (horizon, action_dim)
73
-
74
- # Convert to list for JSON response
75
- actions_list = actions.tolist()
76
 
77
- return {"actions": actions_list}
78
  except Exception as e:
79
- raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
80
-
81
- # Custom Swagger UI route (optional)
82
- @app.get("/docs", include_in_schema=False)
83
- async def custom_swagger_ui_html():
84
- return get_swagger_ui_html(
85
- openapi_url=app.openapi_url,
86
- title=app.title + " - Swagger UI",
87
- oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
88
- )
 
8
  import io
9
  import base64
10
  from typing import List
 
11
 
12
+ # Set JAX to use CPU (adjust to GPU if available)
13
  os.environ['JAX_PLATFORMS'] = 'cpu'
14
 
15
+ # Load Octo 1.5 model globally
16
  model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")
17
 
18
  # Initialize FastAPI app
19
+ app = FastAPI(title="Octo 1.5 Inference API")
 
 
 
20
 
21
+ # Request body model
22
  class InferenceRequest(BaseModel):
23
+ image_base64: List[str] # List of base64-encoded images
24
  task: str = "pick up the fork" # Default task
 
25
 
26
  # Health check endpoint
27
  @app.get("/health")
 
32
  @app.post("/predict")
33
  async def predict(request: InferenceRequest, dataset_name: str = "bridge_dataset"):
34
  try:
35
+ # Decode and process images
 
 
 
 
 
 
 
36
  images = []
37
  for img_base64 in request.image_base64:
38
  if img_base64.startswith("data:image"):
 
42
  img = np.array(img)
43
  images.append(img)
44
 
45
+ # Stack images with batch dimension
46
  img_array = np.stack(images)[np.newaxis, ...] # Shape: (1, T, 256, 256, 3)
47
  observation = {
48
  "image_primary": img_array,
49
+ "timestep_pad_mask": np.ones((1, len(images)), dtype=bool) # Shape: (1, T)
50
  }
51
 
52
  # Create task and predict actions
 
57
  unnormalization_statistics=model.dataset_statistics[dataset_name]["action"],
58
  rng=jax.random.PRNGKey(0)
59
  )
60
+ actions = actions[0] # Remove batch dimension, Shape: (T, action_dim)
 
 
 
61
 
62
+ return {"actions": actions.tolist()}
63
  except Exception as e:
64
+ raise HTTPException(status_code=500, detail=f"Error: {str(e)}")