Program Listing for File train_model.cpp
↰ Return to documentation for file (lib/train_model.cpp
)
#include "AnalysisGraph.hpp"
#include "ModelStatus.hpp"
#include "data.hpp"
#include "TrainingStopper.hpp"
#include "Logger.hpp"
#include <tqdm.hpp>
#include <range/v3/all.hpp>
#include <nlohmann/json.hpp>
#ifdef TIME
#include "utils.hpp"
#include "Timer.hpp"
// #include "CSVWriter.hpp"
#endif
using namespace std;
using tq::trange;
using json = nlohmann::json;
void AnalysisGraph::train_model(int start_year,
int start_month,
int end_year,
int end_month,
int res,
int burn,
string country,
string state,
string county,
map<string, string> units,
InitialBeta initial_beta,
InitialDerivative initial_derivative,
bool use_heuristic,
bool use_continuous) {
this->training_range = make_pair(make_pair(start_year, start_month),
make_pair( end_year, end_month));
this->n_timesteps = this->calculate_num_timesteps(start_year, start_month,
end_year, end_month);
this->modeling_timestep_gaps.clear();
this->modeling_timestep_gaps = vector<double>(this->n_timesteps, 1.0);
this->modeling_timestep_gaps[0] = 0;
if(this->n_timesteps > 0) {
if (!synthetic_data_experiment && !causemos_call) {
// Delphi is run locally using observation data from delphi.db
// For a synthetic data experiment, the observed state sequence is
// generated.
// For a CauseMos call, the observation sequences are provided in the create
// model JSON call and the observed state sequence is set in the method
// AnalysisGraph::set_observed_state_sequence_from_json_data(), which is
// defined in causemos_integration.cpp
this->set_observed_state_sequence_from_data(country, state, county);
}
this->run_train_model(res, burn, HeadNodeModel::HNM_NAIVE, initial_beta,
initial_derivative, use_heuristic, use_continuous);
}
}
void AnalysisGraph::run_train_model(int res,
int burn,
HeadNodeModel head_node_model,
InitialBeta initial_beta,
InitialDerivative initial_derivative,
bool use_heuristic,
bool use_continuous,
int train_start_timestep,
int train_timesteps,
unordered_map<string, int> concept_periods,
unordered_map<string, string> concept_center_measures,
unordered_map<string, string> concept_models,
unordered_map<string, double> concept_min_vals,
unordered_map<string, double> concept_max_vals,
unordered_map<string, function<double(unsigned int, double)>> ext_concepts) {
double training_step = 0.99 / (res + burn);
TrainingStopper training_stopper;
Logger logger;
logger.info("AnalysisGraph::run_train_model");
ModelStatus ms(this->id);
ms.enter_working_state();
this->trained = false;
if (train_timesteps < 0) {
this->n_timesteps = this->observed_state_sequence.size();
}
else {
this->n_timesteps = train_timesteps;
}
unordered_set<int> train_vertices =
unordered_set<int>
(this->node_indices().begin(), this->node_indices().end());
for (const auto & [ concept, deriv_func ] : ext_concepts) {
try {
int vert_id = this->name_to_vertex.at(concept);
this->external_concepts[vert_id] = deriv_func;
train_vertices.erase(vert_id);
}
catch (const std::out_of_range& oor) {
cout << "\nERROR: train_model - Concept << concept << is not in CAG!\n";
}
}
for (const auto & [ concept, period ] : concept_periods) {
try {
int vert_id = this->name_to_vertex.at(concept);
Node &n = (*this)[vert_id];
n.period = period;
}
catch (const std::out_of_range& oor) {
cout << "\nERROR: train_model - Concept << concept << is not in CAG!\n";
}
}
for (const auto & [ concept, center_measure ] : concept_center_measures) {
try {
int vert_id = this->name_to_vertex.at(concept);
Node &n = (*this)[vert_id];
n.center_measure = center_measure;
}
catch (const std::out_of_range& oor) {
cout << "\nERROR: train_model - Concept << concept << is not in CAG!\n";
}
}
for (const auto & [ concept, model ] : concept_models) {
try {
int vert_id = this->name_to_vertex.at(concept);
Node &n = (*this)[vert_id];
n.model = model;
}
catch (const std::out_of_range& oor) {
cout << "\nERROR: train_model - Concept << concept << is not in CAG!\n";
}
}
for (const auto & [ concept, min_val ] : concept_min_vals) {
try {
int vert_id = this->name_to_vertex.at(concept);
Node &n = (*this)[vert_id];
n.min_val = min_val;
n.has_min = true;
}
catch (const std::out_of_range& oor) {
cout << "\nERROR: train_model - Concept << concept << is not in CAG!\n";
}
}
for (const auto & [ concept, max_val ] : concept_max_vals) {
try {
int vert_id = this->name_to_vertex.at(concept);
Node &n = (*this)[vert_id];
n.max_val = max_val;
n.has_max = true;
}
catch (const std::out_of_range& oor) {
cout << "\nERROR: train_model - Concept << concept << is not in CAG!\n";
}
}
this->head_node_model = head_node_model;
this->concept_sample_pool.clear();
for (int vert : train_vertices) {
Node& n = (*this)[vert];
if (this->head_nodes.find(vert) == this->head_nodes.end()) {
this->concept_sample_pool.push_back(vert);
} else if (n.period == 1) {
// Head nodes with period > 1 are modeled using the seasonality
// period == 1 => this is not a seasonal node
// To prevent this from modeled as seasonal,
// remove it from head nodes
// TODO: There is a terminology confusion since this is still a
// head node, but we are removing it from head nodes to
// prevent it from being modeled seasonally.
this->concept_sample_pool.push_back(vert);
this->head_nodes.erase(vert);
this->body_nodes.insert(vert);
}
}
this->edge_sample_pool.clear();
for (EdgeDescriptor ed : this->edges()) {
this->graph[ed].sampled_thetas.clear();
if (!this->graph[ed].is_frozen()) {
this->edge_sample_pool.push_back(ed);
}
}
this->initialize_parameters(res, initial_beta, initial_derivative,
use_heuristic, use_continuous);
this->log_likelihoods.clear();
this->log_likelihoods = vector<double>(burn + this->res, 0);
this->MAP_sample_number = -1;
#ifdef TIME
this->create_mcmc_part_timing_file();
// int n_nodes = this->num_nodes();
// int n_edges = this->num_nodes();
// string filename = string("mcmc_timing_embeded_") +
// to_string(n_nodes) + "-" +
// to_string(n_edges) + "_" +
// delphi::utils::get_timestamp() + ".csv";
// this->writer = CSVWriter(filename);
// vector<string> headings = {"Nodes", "Edges", "Wall Clock Time (ns)", "CPU Time (ns)", "Sample Type"};
// writer.write_row(headings.begin(), headings.end());
// cout << filename << endl;
#endif
string text = "Burning " + to_string(burn) + " samples out...";
logger.info(" " + text);
logger.info(" # log_likelihood");
// cout << "\n" << text << endl;
for (int i : trange(burn)) {
ms.increment_progress(training_step);
{
#ifdef TIME
// durations.first.clear();
durations.second.clear();
durations.second.push_back(this->timing_run_number);
// durations.first.push_back("Nodes");
durations.second.push_back(this->num_nodes());
// durations.first.push_back("Edges");
durations.second.push_back(this->num_nodes());
Timer t = Timer("train", durations);
#endif
this->sample_from_posterior();
}
#ifdef TIME
// durations.first.push_back("sample type");
durations.second.push_back(this->coin_flip < this->coin_flip_thresh? 1 : 0);
writer.write_row(durations.second.begin(), durations.second.end());
#endif
this->log_likelihoods[i] = this->log_likelihood;
char buf[200];
sprintf(buf, "%4d %.10f", i, this->log_likelihood);
logger.info(" " + string(buf));
if (this->log_likelihood > this->log_likelihood_MAP) {
this->log_likelihood_MAP = this->log_likelihood;
this->transition_matrix_collection[this->res - 1] = this->A_original;
this->initial_latent_state_collection[this->res - 1] = this->s0;
this->log_likelihoods[burn + this->res - 1] = this->log_likelihood;
this->MAP_sample_number = this->res - 1;
}
if(training_stopper.stop_training(this->log_likelihoods, i)) {
string text = "Model training stopped early at sample " + to_string(i);
logger.info(" " + text);
cout << text << endl;
break;
}
}
cout << "\nSampling " << this->res << " samples from posterior..." << endl;
for (int i : trange(this->res - 1)) {
{
ms.increment_progress(training_step);
#ifdef TIME
// durations.first.clear();
durations.second.clear();
durations.second.push_back(this->timing_run_number);
// durations.first.push_back("Nodes");
durations.second.push_back(this->num_nodes());
// durations.first.push_back("Edges");
durations.second.push_back(this->num_edges());
Timer t = Timer("Train", durations);
#endif
this->sample_from_posterior();
}
#ifdef TIME
// durations.first.push_back("Sample Type");
durations.second.push_back(this->coin_flip < this->coin_flip_thresh? 1 : 0);
writer.write_row(durations.second.begin(), durations.second.end());
#endif
this->transition_matrix_collection[i] = this->A_original;
this->initial_latent_state_collection[i] = this->s0;
if (this->log_likelihood > this->log_likelihood_MAP) {
this->log_likelihood_MAP = this->log_likelihood;
this->MAP_sample_number = i;
}
for (auto e : this->edges()) {
this->graph[e].sampled_thetas.push_back(this->graph[e].get_theta());
}
this->log_likelihoods[burn + i] = this->log_likelihood;
/*
this->latent_mean_collection[i] = vector<double>(num_verts);
this->latent_std_collection[i] = vector<double>(num_verts);
this->latent_mean_std_collection[i] = vector<
unordered_map<int, pair<double, double>>>(num_verts);
for (int v : this->node_indices()) {
Node &n = (*this)[v];
this->latent_mean_collection[i][v] = n.mean;
this->latent_std_collection[i][v] = n.std;
this->latent_mean_std_collection[i][v] = n.partition_mean_std;
}
*/
}
if (this->MAP_sample_number < int(this->res)) {
this->sample_from_posterior();
this->transition_matrix_collection[this->res - 1] = this->A_original;
this->initial_latent_state_collection[this->res - 1] = this->s0;
this->log_likelihoods[burn + this->res - 1] = this->log_likelihood;
if ((this->log_likelihood > this->log_likelihood_MAP) or (this->log_likelihood_MAP == -1)) {
this->log_likelihood_MAP = this->log_likelihood;
this->MAP_sample_number = this->res - 1;
}
for (auto e : this->edges()) {
this->graph[e].sampled_thetas.push_back(this->graph[e].get_theta());
}
this->log_likelihoods[burn + this->res - 1] = this->log_likelihood;
} else {
this->MAP_sample_number = this->res - 1;
}
this->trained = true;
ms.enter_writing_state();
RNG::release_instance();
}
void AnalysisGraph::run_train_model_2(int res,
int burn,
InitialBeta initial_beta,
InitialDerivative initial_derivative,
bool use_heuristic,
bool use_continuous
) {
this->initialize_parameters(res, initial_beta, initial_derivative,
use_heuristic, use_continuous);
cout << "\nBurning " << burn << " samples out..." << endl;
for (int i : trange(burn)) {
this->sample_from_posterior();
}
cout << "\nSampling " << this->res << " samples from posterior..." << endl;
for (int i : trange(this->res)) {
this->sample_from_posterior();
this->transition_matrix_collection[i] = this->A_original;
this->initial_latent_state_collection[i] = this->s0;
for (auto e : this->edges()) {
this->graph[e].sampled_thetas.push_back(this->graph[e].get_theta());
}
}
this->trained = true;
RNG::release_instance();
}
/*
============================================================================
Private: Get Training Data Sequence
============================================================================
*/
void AnalysisGraph::set_observed_state_sequence_from_data(string country,
string state,
string county) {
this->observed_state_sequence.clear();
// Access (concept is a vertex in the CAG)
// [ timestep ][ concept ][ indicator ][ observation ]
this->observed_state_sequence = ObservedStateSequence(this->n_timesteps);
int year = this->training_range.first.first;
int month = this->training_range.first.second;
for (int ts = 0; ts < this->n_timesteps; ts++) {
this->observed_state_sequence[ts] =
get_observed_state_from_data(year, month, country, state, county);
if (month == 12) {
year++;
month = 1;
}
else {
month++;
}
}
}
vector<vector<vector<double>>> AnalysisGraph::get_observed_state_from_data(
int year, int month, string country, string state, string county) {
using ranges::to;
using ranges::views::transform;
int num_verts = this->num_vertices();
// Access (concept is a vertex in the CAG)
// [ concept ][ indicator ][ observation ]
vector<vector<vector<double>>> observed_state(num_verts);
for (int v = 0; v < num_verts; v++) {
vector<Indicator>& indicators = (*this)[v].indicators;
for (auto& ind : indicators) {
vector<double> vals = get_observations_for(ind.get_name(),
country,
state,
county,
year,
month,
ind.get_unit(),
this->data_heuristic);
observed_state[v].push_back(vals);
}
}
return observed_state;
}