Patrick WAN commited on
Commit
918d1df
·
1 Parent(s): 52933b5
Files changed (3) hide show
  1. app.py +135 -0
  2. envs/tsp_vector_env.py +9 -1
  3. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import gym
4
+ from models.attention_model_wrapper import Agent
5
+ device = 'cpu'
6
+ ckpt_path = './runs/tsp-v0__ppo_or__1__1678160003/ckpt/12000.pt'
7
+ agent = Agent(device=device, name='tsp').to(device)
8
+ agent.load_state_dict(torch.load(ckpt_path))
9
+
10
+ from wrappers.syncVectorEnvPomo import SyncVectorEnv
11
+ from wrappers.recordWrapper import RecordEpisodeStatistics
12
+
13
+ env_id = 'tsp-v0'
14
+ env_entry_point = 'envs.tsp_vector_env:TSPVectorEnv'
15
+ seed = 0
16
+
17
+ gym.envs.register(
18
+ id=env_id,
19
+ entry_point=env_entry_point,
20
+ )
21
+
22
+ def make_env(env_id, seed, cfg={}):
23
+ def thunk():
24
+ env = gym.make(env_id, **cfg)
25
+ env = RecordEpisodeStatistics(env)
26
+ env.seed(seed)
27
+ env.action_space.seed(seed)
28
+ env.observation_space.seed(seed)
29
+ return env
30
+ return thunk
31
+
32
+
33
+ def inference(data):
34
+ envs = SyncVectorEnv([make_env(env_id, seed, dict(n_traj=1,
35
+ max_nodes = len(data),
36
+ eval_data = 'from_input',
37
+ eval_data_from_input = data))])
38
+
39
+ trajectories = []
40
+ agent.eval()
41
+ obs = envs.reset()
42
+ done = np.array([False])
43
+ while not done.all():
44
+ # ALGO LOGIC: action logic
45
+ with torch.no_grad():
46
+ action, logits = agent(obs)
47
+ obs, reward, done, info = envs.step(action.cpu().numpy())
48
+ trajectories.append(action.cpu().numpy())
49
+ nodes_coordinates = obs['observations'][0]
50
+ final_return = info[0]['episode']['r']
51
+ resulting_traj = np.array(trajectories)[:,0,0]
52
+ return resulting_traj, final_return
53
+
54
+ default_data = np.array([[0.5488135 , 0.71518937],
55
+ [0.60276338, 0.54488318],
56
+ [0.4236548 , 0.64589411],
57
+ [0.43758721, 0.891773 ],
58
+ [0.96366276, 0.38344152],
59
+ [0.79172504, 0.52889492],
60
+ [0.56804456, 0.92559664],
61
+ [0.07103606, 0.0871293 ],
62
+ [0.0202184 , 0.83261985],
63
+ [0.77815675, 0.87001215],])
64
+
65
+ #@title Helper function for plotting
66
+ # colorline taken from https://nbviewer.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb
67
+ import matplotlib.pyplot as plt
68
+ from matplotlib.collections import LineCollection
69
+ from matplotlib.colors import ListedColormap, BoundaryNorm
70
+
71
+ def make_segments(x, y):
72
+ '''
73
+ Create list of line segments from x and y coordinates, in the correct format for LineCollection:
74
+ an array of the form numlines x (points per line) x 2 (x and y) array
75
+ '''
76
+
77
+ points = np.array([x, y]).T.reshape(-1, 1, 2)
78
+ segments = np.concatenate([points[:-1], points[1:]], axis=1)
79
+
80
+ return segments
81
+
82
+ def colorline(x, y, z=None, cmap=plt.get_cmap('copper'), norm=plt.Normalize(0.0, 1.0), linewidth=1, alpha=1.0):
83
+ '''
84
+ Plot a colored line with coordinates x and y
85
+ Optionally specify colors in the array z
86
+ Optionally specify a colormap, a norm function and a line width
87
+ '''
88
+
89
+ # Default colors equally spaced on [0,1]:
90
+ if z is None:
91
+ z = np.linspace(0.3, 1.0, len(x))
92
+
93
+ # Special case if a single number:
94
+ if not hasattr(z, "__iter__"): # to check for numerical input -- this is a hack
95
+ z = np.array([z])
96
+
97
+ z = np.asarray(z)
98
+
99
+ segments = make_segments(x, y)
100
+ lc = LineCollection(segments, array=z, cmap=cmap, norm=norm, linewidth=linewidth, alpha=alpha)
101
+
102
+ ax = plt.gca()
103
+ ax.add_collection(lc)
104
+
105
+ return lc
106
+
107
+ def plot(coords):
108
+ fig = plt.figure()
109
+ x,y = coords.T
110
+ lc = colorline(x,y,cmap='Reds')
111
+ plt.axis('square')
112
+ return fig
113
+
114
+ import gradio as gr
115
+
116
+ def run_inference(data):
117
+ data = data.astype(float).to_numpy()
118
+ resulting_traj, final_return = inference(data)
119
+ result_text = f'Planned Tour:\t{resulting_traj}\nTotal tour length:\t{final_return[0]:.2f}'
120
+ return [plot(data[resulting_traj]),result_text]
121
+
122
+ demo = gr.Interface(run_inference, gr.Dataframe(
123
+ label = 'Input',
124
+ headers=['x','y'],
125
+ row_count=10,
126
+ col_count=(2, "fixed"),
127
+ max_rows = 10,
128
+ value = default_data.tolist(),
129
+ overflow_row_behaviour = 'show_ends'
130
+ ),
131
+ [gr.Plot(label= 'Results Visualization'),
132
+ gr.Code(label= 'Results',
133
+ interactive=False)])
134
+ demo.launch(share = True)
135
+
envs/tsp_vector_env.py CHANGED
@@ -25,6 +25,9 @@ class TSPVectorEnv(gym.Env):
25
  self.eval_data = False
