diff --git a/CODES-compile-instructions.sh b/CODES-compile-instructions.sh index 23f862ae..2b4f781e 100644 --- a/CODES-compile-instructions.sh +++ b/CODES-compile-instructions.sh @@ -37,6 +37,17 @@ else echo "Using existing ross checkout: $(realpath ross)" fi + +if [ "$torch_enable" = 1 ]; then + make_args_codes=( + "${make_args_codes[@]}" + ) +else + make_args_codes=( + "${make_args_codes[@]}" + ) +fi + if [ $swm_enable = 1 ]; then if [ ! -d argobots/.git ]; then git clone https://github.com/pmodels/argobots --depth=1 @@ -192,53 +203,41 @@ fi -# Make system pkg-config metadata visible even when Conda's pkg-config is active. -# This is needed for libzmq.pc on systems where ZeroMQ is installed through the OS -# but the active Conda environment's pkg-config only searches Conda pkgconfig dirs. -if ! pkg-config --exists libzmq 2>/dev/null; then - for pcdir in \ - /usr/lib/x86_64-linux-gnu/pkgconfig \ - /usr/lib64/pkgconfig \ - /usr/lib/pkgconfig \ - /usr/local/lib/pkgconfig \ - /usr/local/lib64/pkgconfig \ - /opt/homebrew/lib/pkgconfig \ - /usr/share/pkgconfig - do - if [ -d "$pcdir" ]; then - export PKG_CONFIG_PATH="$pcdir:${PKG_CONFIG_PATH:-}" - fi - done -fi +if [ "$torch_enable" = 1 ]; then + # Make system pkg-config metadata visible even when Conda's pkg-config is active. + # This is needed for libzmq.pc on systems where ZeroMQ is installed through the OS + # but the active Conda environment's pkg-config only searches Conda pkgconfig dirs. + if ! pkg-config --exists libzmq 2>/dev/null; then + for pcdir in \ + /usr/lib/x86_64-linux-gnu/pkgconfig \ + /usr/lib64/pkgconfig \ + /usr/lib/pkgconfig \ + /usr/local/lib/pkgconfig \ + /usr/local/lib64/pkgconfig \ + /opt/homebrew/lib/pkgconfig \ + /usr/share/pkgconfig + do + if [ -d "$pcdir" ]; then + export PKG_CONFIG_PATH="$pcdir:${PKG_CONFIG_PATH:-}" + fi + done + fi + + if ! pkg-config --exists libzmq 2>/dev/null; then + echo "WARNING: pkg-config still cannot find libzmq.pc." >&2 + echo " If ZMQML requester support fails to build, install the ZeroMQ development package" >&2 + echo " or set PKG_CONFIG_PATH to the directory containing libzmq.pc." >&2 + fi -if ! pkg-config --exists libzmq 2>/dev/null; then - echo "WARNING: pkg-config still cannot find libzmq.pc." >&2 - echo " If ZMQML fails to build, install the ZeroMQ development package" >&2 - echo " or set PKG_CONFIG_PATH to the directory containing libzmq.pc." >&2 + # Build local ZMQML requester library required by director-client.C. + pushd codes/src/surrogate/zmqml + make clean + make + test -f libzmqmlrequester.so + test -f zmqmlrequester.h + popd fi -# Build local ZMQML requester library required by director-client.C -pushd codes/src/surrogate/zmqml -make clean -make -test -f libzmqmlrequester.so -test -f zmqmlrequester.h -popd - -# Make imported zmqmlrequester target visible to doc/example and tests. -python3 - <<'INNERPY' -from pathlib import Path -cm = Path("codes/src/CMakeLists.txt") -text = cm.read_text() -old = "add_library(zmqmlrequester SHARED IMPORTED )" -new = "add_library(zmqmlrequester SHARED IMPORTED GLOBAL)" -if old in text: - cm.write_text(text.replace(old, new)) -elif new in text: - pass -else: - raise SystemExit("Could not find zmqmlrequester imported target line in codes/src/CMakeLists.txt") -INNERPY mkdir -p codes/build pushd codes/build @@ -368,10 +367,8 @@ make_args_codes=( -DCMAKE_USE_WIN32_THREADS_INIT=0 -DCMAKE_BUILD_TYPE=Debug -DBUILD_TESTING=ON -DCMAKE_INSTALL_PREFIX="$(realpath bin)" - -DZMQML_BUILD_PATH="$(realpath "$CUR_DIR/codes/src/surrogate/zmqml")" - -DZeroMQ_INCLUDE_DIR=/usr/include - -DZeroMQ_LIBRARY=/usr/lib/x86_64-linux-gnu/libzmq.so ) + if [ $swm_enable = 1 ]; then make_args_codes=( "${make_args_codes[@]}" @@ -390,6 +387,10 @@ if [ "$torch_enable" = 1 ]; then "${make_args_codes[@]}" -DUSE_TORCH=true -DTorch_DIR="${torch_dir}" + -DUSE_ZMQML=true + -DZMQML_BUILD_PATH="$(realpath "$CUR_DIR/codes/src/surrogate/zmqml")" + -DZeroMQ_INCLUDE_DIR=/usr/include + -DZeroMQ_LIBRARY=/usr/lib/x86_64-linux-gnu/libzmq.so ) if [ -n "${CUDA_HOME:-}" ]; then @@ -412,7 +413,10 @@ if [ "$torch_enable" = 1 ]; then ) fi else - make_args_codes=("${make_args_codes[@]}" -DUSE_TORCH=false) + make_args_codes=( + "${make_args_codes[@]}" + -DUSE_TORCH=false + ) fi cmake .. "${make_args_codes[@]}" diff --git a/README.md b/README.md index 92b61cf4..56820db6 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,7 @@ This repo uses [clang-format](https://clang.llvm.org/docs/ClangFormat.html) to k - **Emacs:** see [clang-format.el](https://clang.llvm.org/docs/ClangFormat.html#emacs-integration). To reformat a file manually: `clang-format -i path/to/file.c`. CI runs `clang-format --dry-run --Werror` on every PR and rejects any drift, so PRs with unformatted code don't merge. +Note: The CI uses clang-format major release version 20, so you should format your files with that version. ### Determinism diff --git a/codes/surrogate/director-client.h b/codes/surrogate/director-client.h index aaea3d09..fa563c6a 100644 --- a/codes/surrogate/director-client.h +++ b/codes/surrogate/director-client.h @@ -125,8 +125,7 @@ extern "C" { extern void director_lp_register_model(const char*); - - +extern void director_record_external_zmq_latency(double processing_sec, double total_sec); /* extern void director_parse_args(char *args, int **args_array, int *length); static void director_issue_codes_event(director_state * s, tw_lpid nw_lpid, int dir_registered_event_type, tw_stime ts, tw_lp* lp); @@ -142,5 +141,4 @@ extern void dir_test_finalize(director_state* s, tw_lp* lp); #ifdef __cplusplus } #endif - #endif diff --git a/doc/example/kb.dfdally-72-zeromq-director.conf.in b/doc/example/kb.dfdally-72-zeromq-director.conf.in index 656959c4..fdb77ec7 100644 --- a/doc/example/kb.dfdally-72-zeromq-director.conf.in +++ b/doc/example/kb.dfdally-72-zeromq-director.conf.in @@ -23,19 +23,19 @@ LPGROUPS DIRECTOR { - start_iter="${DIRECTOR_START_ITER}"; - end_iter="${DIRECTOR_END_ITER}"; + start_iter="${START_ITER}"; + end_iter="${END_ITER}"; # Optional one-shot pause/retrain/resume pipeline. # First implementation is intended for --synch=1. - retrain_enabled="${DIRECTOR_RETRAIN_ENABLED}"; - retrain_iter="${DIRECTOR_RETRAIN_ITER}"; - retrain_save_path="${DIRECTOR_RETRAIN_SAVE_PATH}"; + retrain_enabled="${RETRAIN_ENABLED}"; + retrain_iter="${RETRAIN_ITER}"; + retrain_save_path="${RETRAIN_SAVE_PATH}"; # Optional second surrogate window after retraining. - second_surrogate_enabled="${DIRECTOR_SECOND_SURROGATE_ENABLED}"; - second_start_iter="${DIRECTOR_SECOND_START_ITER}"; - second_end_iter="${DIRECTOR_SECOND_END_ITER}"; + second_surrogate_enabled="${SECOND_SURROGATE_ENABLED}"; + second_start_iter="${SECOND_START_ITER}"; + second_end_iter="${SECOND_END_ITER}"; # Common modes: # diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 99430538..4e834b90 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,3 +1,4 @@ +option(USE_ZMQML "Enable ZeroMQ ML requester support" OFF) cmake_print_variables(CMAKE_CURRENT_SOURCE_DIR) find_package(FLEX REQUIRED) @@ -124,68 +125,38 @@ if(USE_TORCH) list(APPEND LIBS_TO_LINK ${TORCH_LIBRARIES}) endif() -# ZMQML / director-client (opt-in). When USE_ZMQML=ON, callers must -# point ZMQML_BUILD_PATH at a directory containing libzmqmlrequester.so -# (build it via src/surrogate/zmqml/Makefile, or set ZMQML_BUILD_PATH to -# wherever you installed it). When OFF (the default), CODES builds with -# no surrogate/director-client linkage; configs that reference -# "dir-nw-lp" will fail at runtime because the LP type isn't registered. -option(USE_ZMQML "Build the director-client + zmqml surrogate integration" OFF) +# ZMQML requester support if(USE_ZMQML) - if(NOT ZMQML_BUILD_PATH) - message(FATAL_ERROR - "USE_ZMQML=ON but ZMQML_BUILD_PATH is unset.\n" - "Build src/surrogate/zmqml/libzmqmlrequester.so first, then " - "reconfigure with -DZMQML_BUILD_PATH=.") - endif() list(APPEND SRCS surrogate/director-client.C) -endif() -add_library(codes STATIC ${SRCS}) - -list(APPEND LIBS_TO_LINK ${MPI_C_LIBRARIES}) -target_include_directories(codes INTERFACE ${MPI_C_INCLUDE_PATH}) + if(NOT DEFINED ZMQML_BUILD_PATH) + message(FATAL_ERROR "USE_ZMQML is ON, but ZMQML_BUILD_PATH is not defined.") + endif() -# set(LIBS_TO_LINK -# PkgConfig::ROSS -# ${DUMPI_LIB} -# PkgConfig::ARGOBOTS -# PkgConfig::SWM -# ) + if(NOT EXISTS "${ZMQML_BUILD_PATH}/libzmqmlrequester.so") + message(FATAL_ERROR "USE_ZMQML is ON, but ${ZMQML_BUILD_PATH}/libzmqmlrequester.so does not exist. Re-run CODES-compile-instructions.sh so the local requester library is built before configuring CODES.") + endif() -#LINK DUMPI -# target_link_libraries(codes PUBLIC ${DUPMI_LIB}) -if(USE_DUMPI) - target_include_directories(codes PUBLIC ${DUMPI_INCLUDE}) -endif() + pkg_check_modules(PC_ZeroMQ QUIET zmq) + find_path(ZeroMQ_INCLUDE_DIR NAMES zmq.hpp PATHS ${PC_ZeroMQ_INCLUDE_DIRS}) + find_library(ZeroMQ_LIBRARY NAMES zmq PATHS ${PC_ZeroMQ_LIBRARY_DIRS}) -#LINK ARGOBOTS, SWM and UNION -# target_link_libraries(codes PUBLIC PkgConfig::ARGOBOTS) -if(USE_ONLINE) - if(USE_SWM) - target_include_directories(codes PUBLIC ${ARGOBOTS_INCLUDE_DIRS}) - # target_link_libraries(codes PUBLIC PkgConfig::SWM) - target_include_directories(codes PUBLIC ${SWM_INCLUDE_DIRS}) + if(NOT ZeroMQ_LIBRARY) + message(FATAL_ERROR "USE_ZMQML is ON, but libzmq was not found.") endif() - if(USE_UNION) - target_include_directories(codes PUBLIC ${ARGOBOTS_INCLUDE_DIRS}) - # target_link_libraries(codes PUBLIC PkgConfig::SWM) - target_include_directories(codes PUBLIC ${SWM_INCLUDE_DIRS}) - target_include_directories(codes PUBLIC ${UNION_INCLUDE_DIRS}) - endif() -endif() -if(USE_ZMQML) add_library(zmqmlrequester SHARED IMPORTED GLOBAL) set_target_properties(zmqmlrequester PROPERTIES IMPORTED_LOCATION "${ZMQML_BUILD_PATH}/libzmqmlrequester.so" INTERFACE_INCLUDE_DIRECTORIES "${ZMQML_BUILD_PATH}") - target_compile_definitions(codes PUBLIC USE_ZMQML) endif() #LINK ROSS # target_link_libraries(codes PUBLIC #{pkgcfg_lib_ROSS_ROSS}) # target_link_libraries(codes PUBLIC PkgConfig::ROSS) + +add_library(codes ${SRCS}) + target_include_directories(codes PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} ${ROSS_INCLUDE_DIRS} @@ -194,11 +165,16 @@ target_include_directories(codes PUBLIC ${PROJECT_SOURCE_DIR}/codes ${PROJECT_SOURCE_DIR}/src ${PROJECT_SOURCE_DIR}/src/modelconfig - $<$:$> ) target_link_libraries(codes PUBLIC ${LIBS_TO_LINK}) +if(USE_ZMQML) + target_compile_definitions(codes PUBLIC USE_ZMQML) + target_include_directories(codes PUBLIC "${ZMQML_BUILD_PATH}" ${ZeroMQ_INCLUDE_DIR}) + target_link_libraries(codes PUBLIC zmqmlrequester ${ZeroMQ_LIBRARY}) +endif() + get_target_property(CODES_INCLUDE_DIRS codes INCLUDE_DIRECTORIES) cmake_print_variables(CODES_INCLUDE_DIRS) @@ -227,18 +203,12 @@ if(USE_DUMPI) list(APPEND CODES_TARGETS model-net-dumpi-traces-dump) endif() -# ZMQ — only resolved + linked when USE_ZMQML is on; otherwise nothing -# in the codes library calls into libzmq. -if(USE_ZMQML) - pkg_check_modules(PC_ZeroMQ QUIET zmq) - find_path(ZeroMQ_INCLUDE_DIR NAMES zmq.hpp PATHS ${PC_ZeroMQ_INCLUDE_DIRS}) - find_library(ZeroMQ_LIBRARY NAMES zmq PATHS ${PC_ZeroMQ_LIBRARY_DIRS}) -endif() - foreach(tar IN LISTS CODES_TARGETS) target_include_directories(${tar} PUBLIC ${CODES_INCLUDE_DIRS} ${ROSS_INCLUDE_DIRS}) target_link_libraries(${tar} PUBLIC codes ${LIBS_TO_LINK}) + if(USE_ZMQML) + target_include_directories(${tar} PUBLIC "${ZMQML_BUILD_PATH}" ${ZeroMQ_INCLUDE_DIR}) target_link_libraries(${tar} PUBLIC zmqmlrequester ${ZeroMQ_LIBRARY}) endif() endforeach() diff --git a/src/networks/model-net/dragonfly-dally.C b/src/networks/model-net/dragonfly-dally.C index b4c064cf..d5470dfa 100644 --- a/src/networks/model-net/dragonfly-dally.C +++ b/src/networks/model-net/dragonfly-dally.C @@ -22,6 +22,7 @@ #include "codes/model-net-method.h" #include "codes/model-net-lp.h" #include "codes/surrogate/init.h" +#include "codes/surrogate/director-client.h" #ifdef USE_TORCH #include "codes/surrogate/packet-latency-predictor/torch-jit.h" #endif @@ -64,11 +65,13 @@ * resolve it to null.) */ #ifdef USE_ZMQML -extern std::vector zmqml_director_request(const std::string& surrogate_family, - const std::string& surrogate_backend, - const std::string& operation, - const std::vector& args, - const std::string& bindata); +extern std::vector +zmqml_director_request(const std::string& surrogate_family, const std::string& surrogate_backend, + const std::string& operation, const std::vector& args, + const std::string& bindata) __attribute__((weak)); + +extern "C" void director_record_external_zmq_latency(double processing_sec, double total_sec) + __attribute__((weak)); extern void director_record_zmq_latency_stats(const char* label, const std::vector& ret, @@ -297,7 +300,20 @@ static std::vector dfdally_event_time_director_request_with_latency (double)(finish.tv_nsec - start.tv_nsec) / 1000000000.0; #ifdef USE_ZMQML - director_record_zmq_latency_stats(label, ret, local_latency_sec); + double zmq_processing_time = 0.0; + + if (ret.size() > 1) { + char* endptr = NULL; + double parsed = strtod(ret[1].c_str(), &endptr); + + if (endptr != ret[1].c_str() && isfinite(parsed) && parsed >= 0.0) { + zmq_processing_time = parsed; + } + } + + if (director_record_external_zmq_latency) { + director_record_external_zmq_latency(zmq_processing_time, local_latency_sec); + } #endif return ret; @@ -2695,14 +2711,8 @@ static void dragonfly_read_config(const char* anno, dragonfly_param* params) { char event_time_inference_enabled_str[MAX_NAME_LENGTH]; event_time_inference_enabled_str[0] = '\0'; - char const* inferencing_enabled_env = getenv("INFERENCING_ENABLED"); - if (inferencing_enabled_env && strlen(inferencing_enabled_env) > 0) { - snprintf(event_time_inference_enabled_str, sizeof(event_time_inference_enabled_str), "%s", - inferencing_enabled_env); - } else { - configuration_get_value(&config, "DIRECTOR", "inferencing_enabled", anno, - event_time_inference_enabled_str, MAX_NAME_LENGTH); - } + configuration_get_value(&config, "DIRECTOR", "inferencing_enabled", anno, + event_time_inference_enabled_str, MAX_NAME_LENGTH); /* * Do not expose a separate event-time inference flag. @@ -2737,14 +2747,8 @@ static void dragonfly_read_config(const char* anno, dragonfly_param* params) { char event_time_training_enabled_str[MAX_NAME_LENGTH]; event_time_training_enabled_str[0] = '\0'; - char const* training_enabled_env = getenv("TRAINING_ENABLED"); - if (training_enabled_env && strlen(training_enabled_env) > 0) { - snprintf(event_time_training_enabled_str, sizeof(event_time_training_enabled_str), "%s", - training_enabled_env); - } else { - configuration_get_value(&config, "DIRECTOR", "training_enabled", anno, - event_time_training_enabled_str, MAX_NAME_LENGTH); - } + configuration_get_value(&config, "DIRECTOR", "training_enabled", anno, + event_time_training_enabled_str, MAX_NAME_LENGTH); event_time_surrogate_family_selected = strcmp(event_time_surrogate_family_str, "event-time") == 0; @@ -2765,15 +2769,6 @@ static void dragonfly_read_config(const char* anno, dragonfly_param* params) { event_time_zmq_flush_registered = 1; } - if (dfdally_surrogate_debug_prints) { - fprintf(stderr, - "[event-time records] family=%s training_enabled=%s send_to_zmq=%d batch_size=%d\n", - event_time_surrogate_family_str, event_time_training_enabled_str, - event_time_training_records_enabled, event_time_zmq_batch_size); - fflush(stderr); - } - - // START Surrogate configuration char enable_str[MAX_NAME_LENGTH]; enable_str[0] = '\0'; @@ -2798,6 +2793,14 @@ static void dragonfly_read_config(const char* anno, dragonfly_param* params) { dfdally_surrogate_debug_prints = dfdally_string_is_true(debug_prints_str); + if (dfdally_surrogate_debug_prints) { + fprintf(stderr, + "[event-time records] family=%s training_enabled=%s send_to_zmq=%d batch_size=%d\n", + event_time_surrogate_family_str, event_time_training_enabled_str, + event_time_training_records_enabled, event_time_zmq_batch_size); + fflush(stderr); + } + // if surrogate mode has been set up if (enable_network_surrogate) { struct network_surrogate_config surr_conf = { diff --git a/src/surrogate/director-client.C b/src/surrogate/director-client.C index 3b115dc0..b6f50757 100644 --- a/src/surrogate/director-client.C +++ b/src/surrogate/director-client.C @@ -23,7 +23,7 @@ #define DIR_ZMQ_CMD_LENGTH 64 #define DIR_ZMQ_ARG_LENGTH 2048 -#define DIR_MAX_PREDICTION 5 +#define DIR_MAX_PREDICTION 1 #define DIR_MAX_TRAINING_RECORDS 10 /* * The Python iteration-time model currently uses history_len=2 and horizon=3, @@ -86,29 +86,40 @@ std::vector director_client_request_family(const char* surrogate_fa int surrogate_enabled = 0; int inferencing_enabled = 1; - -void director_record_zmq_latency_stats(const char* label, const std::vector& ret, - double local_latency_sec) { - (void)label; - +static void director_record_zmq_latency_values(double processing_sec, double total_sec) { if (evaluate_perf != 1) { return; } - director_zmq_total_elapsed_times.push_back(local_latency_sec); + if (!std::isfinite(processing_sec) || processing_sec < 0.0) { + processing_sec = 0.0; + } + + if (!std::isfinite(total_sec) || total_sec < 0.0) { + total_sec = 0.0; + } + director_zmq_processing_times.push_back(processing_sec); + director_zmq_total_elapsed_times.push_back(total_sec); +} + +static void director_record_zmq_latency_stats(const char* label, + const std::vector& ret, + double local_latency_sec) { double zmq_processing_time = 0.0; + if (ret.size() > 1) { - try { - zmq_processing_time = std::stod(ret[1]); - } catch (...) { - if (director_debug_prints) { - fprintf(stderr, - "[DIR] Warning: could not parse zmq processing time from reply field " - "ret[1]=%s\n", - ret[1].c_str()); - fflush(stderr); - } + char* endptr = NULL; + double parsed = strtod(ret[1].c_str(), &endptr); + + if (endptr != ret[1].c_str() && std::isfinite(parsed) && parsed >= 0.0) { + zmq_processing_time = parsed; + } else if (director_debug_prints) { + fprintf(stderr, + "[DIR] Warning: could not parse zmq processing time from reply field ret[1]=%s " + "request=%s\n", + ret[1].c_str(), label ? label : ""); + fflush(stderr); } } else if (director_debug_prints) { fprintf(stderr, @@ -117,9 +128,12 @@ void director_record_zmq_latency_stats(const char* label, const std::vectordirector_id, - DIR_MAX_PREDICTION); // num-of-args;num-record + sprintf(args, "%d;%llu;%d;", 2, (unsigned long long)s->director_id, + DIR_MAX_PREDICTION); // num-of-args;client-id;num-predictions // The Python side primarily uses records previously sent through // send-records. Keep the payload empty for now rather than sending @@ -1246,31 +1260,29 @@ void director_event_handler_commit(director_state* s, tw_bf* bf, director_messag static void director_reduce_and_print_zmq_latency_stat(const char* stat_name, const std::vector& local_values) { - unsigned long long local_count = (unsigned long long)local_values.size(); - + unsigned long long local_count = 0; double local_sum = 0.0; double local_sq_sum = 0.0; double local_min = std::numeric_limits::infinity(); double local_max = -std::numeric_limits::infinity(); for (double value : local_values) { - local_sum += value; - local_sq_sum += value * value; - - if (value < local_min) { - local_min = value; + if (!std::isfinite(value)) { + continue; } - if (value > local_max) { - local_max = value; - } + local_count++; + local_sum += value; + local_sq_sum += value * value; + local_min = std::min(local_min, value); + local_max = std::max(local_max, value); } unsigned long long global_count = 0; double global_sum = 0.0; double global_sq_sum = 0.0; - double global_min = 0.0; - double global_max = 0.0; + double global_min = std::numeric_limits::infinity(); + double global_max = -std::numeric_limits::infinity(); MPI_Reduce(&local_count, &global_count, 1, MPI_UNSIGNED_LONG_LONG, MPI_SUM, 0, MPI_COMM_CODES); @@ -1293,14 +1305,14 @@ static void director_reduce_and_print_zmq_latency_stat(const char* stat_name, double variance = global_sq_sum / (double)global_count - mean * mean; /* - * Floating-point roundoff can make variance slightly negative when - * values are very close together. + * Floating-point roundoff can make variance slightly negative when values + * are very close together. */ if (variance < 0.0 && variance > -1.0e-18) { variance = 0.0; } - double stddev = sqrt(variance); + double stddev = variance > 0.0 ? sqrt(variance) : 0.0; std::cout << std::setprecision(9) << std::fixed << "==DIR_STATS " << stat_name << ": requests = " << global_count << ", mean = " << mean << ", min = " << global_min @@ -1335,6 +1347,7 @@ static void director_print_zmq_latency_stats_once(void) { director_zmq_total_elapsed_times); } + void director_finalize(director_state* s, tw_lp* lp) { director_print_zmq_latency_stats_once(); @@ -1365,8 +1378,7 @@ tw_lptype dir_lp = {(init_f)director_init, extern void director_lp_register_model(const char* dir_lp_name) { int num_dir_per_mgrp = codes_mapping_get_lp_count("MODELNET_GRP", 1, "dir-nw-lp", NULL, 0); if (num_dir_per_mgrp > 0) { - lp_type_register(dir_lp_name, &dir_lp); // DIRECTOR addition - register type - //printf("\n==DIR: Registered\n"); + lp_type_register(dir_lp_name, &dir_lp); } } diff --git a/src/surrogate/zmqml/model/mliterationtime.py b/src/surrogate/zmqml/model/mliterationtime.py index 6fda3461..ce9877ce 100644 --- a/src/surrogate/zmqml/model/mliterationtime.py +++ b/src/surrogate/zmqml/model/mliterationtime.py @@ -302,6 +302,51 @@ def _predict_once(self, client_id: int, history: List[float], iteration: int) -> return np.asarray(cleaned, dtype=np.float64) + def _global_fallback_prediction(self, requested_horizon: int) -> List[float]: + requested_horizon = max(1, int(requested_horizon)) + + recent_values: List[float] = [] + for model in self.models.values(): + if model.records: + recent_values.extend(model.records[-max(1, self.history_len):]) + + recent_values = _as_positive_finite(recent_values) + + if recent_values: + value = float(np.median(np.asarray(recent_values, dtype=np.float64))) + elif self.y_mean is not None and len(self.y_mean) > 0: + value = float(self.y_mean[0]) + else: + # Match the older iteration-time fallback scale instead of using 1.0. + value = 2_000_000.0 + + if not np.isfinite(value) or value <= 0.0: + value = 2_000_000.0 + + return [value for _ in range(requested_horizon)] + + def _global_fallback_prediction(self, requested_horizon: int) -> List[float]: + requested_horizon = max(1, int(requested_horizon)) + + recent_values: List[float] = [] + for model in self.models.values(): + if model.records: + recent_values.extend(model.records[-max(1, self.history_len):]) + + recent_values = _as_positive_finite(recent_values) + + if recent_values: + value = float(np.median(np.asarray(recent_values, dtype=np.float64))) + elif self.y_mean is not None and len(self.y_mean) > 0: + value = float(self.y_mean[0]) + else: + value = 2_000_000.0 + + if not np.isfinite(value) or value <= 0.0: + value = 2_000_000.0 + + return [value for _ in range(requested_horizon)] + def predict(self, client_id: int, requested_horizon: int | None = None) -> List[float]: client_id = int(client_id) requested_horizon = int(requested_horizon or self.horizon) @@ -310,7 +355,15 @@ def predict(self, client_id: int, requested_horizon: int | None = None) -> List[ model = self.get(client_id) if not model.records: - return model._fallback_prediction(requested_horizon) + fallback = self._global_fallback_prediction(requested_horizon) + if self.debug: + print( + "[IterationTimeModelRegistry] predict global-fallback " + f"client={client_id} requested_horizon={requested_horizon} " + f"trained={int(self.trained)} predictions={fallback}", + flush=True, + ) + return fallback # Predict in chunks if requested_horizon > self.horizon. out: List[float] = [] @@ -326,6 +379,9 @@ def predict(self, client_id: int, requested_horizon: int | None = None) -> List[ history.append(float(value)) iteration += 1 + if not out: + out = self._global_fallback_prediction(requested_horizon) + if self.debug: print( "[IterationTimeModelRegistry] predict " @@ -337,6 +393,7 @@ def predict(self, client_id: int, requested_horizon: int | None = None) -> List[ return out + def save(self, path: str) -> None: out_path = Path(path) if out_path.parent: diff --git a/src/surrogate/zmqml/zmqmlserver.py b/src/surrogate/zmqml/zmqmlserver.py index fb4b7944..4e65588c 100755 --- a/src/surrogate/zmqml/zmqmlserver.py +++ b/src/surrogate/zmqml/zmqmlserver.py @@ -20,6 +20,8 @@ from model.mliterationtime import IterationTimeModelRegistry from model.mleventtime import EventTimeModel import csv +import io +import pickle from pathlib import Path import os @@ -37,19 +39,197 @@ launch_id = count(start=1) # unique for launched thread launched_threads = {} # id:obj. keep track of active threads. remove the thread once it finished -training_records = {} # client_id:[] -iteration_time_models = IterationTimeModelRegistry( - history_len=int(os.environ.get("ZMQML_ITERATION_HISTORY_LEN", "4")), - horizon=int(os.environ.get("ZMQML_ITERATION_HORIZON", "30")), - ridge_alpha=float(os.environ.get("ZMQML_ITERATION_RIDGE_ALPHA", "1.0")), - train_stride=int(os.environ.get("ZMQML_ITERATION_TRAIN_STRIDE", "3")), -) -event_time_model = EventTimeModel( - min_rows=int(os.environ.get("ZMQML_EVENT_TIME_MIN_ROWS", "32")), - max_epochs=int(os.environ.get("ZMQML_EVENT_TIME_EPOCHS", "80")), - lr=float(os.environ.get("ZMQML_EVENT_TIME_LR", "0.001")), - hidden_dim=int(os.environ.get("ZMQML_EVENT_TIME_HIDDEN_DIM", "64")), -) +training_records = {} # client_id -> [] + +ITERATION_MODEL_KWARGS = { + "history_len": int(os.environ.get("ZMQML_ITERATION_HISTORY_LEN", "4")), + "horizon": int(os.environ.get("ZMQML_ITERATION_HORIZON", "30")), + "ridge_alpha": float(os.environ.get("ZMQML_ITERATION_RIDGE_ALPHA", "1.0")), + "train_stride": int(os.environ.get("ZMQML_ITERATION_TRAIN_STRIDE", "3")), +} + +EVENT_TIME_MODEL_KWARGS = { + "min_rows": int(os.environ.get("ZMQML_EVENT_TIME_MIN_ROWS", "32")), + "max_epochs": int(os.environ.get("ZMQML_EVENT_TIME_EPOCHS", "80")), + "lr": float(os.environ.get("ZMQML_EVENT_TIME_LR", "0.001")), + "hidden_dim": int(os.environ.get("ZMQML_EVENT_TIME_HIDDEN_DIM", "64")), +} + +DEFAULT_TERMINAL_MODEL_SCOPE = "terminal" +DEFAULT_TERMINAL_MODEL_KEY = "global" + +# The current dragonfly-dally event-time rows use numeric LP-type fields. +# By default, current_lp_type=0 is treated as a terminal LP; all other LP +# types get switch-local models. Override if the enum changes. +EVENT_TIME_TERMINAL_LP_TYPES = { + token.strip() + for token in os.environ.get("ZMQML_EVENT_TIME_TERMINAL_LP_TYPES", "0").split(",") + if token.strip() +} + + +def normalize_model_identity(model_scope: str | None = None, model_key: str | None = None) -> tuple[str, str, str]: + scope = str(model_scope or "").strip() + key = str(model_key or "").strip() + + if not scope: + scope = DEFAULT_TERMINAL_MODEL_SCOPE + + if scope in ("terminal", "term", "client", "global"): + scope = DEFAULT_TERMINAL_MODEL_SCOPE + key = DEFAULT_TERMINAL_MODEL_KEY + elif scope in ("router", "switch", "switch-lp", "router-lp"): + scope = "switch" + if not key: + key = "unknown" + else: + if not key: + key = DEFAULT_TERMINAL_MODEL_KEY + + model_id = f"{scope}:{key}" + return scope, key, model_id + + +def model_identity_from_real_args( + real_args: list[str], + *, + offset: int, + default_scope: str = DEFAULT_TERMINAL_MODEL_SCOPE, + default_key: str = DEFAULT_TERMINAL_MODEL_KEY, +) -> tuple[str, str, str]: + if len(real_args) >= offset + 2: + return normalize_model_identity(real_args[offset], real_args[offset + 1]) + return normalize_model_identity(default_scope, default_key) + + + + +class ScopedEventTimeModelRegistry: + def __init__(self, kwargs: dict): + self.kwargs = dict(kwargs) + self.models: dict[str, EventTimeModel] = {} + self.debug = False + self.model_versions: dict[str, int] = {} + + def set_debug(self, enabled: bool) -> None: + self.debug = bool(enabled) + for model in self.models.values(): + model.set_debug(self.debug) + + def get(self, model_scope: str | None = None, model_key: str | None = None) -> EventTimeModel: + _, _, model_id = normalize_model_identity(model_scope, model_key) + if model_id not in self.models: + model = EventTimeModel(**self.kwargs) + model.set_debug(self.debug) + self.models[model_id] = model + self.model_versions.setdefault(model_id, 0) + return self.models[model_id] + + def model_id(self, model_scope: str | None = None, model_key: str | None = None) -> str: + return normalize_model_identity(model_scope, model_key)[2] + + def train_or_update(self, model_scope: str | None = None, model_key: str | None = None) -> bool: + if model_scope in (None, "", "all", "*"): + trained_any = False + for model_id, model in self.models.items(): + trained = model.train_or_update() + if trained: + self.model_versions[model_id] = self.model_versions.get(model_id, 0) + 1 + trained_any = trained or trained_any + return trained_any + + model_id = self.model_id(model_scope, model_key) + model = self.get(model_scope, model_key) + trained = model.train_or_update() + if trained: + self.model_versions[model_id] = self.model_versions.get(model_id, 0) + 1 + return trained + + @staticmethod + def _safe_filename(model_id: str) -> str: + return model_id.replace("/", "_").replace(":", "__") + + def save(self, path: str | Path, model_scope: str | None = None, model_key: str | None = None) -> None: + path = Path(path) + + if model_scope not in (None, "", "all", "*"): + model = self.get(model_scope, model_key) + if path.parent: + path.parent.mkdir(parents=True, exist_ok=True) + model.save(path) + return + + path.mkdir(parents=True, exist_ok=True) + manifest = { + "format": "scoped-event-time-directory-v1", + "models": {}, + } + + for model_id, model in sorted(self.models.items()): + filename = self._safe_filename(model_id) + ".pt" + model.save(path / filename) + manifest["models"][model_id] = filename + + (path / "manifest.json").write_text(json.dumps(manifest, indent=2, sort_keys=True)) + + def load(self, path: str | Path, model_scope: str | None = None, model_key: str | None = None) -> None: + path = Path(path) + + if path.is_dir(): + manifest_path = path / "manifest.json" + if not manifest_path.exists(): + raise FileNotFoundError(f"missing scoped event-time manifest: {manifest_path}") + + manifest = json.loads(manifest_path.read_text()) + self.models.clear() + self.model_versions.clear() + for model_id, filename in manifest.get("models", {}).items(): + scope, key = model_id.split(":", 1) + model = self.get(scope, key) + model.load(path / filename) + model.set_debug(self.debug) + self.model_versions[model_id] = self.model_versions.get(model_id, 0) + 1 + return + + # Backward-compatible load of an old single event-time model file. + model = self.get(model_scope, model_key) + model.load(path) + model.set_debug(self.debug) + model_id = self.model_id(model_scope, model_key) + self.model_versions[model_id] = self.model_versions.get(model_id, 0) + 1 + + def status(self, model_scope: str | None = None, model_key: str | None = None) -> dict[str, str]: + if model_scope in (None, "", "all", "*"): + model_ids = sorted(self.models) + else: + model_ids = [self.model_id(model_scope, model_key)] + + total_rows = 0 + trained_models = 0 + entries = [] + + for model_id in model_ids: + model = self.models.get(model_id) + if model is None: + continue + total_rows += len(model.rows) + trained_models += int(bool(model.trained)) + entries.append( + f"{model_id}:rows={len(model.rows)},trained={int(model.trained)},examples={model.training_examples}" + ) + + return { + "model_count": str(len(model_ids)), + "trained_models": str(trained_models), + "total_rows": str(total_rows), + "models": ";".join(entries), + } + + +iteration_time_models = IterationTimeModelRegistry(**ITERATION_MODEL_KWARGS) +event_time_models = ScopedEventTimeModelRegistry(EVENT_TIME_MODEL_KWARGS) + +event_time_model = event_time_models.get() iteration_model_path = os.environ.get("ZMQML_ITERATION_MODEL_PATH", "").strip() event_time_model_path = os.environ.get("ZMQML_EVENT_TIME_MODEL_PATH", "").strip() event_time_record_log_path = os.environ.get("ZMQML_EVENT_TIME_RECORD_LOG_PATH", "").strip() @@ -58,10 +238,7 @@ app_alloc_path = os.environ.get("ZMQML_APP_ALLOC_PATH", "").strip() auto_train_on_records = os.environ.get( - "ZMQML_AUTO_TRAIN_ON_RECORDS", "1" -).strip().lower() in ("1", "true", "yes", "on") -event_time_auto_train_on_records = os.environ.get( - "ZMQML_EVENT_TIME_AUTO_TRAIN_ON_RECORDS", "0" + "ZMQML_AUTO_TRAIN_ON_RECORDS", "0" ).strip().lower() in ("1", "true", "yes", "on") iteration_model_version = 0 @@ -83,10 +260,11 @@ ) if event_time_model_path: - event_time_model.load(event_time_model_path) + event_time_models.load(event_time_model_path) + event_time_model = event_time_models.get() event_time_model_version = 1 print( - f"[zmqmlserver] loaded event-time model: {event_time_model_path}", + f"[zmqmlserver] loaded event-time model(s): {event_time_model_path}", flush=True, ) @@ -228,7 +406,7 @@ def set_director_debug_prints(args): director_debug_prints = raw in ("1", "true", "yes", "on", "enabled") iteration_time_models.set_debug(director_debug_prints) - event_time_model.set_debug(director_debug_prints) + event_time_models.set_debug(director_debug_prints) if director_debug_prints: print(f"[zmqmlserver] director_debug_prints=1", flush=True) @@ -335,13 +513,17 @@ def receivedata(args, bindata): # # receive training records # + def receiverecords(args, bindata): status = "failed" st = time.time() - num_args = int(args[0]) # 1st arg is num of args - client = int(args[1]) # 2nd arg is client id - num_records = int(args[2]) # 3rd arg is num records + real_args = _real_command_args(args) + if len(real_args) < 2: + return ("failed", time.time() - st) + + client = int(real_args[0]) + num_records = int(real_args[1]) records_str = str(bindata.decode('utf-8')) records_str = records_str.strip() @@ -361,17 +543,10 @@ def receiverecords(args, bindata): training_records[client].extend(parsed_records) - # Keep the raw records available for offline/pretraining workflows. - # By default this preserves the old behavior and trains immediately. - # Set ZMQML_AUTO_TRAIN_ON_RECORDS=0 for pure-PDES collection or - # frozen pretrained inference runs. if parsed_records: append_record_log(client, parsed_records) model = iteration_time_models.get(client) - # Enrich the ML model with app_id metadata when available. - # The C++ protocol still sends client + timing values, while the Python - # server infers app_id from ZMQML_APP_ALLOC_PATH. app_id = client_app_map.get(client, -1) if "client_app_map" in globals() else -1 if hasattr(model, "set_app_id"): @@ -400,19 +575,17 @@ def receiverecords(args, bindata): return (status, elapsed_time) -# -# do inference to get predictions -# def launch_iteration_time_inferencing(args, bindata): status = "failed" st = time.time() - num_args = int(args[0]) # 1st arg is num of args - client = int(args[1]) # 2nd arg is client id - num_steps = int(args[2]) # 3rd arg is num steps to predict + real_args = _real_command_args(args) + if len(real_args) < 2: + return ("failed", time.time() - st, "") + + client = int(real_args[0]) + num_steps = int(real_args[1]) - # Optional recent-context payload. The normal path uses records previously - # received through send-records, but accepting context keeps the API flexible. records_str = str(bindata.decode('utf-8')) records_str = records_str.strip() @@ -448,11 +621,6 @@ def launch_iteration_time_inferencing(args, bindata): elapsed_time = time.time() - st return (status, elapsed_time, inferences_str) - - - - - def event_time_payload_has_header(payload: str) -> bool: for line in payload.splitlines(): line = line.strip() @@ -502,27 +670,101 @@ def append_event_time_record_log(payload: str) -> None: f.write("\n") +def event_time_model_identity_from_row(row: dict) -> tuple[str, str, str]: + raw_lp_type = str(row.get("current_lp_type", "")).strip() + raw_gid = str(row.get("current_lp_gid", "")).strip() + + if raw_lp_type in EVENT_TIME_TERMINAL_LP_TYPES: + return normalize_model_identity("terminal", "global") + + return normalize_model_identity("switch", raw_gid or "unknown") + + +def iter_event_time_rows_from_payload(raw_payload: str): + payload = event_time_payload_with_header(raw_payload) + if not payload.strip(): + return + + reader = csv.DictReader(io.StringIO(payload)) + if reader.fieldnames: + reader.fieldnames = [str(name).strip().lstrip("#").strip() for name in reader.fieldnames] + + for row in reader: + clean = {str(k).strip().lstrip("#").strip(): v for k, v in row.items()} + yield clean + + def receive_event_time_records(args, bindata): st = time.time() raw_payload = bindata.decode("utf-8", errors="replace").strip() - payload = event_time_payload_with_header(raw_payload) - loaded_rows = event_time_model.add_records_text(payload) if payload else 0 + loaded_rows = 0 + + # Important performance rule: + # Do NOT call EventTimeModel.add_records_text(...) once per row. + # Event-time batches can contain 65K+ rows, and per-row parsing/routing makes + # pure-PDES collection much slower than the old single-global model path. + # + # Instead, parse once, group rows by scoped model id, then call + # add_records_text(...) once per scoped model per C++ batch. + grouped_rows: dict[str, list[dict]] = {} + grouped_identity: dict[str, tuple[str, str]] = {} + + for row in iter_event_time_rows_from_payload(raw_payload) or []: + model_scope, model_key, model_id = event_time_model_identity_from_row(row) + grouped_rows.setdefault(model_id, []).append(row) + grouped_identity[model_id] = (model_scope, model_key) + + per_model_loaded: dict[str, int] = {} + + for model_id, rows in grouped_rows.items(): + model_scope, model_key = grouped_identity[model_id] + model = event_time_models.get(model_scope, model_key) + + if not rows: + continue + + # Build one CSV payload for this model. Preserve the field order from + # the first row and include any later extra keys defensively. + fieldnames = list(rows[0].keys()) + seen = set(fieldnames) + for row in rows[1:]: + for key in row.keys(): + if key not in seen: + fieldnames.append(key) + seen.add(key) + + buf = io.StringIO() + writer = csv.DictWriter(buf, fieldnames=fieldnames, extrasaction="ignore") + writer.writeheader() + writer.writerows(rows) + + accepted = model.add_records_text(buf.getvalue()) + loaded_rows += accepted + per_model_loaded[model_id] = accepted + + if accepted > 0 and auto_train_on_records: + model.train_or_update() if loaded_rows > 0: append_event_time_record_log(raw_payload) - if event_time_auto_train_on_records: - event_time_model.train_or_update() - director_debug( - f"[event-time records] loaded_rows={loaded_rows} " - f"total_rows={len(event_time_model.rows)} " - f"trained={int(event_time_model.trained)}" - ) + if director_debug_prints: + # Keep this compact. Printing the full per_model dict for 100+ switch + # models is noisy and can itself become expensive. + nonempty_models = sum(1 for v in per_model_loaded.values() if v > 0) + min_rows = min(per_model_loaded.values()) if per_model_loaded else 0 + max_rows = max(per_model_loaded.values()) if per_model_loaded else 0 + sample = sorted(per_model_loaded.items())[:8] + print( + f"[event-time records] loaded_rows={loaded_rows} " + f"models={nonempty_models} min_rows_per_model={min_rows} " + f"max_rows_per_model={max_rows} sample={sample}", + flush=True, + ) return ("done", time.time() - st, loaded_rows) - def load_event_time_records_csv_command(args): st = time.time() real_args = _real_command_args(args) @@ -542,22 +784,44 @@ def load_event_time_records_csv_command(args): "error": f"event-time records path does not exist: {path}", } - loaded_rows = event_time_model.load_csv(path) + loaded_rows = 0 + files = sorted(path.rglob("*")) if path.is_dir() else [path] - return { + for child in files: + if not child.is_file() or child.suffix.lower() not in (".csv", ".txt", ".log"): + continue + status, _et, child_rows = receive_event_time_records(["0"], child.read_text().encode("utf-8")) + if status == "done": + loaded_rows += int(child_rows) + + ret = { "status": "done", "et": str(time.time() - st), "path": str(path), "loaded_rows": str(loaded_rows), - "total_rows": str(len(event_time_model.rows)), } + ret.update(event_time_models.status()) + return ret def train_event_time_model_command(args): global event_time_model_version st = time.time() - trained = event_time_model.train_or_update() + real_args = _real_command_args(args) + + target = real_args[0] if real_args else "all" + if target in ("", "all", "*"): + trained = event_time_models.train_or_update("all", "") + model_scope = "all" + model_key = "" + model_id = "all" + elif len(real_args) >= 2: + model_scope, model_key, model_id = normalize_model_identity(real_args[0], real_args[1]) + trained = event_time_models.train_or_update(model_scope, model_key) + else: + model_scope, model_key, model_id = normalize_model_identity("switch", target) + trained = event_time_models.train_or_update(model_scope, model_key) if trained: event_time_model_version += 1 @@ -565,18 +829,17 @@ def train_event_time_model_command(args): ret = { "status": "done" if trained else "failed", "et": str(time.time() - st), + "target": model_id, "model_version": str(event_time_model_version), } - ret.update(event_time_model.status()) + ret.update(event_time_models.status(model_scope, model_key)) if not trained: - ret["error"] = "event-time model was not trained; load enough generic event-time rows first" + ret["error"] = "event-time model was not trained; load enough scoped event-time rows first" print( - f"[event-time model-train-command] trained={int(trained)} " - f"rows={len(event_time_model.rows)} " - f"training_examples={event_time_model.training_examples} " - f"model_version={event_time_model_version}", + f"[event-time model-train-command] target={model_id} trained={int(trained)} " + f"model_version={event_time_model_version} status={event_time_models.status(model_scope, model_key)}", flush=True, ) @@ -595,21 +858,29 @@ def save_event_time_model_command(args): } model_path = Path(real_args[0]) - if model_path.parent: - model_path.parent.mkdir(parents=True, exist_ok=True) + target = real_args[1] if len(real_args) >= 2 else "all" - event_time_model.save(model_path) + if target in ("", "all", "*"): + event_time_models.save(model_path, "all", "") + model_id = "all" + elif len(real_args) >= 3: + model_scope, model_key, model_id = normalize_model_identity(real_args[1], real_args[2]) + event_time_models.save(model_path, model_scope, model_key) + else: + model_scope, model_key, model_id = normalize_model_identity("switch", target) + event_time_models.save(model_path, model_scope, model_key) return { "status": "done", "et": str(time.time() - st), "path": str(model_path), + "target": model_id, "model_version": str(event_time_model_version), } def load_event_time_model_command(args): - global event_time_model_version + global event_time_model_version, event_time_model st = time.time() real_args = _real_command_args(args) @@ -629,25 +900,50 @@ def load_event_time_model_command(args): "error": f"model path does not exist: {model_path}", } - event_time_model.load(model_path) + target = real_args[1] if len(real_args) >= 2 else "all" + + if target in ("", "all", "*"): + event_time_models.load(model_path) + model_id = "all" + elif len(real_args) >= 3: + model_scope, model_key, model_id = normalize_model_identity(real_args[1], real_args[2]) + event_time_models.load(model_path, model_scope, model_key) + else: + model_scope, model_key, model_id = normalize_model_identity("switch", target) + event_time_models.load(model_path, model_scope, model_key) + + event_time_model = event_time_models.get() event_time_model_version += 1 return { "status": "done", "et": str(time.time() - st), "path": str(model_path), + "target": model_id, "model_version": str(event_time_model_version), } def event_time_model_status_command(args): st = time.time() + real_args = _real_command_args(args) + + if not real_args or real_args[0] in ("", "all", "*"): + model_scope = "all" + model_key = "" + model_id = "all" + elif len(real_args) >= 2: + model_scope, model_key, model_id = normalize_model_identity(real_args[0], real_args[1]) + else: + model_scope, model_key, model_id = normalize_model_identity("switch", real_args[0]) + ret = { "status": "done", "et": str(time.time() - st), + "target": model_id, "model_version": str(event_time_model_version), } - ret.update(event_time_model.status()) + ret.update(event_time_models.status(model_scope, model_key)) return ret @@ -663,15 +959,24 @@ def launch_event_time_inferencing(args, bindata): requested_count = 1 payload = bindata.decode("utf-8", errors="replace").strip() - predictions = event_time_model.predict_from_text( + rows = list(iter_event_time_rows_from_payload(payload) or []) + + if rows: + model_scope, model_key, model_id = event_time_model_identity_from_row(rows[0]) + else: + model_scope, model_key, model_id = normalize_model_identity() + + model = event_time_models.get(model_scope, model_key) + predictions = model.predict_from_text( payload, requested_count=max(1, requested_count), ) predictions_str = " ".join(str(float(x)) for x in predictions) director_debug( - f"[event-time inference] requested_count={requested_count} " - f"payload_bytes={len(payload)} predictions={predictions_str}" + f"[event-time inference] model={model_id} requested_count={requested_count} " + f"payload_bytes={len(payload)} trained={int(model.trained)} " + f"rows={len(model.rows)} predictions={predictions_str}" ) return ("done", time.time() - st, predictions_str) @@ -695,6 +1000,7 @@ def _real_command_args(args): return out + def train_iteration_time_model_command(args): global iteration_model_version @@ -875,7 +1181,6 @@ def load_iteration_records_csv_command(args): "loaded_clients": str(len(loaded_clients)), } - def load_iteration_time_model_command(args): global iteration_model_version @@ -916,6 +1221,7 @@ def load_iteration_time_model_command(args): } + def iteration_time_model_status_command(args): st = time.time() real_args = _real_command_args(args) @@ -960,8 +1266,6 @@ def iteration_time_model_status_command(args): "clients": ";".join(per_client), } - -# Backwards-compatible wrapper for the old command name. def launch_surrogate_inferencing(args, bindata): return launch_iteration_time_inferencing(args, bindata) @@ -991,7 +1295,7 @@ def director_request_command(msg, bindata): family = str(msg.get("surrogate_family", "iteration-time")).strip() operation = str(msg.get("operation", "")).strip() backend = str(msg.get("surrogate_backend", "")).strip() - args = _real_command_args(msg.get("args", [])) + args = msg.get("args", []) operation_aliases = { "status": "model-status",