Skip to content
Snippets Groups Projects
main.cpp 17.18 KiB
#include <fstream>
#include <set>
#include <stdio.h>
#include <vector>

#include <boost/algorithm/string.hpp>
#include <boost/format.hpp>
#include <boost/program_options.hpp>

#include <rapidjson/document.h>

#include <loguru.hpp>

#include "external/taywee_args.hpp"

#include "decision.hpp"
#include "isalgorithm.hpp"
#include "json_workload.hpp"
#include "network.hpp"
#include "pempek_assert.hpp"

#include "queue.hpp"
#include "scheds/easy_bf_fast.hpp"
#include "scheds/fcfs.hpp"
#include "scheds/bin_packing.hpp"
#include "scheds/bin_packing_energy.hpp"
#include "scheds/multicore_filler.hpp"

using namespace std;
using namespace boost;

namespace n = network;
namespace r = rapidjson;

void run(Network &n, ISchedulingAlgorithm *algo, SchedulingDecision &d, Workload &workload,
    bool call_make_decisions_on_single_nop = true);

/** @def STR_HELPER(x)
 *  @brief Helper macro to retrieve the string view of a macro.
 */
#define STR_HELPER(x) #x

/** @def STR(x)
 *  @brief Macro to get a const char* from a macro
 */
#define STR(x) STR_HELPER(x)


