Infrared
Loading...
Searching...
No Matches
cluster_tree.hpp
1#ifndef INFRARED_CLUSTER_TREE_HPP
2#define INFRARED_CLUSTER_TREE_HPP
3
4/*
5 * InfraRed --- A generic engine for Boltzmann sampling over constraint networks
6 * (C) Sebastian Will, 2018
7 *
8 * This file is part of the InfraRed source code.
9 *
10 * InfraRed provides a generic framework for tree decomposition-based
11 * Boltzmann sampling over constraint networks
12 */
13
20#include <set>
21#include <optional>
22
23#include "graph.hpp"
24#include "feature_network.hpp"
25
26namespace ired {
54 template<class FunValue=double, class EvaluationPolicy=StdEvaluationPolicy<FunValue>>
56
57 public:
59
62
69
71
75 : cluster() {}
77 : cluster(cluster) {}
79 };
80
82 struct edge_info_t {
84 : message(message) {}
86 };
87
89
92
97 }
98
105 explicit
107 : fn_( domains ) {
108 }
109
113 explicit
114 ClusterTree(std::vector<int> domsizes)
115 : fn_( domains_from_domsizes( domsizes ) ) {
116 }
117
124 ClusterTree(int num_vars, const FiniteDomain &domain)
125 : fn_(num_vars, domain) {
126 }
127
134 ClusterTree(int num_vars, int domsize)
135 : fn_( num_vars, FiniteDomain(0, domsize-1) ) {
136 }
137
139
141 const auto & feature_network() const {
142 return fn_;
143 }
144
155 auto
156 add_root_cluster(const std::vector<var_idx_t> &vars) {
157 return tree_.add_vertex(cluster_t(vars));
158 }
159
171 auto
172 add_child_cluster( vertex_descriptor_t parent, const std::vector<int> &vars) {
173 auto child = tree_.add_vertex(cluster_t(vars));
174 tree_.add_edge(parent, child, edge_info_t());
175 return child;
176 }
177
186 void
187 add_constraint( vertex_descriptor_t node, const std::shared_ptr<constraint_t> &x ) {
188 tree_[node].cluster.add_constraint( fn_.add_constraint(x) );
189 }
190
199 void
200 add_function( vertex_descriptor_t node, const std::shared_ptr<function_t> &x ) {
201 tree_[node].cluster.add_function( fn_.add_function(x) );
202 }
203
214 auto
215 evaluate();
216
227 return evaluate() != evaluation_policy_t::zero();
228 }
229
253 auto traceback();
254
266 auto restricted_traceback(const std::set<var_idx_t> &variables,
267 const assignment_t &assignment);
268
269 private:
271 tree_t tree_;
272
273 bool evaluated_ = false;
274 fun_value_t evaluation_result_;
275 bool single_empty_rooted_ = false;
276
278
279 std::optional<std::set<var_idx_t>> trace_variables;
280
281 // insert pseudo root to connect trees of the forest (unless already done)
282 // @returns new root
283 auto
284 single_empty_root();
285
286 auto domains_from_domsizes( std::vector<int> &domsizes ) {
287 auto domains = FiniteDomainVector();
288 for (auto x: domsizes) {
289 domains.push_back( FiniteDomain( 0, x - 1 ) );
290 }
291 return domains;
292 }
293
294 void
295 dfs_evaluate(vertex_descriptor_t v) {
296 // notes: this terminates if there are no children
297 // as well, we can know, that tree_ is a tree, without any
298 // cycles
299
300 for(auto &e: tree_.adj_edges(v)) {
301 // recursively evaluate subtree at e.target()
302 dfs_evaluate(e.target());
303
304 // then, compute the message for the current edge
305 // from cluster child to cluster parent
306
307 const auto &parent = tree_[ e.source() ].cluster;
308 const auto &child = tree_[ e.target() ].cluster;
309
310 auto sep = child.sep_vars(parent);
311 auto diff = child.diff_vars(parent);
312
313 // concat sep + diff
314 auto sep_diff = sep;
315 sep_diff.insert(sep_diff.end(),diff.begin(),diff.end());
316
317 auto message = std::make_unique<message_t>(sep, fn_);
318
319 auto a = assignment_t(fn_.domains());
320
321 auto it = a.make_iterator
322 (sep_diff,
323 fn_,
324 child.constraints(),
325 child.functions(),
326 //evaluate 0-ary functions
327 a.eval_determined(child.functions(), evaluation_policy_t())
328 );
329
330 fun_value_t x = evaluation_policy_t::zero();
331
332 it.register_finish_stage2_hook
333 (sep.size(),
334 [&message,&a,&x] () {
335 message->set(a, x);
336 x = evaluation_policy_t::zero();
337 } );
338
339 for(; ! it.finished() ; ++it ) {
340 x = evaluation_policy_t::plus( x, it.value() );
341 }
342
343 // register message in cn, such that it persists!
344 auto msg = fn_.add_function(std::move(message));
345 // then, register in cluster parent
346 tree_[ e.source() ].cluster.add_function(msg);
347
348 // ... and as edge property
349 e.message = msg;
350 }
351 }
352
360 void
361 dfs_traceback(vertex_descriptor_t v, assignment_t &a) {
362
363 for(const auto &e: tree_.adj_edges(v)) {
364
365 const auto &parent = tree_[ e.source() ].cluster;
366 const auto &child = tree_[ e.target() ].cluster;
367
368 auto diff = child.diff_vars(parent);
369 bool diff_changed = false;
370
371 if (trace_variables) { // if the trace is restricted,
372 // then remove other vars from diff
373 auto new_end = remove_if(diff.begin(), diff.end(),
374 [&](var_idx_t i) {
375 return trace_variables->find(i) == trace_variables->end();
376 });
377 diff_changed = new_end != diff.end();
378 diff.erase(new_end, diff.end());
379 }
380
381 if (! diff.empty()) {
382
383 a.set_undet(diff);
384
385 assert(a.eval_determined(child.constraints(),
386 StdEvaluationPolicy<bool>()));
387
388 auto it = a.make_iterator(diff, fn_,
389 child.constraints(),
390 child.functions(),
391 a.eval_determined(child.functions(),
393 );
394
395 // -----
396 // determine value for selector
397 auto value = evaluation_policy_t::zero();
398 if (diff_changed) {
399 // recompute value for selector
400 for( ; ! it.finished(); ++it ) {
401 value = evaluation_policy_t::plus( value, it.value() );
402 }
403 it.reset();
404 } else {
405 // get value by evaluating message:
406 // evaluate message at partial assignment a;
407 value = (*e.message)(a);
408 }
409 // initialize the Selector
410 auto selector = typename evaluation_policy_t::selector(value);
411
412 // enumerate all combinations of values assigned to diff variables;
413 // recompute SUM_a(PROD_f(f(a))), where a runs over these
414 // combinations
415 auto x = evaluation_policy_t::zero();
416 for( ; ! it.finished(); ++it ) {
417 x = evaluation_policy_t::plus( x, it.value() );
418
419 if (selector.select(x)) {
420 break;
421 }
422 }
423 }
424 // trace back from target
425 dfs_traceback(e.target(), a);
426 }
427 }
428
429 }; // end class ClusterTree
430
431 // return single empty cluster that roots the tree; if such a
432 // cluster exists or was generated before, simply return it;
433 // otherwise, construct a new empty cluster, connect it as parent
434 // to all existing roots, and return it.
435 template<class FunValue, class EvaluationPolicy>
436 auto
437 ClusterTree<FunValue,EvaluationPolicy>::single_empty_root() {
438 if (single_empty_rooted_) {return root_;}
439
440 // find all root nodes of the tree
441
442 std::set<vertex_descriptor_t> old_roots;
443
444 for (int i=0; i < int(tree_.size()); ++i) {
445 old_roots.insert(i);
446 }
447
448 for (int i=0; i < int(tree_.size()); ++i) {
449 for (auto &e: tree_.adj_edges(i)) {
450 old_roots.erase( e.target() );
451 }
452 }
453
454 if ( old_roots.size()==1 && tree_[ *old_roots.begin() ].cluster.empty() ) {
455 root_ = *old_roots.begin();
456 } else {
457 // insert new root and point to all old roots
458 root_ = tree_.add_vertex();
459
460 for (auto old_root : old_roots) {
461 tree_.add_edge(root_, old_root);
462 }
463 }
464
465 single_empty_rooted_ = true;
466
467 return root_;
468 }
469
470 template<class FunValue, class EvaluationPolicy>
471 auto
472 ClusterTree<FunValue,EvaluationPolicy>
473 ::evaluate() {
474 auto root = single_empty_root();
475
476 if (!evaluated_) {
477 dfs_evaluate(root);
478
479 auto a = assignment_t(fn_.domains());
480 evaluation_result_ = a.eval_determined(tree_[root].cluster.functions(), evaluation_policy_t());
481
482 evaluated_ = true;
483 }
484 return evaluation_result_;
485 }
486
487 template<class FunValue, class EvaluationPolicy>
488 auto
490
491 assert(evaluated_);
492 //assert(is_consistent());
493
494 auto a = assignment_t(fn_.domains());
495
496 dfs_traceback(single_empty_root(), a);
497
498 return a;
499 }
500
501 template<class FunValue, class EvaluationPolicy>
502 auto
504 ::restricted_traceback(const std::set<var_idx_t> &variables,
505 const assignment_t &assignment) {
506
507 assert(evaluated_);
508
509 //trace_variables = std::set<var_idx_t>(variables.begin(),variables.end());
510 trace_variables = variables;
511
512 auto a = assignment_t(assignment);
513
514 dfs_traceback(single_empty_root(), a);
515
516 trace_variables.reset();
517
518 return a;
519 }
520}
521
522#endif
A tree of clusters (=variables, functions, constraints)
Definition cluster_tree.hpp:55
ClusterTree(int num_vars, const FiniteDomain &domain)
Construct with uniform domains.
Definition cluster_tree.hpp:124
FeatureNetwork< FunValue, EvaluationPolicy > feature_network_t
Definition cluster_tree.hpp:58
typename tree_t::vertex_descriptor_t vertex_descriptor_t
type of identifiers of vertices (typically 'long int')
Definition cluster_tree.hpp:91
bool is_consistent()
Check consistency.
Definition cluster_tree.hpp:226
ClusterTree(std::vector< int > domsizes)
construct from vector of upper bounds
Definition cluster_tree.hpp:114
~ClusterTree()
Definition cluster_tree.hpp:138
auto traceback()
Generate a traceback.
Definition cluster_tree.hpp:489
auto evaluate()
Evaluate the cluster tree (by DP)
Definition cluster_tree.hpp:473
ClusterTree(int num_vars, int domsize)
Construct with uniform domains.
Definition cluster_tree.hpp:134
auto restricted_traceback(const std::set< var_idx_t > &variables, const assignment_t &assignment)
Generate a restricted traceback.
Definition cluster_tree.hpp:504
typename feature_network_t::constraint_t constraint_t
Definition cluster_tree.hpp:67
auto add_root_cluster(const std::vector< var_idx_t > &vars)
add new root cluster to the tree
Definition cluster_tree.hpp:156
ClusterTree(const FiniteDomainVector &domains)
Construct with variable domains.
Definition cluster_tree.hpp:106
graph::adjacency_list< vertex_info_t, edge_info_t > tree_t
Definition cluster_tree.hpp:88
typename feature_network_t::cluster_t cluster_t
Definition cluster_tree.hpp:64
typename feature_network_t::var_idx_t var_idx_t
Definition cluster_tree.hpp:63
const auto & feature_network() const
read access to constraint network
Definition cluster_tree.hpp:141
typename feature_network_t::evaluation_policy_t evaluation_policy_t
evaluation policy type
Definition cluster_tree.hpp:61
typename feature_network_t::function_t function_t
Definition cluster_tree.hpp:68
void add_constraint(vertex_descriptor_t node, const std::shared_ptr< constraint_t > &x)
add new constraint to cluster
Definition cluster_tree.hpp:187
ClusterTree()
construct empty
Definition cluster_tree.hpp:96
auto add_child_cluster(vertex_descriptor_t parent, const std::vector< int > &vars)
add new child cluster to the tree
Definition cluster_tree.hpp:172
void add_function(vertex_descriptor_t node, const std::shared_ptr< function_t > &x)
add new function to cluster
Definition cluster_tree.hpp:200
typename feature_network_t::fun_value_t fun_value_t
Definition cluster_tree.hpp:66
typename feature_network_t::assignment_t assignment_t
Definition cluster_tree.hpp:65
the feature network
Definition feature_network.hpp:208
FunValue fun_value_t
Definition feature_network.hpp:211
auto add_function(const std::shared_ptr< function_t > &x)
add function
Definition feature_network.hpp:298
int var_idx_t
Definition feature_network.hpp:209
Function< FunValue > function_t
Definition feature_network.hpp:214
EvaluationPolicy evaluation_policy_t
Definition feature_network.hpp:217
auto add_constraint(const std::shared_ptr< constraint_t > &x)
add constraint
Definition feature_network.hpp:269
Cluster< fun_value_t > cluster_t
suitable cluster type for the constraint network
Definition feature_network.hpp:220
Assignment assignment_t
Definition feature_network.hpp:213
Constraint constraint_t
Definition feature_network.hpp:215
Definition finite_domain.hpp:29
A materialized function.
Definition functions.hpp:233
auto & add_edge(vertex_descriptor_t source, vertex_descriptor_t target, const edge_info_t &e)
add egde
Definition graph.hpp:78
vertex_descriptor_t add_vertex(const vertex_t &v)
add vertex
Definition graph.hpp:63
Definition assignment.hpp:29
std::vector< FiniteDomain > FiniteDomainVector
Definition finite_domain.hpp:148
information at an edge of the tree
Definition cluster_tree.hpp:82
edge_info_t(function_t *message=nullptr)
Definition cluster_tree.hpp:83
function_t * message
Definition cluster_tree.hpp:85
information at a vertex (=cluster/bag) of the tree
Definition cluster_tree.hpp:73
vertex_info_t(const cluster_t &cluster)
Definition cluster_tree.hpp:76
vertex_info_t()
Definition cluster_tree.hpp:74
cluster_t cluster
Definition cluster_tree.hpp:78