shunk031 commited on
Commit
b3850a5
·
1 Parent(s): d16c9dc

deploy: fc0c10e734116107123b6dce81a6df2cbbf84dfe

Browse files
Files changed (2) hide show
  1. layout-occlusion.py +147 -0
  2. requirements.txt +89 -0
layout-occlusion.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Union
3
+
4
+ import datasets as ds
5
+ import evaluate
6
+ import numpy as np
7
+ import numpy.typing as npt
8
+ from PIL import Image
9
+
10
+ _DESCRIPTION = r"""\
11
+ Computes the average pixel value of areas covered by elements in S.
12
+ """
13
+
14
+ _KWARGS_DESCRIPTION = """\
15
+ FIXME
16
+ """
17
+
18
+ _CITATION = """\
19
+ @inproceedings{hsu2023posterlayout,
20
+ title={Posterlayout: A new benchmark and approach for content-aware visual-textual presentation layout},
21
+ author={Hsu, Hsiao Yuan and He, Xiangteng and Peng, Yuxin and Kong, Hao and Zhang, Qing},
22
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
23
+ pages={6018--6026},
24
+ year={2023}
25
+ }
26
+ """
27
+
28
+
29
+ class LayoutOcculusion(evaluate.Metric):
30
+ def __init__(
31
+ self,
32
+ canvas_width: int,
33
+ canvas_height: int,
34
+ **kwargs,
35
+ ) -> None:
36
+ super().__init__(**kwargs)
37
+ self.canvas_width = canvas_width
38
+ self.canvas_height = canvas_height
39
+
40
+ def _info(self) -> evaluate.EvaluationModuleInfo:
41
+ return evaluate.MetricInfo(
42
+ description=_DESCRIPTION,
43
+ citation=_CITATION,
44
+ inputs_description=_KWARGS_DESCRIPTION,
45
+ features=ds.Features(
46
+ {
47
+ "predictions": ds.Sequence(ds.Sequence(ds.Value("float64"))),
48
+ "gold_labels": ds.Sequence(ds.Sequence(ds.Value("int64"))),
49
+ "saliency_maps_1": ds.Sequence(ds.Value("string")),
50
+ "saliency_maps_2": ds.Sequence(ds.Value("string")),
51
+ }
52
+ ),
53
+ codebase_urls=[
54
+ "https://github.com/PKU-ICST-MIPL/PosterLayout-CVPR2023/blob/main/eval.py#L144-L171"
55
+ ],
56
+ )
57
+
58
+ def load_saliency_map(
59
+ self,
60
+ filepath: Union[os.PathLike, List[os.PathLike]],
61
+ ) -> npt.NDArray[np.float64]:
62
+ if isinstance(filepath, list):
63
+ assert len(filepath) == 1, filepath
64
+ filepath = filepath[0]
65
+
66
+ map_pil = Image.open(filepath) # type: ignore
67
+ map_pil = map_pil.convert("L")
68
+
69
+ if map_pil.size != (self.canvas_width, self.canvas_height):
70
+ map_pil = map_pil.resize((self.canvas_width, self.canvas_height))
71
+
72
+ map_arr = np.array(map_pil)
73
+ map_arr = map_arr / 255.0
74
+ return map_arr
75
+
76
+ def get_rid_of_invalid(
77
+ self, predictions: npt.NDArray[np.float64], gold_labels: npt.NDArray[np.int64]
78
+ ) -> npt.NDArray[np.int64]:
79
+ assert len(predictions) == len(gold_labels)
80
+
81
+ w = self.canvas_width / 100
82
+ h = self.canvas_height / 100
83
+
84
+ for i, prediction in enumerate(predictions):
85
+ for j, b in enumerate(prediction):
86
+ xl, yl, xr, yr = b
87
+ xl = max(0, xl)
88
+ yl = max(0, yl)
89
+ xr = min(self.canvas_width, xr)
90
+ yr = min(self.canvas_height, yr)
91
+ if abs((xr - xl) * (yr - yl)) < w * h * 10:
92
+ if gold_labels[i, j]:
93
+ gold_labels[i, j] = 0
94
+ return gold_labels
95
+
96
+ def _compute(
97
+ self,
98
+ *,
99
+ predictions: Union[npt.NDArray[np.float64], List[List[float]]],
100
+ gold_labels: Union[npt.NDArray[np.int64], List[int]],
101
+ saliency_maps_1: List[os.PathLike],
102
+ saliency_maps_2: List[os.PathLike],
103
+ ) -> float:
104
+ predictions = np.array(predictions)
105
+ gold_labels = np.array(gold_labels)
106
+
107
+ predictions[:, :, ::2] *= self.canvas_width
108
+ predictions[:, :, 1::2] *= self.canvas_height
109
+
110
+ gold_labels = self.get_rid_of_invalid(
111
+ predictions=predictions, gold_labels=gold_labels
112
+ )
113
+
114
+ score = 0.0
115
+
116
+ assert (
117
+ len(predictions)
118
+ == len(gold_labels)
119
+ == len(saliency_maps_1)
120
+ == len(saliency_maps_2)
121
+ )
122
+ num_predictions = len(predictions)
123
+ it = zip(predictions, gold_labels, saliency_maps_1, saliency_maps_2)
124
+
125
+ for prediction, gold_label, smap_1, smap_2 in it:
126
+ smap_arr_1 = self.load_saliency_map(smap_1)
127
+ smap_arr_2 = self.load_saliency_map(smap_2)
128
+
129
+ smap_arr = np.maximum(smap_arr_1, smap_arr_2)
130
+ cal_mask = np.zeros_like(smap_arr)
131
+
132
+ prediction = np.array(prediction, dtype=int)
133
+ gold_label = np.array(gold_label, dtype=int)
134
+
135
+ mask = (gold_label > 0).reshape(-1)
136
+ mask_prediction = prediction[mask]
137
+
138
+ for mp in mask_prediction:
139
+ xl, yl, xr, yr = mp
140
+ cal_mask[yl:yr, xl:xr] = 1
141
+
142
+ total_area = np.sum(cal_mask)
143
+ total_sal = np.sum(smap_arr[cal_mask == 1])
144
+ if total_sal and total_area:
145
+ score += total_sal / total_area
146
+
147
+ return score / num_predictions
requirements.txt ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1 ; python_version >= "3.9" and python_version < "4.0"
2
+ aiohttp==3.9.3 ; python_version >= "3.9" and python_version < "4.0"
3
+ aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "4.0"
4
+ altair==5.2.0 ; python_version >= "3.9" and python_version < "4.0"
5
+ annotated-types==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
6
+ anyio==4.2.0 ; python_version >= "3.9" and python_version < "4.0"
7
+ arrow==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
8
+ async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.11"
9
+ attrs==23.2.0 ; python_version >= "3.9" and python_version < "4.0"
10
+ binaryornot==0.4.4 ; python_version >= "3.9" and python_version < "4.0"
11
+ certifi==2024.2.2 ; python_version >= "3.9" and python_version < "4.0"
12
+ chardet==5.2.0 ; python_version >= "3.9" and python_version < "4.0"
13
+ charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "4.0"
14
+ click==8.1.7 ; python_version >= "3.9" and python_version < "4.0"
15
+ colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0"
16
+ contourpy==1.2.0 ; python_version >= "3.9" and python_version < "4.0"
17
+ cookiecutter==2.5.0 ; python_version >= "3.9" and python_version < "4.0"
18
+ cycler==0.12.1 ; python_version >= "3.9" and python_version < "4.0"
19
+ datasets==2.17.0 ; python_version >= "3.9" and python_version < "4.0"
20
+ dill==0.3.8 ; python_version >= "3.9" and python_version < "4.0"
21
+ evaluate[template]==0.4.1 ; python_version >= "3.9" and python_version < "4.0"
22
+ exceptiongroup==1.2.0 ; python_version >= "3.9" and python_version < "3.11"
23
+ fastapi==0.109.2 ; python_version >= "3.9" and python_version < "4.0"
24
+ ffmpy==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
25
+ filelock==3.13.1 ; python_version >= "3.9" and python_version < "4.0"
26
+ fonttools==4.48.1 ; python_version >= "3.9" and python_version < "4.0"
27
+ frozenlist==1.4.1 ; python_version >= "3.9" and python_version < "4.0"
28
+ fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "4.0"
29
+ fsspec[http]==2023.10.0 ; python_version >= "3.9" and python_version < "4.0"
30
+ gradio-client==0.10.0 ; python_version >= "3.9" and python_version < "4.0"
31
+ gradio==4.18.0 ; python_version >= "3.9" and python_version < "4.0"
32
+ h11==0.14.0 ; python_version >= "3.9" and python_version < "4.0"
33
+ httpcore==1.0.2 ; python_version >= "3.9" and python_version < "4.0"
34
+ httpx==0.26.0 ; python_version >= "3.9" and python_version < "4.0"
35
+ huggingface-hub==0.20.3 ; python_version >= "3.9" and python_version < "4.0"
36
+ idna==3.6 ; python_version >= "3.9" and python_version < "4.0"
37
+ importlib-resources==6.1.1 ; python_version >= "3.9" and python_version < "4.0"
38
+ jinja2==3.1.3 ; python_version >= "3.9" and python_version < "4.0"
39
+ jsonschema-specifications==2023.12.1 ; python_version >= "3.9" and python_version < "4.0"
40
+ jsonschema==4.21.1 ; python_version >= "3.9" and python_version < "4.0"
41
+ kiwisolver==1.4.5 ; python_version >= "3.9" and python_version < "4.0"
42
+ markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "4.0"
43
+ markupsafe==2.1.5 ; python_version >= "3.9" and python_version < "4.0"
44
+ matplotlib==3.8.2 ; python_version >= "3.9" and python_version < "4.0"
45
+ mdurl==0.1.2 ; python_version >= "3.9" and python_version < "4.0"
46
+ multidict==6.0.5 ; python_version >= "3.9" and python_version < "4.0"
47
+ multiprocess==0.70.16 ; python_version >= "3.9" and python_version < "4.0"
48
+ numpy==1.26.4 ; python_version >= "3.9" and python_version < "4.0"
49
+ orjson==3.9.13 ; python_version >= "3.9" and python_version < "4.0"
50
+ packaging==23.2 ; python_version >= "3.9" and python_version < "4.0"
51
+ pandas==2.2.0 ; python_version >= "3.9" and python_version < "4.0"
52
+ pillow==10.2.0 ; python_version >= "3.9" and python_version < "4.0"
53
+ pyarrow-hotfix==0.6 ; python_version >= "3.9" and python_version < "4.0"
54
+ pyarrow==15.0.0 ; python_version >= "3.9" and python_version < "4.0"
55
+ pydantic-core==2.16.2 ; python_version >= "3.9" and python_version < "4.0"
56
+ pydantic==2.6.1 ; python_version >= "3.9" and python_version < "4.0"
57
+ pydub==0.25.1 ; python_version >= "3.9" and python_version < "4.0"
58
+ pygments==2.17.2 ; python_version >= "3.9" and python_version < "4.0"
59
+ pyparsing==3.1.1 ; python_version >= "3.9" and python_version < "4.0"
60
+ python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "4.0"
61
+ python-multipart==0.0.9 ; python_version >= "3.9" and python_version < "4.0"
62
+ python-slugify==8.0.4 ; python_version >= "3.9" and python_version < "4.0"
63
+ pytz==2024.1 ; python_version >= "3.9" and python_version < "4.0"
64
+ pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "4.0"
65
+ referencing==0.33.0 ; python_version >= "3.9" and python_version < "4.0"
66
+ requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
67
+ responses==0.18.0 ; python_version >= "3.9" and python_version < "4.0"
68
+ rich==13.7.0 ; python_version >= "3.9" and python_version < "4.0"
69
+ rpds-py==0.17.1 ; python_version >= "3.9" and python_version < "4.0"
70
+ ruff==0.2.1 ; python_version >= "3.9" and python_version < "4.0"
71
+ semantic-version==2.10.0 ; python_version >= "3.9" and python_version < "4.0"
72
+ shellingham==1.5.4 ; python_version >= "3.9" and python_version < "4.0"
73
+ six==1.16.0 ; python_version >= "3.9" and python_version < "4.0"
74
+ sniffio==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
75
+ starlette==0.36.3 ; python_version >= "3.9" and python_version < "4.0"
76
+ text-unidecode==1.3 ; python_version >= "3.9" and python_version < "4.0"
77
+ tomlkit==0.12.0 ; python_version >= "3.9" and python_version < "4.0"
78
+ toolz==0.12.1 ; python_version >= "3.9" and python_version < "4.0"
79
+ tqdm==4.66.2 ; python_version >= "3.9" and python_version < "4.0"
80
+ typer[all]==0.9.0 ; python_version >= "3.9" and python_version < "4.0"
81
+ types-python-dateutil==2.8.19.20240106 ; python_version >= "3.9" and python_version < "4.0"
82
+ typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "4.0"
83
+ tzdata==2024.1 ; python_version >= "3.9" and python_version < "4.0"
84
+ urllib3==2.2.0 ; python_version >= "3.9" and python_version < "4.0"
85
+ uvicorn==0.27.1 ; python_version >= "3.9" and python_version < "4.0"
86
+ websockets==11.0.3 ; python_version >= "3.9" and python_version < "4.0"
87
+ xxhash==3.4.1 ; python_version >= "3.9" and python_version < "4.0"
88
+ yarl==1.9.4 ; python_version >= "3.9" and python_version < "4.0"
89
+ zipp==3.17.0 ; python_version >= "3.9" and python_version < "3.10"