Skip to content
Snippets Groups Projects
main.cpp 19.33 KiB
#include <stdio.h>
#include <vector>
#include <fstream>
#include <set>

#include <boost/program_options.hpp>
#include <boost/algorithm/string.hpp>
#include <boost/format.hpp>

#include <rapidjson/document.h>

#include <loguru.hpp>

#include "external/taywee_args.hpp"

#include "isalgorithm.hpp"
#include "decision.hpp"
#include "network.hpp"
#include "json_workload.hpp"
#include "pempek_assert.hpp"

#include "algo/easy_bf.hpp"
#include "algo/fcfs.hpp"
#include "algo/rejecter.hpp"
#include "algo/sequencer.hpp"
#include "algo/carbone.hpp"


using namespace std;
using namespace boost;

namespace n = network;
namespace r = rapidjson;

void run(Network & n, ISchedulingAlgorithm * algo, SchedulingDecision &d,
         Workload &workload, bool call_make_decisions_on_single_nop = true);

/** @def STR_HELPER(x)
 *  @brief Helper macro to retrieve the string view of a macro.
 */
#define STR_HELPER(x) #x

/** @def STR(x)
 *  @brief Macro to get a const char* from a macro
 */
#define STR(x) STR_HELPER(x)

/** @def BATSCHED_VERSION
 *  @brief What batsched --version should return.
 *
 *  It is either set by CMake or set to vUNKNOWN_PLEASE_COMPILE_VIA_CMAKE
**/
#ifndef BATSCHED_VERSION
    #define BATSCHED_VERSION vUNKNOWN_PLEASE_COMPILE_VIA_CMAKE
#endif