26
  self.eval_partition = "test"
27
  self.eval_data_idx = 0
 
 
 
28
  assign_env_config(self, kwargs)
29
 
30
  obs_dict = {"observations": spaces.Box(low=0, high=1, shape=(self.max_nodes, 2))}
@@ -50,7 +53,9 @@ class TSPVectorEnv(gym.Env):
50
  self.last = np.zeros(self.n_traj, dtype=int) # idx of the first elem
51
  self.first = np.zeros(self.n_traj, dtype=int) # idx of the first elem
52
 
53
- if self.eval_data:
 
 
54
  self._load_orders()
55
  else:
56
  self._generate_orders()
@@ -59,6 +64,9 @@ class TSPVectorEnv(gym.Env):
59
  self.done = False
60
  return self.state
61
 
 
 
 
62
  def _load_orders(self):
63
  self.nodes = np.array(TSPDataset[self.eval_partition, self.max_nodes, self.eval_data_idx])
64
 
 
25
  self.eval_data = False
26
  self.eval_partition = "test"
27
  self.eval_data_idx = 0
28
+
29
+ self.eval_data_from_input = None
30
+
31
  assign_env_config(self, kwargs)
32
 
33
  obs_dict = {"observations": spaces.Box(low=0, high=1, shape=(self.max_nodes, 2))}
 
53
  self.last = np.zeros(self.n_traj, dtype=int) # idx of the first elem
54
  self.first = np.zeros(self.n_traj, dtype=int) # idx of the first elem
55
 
56
+ if self.eval_data == 'from_input':
57
+ self._load_orders_from_input()
58
+ elif self.eval_data:
59
  self._load_orders()
60
  else:
61
  self._generate_orders()
 
64
  self.done = False
65
  return self.state
66
 
67
+ def _load_orders_from_input(self):
68
+ self.nodes = self.eval_data_from_input.copy()
69
+
70
  def _load_orders(self):
71
  self.nodes = np.array(TSPDataset[self.eval_partition, self.max_nodes, self.eval_data_idx])
72
 
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ matplotlib
2
+ torch==1.13.0
3
+ torchvision==0.14.0
4
+ gym==0.23.1
5
+ numpy==1.23.4