Patrick WAN
commited on
Commit
·
918d1df
1
Parent(s):
52933b5
space
Browse files- app.py +135 -0
- envs/tsp_vector_env.py +9 -1
- requirements.txt +5 -0
app.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import gym
|
4 |
+
from models.attention_model_wrapper import Agent
|
5 |
+
device = 'cpu'
|
6 |
+
ckpt_path = './runs/tsp-v0__ppo_or__1__1678160003/ckpt/12000.pt'
|
7 |
+
agent = Agent(device=device, name='tsp').to(device)
|
8 |
+
agent.load_state_dict(torch.load(ckpt_path))
|
9 |
+
|
10 |
+
from wrappers.syncVectorEnvPomo import SyncVectorEnv
|
11 |
+
from wrappers.recordWrapper import RecordEpisodeStatistics
|
12 |
+
|
13 |
+
env_id = 'tsp-v0'
|
14 |
+
env_entry_point = 'envs.tsp_vector_env:TSPVectorEnv'
|
15 |
+
seed = 0
|
16 |
+
|
17 |
+
gym.envs.register(
|
18 |
+
id=env_id,
|
19 |
+
entry_point=env_entry_point,
|
20 |
+
)
|
21 |
+
|
22 |
+
def make_env(env_id, seed, cfg={}):
|
23 |
+
def thunk():
|
24 |
+
env = gym.make(env_id, **cfg)
|
25 |
+
env = RecordEpisodeStatistics(env)
|
26 |
+
env.seed(seed)
|
27 |
+
env.action_space.seed(seed)
|
28 |
+
env.observation_space.seed(seed)
|
29 |
+
return env
|
30 |
+
return thunk
|
31 |
+
|
32 |
+
|
33 |
+
def inference(data):
|
34 |
+
envs = SyncVectorEnv([make_env(env_id, seed, dict(n_traj=1,
|
35 |
+
max_nodes = len(data),
|
36 |
+
eval_data = 'from_input',
|
37 |
+
eval_data_from_input = data))])
|
38 |
+
|
39 |
+
trajectories = []
|
40 |
+
agent.eval()
|
41 |
+
obs = envs.reset()
|
42 |
+
done = np.array([False])
|
43 |
+
while not done.all():
|
44 |
+
# ALGO LOGIC: action logic
|
45 |
+
with torch.no_grad():
|
46 |
+
action, logits = agent(obs)
|
47 |
+
obs, reward, done, info = envs.step(action.cpu().numpy())
|
48 |
+
trajectories.append(action.cpu().numpy())
|
49 |
+
nodes_coordinates = obs['observations'][0]
|
50 |
+
final_return = info[0]['episode']['r']
|
51 |
+
resulting_traj = np.array(trajectories)[:,0,0]
|
52 |
+
return resulting_traj, final_return
|
53 |
+
|
54 |
+
default_data = np.array([[0.5488135 , 0.71518937],
|
55 |
+
[0.60276338, 0.54488318],
|
56 |
+
[0.4236548 , 0.64589411],
|
57 |
+
[0.43758721, 0.891773 ],
|
58 |
+
[0.96366276, 0.38344152],
|
59 |
+
[0.79172504, 0.52889492],
|
60 |
+
[0.56804456, 0.92559664],
|
61 |
+
[0.07103606, 0.0871293 ],
|
62 |
+
[0.0202184 , 0.83261985],
|
63 |
+
[0.77815675, 0.87001215],])
|
64 |
+
|
65 |
+
#@title Helper function for plotting
|
66 |
+
# colorline taken from https://nbviewer.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb
|
67 |
+
import matplotlib.pyplot as plt
|
68 |
+
from matplotlib.collections import LineCollection
|
69 |
+
from matplotlib.colors import ListedColormap, BoundaryNorm
|
70 |
+
|
71 |
+
def make_segments(x, y):
|
72 |
+
'''
|
73 |
+
Create list of line segments from x and y coordinates, in the correct format for LineCollection:
|
74 |
+
an array of the form numlines x (points per line) x 2 (x and y) array
|
75 |
+
'''
|
76 |
+
|
77 |
+
points = np.array([x, y]).T.reshape(-1, 1, 2)
|
78 |
+
segments = np.concatenate([points[:-1], points[1:]], axis=1)
|
79 |
+
|
80 |
+
return segments
|
81 |
+
|
82 |
+
def colorline(x, y, z=None, cmap=plt.get_cmap('copper'), norm=plt.Normalize(0.0, 1.0), linewidth=1, alpha=1.0):
|
83 |
+
'''
|
84 |
+
Plot a colored line with coordinates x and y
|
85 |
+
Optionally specify colors in the array z
|
86 |
+
Optionally specify a colormap, a norm function and a line width
|
87 |
+
'''
|
88 |
+
|
89 |
+
# Default colors equally spaced on [0,1]:
|
90 |
+
if z is None:
|
91 |
+
z = np.linspace(0.3, 1.0, len(x))
|
92 |
+
|
93 |
+
# Special case if a single number:
|
94 |
+
if not hasattr(z, "__iter__"): # to check for numerical input -- this is a hack
|
95 |
+
z = np.array([z])
|
96 |
+
|
97 |
+
z = np.asarray(z)
|
98 |
+
|
99 |
+
segments = make_segments(x, y)
|
100 |
+
lc = LineCollection(segments, array=z, cmap=cmap, norm=norm, linewidth=linewidth, alpha=alpha)
|
101 |
+
|
102 |
+
ax = plt.gca()
|
103 |
+
ax.add_collection(lc)
|
104 |
+
|
105 |
+
return lc
|
106 |
+
|
107 |
+
def plot(coords):
|
108 |
+
fig = plt.figure()
|
109 |
+
x,y = coords.T
|
110 |
+
lc = colorline(x,y,cmap='Reds')
|
111 |
+
plt.axis('square')
|
112 |
+
return fig
|
113 |
+
|
114 |
+
import gradio as gr
|
115 |
+
|
116 |
+
def run_inference(data):
|
117 |
+
data = data.astype(float).to_numpy()
|
118 |
+
resulting_traj, final_return = inference(data)
|
119 |
+
result_text = f'Planned Tour:\t{resulting_traj}\nTotal tour length:\t{final_return[0]:.2f}'
|
120 |
+
return [plot(data[resulting_traj]),result_text]
|
121 |
+
|
122 |
+
demo = gr.Interface(run_inference, gr.Dataframe(
|
123 |
+
label = 'Input',
|
124 |
+
headers=['x','y'],
|
125 |
+
row_count=10,
|
126 |
+
col_count=(2, "fixed"),
|
127 |
+
max_rows = 10,
|
128 |
+
value = default_data.tolist(),
|
129 |
+
overflow_row_behaviour = 'show_ends'
|
130 |
+
),
|
131 |
+
[gr.Plot(label= 'Results Visualization'),
|
132 |
+
gr.Code(label= 'Results',
|
133 |
+
interactive=False)])
|
134 |
+
demo.launch(share = True)
|
135 |
+
|
envs/tsp_vector_env.py
CHANGED
@@ -25,6 +25,9 @@ class TSPVectorEnv(gym.Env):
|
|
25 |
self.eval_data = False
|
26 |
self.eval_partition = "test"
|
27 |
self.eval_data_idx = 0
|
|
|
|
|
|
|
28 |
assign_env_config(self, kwargs)
|
29 |
|
30 |
obs_dict = {"observations": spaces.Box(low=0, high=1, shape=(self.max_nodes, 2))}
|
@@ -50,7 +53,9 @@ class TSPVectorEnv(gym.Env):
|
|
50 |
self.last = np.zeros(self.n_traj, dtype=int) # idx of the first elem
|
51 |
self.first = np.zeros(self.n_traj, dtype=int) # idx of the first elem
|
52 |
|
53 |
-
if self.eval_data:
|
|
|
|
|
54 |
self._load_orders()
|
55 |
else:
|
56 |
self._generate_orders()
|
@@ -59,6 +64,9 @@ class TSPVectorEnv(gym.Env):
|
|
59 |
self.done = False
|
60 |
return self.state
|
61 |
|
|
|
|
|
|
|
62 |
def _load_orders(self):
|
63 |
self.nodes = np.array(TSPDataset[self.eval_partition, self.max_nodes, self.eval_data_idx])
|
64 |
|
|
|
25 |
self.eval_data = False
|
26 |
self.eval_partition = "test"
|
27 |
self.eval_data_idx = 0
|
28 |
+
|
29 |
+
self.eval_data_from_input = None
|
30 |
+
|
31 |
assign_env_config(self, kwargs)
|
32 |
|
33 |
obs_dict = {"observations": spaces.Box(low=0, high=1, shape=(self.max_nodes, 2))}
|
|
|
53 |
self.last = np.zeros(self.n_traj, dtype=int) # idx of the first elem
|
54 |
self.first = np.zeros(self.n_traj, dtype=int) # idx of the first elem
|
55 |
|
56 |
+
if self.eval_data == 'from_input':
|
57 |
+
self._load_orders_from_input()
|
58 |
+
elif self.eval_data:
|
59 |
self._load_orders()
|
60 |
else:
|
61 |
self._generate_orders()
|
|
|
64 |
self.done = False
|
65 |
return self.state
|
66 |
|
67 |
+
def _load_orders_from_input(self):
|
68 |
+
self.nodes = self.eval_data_from_input.copy()
|
69 |
+
|
70 |
def _load_orders(self):
|
71 |
self.nodes = np.array(TSPDataset[self.eval_partition, self.max_nodes, self.eval_data_idx])
|
72 |
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
matplotlib
|
2 |
+
torch==1.13.0
|
3 |
+
torchvision==0.14.0
|
4 |
+
gym==0.23.1
|
5 |
+
numpy==1.23.4
|