int main(int argc, char **argv)
{
    const set<string> variants_set = { "fcfs", "easy_bf", "multicore_filler", "bin_packing", "bin_packing_energy" };
    const set<string> policies_set = { "basic", "contiguous" };
    const set<string> verbosity_levels_set = { "debug", "info", "quiet", "silent" };

    const string variants_string = "{" + boost::algorithm::join(variants_set, ", ") + "}";
    const string policies_string = "{" + boost::algorithm::join(policies_set, ", ") + "}";
    const string verbosity_levels_string = "{" + boost::algorithm::join(verbosity_levels_set, ", ") + "}";

    ISchedulingAlgorithm *algo = nullptr;
    ResourceSelector *selector = nullptr;
    Queue *queue = nullptr;
    SortableJobOrder *order = nullptr;

    args::ArgumentParser parser("Custom fork of batsched, a Batsim-compatible scheduler in C++.");
    args::HelpFlag flag_help(parser, "help", "Display this help menu", { 'h', "help" });

    args::ValueFlag<double> flag_rjms_delay(parser, "delay",
        "Sets the expected time that the RJMS takes to do some things like killing a job", { 'd', "rjms_delay" }, 5.0);
    args::ValueFlag<string> flag_selection_policy(parser, "policy",
        "Sets the resource selection policy. Available values are " + policies_string, { 'p', "policy" }, "basic");
    args::ValueFlag<string> flag_socket_endpoint(
        parser, "endpoint", "Sets the socket endpoint.", { 's', "socket-endpoint" }, "tcp://*:28000");
    args::ValueFlag<string> flag_scheduling_variant(parser, "variant",
        "Sets the scheduling variant. Available values are " + variants_string, { 'v', "variant" }, "multicore_filler");
    args::ValueFlag<string> flag_variant_options(parser, "options",
        "Sets the scheduling variant options. Must be formatted as a JSON object.", { "variant_options" }, "{}");
    args::ValueFlag<string> flag_variant_options_filepath(parser, "options-filepath",
        "Sets the scheduling variant options as the content of the given filepath. Overrides the variant_options "
        "options.", { "variant_options_filepath" }, "");
    args::ValueFlag<int> flag_nb_cores_per_machine(parser, "nb_cores_per_machine",
        "Sets the number of cores per machine in the plateform (warning: must match the platform file)", 
        { "nb_cores" }, DEFAULT_NB_CORE_PER_MACHINE);
    args::ValueFlag<string> flag_verbosity_level(parser, "verbosity-level",
        "Sets the verbosity level. Available values are " + verbosity_levels_string, { "verbosity" }, "info");
    args::ValueFlag<bool> flag_call_make_decisions_on_single_nop(parser, "flag",
        "If set to true, make_decisions will be called after single NOP messages.",
        { "call_make_decisions_on_single_nop" }, true);

    try
    {
        parser.ParseCLI(argc, argv);

        if (flag_rjms_delay.Get() < 0)
            throw args::ValidationError(str(format("Invalid '%1%' parameter value (%2%): Must be non-negative.")
                % flag_rjms_delay.Name() % flag_rjms_delay.Get()));

        if (flag_nb_cores_per_machine.Get() < 0)
            throw args::ValidationError(str(format("Invalid '%1%' parameter value (%2%): Must be non-negative.")
                % flag_nb_cores_per_machine.Name() % flag_nb_cores_per_machine.Get()));

        if (variants_set.find(flag_scheduling_variant.Get()) == variants_set.end())
            throw args::ValidationError(str(format("Invalid '%1%' value (%2%): Not in %3%")
                % flag_scheduling_variant.Name() % flag_scheduling_variant.Get() % variants_string));

        if (verbosity_levels_set.find(flag_verbosity_level.Get()) == verbosity_levels_set.end())
            throw args::ValidationError(str(format("Invalid '%1%' value (%2%): Not in %3%")
                % flag_verbosity_level.Name() % flag_verbosity_level.Get() % verbosity_levels_string));
    }
    catch (args::Help &)
    {
        parser.helpParams.addDefault = true;
        printf("%s", parser.Help().c_str());
        return 0;
    }
    catch (args::ParseError &e)
    {
        printf("%s\n", e.what());
        return 1;
    }
    catch (args::ValidationError &e)
    {
        printf("%s\n", e.what());
        return 1;
    }


    string socket_endpoint = flag_socket_endpoint.Get();
    string scheduling_variant = flag_scheduling_variant.Get();
    string selection_policy = flag_selection_policy.Get();
    string variant_options = flag_variant_options.Get();
    string variant_options_filepath = flag_variant_options_filepath.Get();
    string verbosity_level = flag_verbosity_level.Get();
    double rjms_delay = flag_rjms_delay.Get();
    int nb_cores_per_machine = flag_nb_cores_per_machine.Get();
    bool call_make_decisions_on_single_nop = flag_call_make_decisions_on_single_nop.Get();

    try
    {
        // Logging configuration
        if (verbosity_level == "debug")
            loguru::g_stderr_verbosity = loguru::Verbosity_1;
        else if (verbosity_level == "quiet")
            loguru::g_stderr_verbosity = loguru::Verbosity_WARNING;
        else if (verbosity_level == "silent")
            loguru::g_stderr_verbosity = loguru::Verbosity_OFF;
        else
            loguru::g_stderr_verbosity = loguru::Verbosity_INFO;

        // Workload creation
        Workload w;
        w.set_rjms_delay(rjms_delay);

        // Scheduling parameters
        SchedulingDecision decision;

        // Resource selector
        if (selection_policy == "basic")
            selector = new BasicResourceSelector;
        else if (selection_policy == "contiguous")
            selector = new ContiguousResourceSelector;
        else
        {
            printf("Invalid resource selection policy '%s'. Available options are %s\n", selection_policy.c_str(),
                policies_string.c_str());
            return 1;
        }

        // Scheduling variant options
        if (!variant_options_filepath.empty())
        {
            ifstream variants_options_file(variant_options_filepath);

            if (variants_options_file.is_open())
            {
                // Let's put the whole file content into one string
                variants_options_file.seekg(0, ios::end);
                variant_options.reserve(variants_options_file.tellg());
                variants_options_file.seekg(0, ios::beg);

                variant_options.assign(
                    (std::istreambuf_iterator<char>(variants_options_file)), std::istreambuf_iterator<char>());
            }
            else
            {
                printf("Couldn't open variants options file '%s'. Aborting.\n", variant_options_filepath.c_str());
                return 1;
            }
        }

        rapidjson::Document json_doc_variant_options;
        json_doc_variant_options.Parse(variant_options.c_str());
        if (!json_doc_variant_options.IsObject())
        {
            printf("Invalid variant options: Not a JSON object. variant_options='%s'\n", variant_options.c_str());
            return 1;
        }
        LOG_F(1, "variant_options = '%s'", variant_options.c_str());

        // Job order
        if (scheduling_variant == "easy_bf" || scheduling_variant == "multicore_filler"){
            order = new FCFSOrder;
            queue = new Queue(order);
        }
        if (scheduling_variant == "bin_packing" || scheduling_variant == "bin_packing_energy"){
            order = new DescendingSizeOrder;
            queue = new Queue(order);
        }

        // Scheduling variant
        if (scheduling_variant == "fcfs")
            algo = new FCFS(&w, &decision, nullptr, selector, rjms_delay, &json_doc_variant_options);

        else if (scheduling_variant == "easy_bf")
            algo = new EasyBackfillingFast(&w, &decision, queue, selector, rjms_delay, &json_doc_variant_options);
        
        else if (scheduling_variant == "multicore_filler")
            algo = new MulticoreFiller(&w, &decision, queue, selector, rjms_delay, &json_doc_variant_options);
        
        else if (scheduling_variant == "bin_packing")
            algo = new BinPacking(&w, &decision, queue, selector, rjms_delay, &json_doc_variant_options);
        
        else if (scheduling_variant == "bin_packing_energy")
            algo = new BinPackingEnergy(&w, &decision, queue, selector, rjms_delay, &json_doc_variant_options);
        
        // Nb_core_per_machine
        // TODO (wait for batsim5.0): retrieve it from the platform file at SIMULATION_BEGIN */
        algo->set_nb_cores_per_machine(nb_cores_per_machine);

        // Network
        Network n;
        n.bind(socket_endpoint);

        // Run the simulation
        run(n, algo, decision, w, call_make_decisions_on_single_nop);
    }
    catch (const std::exception &e)
    {
        string what = e.what();

        if (what == "Connection lost")
        {
            LOG_F(ERROR, "%s", what.c_str());
        }
        else
        {
            LOG_F(ERROR, "%s", what.c_str());

            delete queue;
            delete order;

            delete algo;
            delete selector;

            throw;
        }
    }

    delete queue;
    delete order;

    delete algo;
    delete selector;

    return 0;
}

