overlapping_wfc.hpp
9.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
#pragma once
#include "wfc.hpp"
#include <iostream>
/**
* Options needed to use the overlapping wfc.
*/
struct OverlappingWFCOptions {
bool periodic_input; // True if the input is toric.
bool periodic_output; // True if the output is toric.
unsigned out_height; // The height of the output in pixels.
unsigned out_width; // The width of the output in pixels.
unsigned symmetry; // The number of symmetries (the order is defined in wfc).
bool ground; // True if the ground needs to be set (see init_ground).
unsigned pattern_size; // The width and height in pixel of the patterns.
/**
* Get the wave height given these options.
*/
unsigned get_wave_height() const noexcept {
return periodic_output ? out_height : out_height - pattern_size + 1;
}
/**
* Get the wave width given these options.
*/
unsigned get_wave_width() const noexcept {
return periodic_output ? out_width : out_width - pattern_size + 1;
}
};
/**
* Class generating a new image with the overlapping WFC algorithm.
*/
template<typename T>
class OverlappingWFC {
private:
/**
* The input image. T is usually a color.
*/
Array2D<T> input;
/**
* Options needed by the algorithm.
*/
OverlappingWFCOptions options;
/**
* The array of the different patterns extracted from the input.
*/
vector<Array2D<T>> patterns;
/**
* The underlying generic WFC algorithm.
*/
WFC wfc;
/**
* Constructor initializing the wfc.
* This constructor is called by the other constructors.
* This is necessary in order to initialize wfc only once.
*/
OverlappingWFC(const Array2D<T>& input, const OverlappingWFCOptions& options, const int& seed,
const pair<vector<Array2D<T>>, vector<double>>& patterns,
const vector<array<vector<unsigned>, 4>>& propagator) noexcept :
input(input),
options(options),
patterns(patterns.first),
wfc(options.periodic_output, seed, patterns.second, propagator,
options.get_wave_height(), options.get_wave_width())
{
// If necessary, the ground is set.
if(options.ground) {
init_ground(wfc, input, patterns.first, options);
}
}
/**
* Constructor used only to call the other constructor with more computed parameters.
*/
OverlappingWFC(const Array2D<T>& input, const OverlappingWFCOptions& options, const int& seed,
const pair<vector<Array2D<T>>, vector<double>>& patterns) noexcept :
OverlappingWFC(input, options, seed, patterns, generate_compatible(patterns.first))
{}
/**
* Init the ground of the output image.
* The lowest middle pattern is used as a floor (and ceiling when the input is toric)
* and is placed at the lowest possible pattern position in the output image, on all its width.
* The pattern cannot be used at any other place in the output image.
*/
static void init_ground(WFC& wfc, const Array2D<T>& input, const vector<Array2D<T>>& patterns, const OverlappingWFCOptions& options) noexcept {
unsigned ground_pattern_id = get_ground_pattern_id(input, patterns, options);
// Place the pattern in the ground.
for(unsigned j = 0; j < options.get_wave_width(); j++) {
for(unsigned p = 0; p < patterns.size(); p++) {
if(ground_pattern_id != p) {
wfc.remove_wave_pattern(options.get_wave_height() - 1, j, p);
}
}
}
// Remove the pattern from the other positions.
for(unsigned i = 0; i < options.get_wave_height() - 1; i++) {
for(unsigned j = 0; j < options.get_wave_width(); j++) {
wfc.remove_wave_pattern(i, j, ground_pattern_id);
}
}
// Propagate the information with wfc.
wfc.propagate();
}
/**
* Return the id of the lowest middle pattern.
*/
static unsigned get_ground_pattern_id(const Array2D<T>& input, const vector<Array2D<T>>& patterns, const OverlappingWFCOptions& options) noexcept {
// Get the pattern.
Array2D<T> ground_pattern = input.get_sub_array(input.height - 1, input.width / 2, options.pattern_size, options.pattern_size);
// Retrieve the id of the pattern.
for(unsigned i = 0; i < patterns.size(); i++) {
if(ground_pattern == patterns[i]) {
return i;
}
}
// The pattern exists.
assert(false);
return 0;
}
/**
* Return the list of patterns, as well as their probabilities of apparition.
*/
static pair<vector<Array2D<T>>, vector<double>> get_patterns(const Array2D<T>& input, const OverlappingWFCOptions& options) noexcept {
unordered_map<Array2D<T>, unsigned> patterns_id;
vector<Array2D<T>> patterns;
// The number of time a pattern is seen in the input image.
vector<double> patterns_frequency;
vector<Array2D<T>> symmetries(8, Array2D<T>(options.pattern_size, options.pattern_size));
unsigned max_i = options.periodic_input ? input.height : input.height - options.pattern_size + 1;
unsigned max_j = options.periodic_input ? input.width : input.width - options.pattern_size + 1;
for(unsigned i = 0; i < max_i; i++) {
for(unsigned j = 0; j < max_j; j++) {
// Compute the symmetries of every pattern in the image.
symmetries[0].data = input.get_sub_array(i, j, options.pattern_size, options.pattern_size).data;
symmetries[1].data = symmetries[0].reflected().data;
symmetries[2].data = symmetries[0].rotated().data;
symmetries[3].data = symmetries[2].reflected().data;
symmetries[4].data = symmetries[2].rotated().data;
symmetries[5].data = symmetries[4].reflected().data;
symmetries[6].data = symmetries[4].rotated().data;
symmetries[7].data = symmetries[6].reflected().data;
// The number of symmetries in the option class define which symetries will be used.
for(unsigned k = 0; k<options.symmetry; k++) {
auto res = patterns_id.insert(make_pair(symmetries[k],patterns.size()));
// If the pattern already exist, we just have to increase its number of appearance.
if(!res.second) {
patterns_frequency[res.first->second] += 1;
} else {
patterns.push_back(symmetries[k]);
patterns_frequency.push_back(1);
}
}
}
}
return {patterns, patterns_frequency};
}
/**
* Return true if the pattern1 is compatible with pattern2
* when pattern2 is at a distance (dy,dx) from pattern1.
*/
static bool agrees(const Array2D<T>& pattern1, const Array2D<T>& pattern2, int dy, int dx) noexcept {
unsigned xmin = dx < 0 ? 0 : dx;
unsigned xmax = dx < 0 ? dx + pattern2.width : pattern1.width;
unsigned ymin = dy < 0 ? 0 : dy;
unsigned ymax = dy < 0 ? dy + pattern2.height : pattern1.width;
// Iterate on every pixel contained in the intersection of the two pattern.
for(unsigned y = ymin; y < ymax; y++) {
for(unsigned x = xmin; x < xmax; x++) {
// Check if the color is the same in the two patterns in that pixel.
if(pattern1.get(y,x) != pattern2.get(y-dy,x-dx)) {
return false;
}
}
}
return true;
}
/**
* Precompute the function agrees(pattern1, pattern2, dy, dx).
* If agrees(pattern1, pattern2, dy, dx), then compatible[pattern1][direction] contains pattern2,
* where direction is the direction defined by (dy, dx) (see direction.hpp).
*/
static vector<array<vector<unsigned>, 4>> generate_compatible(const vector<Array2D<T>>& patterns) noexcept {
vector<array<vector<unsigned>, 4>> compatible = vector<array<vector<unsigned>, 4>>(patterns.size());
// Iterate on every dy, dx, pattern1 and pattern2
for(unsigned pattern1 = 0; pattern1 < patterns.size(); pattern1++) {
for(unsigned direction = 0; direction < 4; direction++) {
for(unsigned pattern2 = 0; pattern2 < patterns.size(); pattern2++) {
if(agrees(patterns[pattern1], patterns[pattern2], directions_y[direction], directions_x[direction])) {
compatible[pattern1][direction].push_back(pattern2);
}
}
}
}
return compatible;
}
/**
* Transform a 2D array containing the patterns id to a 2D array containing the pixels.
*/
Array2D<T> to_image(const Array2D<unsigned>& output_patterns) const noexcept {
Array2D<T> output = Array2D<T>(options.out_height, options.out_width);
if(options.periodic_output) {
for(unsigned y = 0; y < options.get_wave_height(); y++) {
for(unsigned x = 0; x < options.get_wave_width(); x++) {
output.get(y,x) = patterns[output_patterns.get(y,x)].get(0,0);
}
}
} else {
for(unsigned y = 0; y < options.get_wave_height(); y++) {
for(unsigned x = 0; x < options.get_wave_width(); x++) {
output.get(y, x) = patterns[output_patterns.get(y,x)].get(0,0);
}
}
for(unsigned y = 0; y < options.get_wave_height(); y++) {
const Array2D<T>& pattern = patterns[output_patterns.get(y, options.get_wave_width()-1)];
for(unsigned dx = 1; dx < options.pattern_size; dx++) {
output.get(y, options.get_wave_width() - 1 + dx) = pattern.get(0, dx);
}
}
for(unsigned x = 0; x < options.get_wave_width(); x++) {
const Array2D<T>& pattern = patterns[output_patterns.get(options.get_wave_height() - 1, x)];
for(unsigned dy = 1; dy < options.pattern_size; dy++) {
output.get(options.get_wave_height() - 1 + dy, x) = pattern.get(dy, 0);
}
}
const Array2D<T>& pattern = patterns[output_patterns.get(options.get_wave_height() - 1, options.get_wave_width() - 1)];
for(unsigned dy = 1; dy < options.pattern_size; dy++) {
for(unsigned dx = 1; dx < options.pattern_size; dx++) {
output.get(options.get_wave_height() - 1 + dy, options.get_wave_width() - 1 + dx) = pattern.get(dy, dx);
}
}
}
return output;
}
public:
/**
* The constructor used by the user.
*/
OverlappingWFC(const Array2D<T>& input, const OverlappingWFCOptions& options, int seed) noexcept :
OverlappingWFC(input, options, seed, get_patterns(input, options))
{}
/**
* Run the WFC algorithm, and return the result if the algorithm succeeded.
*/
nonstd::optional<Array2D<T>> run() noexcept {
nonstd::optional<Array2D<unsigned>> result = wfc.run();
if(result.has_value()) {
return to_image(*result);
}
return nullopt;
}
};