int main(int argc, char ** argv)
{
    const set<string> variants_set = {"easy_bf", "fcfs", "rejecter", "sequencer", "carbone"}; 
    const set<string> policies_set = {"basic", "contiguous"};
    const set<string> queue_orders_set = {"fcfs", "lcfs", "desc_bounded_slowdown", "desc_slowdown",
                                          "asc_size", "desc_size", "asc_walltime", "desc_walltime"};
    const set<string> verbosity_levels_set = {"debug", "info", "quiet", "silent"};

    const string variants_string = "{" + boost::algorithm::join(variants_set, ", ") + "}";
    const string policies_string = "{" + boost::algorithm::join(policies_set, ", ") + "}";
    const string queue_orders_string = "{" + boost::algorithm::join(queue_orders_set, ", ") + "}";
    const string verbosity_levels_string = "{" + boost::algorithm::join(verbosity_levels_set, ", ") + "}";

    ISchedulingAlgorithm * algo = nullptr;
    ResourceSelector * selector = nullptr;
    Queue * queue = nullptr;
    SortableJobOrder * order = nullptr;

    args::ArgumentParser parser("A Batsim-compatible scheduler in C++.");
    args::HelpFlag flag_help(parser, "help", "Display this help menu", {'h', "help"});
    args::CompletionFlag completion(parser, {"complete"});

    args::ValueFlag<double> flag_rjms_delay(parser, "delay", "Sets the expected time that the RJMS takes to do some things like killing a job", {'d', "rjms_delay"}, 5.0);
    args::ValueFlag<string> flag_selection_policy(parser, "policy", "Sets the resource selection policy. Available values are " + policies_string, {'p', "policy"}, "basic");
    args::ValueFlag<string> flag_socket_endpoint(parser, "endpoint", "Sets the socket endpoint.", {'s', "socket-endpoint"}, "tcp://*:28000");
    args::ValueFlag<string> flag_scheduling_variant(parser, "variant", "Sets the scheduling variant. Available values are " + variants_string, {'v', "variant"}, "rejecter");
    args::ValueFlag<string> flag_variant_options(parser, "options", "Sets the scheduling variant options. Must be formatted as a JSON object.", {"variant_options"}, "{}");
    args::ValueFlag<string> flag_variant_options_filepath(parser, "options-filepath", "Sets the scheduling variant options as the content of the given filepath. Overrides the variant_options options.", {"variant_options_filepath"}, "");
    args::ValueFlag<string> flag_queue_order(parser, "order", "Sets the queue order. Available values are " + queue_orders_string, {'o', "queue_order"}, "fcfs");
    args::ValueFlag<string> flag_verbosity_level(parser, "verbosity-level", "Sets the verbosity level. Available values are " + verbosity_levels_string, {"verbosity"}, "info");
    args::ValueFlag<bool> flag_call_make_decisions_on_single_nop(parser, "flag", "If set to true, make_decisions will be called after single NOP messages.", {"call_make_decisions_on_single_nop"}, true);
    args::Flag flag_version(parser, "version", "Shows batsched version", {"version"});

    try
    {
        parser.ParseCLI(argc, argv);

        if (flag_rjms_delay.Get() < 0)
            throw args::ValidationError(str(format("Invalid '%1%' parameter value (%2%): Must be non-negative.")
                                            % flag_rjms_delay.Name()
                                            % flag_rjms_delay.Get()));

        if (queue_orders_set.find(flag_queue_order.Get()) == queue_orders_set.end())
            throw args::ValidationError(str(format("Invalid '%1%' value (%2%): Not in %3%")
                                            % flag_queue_order.Name()
                                            % flag_queue_order.Get()
                                            % queue_orders_string));

        if (variants_set.find(flag_scheduling_variant.Get()) == variants_set.end())
            throw args::ValidationError(str(format("Invalid '%1%' value (%2%): Not in %3%")
                                            % flag_scheduling_variant.Name()
                                            % flag_scheduling_variant.Get()
                                            % variants_string));

        if (verbosity_levels_set.find(flag_verbosity_level.Get()) == verbosity_levels_set.end())
            throw args::ValidationError(str(format("Invalid '%1%' value (%2%): Not in %3%")
                                            % flag_verbosity_level.Name()
                                            % flag_verbosity_level.Get()
                                            % verbosity_levels_string));
    }
    catch(args::Help&)
    {
        parser.helpParams.addDefault = true;
        printf("%s", parser.Help().c_str());
        return 0;
    }
    catch (args::Completion & e)
    {
        printf("%s", e.what());
        return 0;
    }
    catch(args::ParseError & e)
    {
        printf("%s\n", e.what());
        return 1;
    }
    catch(args::ValidationError & e)
    {
        printf("%s\n", e.what());
        return 1;
    }

    if (flag_version)
    {
        printf("%s\n", STR(BATSCHED_VERSION));
        return 0;
    }

    string socket_endpoint = flag_socket_endpoint.Get();
    string scheduling_variant = flag_scheduling_variant.Get();
    string selection_policy = flag_selection_policy.Get();
    string queue_order = flag_queue_order.Get();
    string variant_options = flag_variant_options.Get();
    string variant_options_filepath = flag_variant_options_filepath.Get();
    string verbosity_level = flag_verbosity_level.Get();
    double rjms_delay = flag_rjms_delay.Get();
    bool call_make_decisions_on_single_nop = flag_call_make_decisions_on_single_nop.Get();

    try
    {
        // Logging configuration
        if (verbosity_level == "debug")
            loguru::g_stderr_verbosity = loguru::Verbosity_1;
        else if (verbosity_level == "quiet")
            loguru::g_stderr_verbosity = loguru::Verbosity_WARNING;
        else if (verbosity_level == "silent")
            loguru::g_stderr_verbosity = loguru::Verbosity_OFF;
        else
            loguru::g_stderr_verbosity = loguru::Verbosity_INFO;

        // Workload creation
        Workload w;
        w.set_rjms_delay(rjms_delay);

        // Scheduling parameters
        SchedulingDecision decision;

        // Queue order
        if (queue_order == "fcfs")
            order = new FCFSOrder;
        else if (queue_order == "lcfs")
            order = new LCFSOrder;
        else if (queue_order == "desc_bounded_slowdown")
            order = new DescendingBoundedSlowdownOrder(1);
        else if (queue_order == "desc_slowdown")
            order = new DescendingSlowdownOrder;
        else if (queue_order == "asc_size")
            order = new AscendingSizeOrder;
        else if (queue_order == "desc_size")
            order = new DescendingSizeOrder;
        else if (queue_order == "asc_walltime")
            order = new AscendingWalltimeOrder;
        else if (queue_order == "desc_walltime")
            order = new DescendingWalltimeOrder;

        queue = new Queue(order);

        // Resource selector
        if (selection_policy == "basic")
            selector = new BasicResourceSelector;
        else if (selection_policy == "contiguous")
            selector = new ContiguousResourceSelector;
        else
        {
            printf("Invalid resource selection policy '%s'. Available options are %s\n", selection_policy.c_str(), policies_string.c_str());
            return 1;
        }

        // Scheduling variant options
        if (!variant_options_filepath.empty())
        {
            ifstream variants_options_file(variant_options_filepath);

            if (variants_options_file.is_open())
            {
                // Let's put the whole file content into one string
                variants_options_file.seekg(0, ios::end);
                variant_options.reserve(variants_options_file.tellg());
                variants_options_file.seekg(0, ios::beg);

                variant_options.assign((std::istreambuf_iterator<char>(variants_options_file)),
                                        std::istreambuf_iterator<char>());
            }
            else
            {
                printf("Couldn't open variants options file '%s'. Aborting.\n", variant_options_filepath.c_str());
                return 1;
            }
        }

        rapidjson::Document json_doc_variant_options;
        json_doc_variant_options.Parse(variant_options.c_str());
        if (!json_doc_variant_options.IsObject())
        {
            printf("Invalid variant options: Not a JSON object. variant_options='%s'\n", variant_options.c_str());
            return 1;
        }
        LOG_F(1, "variant_options = '%s'", variant_options.c_str());

        // Scheduling variant
        if (scheduling_variant == "easy_bf")
            algo = new EasyBackfilling(&w, &decision, queue, selector, rjms_delay, &json_doc_variant_options);
        else if (scheduling_variant == "fcfs")
            algo = new FCFS(&w, &decision, queue, selector, rjms_delay, &json_doc_variant_options);
        else if (scheduling_variant == "rejecter")
            algo = new Rejecter(&w, &decision, queue, selector, rjms_delay, &json_doc_variant_options);
        else if (scheduling_variant == "sequencer")
            algo = new Sequencer(&w, &decision, queue, selector, rjms_delay, &json_doc_variant_options);
        else if (scheduling_variant == "carbone")
            algo = new CarboneAlgorithm(&w, &decision, queue, selector, rjms_delay, &json_doc_variant_options);
        // Network
        Network n;
        n.bind(socket_endpoint);

        // Run the simulation
        run(n, algo, decision, w, call_make_decisions_on_single_nop);
    }
    catch(const std::exception & e)
    {
        string what = e.what();

        if (what == "Connection lost")
        {
            LOG_F(ERROR, "%s", what.c_str());
        }
        else
        {
            LOG_F(ERROR, "%s", what.c_str());

            delete queue;
            delete order;

            delete algo;
            delete selector;

            throw;
        }
    }

    delete queue;
    delete order;

    delete algo;
    delete selector;

    return 0;
}