void run(Network &n, ISchedulingAlgorithm *algo, SchedulingDecision &d, Workload &workload,
    bool call_make_decisions_on_single_nop)
{
    bool simulation_finished = false;

    while (!simulation_finished)
    {
        string received_message;
        n.read(received_message);

        if (boost::trim_copy(received_message).empty())
            throw runtime_error("Empty message received (connection lost ?)");

        d.clear();

        r::Document doc;
        doc.Parse(received_message.c_str());

        double message_date = doc["now"].GetDouble();
        double current_date = message_date;
        bool requested_callback_received = false;

        // Let's handle all received events
        const r::Value &events_array = doc["events"];

        for (unsigned int event_i = 0; event_i < events_array.Size(); ++event_i)
        {
            const r::Value &event_object = events_array[event_i];
            const std::string event_type = event_object["type"].GetString();
            current_date = event_object["timestamp"].GetDouble();
            const r::Value &event_data = event_object["data"];

            if (event_type == "SIMULATION_BEGINS")
            {
                int nb_resources;
                // DO this for retrocompatibility with batsim 2 API
                if (event_data.HasMember("nb_compute_resources"))
                {
                    nb_resources = event_data["nb_compute_resources"].GetInt();
                }
                else
                {
                    nb_resources = event_data["nb_resources"].GetInt();
                }

                algo->set_nb_machines(nb_resources);
                algo->on_simulation_start(current_date, event_data["config"]);
            }
            else if (event_type == "SIMULATION_ENDS")
            {
                algo->on_simulation_end(current_date);
                simulation_finished = true;
            }
            else if (event_type == "JOB_SUBMITTED")
            {
                string job_id = event_data["job_id"].GetString();

                workload.add_job_from_json_object(event_data["job"], job_id, current_date);

                algo->on_job_release(current_date, { job_id });
            }
            else if (event_type == "JOB_COMPLETED")
            {
                string job_id = event_data["job_id"].GetString();
                workload[job_id]->completion_time = current_date;
                algo->on_job_end(current_date, { job_id });
            }
            else if (event_type == "RESOURCE_STATE_CHANGED")
            {
                IntervalSet resources = IntervalSet::from_string_hyphen(event_data["resources"].GetString(), " ");
                string new_state = event_data["state"].GetString();
                algo->on_machine_state_changed(current_date, resources, std::stoi(new_state));
            }
            else if (event_type == "JOB_KILLED")
            {
                const r::Value &job_ids_map = event_data["job_progress"];
                PPK_ASSERT_ERROR(job_ids_map.GetType() == r::kObjectType);

                vector<string> job_ids;

                for (auto itr = job_ids_map.MemberBegin(); itr != job_ids_map.MemberEnd(); ++itr)
                {
                    string job_id = itr->name.GetString();
                    job_ids.push_back(job_id);
                }

                algo->on_job_killed(current_date, job_ids);
            }
            else if (event_type == "REQUESTED_CALL")
            {
                requested_callback_received = true;
                algo->on_requested_call(current_date);
            }
            else if (event_type == "ANSWER")
            {
                for (auto itr = event_data.MemberBegin(); itr != event_data.MemberEnd(); ++itr)
                {
                    string key_value = itr->name.GetString();

                    if (key_value == "consumed_energy")
                    {
                        double consumed_joules = itr->value.GetDouble();
                        algo->on_answer_energy_consumption(current_date, consumed_joules);
                    }
                    else
                    {
                        PPK_ASSERT_ERROR(false, "Unknown ANSWER type received '%s'", key_value.c_str());
                    }
                }
            }
            else if (event_type == "QUERY")
            {
                const r::Value &requests = event_data["requests"];

                for (auto itr = requests.MemberBegin(); itr != requests.MemberEnd(); ++itr)
                {
                    string key_value = itr->name.GetString();

                    if (key_value == "estimate_waiting_time")
                    {
                        const r::Value &request_object = itr->value;
                        string job_id = request_object["job_id"].GetString();
                        workload.add_job_from_json_object(request_object["job"], job_id, current_date);

                        algo->on_query_estimate_waiting_time(current_date, job_id);
                    }
                    else
                    {
                        PPK_ASSERT_ERROR(false, "Unknown QUERY type received '%s'", key_value.c_str());
                    }
                }
            }
            else if (event_type == "NOTIFY")
            {
                string notify_type = event_data["type"].GetString();

                if (notify_type == "no_more_static_job_to_submit")
                {
                    algo->on_no_more_static_job_to_submit_received(current_date);
                }
                else if (notify_type == "no_more_external_event_to_occur")
                {
                    algo->on_no_more_external_event_to_occur(current_date);
                }
                else if (notify_type == "event_machine_available")
                {
                    IntervalSet resources = IntervalSet::from_string_hyphen(event_data["resources"].GetString(), " ");
                    algo->on_machine_available_notify_event(current_date, resources);
                }
                else if (notify_type == "event_machine_unavailable")
                {
                    IntervalSet resources = IntervalSet::from_string_hyphen(event_data["resources"].GetString(), " ");
                    algo->on_machine_unavailable_notify_event(current_date, resources);
                }
                else
                {
                    throw runtime_error("Unknown NOTIFY type received. Type = " + notify_type);
                }
            }
            else
            {
                throw runtime_error("Unknown event received. Type = " + event_type);
            }
        }

        bool requested_callback_only = requested_callback_received && (events_array.Size() == 1);

        // make_decisions is not called if (!call_make_decisions_on_single_nop && single_nop_received)
        if (!(!call_make_decisions_on_single_nop && requested_callback_only))
        {
            SortableJobOrder::UpdateInformation update_info(current_date);
            algo->make_decisions(message_date, &update_info, nullptr);
            algo->clear_recent_data_structures();
        }

        message_date = max(message_date, d.last_date());

        const string &message_to_send = d.content(message_date);
        n.write(message_to_send);
    }
}