wfc.hpp
3.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#pragma once
#include <unordered_map>
#include <limits>
#include <cmath>
#include <random>
#include "optional.hpp"
#include "utils/array2D.hpp"
#include "wave.hpp"
#include "propagator.hpp"
using namespace std;
using namespace nonstd;
/**
* Class containing the generic WFC algorithm.
*/
class WFC {
private:
/**
* The random number generator.
*/
minstd_rand gen;
/**
* The wave, indicating which patterns can be put in which cell.
*/
Wave wave;
/**
* The distribution of the patterns as given in input.
*/
const vector<double> patterns_frequencies;
/**
* The number of distinct patterns.
*/
const unsigned nb_patterns;
/**
* The propagator, used to propagate the information in the wave.
*/
Propagator propagator;
/**
* True if the output is periodic.
*/
const bool periodic_output;
/**
* Transform the wave to a valid output (a 2d array of patterns that aren't in contradiction).
* This function should be used only when all cell of the wave are defined.
*/
Array2D<unsigned> wave_to_output() const noexcept {
Array2D<unsigned> output_patterns(wave.height, wave.width);
for(unsigned i = 0; i< wave.size; i++) {
for(unsigned k = 0; k < nb_patterns; k++) {
if(wave.get(i, k)) {
output_patterns.data[i] = k;
}
}
}
return output_patterns;
}
public:
/**
* Basic constructor initializing the algorithm.
*/
WFC(bool periodic_output, int seed, vector<double> patterns_frequencies,
vector<array<vector<unsigned>, 4>> propagator, unsigned wave_height, unsigned wave_width) noexcept
: gen(seed), wave(wave_height, wave_width, patterns_frequencies),
patterns_frequencies(patterns_frequencies), nb_patterns(propagator.size()),
propagator(wave.height, wave.width, periodic_output, propagator),
periodic_output(periodic_output)
{
}
/**
* Run the algorithm, and return a result if it succeeded.
*/
optional<Array2D<unsigned>> run() noexcept {
while(true) {
// Define the value of an undefined cell.
ObserveStatus result = observe();
// Check if the algorithm has terminated.
if(result == failure) {
return nullopt;
} else if(result == success) {
return wave_to_output();
}
// Propagate the information.
propagator.propagate(wave);
}
}
/**
* Return value of observe.
*/
enum ObserveStatus {
success, // WFC has finished and has succeeded.
failure, // WFC has finished and failed.
to_continue // WFC isn't finished.
};
/**
* Define the value of the cell with lowest entropy.
*/
ObserveStatus observe() noexcept {
// Get the cell with lowest entropy.
int argmin = wave.get_min_entropy(gen);
// If there is a contradiction, the algorithm has failed.
if(argmin == -2) {
return failure;
}
// If the lowest entropy is 0, then the algorithm has succeeded and finished.
if(argmin == -1) {
wave_to_output();
return success;
}
// Choose an element according to the pattern distribution
double s = 0;
for(unsigned k = 0; k < nb_patterns; k++) {
s+= wave.get(argmin,k) ? patterns_frequencies[k] : 0;
}
std::uniform_real_distribution<> dis(0,s);
double random_value = dis(gen);
unsigned chosen_value = nb_patterns - 1;
for(unsigned k = 0; k < nb_patterns; k++) {
random_value -= wave.get(argmin,k) ? patterns_frequencies[k] : 0;
if(random_value <= 0) {
chosen_value = k;
break;
}
}
// And define the cell with the pattern.
for(unsigned k = 0; k < nb_patterns; k++) {
if(wave.get(argmin, k) != (k == chosen_value)) {
propagator.add_to_propagator(argmin / wave.width, argmin % wave.width, k);
wave.set(argmin, k, false);
}
}
return to_continue;
}
/**
* Propagate the information of the wave.
*/
void propagate() noexcept {
propagator.propagate(wave);
}
/**
* Remove pattern from cell (i,j).
*/
void remove_wave_pattern(unsigned i, unsigned j, unsigned pattern) noexcept {
if(wave.get(i, j, pattern)) {
wave.set(i, j, pattern, false);
propagator.add_to_propagator(i, j, pattern);
}
}
};