void run(Network & n, ISchedulingAlgorithm * algo, SchedulingDecision & d,
         Workload & workload, bool call_make_decisions_on_single_nop)
{
    bool simulation_finished = false;

    while (!simulation_finished)
    {
        string received_message;
        n.read(received_message);

        if (boost::trim_copy(received_message).empty())
            throw runtime_error("Empty message received (connection lost ?)");

        d.clear();

        r::Document doc;
        doc.Parse(received_message.c_str());

        double message_date = doc["now"].GetDouble();
        double current_date = message_date;
        bool requested_callback_received = false;

        // Let's handle all received events
        const r::Value & events_array = doc["events"];

        for (unsigned int event_i = 0; event_i < events_array.Size(); ++event_i)
        {
            const r::Value & event_object = events_array[event_i];
            const std::string event_type = event_object["type"].GetString();
            current_date = event_object["timestamp"].GetDouble();
            const r::Value & event_data = event_object["data"];


            if (event_type == "SIMULATION_BEGINS")
            {
                int nb_resources;
                // DO this for retrocompatibility with batsim 2 API
                if (event_data.HasMember("nb_compute_resources"))
                {
                    nb_resources = event_data["nb_compute_resources"].GetInt();
                }
                else
                {
                    nb_resources = event_data["nb_resources"].GetInt();
                }

                algo->set_nb_machines(nb_resources);
                algo->on_simulation_start(current_date, event_data["config"]);
            }
            else if (event_type == "SIMULATION_ENDS")
            {
                algo->on_simulation_end(current_date);
                simulation_finished = true;
            }
            else if (event_type == "JOB_SUBMITTED")
            {
                string job_id = event_data["job_id"].GetString();

                workload.add_job_from_json_object(event_data["job"], job_id, current_date);

                algo->on_job_release(current_date, {job_id});
            }
            else if (event_type == "JOB_COMPLETED")
            {
                string job_id = event_data["job_id"].GetString();
                workload[job_id]->completion_time = current_date;
                algo->on_job_end(current_date, {job_id});
            }
            else if (event_type == "RESOURCE_STATE_CHANGED")
            {
                IntervalSet resources = IntervalSet::from_string_hyphen(event_data["resources"].GetString(), " ");
                string new_state = event_data["state"].GetString();
                algo->on_machine_state_changed(current_date, resources, std::stoi(new_state));
            }
            else if (event_type == "JOB_KILLED")
            {
                const r::Value & job_ids_map = event_data["job_progress"];
                PPK_ASSERT_ERROR(job_ids_map.GetType() == r::kObjectType);

                vector<string> job_ids;

                for (auto itr = job_ids_map.MemberBegin(); itr != job_ids_map.MemberEnd(); ++itr)
                {
                    string job_id = itr->name.GetString();
                    job_ids.push_back(job_id);
                }

                algo->on_job_killed(current_date, job_ids);
            }
            else if (event_type == "REQUESTED_CALL")
            {
                requested_callback_received = true;
                algo->on_requested_call(current_date);
            }
            else if (event_type == "ANSWER")
            {
                for (auto itr = event_data.MemberBegin(); itr != event_data.MemberEnd(); ++itr)
                {
                    string key_value = itr->name.GetString();

                    if (key_value == "consumed_energy")
                    {
                        double consumed_joules = itr->value.GetDouble();
                        algo->on_answer_energy_consumption(current_date, consumed_joules);
                    }
                    else
                    {
                        PPK_ASSERT_ERROR(false, "Unknown ANSWER type received '%s'", key_value.c_str());
                    }
                }
            }
            else if (event_type == "QUERY")
            {
                const r::Value & requests = event_data["requests"];

                for (auto itr = requests.MemberBegin(); itr != requests.MemberEnd(); ++itr)
                {
                    string key_value = itr->name.GetString();

                    if (key_value == "estimate_waiting_time")
                    {
                        const r::Value & request_object = itr->value;
                        string job_id = request_object["job_id"].GetString();
                        workload.add_job_from_json_object(request_object["job"], job_id, current_date);

                        algo->on_query_estimate_waiting_time(current_date, job_id);
                    }
                    else
                    {
                        PPK_ASSERT_ERROR(false, "Unknown QUERY type received '%s'", key_value.c_str());
                    }
                }
            }
            else if (event_type == "NOTIFY")
            {
                string notify_type = event_data["type"].GetString();

                if (notify_type == "no_more_static_job_to_submit")
                {
                    algo->on_no_more_static_job_to_submit_received(current_date);
                }
                else if (notify_type == "no_more_external_event_to_occur")
                {
                    algo->on_no_more_external_event_to_occur(current_date);
                }
                else if (notify_type == "event_machine_available")
                {
                    IntervalSet resources = IntervalSet::from_string_hyphen(event_data["resources"].GetString(), " ");
                    algo->on_machine_available_notify_event(current_date, resources);
                }
                else if (notify_type == "event_machine_unavailable")
                {
                    IntervalSet resources = IntervalSet::from_string_hyphen(event_data["resources"].GetString(), " ");
                    algo->on_machine_unavailable_notify_event(current_date, resources);
                }
                else
                {
                    throw runtime_error("Unknown NOTIFY type received. Type = " + notify_type);
                }

            }
              else if (event_type == "carbon_co2") {
                    int val_dcA = 0, val_dcB = 0, val_dcC = 0; // Initialisation 

                if (event_data.HasMember("some_field_dcA")) {
                    val_dcA = event_data["some_field_dcA"].GetInt();
                }
                if (event_data.HasMember("another_field_dcB")) {
                    val_dcB = event_data["another_field_dcB"].GetInt();
                }
                if (event_data.HasMember("additional_field_dcC")) {
                    val_dcC = event_data["additional_field_dcC"].GetInt();
                }

                CarboneAlgorithm* carboneAlgo = dynamic_cast<CarboneAlgorithm*>(algo);
                if (carboneAlgo != nullptr) {
                carboneAlgo->updateCarbonFootprint(val_dcA, val_dcB, val_dcC);
                } 
                else {
                std::cerr << "Erreur : l'algorithme en cours n'est pas une instance de CarboneAlgorithm." << std::endl;
                }
            }

            else
            {
                throw runtime_error("Unknown event received. Type = " + event_type);
            }
        }

        bool requested_callback_only = requested_callback_received && (events_array.Size() == 1);

        // make_decisions is not called if (!call_make_decisions_on_single_nop && single_nop_received)
        if (!(!call_make_decisions_on_single_nop && requested_callback_only))
        {
            SortableJobOrder::UpdateInformation update_info(current_date);
            algo->make_decisions(message_date, &update_info, nullptr);
            algo->clear_recent_data_structures();
        }

        message_date = max(message_date, d.last_date());

        const string & message_to_send = d.content(message_date);
        n.write(message_to_send);
    }
}