diff --git a/CODES-compile-instructions.sh b/CODES-compile-instructions.sh
index 23f862ae..2b4f781e 100644
--- a/CODES-compile-instructions.sh
+++ b/CODES-compile-instructions.sh
@@ -37,6 +37,17 @@ else
echo "Using existing ross checkout: $(realpath ross)"
fi
+
+if [ "$torch_enable" = 1 ]; then
+ make_args_codes=(
+ "${make_args_codes[@]}"
+ )
+else
+ make_args_codes=(
+ "${make_args_codes[@]}"
+ )
+fi
+
if [ $swm_enable = 1 ]; then
if [ ! -d argobots/.git ]; then
git clone https://github.com/pmodels/argobots --depth=1
@@ -192,53 +203,41 @@ fi
-# Make system pkg-config metadata visible even when Conda's pkg-config is active.
-# This is needed for libzmq.pc on systems where ZeroMQ is installed through the OS
-# but the active Conda environment's pkg-config only searches Conda pkgconfig dirs.
-if ! pkg-config --exists libzmq 2>/dev/null; then
- for pcdir in \
- /usr/lib/x86_64-linux-gnu/pkgconfig \
- /usr/lib64/pkgconfig \
- /usr/lib/pkgconfig \
- /usr/local/lib/pkgconfig \
- /usr/local/lib64/pkgconfig \
- /opt/homebrew/lib/pkgconfig \
- /usr/share/pkgconfig
- do
- if [ -d "$pcdir" ]; then
- export PKG_CONFIG_PATH="$pcdir:${PKG_CONFIG_PATH:-}"
- fi
- done
-fi
+if [ "$torch_enable" = 1 ]; then
+ # Make system pkg-config metadata visible even when Conda's pkg-config is active.
+ # This is needed for libzmq.pc on systems where ZeroMQ is installed through the OS
+ # but the active Conda environment's pkg-config only searches Conda pkgconfig dirs.
+ if ! pkg-config --exists libzmq 2>/dev/null; then
+ for pcdir in \
+ /usr/lib/x86_64-linux-gnu/pkgconfig \
+ /usr/lib64/pkgconfig \
+ /usr/lib/pkgconfig \
+ /usr/local/lib/pkgconfig \
+ /usr/local/lib64/pkgconfig \
+ /opt/homebrew/lib/pkgconfig \
+ /usr/share/pkgconfig
+ do
+ if [ -d "$pcdir" ]; then
+ export PKG_CONFIG_PATH="$pcdir:${PKG_CONFIG_PATH:-}"
+ fi
+ done
+ fi
+
+ if ! pkg-config --exists libzmq 2>/dev/null; then
+ echo "WARNING: pkg-config still cannot find libzmq.pc." >&2
+ echo " If ZMQML requester support fails to build, install the ZeroMQ development package" >&2
+ echo " or set PKG_CONFIG_PATH to the directory containing libzmq.pc." >&2
+ fi
-if ! pkg-config --exists libzmq 2>/dev/null; then
- echo "WARNING: pkg-config still cannot find libzmq.pc." >&2
- echo " If ZMQML fails to build, install the ZeroMQ development package" >&2
- echo " or set PKG_CONFIG_PATH to the directory containing libzmq.pc." >&2
+ # Build local ZMQML requester library required by director-client.C.
+ pushd codes/src/surrogate/zmqml
+ make clean
+ make
+ test -f libzmqmlrequester.so
+ test -f zmqmlrequester.h
+ popd
fi
-# Build local ZMQML requester library required by director-client.C
-pushd codes/src/surrogate/zmqml
-make clean
-make
-test -f libzmqmlrequester.so
-test -f zmqmlrequester.h
-popd
-
-# Make imported zmqmlrequester target visible to doc/example and tests.
-python3 - <<'INNERPY'
-from pathlib import Path
-cm = Path("codes/src/CMakeLists.txt")
-text = cm.read_text()
-old = "add_library(zmqmlrequester SHARED IMPORTED )"
-new = "add_library(zmqmlrequester SHARED IMPORTED GLOBAL)"
-if old in text:
- cm.write_text(text.replace(old, new))
-elif new in text:
- pass
-else:
- raise SystemExit("Could not find zmqmlrequester imported target line in codes/src/CMakeLists.txt")
-INNERPY
mkdir -p codes/build
pushd codes/build
@@ -368,10 +367,8 @@ make_args_codes=(
-DCMAKE_USE_WIN32_THREADS_INIT=0
-DCMAKE_BUILD_TYPE=Debug -DBUILD_TESTING=ON
-DCMAKE_INSTALL_PREFIX="$(realpath bin)"
- -DZMQML_BUILD_PATH="$(realpath "$CUR_DIR/codes/src/surrogate/zmqml")"
- -DZeroMQ_INCLUDE_DIR=/usr/include
- -DZeroMQ_LIBRARY=/usr/lib/x86_64-linux-gnu/libzmq.so
)
+
if [ $swm_enable = 1 ]; then
make_args_codes=(
"${make_args_codes[@]}"
@@ -390,6 +387,10 @@ if [ "$torch_enable" = 1 ]; then
"${make_args_codes[@]}"
-DUSE_TORCH=true
-DTorch_DIR="${torch_dir}"
+ -DUSE_ZMQML=true
+ -DZMQML_BUILD_PATH="$(realpath "$CUR_DIR/codes/src/surrogate/zmqml")"
+ -DZeroMQ_INCLUDE_DIR=/usr/include
+ -DZeroMQ_LIBRARY=/usr/lib/x86_64-linux-gnu/libzmq.so
)
if [ -n "${CUDA_HOME:-}" ]; then
@@ -412,7 +413,10 @@ if [ "$torch_enable" = 1 ]; then
)
fi
else
- make_args_codes=("${make_args_codes[@]}" -DUSE_TORCH=false)
+ make_args_codes=(
+ "${make_args_codes[@]}"
+ -DUSE_TORCH=false
+ )
fi
cmake .. "${make_args_codes[@]}"
diff --git a/README.md b/README.md
index 92b61cf4..56820db6 100644
--- a/README.md
+++ b/README.md
@@ -139,6 +139,7 @@ This repo uses [clang-format](https://clang.llvm.org/docs/ClangFormat.html) to k
- **Emacs:** see [clang-format.el](https://clang.llvm.org/docs/ClangFormat.html#emacs-integration).
To reformat a file manually: `clang-format -i path/to/file.c`. CI runs `clang-format --dry-run --Werror` on every PR and rejects any drift, so PRs with unformatted code don't merge.
+Note: The CI uses clang-format major release version 20, so you should format your files with that version.
### Determinism
diff --git a/codes/surrogate/director-client.h b/codes/surrogate/director-client.h
index aaea3d09..fa563c6a 100644
--- a/codes/surrogate/director-client.h
+++ b/codes/surrogate/director-client.h
@@ -125,8 +125,7 @@ extern "C" {
extern void director_lp_register_model(const char*);
-
-
+extern void director_record_external_zmq_latency(double processing_sec, double total_sec);
/*
extern void director_parse_args(char *args, int **args_array, int *length);
static void director_issue_codes_event(director_state * s, tw_lpid nw_lpid, int dir_registered_event_type, tw_stime ts, tw_lp* lp);
@@ -142,5 +141,4 @@ extern void dir_test_finalize(director_state* s, tw_lp* lp);
#ifdef __cplusplus
}
#endif
-
#endif
diff --git a/doc/example/kb.dfdally-72-zeromq-director.conf.in b/doc/example/kb.dfdally-72-zeromq-director.conf.in
index 656959c4..fdb77ec7 100644
--- a/doc/example/kb.dfdally-72-zeromq-director.conf.in
+++ b/doc/example/kb.dfdally-72-zeromq-director.conf.in
@@ -23,19 +23,19 @@ LPGROUPS
DIRECTOR
{
- start_iter="${DIRECTOR_START_ITER}";
- end_iter="${DIRECTOR_END_ITER}";
+ start_iter="${START_ITER}";
+ end_iter="${END_ITER}";
# Optional one-shot pause/retrain/resume pipeline.
# First implementation is intended for --synch=1.
- retrain_enabled="${DIRECTOR_RETRAIN_ENABLED}";
- retrain_iter="${DIRECTOR_RETRAIN_ITER}";
- retrain_save_path="${DIRECTOR_RETRAIN_SAVE_PATH}";
+ retrain_enabled="${RETRAIN_ENABLED}";
+ retrain_iter="${RETRAIN_ITER}";
+ retrain_save_path="${RETRAIN_SAVE_PATH}";
# Optional second surrogate window after retraining.
- second_surrogate_enabled="${DIRECTOR_SECOND_SURROGATE_ENABLED}";
- second_start_iter="${DIRECTOR_SECOND_START_ITER}";
- second_end_iter="${DIRECTOR_SECOND_END_ITER}";
+ second_surrogate_enabled="${SECOND_SURROGATE_ENABLED}";
+ second_start_iter="${SECOND_START_ITER}";
+ second_end_iter="${SECOND_END_ITER}";
# Common modes:
#
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 99430538..4e834b90 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -1,3 +1,4 @@
+option(USE_ZMQML "Enable ZeroMQ ML requester support" OFF)
cmake_print_variables(CMAKE_CURRENT_SOURCE_DIR)
find_package(FLEX REQUIRED)
@@ -124,68 +125,38 @@ if(USE_TORCH)
list(APPEND LIBS_TO_LINK ${TORCH_LIBRARIES})
endif()
-# ZMQML / director-client (opt-in). When USE_ZMQML=ON, callers must
-# point ZMQML_BUILD_PATH at a directory containing libzmqmlrequester.so
-# (build it via src/surrogate/zmqml/Makefile, or set ZMQML_BUILD_PATH to
-# wherever you installed it). When OFF (the default), CODES builds with
-# no surrogate/director-client linkage; configs that reference
-# "dir-nw-lp" will fail at runtime because the LP type isn't registered.
-option(USE_ZMQML "Build the director-client + zmqml surrogate integration" OFF)
+# ZMQML requester support
if(USE_ZMQML)
- if(NOT ZMQML_BUILD_PATH)
- message(FATAL_ERROR
- "USE_ZMQML=ON but ZMQML_BUILD_PATH is unset.\n"
- "Build src/surrogate/zmqml/libzmqmlrequester.so first, then "
- "reconfigure with -DZMQML_BUILD_PATH=
.")
- endif()
list(APPEND SRCS surrogate/director-client.C)
-endif()
-add_library(codes STATIC ${SRCS})
-
-list(APPEND LIBS_TO_LINK ${MPI_C_LIBRARIES})
-target_include_directories(codes INTERFACE ${MPI_C_INCLUDE_PATH})
+ if(NOT DEFINED ZMQML_BUILD_PATH)
+ message(FATAL_ERROR "USE_ZMQML is ON, but ZMQML_BUILD_PATH is not defined.")
+ endif()
-# set(LIBS_TO_LINK
-# PkgConfig::ROSS
-# ${DUMPI_LIB}
-# PkgConfig::ARGOBOTS
-# PkgConfig::SWM
-# )
+ if(NOT EXISTS "${ZMQML_BUILD_PATH}/libzmqmlrequester.so")
+ message(FATAL_ERROR "USE_ZMQML is ON, but ${ZMQML_BUILD_PATH}/libzmqmlrequester.so does not exist. Re-run CODES-compile-instructions.sh so the local requester library is built before configuring CODES.")
+ endif()
-#LINK DUMPI
-# target_link_libraries(codes PUBLIC ${DUPMI_LIB})
-if(USE_DUMPI)
- target_include_directories(codes PUBLIC ${DUMPI_INCLUDE})
-endif()
+ pkg_check_modules(PC_ZeroMQ QUIET zmq)
+ find_path(ZeroMQ_INCLUDE_DIR NAMES zmq.hpp PATHS ${PC_ZeroMQ_INCLUDE_DIRS})
+ find_library(ZeroMQ_LIBRARY NAMES zmq PATHS ${PC_ZeroMQ_LIBRARY_DIRS})
-#LINK ARGOBOTS, SWM and UNION
-# target_link_libraries(codes PUBLIC PkgConfig::ARGOBOTS)
-if(USE_ONLINE)
- if(USE_SWM)
- target_include_directories(codes PUBLIC ${ARGOBOTS_INCLUDE_DIRS})
- # target_link_libraries(codes PUBLIC PkgConfig::SWM)
- target_include_directories(codes PUBLIC ${SWM_INCLUDE_DIRS})
+ if(NOT ZeroMQ_LIBRARY)
+ message(FATAL_ERROR "USE_ZMQML is ON, but libzmq was not found.")
endif()
- if(USE_UNION)
- target_include_directories(codes PUBLIC ${ARGOBOTS_INCLUDE_DIRS})
- # target_link_libraries(codes PUBLIC PkgConfig::SWM)
- target_include_directories(codes PUBLIC ${SWM_INCLUDE_DIRS})
- target_include_directories(codes PUBLIC ${UNION_INCLUDE_DIRS})
- endif()
-endif()
-if(USE_ZMQML)
add_library(zmqmlrequester SHARED IMPORTED GLOBAL)
set_target_properties(zmqmlrequester PROPERTIES
IMPORTED_LOCATION "${ZMQML_BUILD_PATH}/libzmqmlrequester.so"
INTERFACE_INCLUDE_DIRECTORIES "${ZMQML_BUILD_PATH}")
- target_compile_definitions(codes PUBLIC USE_ZMQML)
endif()
#LINK ROSS
# target_link_libraries(codes PUBLIC #{pkgcfg_lib_ROSS_ROSS})
# target_link_libraries(codes PUBLIC PkgConfig::ROSS)
+
+add_library(codes ${SRCS})
+
target_include_directories(codes PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}
${ROSS_INCLUDE_DIRS}
@@ -194,11 +165,16 @@ target_include_directories(codes PUBLIC
${PROJECT_SOURCE_DIR}/codes
${PROJECT_SOURCE_DIR}/src
${PROJECT_SOURCE_DIR}/src/modelconfig
- $<$:$>
)
target_link_libraries(codes PUBLIC ${LIBS_TO_LINK})
+if(USE_ZMQML)
+ target_compile_definitions(codes PUBLIC USE_ZMQML)
+ target_include_directories(codes PUBLIC "${ZMQML_BUILD_PATH}" ${ZeroMQ_INCLUDE_DIR})
+ target_link_libraries(codes PUBLIC zmqmlrequester ${ZeroMQ_LIBRARY})
+endif()
+
get_target_property(CODES_INCLUDE_DIRS codes INCLUDE_DIRECTORIES)
cmake_print_variables(CODES_INCLUDE_DIRS)
@@ -227,18 +203,12 @@ if(USE_DUMPI)
list(APPEND CODES_TARGETS model-net-dumpi-traces-dump)
endif()
-# ZMQ — only resolved + linked when USE_ZMQML is on; otherwise nothing
-# in the codes library calls into libzmq.
-if(USE_ZMQML)
- pkg_check_modules(PC_ZeroMQ QUIET zmq)
- find_path(ZeroMQ_INCLUDE_DIR NAMES zmq.hpp PATHS ${PC_ZeroMQ_INCLUDE_DIRS})
- find_library(ZeroMQ_LIBRARY NAMES zmq PATHS ${PC_ZeroMQ_LIBRARY_DIRS})
-endif()
-
foreach(tar IN LISTS CODES_TARGETS)
target_include_directories(${tar} PUBLIC ${CODES_INCLUDE_DIRS} ${ROSS_INCLUDE_DIRS})
target_link_libraries(${tar} PUBLIC codes ${LIBS_TO_LINK})
+
if(USE_ZMQML)
+ target_include_directories(${tar} PUBLIC "${ZMQML_BUILD_PATH}" ${ZeroMQ_INCLUDE_DIR})
target_link_libraries(${tar} PUBLIC zmqmlrequester ${ZeroMQ_LIBRARY})
endif()
endforeach()
diff --git a/src/networks/model-net/dragonfly-dally.C b/src/networks/model-net/dragonfly-dally.C
index b4c064cf..d5470dfa 100644
--- a/src/networks/model-net/dragonfly-dally.C
+++ b/src/networks/model-net/dragonfly-dally.C
@@ -22,6 +22,7 @@
#include "codes/model-net-method.h"
#include "codes/model-net-lp.h"
#include "codes/surrogate/init.h"
+#include "codes/surrogate/director-client.h"
#ifdef USE_TORCH
#include "codes/surrogate/packet-latency-predictor/torch-jit.h"
#endif
@@ -64,11 +65,13 @@
* resolve it to null.)
*/
#ifdef USE_ZMQML
-extern std::vector zmqml_director_request(const std::string& surrogate_family,
- const std::string& surrogate_backend,
- const std::string& operation,
- const std::vector& args,
- const std::string& bindata);
+extern std::vector
+zmqml_director_request(const std::string& surrogate_family, const std::string& surrogate_backend,
+ const std::string& operation, const std::vector& args,
+ const std::string& bindata) __attribute__((weak));
+
+extern "C" void director_record_external_zmq_latency(double processing_sec, double total_sec)
+ __attribute__((weak));
extern void director_record_zmq_latency_stats(const char* label,
const std::vector& ret,
@@ -297,7 +300,20 @@ static std::vector dfdally_event_time_director_request_with_latency
(double)(finish.tv_nsec - start.tv_nsec) / 1000000000.0;
#ifdef USE_ZMQML
- director_record_zmq_latency_stats(label, ret, local_latency_sec);
+ double zmq_processing_time = 0.0;
+
+ if (ret.size() > 1) {
+ char* endptr = NULL;
+ double parsed = strtod(ret[1].c_str(), &endptr);
+
+ if (endptr != ret[1].c_str() && isfinite(parsed) && parsed >= 0.0) {
+ zmq_processing_time = parsed;
+ }
+ }
+
+ if (director_record_external_zmq_latency) {
+ director_record_external_zmq_latency(zmq_processing_time, local_latency_sec);
+ }
#endif
return ret;
@@ -2695,14 +2711,8 @@ static void dragonfly_read_config(const char* anno, dragonfly_param* params) {
char event_time_inference_enabled_str[MAX_NAME_LENGTH];
event_time_inference_enabled_str[0] = '\0';
- char const* inferencing_enabled_env = getenv("INFERENCING_ENABLED");
- if (inferencing_enabled_env && strlen(inferencing_enabled_env) > 0) {
- snprintf(event_time_inference_enabled_str, sizeof(event_time_inference_enabled_str), "%s",
- inferencing_enabled_env);
- } else {
- configuration_get_value(&config, "DIRECTOR", "inferencing_enabled", anno,
- event_time_inference_enabled_str, MAX_NAME_LENGTH);
- }
+ configuration_get_value(&config, "DIRECTOR", "inferencing_enabled", anno,
+ event_time_inference_enabled_str, MAX_NAME_LENGTH);
/*
* Do not expose a separate event-time inference flag.
@@ -2737,14 +2747,8 @@ static void dragonfly_read_config(const char* anno, dragonfly_param* params) {
char event_time_training_enabled_str[MAX_NAME_LENGTH];
event_time_training_enabled_str[0] = '\0';
- char const* training_enabled_env = getenv("TRAINING_ENABLED");
- if (training_enabled_env && strlen(training_enabled_env) > 0) {
- snprintf(event_time_training_enabled_str, sizeof(event_time_training_enabled_str), "%s",
- training_enabled_env);
- } else {
- configuration_get_value(&config, "DIRECTOR", "training_enabled", anno,
- event_time_training_enabled_str, MAX_NAME_LENGTH);
- }
+ configuration_get_value(&config, "DIRECTOR", "training_enabled", anno,
+ event_time_training_enabled_str, MAX_NAME_LENGTH);
event_time_surrogate_family_selected =
strcmp(event_time_surrogate_family_str, "event-time") == 0;
@@ -2765,15 +2769,6 @@ static void dragonfly_read_config(const char* anno, dragonfly_param* params) {
event_time_zmq_flush_registered = 1;
}
- if (dfdally_surrogate_debug_prints) {
- fprintf(stderr,
- "[event-time records] family=%s training_enabled=%s send_to_zmq=%d batch_size=%d\n",
- event_time_surrogate_family_str, event_time_training_enabled_str,
- event_time_training_records_enabled, event_time_zmq_batch_size);
- fflush(stderr);
- }
-
-
// START Surrogate configuration
char enable_str[MAX_NAME_LENGTH];
enable_str[0] = '\0';
@@ -2798,6 +2793,14 @@ static void dragonfly_read_config(const char* anno, dragonfly_param* params) {
dfdally_surrogate_debug_prints = dfdally_string_is_true(debug_prints_str);
+ if (dfdally_surrogate_debug_prints) {
+ fprintf(stderr,
+ "[event-time records] family=%s training_enabled=%s send_to_zmq=%d batch_size=%d\n",
+ event_time_surrogate_family_str, event_time_training_enabled_str,
+ event_time_training_records_enabled, event_time_zmq_batch_size);
+ fflush(stderr);
+ }
+
// if surrogate mode has been set up
if (enable_network_surrogate) {
struct network_surrogate_config surr_conf = {
diff --git a/src/surrogate/director-client.C b/src/surrogate/director-client.C
index 3b115dc0..b6f50757 100644
--- a/src/surrogate/director-client.C
+++ b/src/surrogate/director-client.C
@@ -23,7 +23,7 @@
#define DIR_ZMQ_CMD_LENGTH 64
#define DIR_ZMQ_ARG_LENGTH 2048
-#define DIR_MAX_PREDICTION 5
+#define DIR_MAX_PREDICTION 1
#define DIR_MAX_TRAINING_RECORDS 10
/*
* The Python iteration-time model currently uses history_len=2 and horizon=3,
@@ -86,29 +86,40 @@ std::vector director_client_request_family(const char* surrogate_fa
int surrogate_enabled = 0;
int inferencing_enabled = 1;
-
-void director_record_zmq_latency_stats(const char* label, const std::vector& ret,
- double local_latency_sec) {
- (void)label;
-
+static void director_record_zmq_latency_values(double processing_sec, double total_sec) {
if (evaluate_perf != 1) {
return;
}
- director_zmq_total_elapsed_times.push_back(local_latency_sec);
+ if (!std::isfinite(processing_sec) || processing_sec < 0.0) {
+ processing_sec = 0.0;
+ }
+
+ if (!std::isfinite(total_sec) || total_sec < 0.0) {
+ total_sec = 0.0;
+ }
+ director_zmq_processing_times.push_back(processing_sec);
+ director_zmq_total_elapsed_times.push_back(total_sec);
+}
+
+static void director_record_zmq_latency_stats(const char* label,
+ const std::vector& ret,
+ double local_latency_sec) {
double zmq_processing_time = 0.0;
+
if (ret.size() > 1) {
- try {
- zmq_processing_time = std::stod(ret[1]);
- } catch (...) {
- if (director_debug_prints) {
- fprintf(stderr,
- "[DIR] Warning: could not parse zmq processing time from reply field "
- "ret[1]=%s\n",
- ret[1].c_str());
- fflush(stderr);
- }
+ char* endptr = NULL;
+ double parsed = strtod(ret[1].c_str(), &endptr);
+
+ if (endptr != ret[1].c_str() && std::isfinite(parsed) && parsed >= 0.0) {
+ zmq_processing_time = parsed;
+ } else if (director_debug_prints) {
+ fprintf(stderr,
+ "[DIR] Warning: could not parse zmq processing time from reply field ret[1]=%s "
+ "request=%s\n",
+ ret[1].c_str(), label ? label : "");
+ fflush(stderr);
}
} else if (director_debug_prints) {
fprintf(stderr,
@@ -117,9 +128,12 @@ void director_record_zmq_latency_stats(const char* label, const std::vectordirector_id,
- DIR_MAX_PREDICTION); // num-of-args;num-record
+ sprintf(args, "%d;%llu;%d;", 2, (unsigned long long)s->director_id,
+ DIR_MAX_PREDICTION); // num-of-args;client-id;num-predictions
// The Python side primarily uses records previously sent through
// send-records. Keep the payload empty for now rather than sending
@@ -1246,31 +1260,29 @@ void director_event_handler_commit(director_state* s, tw_bf* bf, director_messag
static void director_reduce_and_print_zmq_latency_stat(const char* stat_name,
const std::vector& local_values) {
- unsigned long long local_count = (unsigned long long)local_values.size();
-
+ unsigned long long local_count = 0;
double local_sum = 0.0;
double local_sq_sum = 0.0;
double local_min = std::numeric_limits::infinity();
double local_max = -std::numeric_limits::infinity();
for (double value : local_values) {
- local_sum += value;
- local_sq_sum += value * value;
-
- if (value < local_min) {
- local_min = value;
+ if (!std::isfinite(value)) {
+ continue;
}
- if (value > local_max) {
- local_max = value;
- }
+ local_count++;
+ local_sum += value;
+ local_sq_sum += value * value;
+ local_min = std::min(local_min, value);
+ local_max = std::max(local_max, value);
}
unsigned long long global_count = 0;
double global_sum = 0.0;
double global_sq_sum = 0.0;
- double global_min = 0.0;
- double global_max = 0.0;
+ double global_min = std::numeric_limits::infinity();
+ double global_max = -std::numeric_limits::infinity();
MPI_Reduce(&local_count, &global_count, 1, MPI_UNSIGNED_LONG_LONG, MPI_SUM, 0, MPI_COMM_CODES);
@@ -1293,14 +1305,14 @@ static void director_reduce_and_print_zmq_latency_stat(const char* stat_name,
double variance = global_sq_sum / (double)global_count - mean * mean;
/*
- * Floating-point roundoff can make variance slightly negative when
- * values are very close together.
+ * Floating-point roundoff can make variance slightly negative when values
+ * are very close together.
*/
if (variance < 0.0 && variance > -1.0e-18) {
variance = 0.0;
}
- double stddev = sqrt(variance);
+ double stddev = variance > 0.0 ? sqrt(variance) : 0.0;
std::cout << std::setprecision(9) << std::fixed << "==DIR_STATS " << stat_name
<< ": requests = " << global_count << ", mean = " << mean << ", min = " << global_min
@@ -1335,6 +1347,7 @@ static void director_print_zmq_latency_stats_once(void) {
director_zmq_total_elapsed_times);
}
+
void director_finalize(director_state* s, tw_lp* lp) {
director_print_zmq_latency_stats_once();
@@ -1365,8 +1378,7 @@ tw_lptype dir_lp = {(init_f)director_init,
extern void director_lp_register_model(const char* dir_lp_name) {
int num_dir_per_mgrp = codes_mapping_get_lp_count("MODELNET_GRP", 1, "dir-nw-lp", NULL, 0);
if (num_dir_per_mgrp > 0) {
- lp_type_register(dir_lp_name, &dir_lp); // DIRECTOR addition - register type
- //printf("\n==DIR: Registered\n");
+ lp_type_register(dir_lp_name, &dir_lp);
}
}
diff --git a/src/surrogate/zmqml/model/mliterationtime.py b/src/surrogate/zmqml/model/mliterationtime.py
index 6fda3461..ce9877ce 100644
--- a/src/surrogate/zmqml/model/mliterationtime.py
+++ b/src/surrogate/zmqml/model/mliterationtime.py
@@ -302,6 +302,51 @@ def _predict_once(self, client_id: int, history: List[float], iteration: int) ->
return np.asarray(cleaned, dtype=np.float64)
+ def _global_fallback_prediction(self, requested_horizon: int) -> List[float]:
+ requested_horizon = max(1, int(requested_horizon))
+
+ recent_values: List[float] = []
+ for model in self.models.values():
+ if model.records:
+ recent_values.extend(model.records[-max(1, self.history_len):])
+
+ recent_values = _as_positive_finite(recent_values)
+
+ if recent_values:
+ value = float(np.median(np.asarray(recent_values, dtype=np.float64)))
+ elif self.y_mean is not None and len(self.y_mean) > 0:
+ value = float(self.y_mean[0])
+ else:
+ # Match the older iteration-time fallback scale instead of using 1.0.
+ value = 2_000_000.0
+
+ if not np.isfinite(value) or value <= 0.0:
+ value = 2_000_000.0
+
+ return [value for _ in range(requested_horizon)]
+
+ def _global_fallback_prediction(self, requested_horizon: int) -> List[float]:
+ requested_horizon = max(1, int(requested_horizon))
+
+ recent_values: List[float] = []
+ for model in self.models.values():
+ if model.records:
+ recent_values.extend(model.records[-max(1, self.history_len):])
+
+ recent_values = _as_positive_finite(recent_values)
+
+ if recent_values:
+ value = float(np.median(np.asarray(recent_values, dtype=np.float64)))
+ elif self.y_mean is not None and len(self.y_mean) > 0:
+ value = float(self.y_mean[0])
+ else:
+ value = 2_000_000.0
+
+ if not np.isfinite(value) or value <= 0.0:
+ value = 2_000_000.0
+
+ return [value for _ in range(requested_horizon)]
+
def predict(self, client_id: int, requested_horizon: int | None = None) -> List[float]:
client_id = int(client_id)
requested_horizon = int(requested_horizon or self.horizon)
@@ -310,7 +355,15 @@ def predict(self, client_id: int, requested_horizon: int | None = None) -> List[
model = self.get(client_id)
if not model.records:
- return model._fallback_prediction(requested_horizon)
+ fallback = self._global_fallback_prediction(requested_horizon)
+ if self.debug:
+ print(
+ "[IterationTimeModelRegistry] predict global-fallback "
+ f"client={client_id} requested_horizon={requested_horizon} "
+ f"trained={int(self.trained)} predictions={fallback}",
+ flush=True,
+ )
+ return fallback
# Predict in chunks if requested_horizon > self.horizon.
out: List[float] = []
@@ -326,6 +379,9 @@ def predict(self, client_id: int, requested_horizon: int | None = None) -> List[
history.append(float(value))
iteration += 1
+ if not out:
+ out = self._global_fallback_prediction(requested_horizon)
+
if self.debug:
print(
"[IterationTimeModelRegistry] predict "
@@ -337,6 +393,7 @@ def predict(self, client_id: int, requested_horizon: int | None = None) -> List[
return out
+
def save(self, path: str) -> None:
out_path = Path(path)
if out_path.parent:
diff --git a/src/surrogate/zmqml/zmqmlserver.py b/src/surrogate/zmqml/zmqmlserver.py
index fb4b7944..4e65588c 100755
--- a/src/surrogate/zmqml/zmqmlserver.py
+++ b/src/surrogate/zmqml/zmqmlserver.py
@@ -20,6 +20,8 @@
from model.mliterationtime import IterationTimeModelRegistry
from model.mleventtime import EventTimeModel
import csv
+import io
+import pickle
from pathlib import Path
import os
@@ -37,19 +39,197 @@
launch_id = count(start=1) # unique for launched thread
launched_threads = {} # id:obj. keep track of active threads. remove the thread once it finished
-training_records = {} # client_id:[]
-iteration_time_models = IterationTimeModelRegistry(
- history_len=int(os.environ.get("ZMQML_ITERATION_HISTORY_LEN", "4")),
- horizon=int(os.environ.get("ZMQML_ITERATION_HORIZON", "30")),
- ridge_alpha=float(os.environ.get("ZMQML_ITERATION_RIDGE_ALPHA", "1.0")),
- train_stride=int(os.environ.get("ZMQML_ITERATION_TRAIN_STRIDE", "3")),
-)
-event_time_model = EventTimeModel(
- min_rows=int(os.environ.get("ZMQML_EVENT_TIME_MIN_ROWS", "32")),
- max_epochs=int(os.environ.get("ZMQML_EVENT_TIME_EPOCHS", "80")),
- lr=float(os.environ.get("ZMQML_EVENT_TIME_LR", "0.001")),
- hidden_dim=int(os.environ.get("ZMQML_EVENT_TIME_HIDDEN_DIM", "64")),
-)
+training_records = {} # client_id -> []
+
+ITERATION_MODEL_KWARGS = {
+ "history_len": int(os.environ.get("ZMQML_ITERATION_HISTORY_LEN", "4")),
+ "horizon": int(os.environ.get("ZMQML_ITERATION_HORIZON", "30")),
+ "ridge_alpha": float(os.environ.get("ZMQML_ITERATION_RIDGE_ALPHA", "1.0")),
+ "train_stride": int(os.environ.get("ZMQML_ITERATION_TRAIN_STRIDE", "3")),
+}
+
+EVENT_TIME_MODEL_KWARGS = {
+ "min_rows": int(os.environ.get("ZMQML_EVENT_TIME_MIN_ROWS", "32")),
+ "max_epochs": int(os.environ.get("ZMQML_EVENT_TIME_EPOCHS", "80")),
+ "lr": float(os.environ.get("ZMQML_EVENT_TIME_LR", "0.001")),
+ "hidden_dim": int(os.environ.get("ZMQML_EVENT_TIME_HIDDEN_DIM", "64")),
+}
+
+DEFAULT_TERMINAL_MODEL_SCOPE = "terminal"
+DEFAULT_TERMINAL_MODEL_KEY = "global"
+
+# The current dragonfly-dally event-time rows use numeric LP-type fields.
+# By default, current_lp_type=0 is treated as a terminal LP; all other LP
+# types get switch-local models. Override if the enum changes.
+EVENT_TIME_TERMINAL_LP_TYPES = {
+ token.strip()
+ for token in os.environ.get("ZMQML_EVENT_TIME_TERMINAL_LP_TYPES", "0").split(",")
+ if token.strip()
+}
+
+
+def normalize_model_identity(model_scope: str | None = None, model_key: str | None = None) -> tuple[str, str, str]:
+ scope = str(model_scope or "").strip()
+ key = str(model_key or "").strip()
+
+ if not scope:
+ scope = DEFAULT_TERMINAL_MODEL_SCOPE
+
+ if scope in ("terminal", "term", "client", "global"):
+ scope = DEFAULT_TERMINAL_MODEL_SCOPE
+ key = DEFAULT_TERMINAL_MODEL_KEY
+ elif scope in ("router", "switch", "switch-lp", "router-lp"):
+ scope = "switch"
+ if not key:
+ key = "unknown"
+ else:
+ if not key:
+ key = DEFAULT_TERMINAL_MODEL_KEY
+
+ model_id = f"{scope}:{key}"
+ return scope, key, model_id
+
+
+def model_identity_from_real_args(
+ real_args: list[str],
+ *,
+ offset: int,
+ default_scope: str = DEFAULT_TERMINAL_MODEL_SCOPE,
+ default_key: str = DEFAULT_TERMINAL_MODEL_KEY,
+) -> tuple[str, str, str]:
+ if len(real_args) >= offset + 2:
+ return normalize_model_identity(real_args[offset], real_args[offset + 1])
+ return normalize_model_identity(default_scope, default_key)
+
+
+
+
+class ScopedEventTimeModelRegistry:
+ def __init__(self, kwargs: dict):
+ self.kwargs = dict(kwargs)
+ self.models: dict[str, EventTimeModel] = {}
+ self.debug = False
+ self.model_versions: dict[str, int] = {}
+
+ def set_debug(self, enabled: bool) -> None:
+ self.debug = bool(enabled)
+ for model in self.models.values():
+ model.set_debug(self.debug)
+
+ def get(self, model_scope: str | None = None, model_key: str | None = None) -> EventTimeModel:
+ _, _, model_id = normalize_model_identity(model_scope, model_key)
+ if model_id not in self.models:
+ model = EventTimeModel(**self.kwargs)
+ model.set_debug(self.debug)
+ self.models[model_id] = model
+ self.model_versions.setdefault(model_id, 0)
+ return self.models[model_id]
+
+ def model_id(self, model_scope: str | None = None, model_key: str | None = None) -> str:
+ return normalize_model_identity(model_scope, model_key)[2]
+
+ def train_or_update(self, model_scope: str | None = None, model_key: str | None = None) -> bool:
+ if model_scope in (None, "", "all", "*"):
+ trained_any = False
+ for model_id, model in self.models.items():
+ trained = model.train_or_update()
+ if trained:
+ self.model_versions[model_id] = self.model_versions.get(model_id, 0) + 1
+ trained_any = trained or trained_any
+ return trained_any
+
+ model_id = self.model_id(model_scope, model_key)
+ model = self.get(model_scope, model_key)
+ trained = model.train_or_update()
+ if trained:
+ self.model_versions[model_id] = self.model_versions.get(model_id, 0) + 1
+ return trained
+
+ @staticmethod
+ def _safe_filename(model_id: str) -> str:
+ return model_id.replace("/", "_").replace(":", "__")
+
+ def save(self, path: str | Path, model_scope: str | None = None, model_key: str | None = None) -> None:
+ path = Path(path)
+
+ if model_scope not in (None, "", "all", "*"):
+ model = self.get(model_scope, model_key)
+ if path.parent:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ model.save(path)
+ return
+
+ path.mkdir(parents=True, exist_ok=True)
+ manifest = {
+ "format": "scoped-event-time-directory-v1",
+ "models": {},
+ }
+
+ for model_id, model in sorted(self.models.items()):
+ filename = self._safe_filename(model_id) + ".pt"
+ model.save(path / filename)
+ manifest["models"][model_id] = filename
+
+ (path / "manifest.json").write_text(json.dumps(manifest, indent=2, sort_keys=True))
+
+ def load(self, path: str | Path, model_scope: str | None = None, model_key: str | None = None) -> None:
+ path = Path(path)
+
+ if path.is_dir():
+ manifest_path = path / "manifest.json"
+ if not manifest_path.exists():
+ raise FileNotFoundError(f"missing scoped event-time manifest: {manifest_path}")
+
+ manifest = json.loads(manifest_path.read_text())
+ self.models.clear()
+ self.model_versions.clear()
+ for model_id, filename in manifest.get("models", {}).items():
+ scope, key = model_id.split(":", 1)
+ model = self.get(scope, key)
+ model.load(path / filename)
+ model.set_debug(self.debug)
+ self.model_versions[model_id] = self.model_versions.get(model_id, 0) + 1
+ return
+
+ # Backward-compatible load of an old single event-time model file.
+ model = self.get(model_scope, model_key)
+ model.load(path)
+ model.set_debug(self.debug)
+ model_id = self.model_id(model_scope, model_key)
+ self.model_versions[model_id] = self.model_versions.get(model_id, 0) + 1
+
+ def status(self, model_scope: str | None = None, model_key: str | None = None) -> dict[str, str]:
+ if model_scope in (None, "", "all", "*"):
+ model_ids = sorted(self.models)
+ else:
+ model_ids = [self.model_id(model_scope, model_key)]
+
+ total_rows = 0
+ trained_models = 0
+ entries = []
+
+ for model_id in model_ids:
+ model = self.models.get(model_id)
+ if model is None:
+ continue
+ total_rows += len(model.rows)
+ trained_models += int(bool(model.trained))
+ entries.append(
+ f"{model_id}:rows={len(model.rows)},trained={int(model.trained)},examples={model.training_examples}"
+ )
+
+ return {
+ "model_count": str(len(model_ids)),
+ "trained_models": str(trained_models),
+ "total_rows": str(total_rows),
+ "models": ";".join(entries),
+ }
+
+
+iteration_time_models = IterationTimeModelRegistry(**ITERATION_MODEL_KWARGS)
+event_time_models = ScopedEventTimeModelRegistry(EVENT_TIME_MODEL_KWARGS)
+
+event_time_model = event_time_models.get()
iteration_model_path = os.environ.get("ZMQML_ITERATION_MODEL_PATH", "").strip()
event_time_model_path = os.environ.get("ZMQML_EVENT_TIME_MODEL_PATH", "").strip()
event_time_record_log_path = os.environ.get("ZMQML_EVENT_TIME_RECORD_LOG_PATH", "").strip()
@@ -58,10 +238,7 @@
app_alloc_path = os.environ.get("ZMQML_APP_ALLOC_PATH", "").strip()
auto_train_on_records = os.environ.get(
- "ZMQML_AUTO_TRAIN_ON_RECORDS", "1"
-).strip().lower() in ("1", "true", "yes", "on")
-event_time_auto_train_on_records = os.environ.get(
- "ZMQML_EVENT_TIME_AUTO_TRAIN_ON_RECORDS", "0"
+ "ZMQML_AUTO_TRAIN_ON_RECORDS", "0"
).strip().lower() in ("1", "true", "yes", "on")
iteration_model_version = 0
@@ -83,10 +260,11 @@
)
if event_time_model_path:
- event_time_model.load(event_time_model_path)
+ event_time_models.load(event_time_model_path)
+ event_time_model = event_time_models.get()
event_time_model_version = 1
print(
- f"[zmqmlserver] loaded event-time model: {event_time_model_path}",
+ f"[zmqmlserver] loaded event-time model(s): {event_time_model_path}",
flush=True,
)
@@ -228,7 +406,7 @@ def set_director_debug_prints(args):
director_debug_prints = raw in ("1", "true", "yes", "on", "enabled")
iteration_time_models.set_debug(director_debug_prints)
- event_time_model.set_debug(director_debug_prints)
+ event_time_models.set_debug(director_debug_prints)
if director_debug_prints:
print(f"[zmqmlserver] director_debug_prints=1", flush=True)
@@ -335,13 +513,17 @@ def receivedata(args, bindata):
#
# receive training records
#
+
def receiverecords(args, bindata):
status = "failed"
st = time.time()
- num_args = int(args[0]) # 1st arg is num of args
- client = int(args[1]) # 2nd arg is client id
- num_records = int(args[2]) # 3rd arg is num records
+ real_args = _real_command_args(args)
+ if len(real_args) < 2:
+ return ("failed", time.time() - st)
+
+ client = int(real_args[0])
+ num_records = int(real_args[1])
records_str = str(bindata.decode('utf-8'))
records_str = records_str.strip()
@@ -361,17 +543,10 @@ def receiverecords(args, bindata):
training_records[client].extend(parsed_records)
- # Keep the raw records available for offline/pretraining workflows.
- # By default this preserves the old behavior and trains immediately.
- # Set ZMQML_AUTO_TRAIN_ON_RECORDS=0 for pure-PDES collection or
- # frozen pretrained inference runs.
if parsed_records:
append_record_log(client, parsed_records)
model = iteration_time_models.get(client)
- # Enrich the ML model with app_id metadata when available.
- # The C++ protocol still sends client + timing values, while the Python
- # server infers app_id from ZMQML_APP_ALLOC_PATH.
app_id = client_app_map.get(client, -1) if "client_app_map" in globals() else -1
if hasattr(model, "set_app_id"):
@@ -400,19 +575,17 @@ def receiverecords(args, bindata):
return (status, elapsed_time)
-#
-# do inference to get predictions
-#
def launch_iteration_time_inferencing(args, bindata):
status = "failed"
st = time.time()
- num_args = int(args[0]) # 1st arg is num of args
- client = int(args[1]) # 2nd arg is client id
- num_steps = int(args[2]) # 3rd arg is num steps to predict
+ real_args = _real_command_args(args)
+ if len(real_args) < 2:
+ return ("failed", time.time() - st, "")
+
+ client = int(real_args[0])
+ num_steps = int(real_args[1])
- # Optional recent-context payload. The normal path uses records previously
- # received through send-records, but accepting context keeps the API flexible.
records_str = str(bindata.decode('utf-8'))
records_str = records_str.strip()
@@ -448,11 +621,6 @@ def launch_iteration_time_inferencing(args, bindata):
elapsed_time = time.time() - st
return (status, elapsed_time, inferences_str)
-
-
-
-
-
def event_time_payload_has_header(payload: str) -> bool:
for line in payload.splitlines():
line = line.strip()
@@ -502,27 +670,101 @@ def append_event_time_record_log(payload: str) -> None:
f.write("\n")
+def event_time_model_identity_from_row(row: dict) -> tuple[str, str, str]:
+ raw_lp_type = str(row.get("current_lp_type", "")).strip()
+ raw_gid = str(row.get("current_lp_gid", "")).strip()
+
+ if raw_lp_type in EVENT_TIME_TERMINAL_LP_TYPES:
+ return normalize_model_identity("terminal", "global")
+
+ return normalize_model_identity("switch", raw_gid or "unknown")
+
+
+def iter_event_time_rows_from_payload(raw_payload: str):
+ payload = event_time_payload_with_header(raw_payload)
+ if not payload.strip():
+ return
+
+ reader = csv.DictReader(io.StringIO(payload))
+ if reader.fieldnames:
+ reader.fieldnames = [str(name).strip().lstrip("#").strip() for name in reader.fieldnames]
+
+ for row in reader:
+ clean = {str(k).strip().lstrip("#").strip(): v for k, v in row.items()}
+ yield clean
+
+
def receive_event_time_records(args, bindata):
st = time.time()
raw_payload = bindata.decode("utf-8", errors="replace").strip()
- payload = event_time_payload_with_header(raw_payload)
- loaded_rows = event_time_model.add_records_text(payload) if payload else 0
+ loaded_rows = 0
+
+ # Important performance rule:
+ # Do NOT call EventTimeModel.add_records_text(...) once per row.
+ # Event-time batches can contain 65K+ rows, and per-row parsing/routing makes
+ # pure-PDES collection much slower than the old single-global model path.
+ #
+ # Instead, parse once, group rows by scoped model id, then call
+ # add_records_text(...) once per scoped model per C++ batch.
+ grouped_rows: dict[str, list[dict]] = {}
+ grouped_identity: dict[str, tuple[str, str]] = {}
+
+ for row in iter_event_time_rows_from_payload(raw_payload) or []:
+ model_scope, model_key, model_id = event_time_model_identity_from_row(row)
+ grouped_rows.setdefault(model_id, []).append(row)
+ grouped_identity[model_id] = (model_scope, model_key)
+
+ per_model_loaded: dict[str, int] = {}
+
+ for model_id, rows in grouped_rows.items():
+ model_scope, model_key = grouped_identity[model_id]
+ model = event_time_models.get(model_scope, model_key)
+
+ if not rows:
+ continue
+
+ # Build one CSV payload for this model. Preserve the field order from
+ # the first row and include any later extra keys defensively.
+ fieldnames = list(rows[0].keys())
+ seen = set(fieldnames)
+ for row in rows[1:]:
+ for key in row.keys():
+ if key not in seen:
+ fieldnames.append(key)
+ seen.add(key)
+
+ buf = io.StringIO()
+ writer = csv.DictWriter(buf, fieldnames=fieldnames, extrasaction="ignore")
+ writer.writeheader()
+ writer.writerows(rows)
+
+ accepted = model.add_records_text(buf.getvalue())
+ loaded_rows += accepted
+ per_model_loaded[model_id] = accepted
+
+ if accepted > 0 and auto_train_on_records:
+ model.train_or_update()
if loaded_rows > 0:
append_event_time_record_log(raw_payload)
- if event_time_auto_train_on_records:
- event_time_model.train_or_update()
- director_debug(
- f"[event-time records] loaded_rows={loaded_rows} "
- f"total_rows={len(event_time_model.rows)} "
- f"trained={int(event_time_model.trained)}"
- )
+ if director_debug_prints:
+ # Keep this compact. Printing the full per_model dict for 100+ switch
+ # models is noisy and can itself become expensive.
+ nonempty_models = sum(1 for v in per_model_loaded.values() if v > 0)
+ min_rows = min(per_model_loaded.values()) if per_model_loaded else 0
+ max_rows = max(per_model_loaded.values()) if per_model_loaded else 0
+ sample = sorted(per_model_loaded.items())[:8]
+ print(
+ f"[event-time records] loaded_rows={loaded_rows} "
+ f"models={nonempty_models} min_rows_per_model={min_rows} "
+ f"max_rows_per_model={max_rows} sample={sample}",
+ flush=True,
+ )
return ("done", time.time() - st, loaded_rows)
-
def load_event_time_records_csv_command(args):
st = time.time()
real_args = _real_command_args(args)
@@ -542,22 +784,44 @@ def load_event_time_records_csv_command(args):
"error": f"event-time records path does not exist: {path}",
}
- loaded_rows = event_time_model.load_csv(path)
+ loaded_rows = 0
+ files = sorted(path.rglob("*")) if path.is_dir() else [path]
- return {
+ for child in files:
+ if not child.is_file() or child.suffix.lower() not in (".csv", ".txt", ".log"):
+ continue
+ status, _et, child_rows = receive_event_time_records(["0"], child.read_text().encode("utf-8"))
+ if status == "done":
+ loaded_rows += int(child_rows)
+
+ ret = {
"status": "done",
"et": str(time.time() - st),
"path": str(path),
"loaded_rows": str(loaded_rows),
- "total_rows": str(len(event_time_model.rows)),
}
+ ret.update(event_time_models.status())
+ return ret
def train_event_time_model_command(args):
global event_time_model_version
st = time.time()
- trained = event_time_model.train_or_update()
+ real_args = _real_command_args(args)
+
+ target = real_args[0] if real_args else "all"
+ if target in ("", "all", "*"):
+ trained = event_time_models.train_or_update("all", "")
+ model_scope = "all"
+ model_key = ""
+ model_id = "all"
+ elif len(real_args) >= 2:
+ model_scope, model_key, model_id = normalize_model_identity(real_args[0], real_args[1])
+ trained = event_time_models.train_or_update(model_scope, model_key)
+ else:
+ model_scope, model_key, model_id = normalize_model_identity("switch", target)
+ trained = event_time_models.train_or_update(model_scope, model_key)
if trained:
event_time_model_version += 1
@@ -565,18 +829,17 @@ def train_event_time_model_command(args):
ret = {
"status": "done" if trained else "failed",
"et": str(time.time() - st),
+ "target": model_id,
"model_version": str(event_time_model_version),
}
- ret.update(event_time_model.status())
+ ret.update(event_time_models.status(model_scope, model_key))
if not trained:
- ret["error"] = "event-time model was not trained; load enough generic event-time rows first"
+ ret["error"] = "event-time model was not trained; load enough scoped event-time rows first"
print(
- f"[event-time model-train-command] trained={int(trained)} "
- f"rows={len(event_time_model.rows)} "
- f"training_examples={event_time_model.training_examples} "
- f"model_version={event_time_model_version}",
+ f"[event-time model-train-command] target={model_id} trained={int(trained)} "
+ f"model_version={event_time_model_version} status={event_time_models.status(model_scope, model_key)}",
flush=True,
)
@@ -595,21 +858,29 @@ def save_event_time_model_command(args):
}
model_path = Path(real_args[0])
- if model_path.parent:
- model_path.parent.mkdir(parents=True, exist_ok=True)
+ target = real_args[1] if len(real_args) >= 2 else "all"
- event_time_model.save(model_path)
+ if target in ("", "all", "*"):
+ event_time_models.save(model_path, "all", "")
+ model_id = "all"
+ elif len(real_args) >= 3:
+ model_scope, model_key, model_id = normalize_model_identity(real_args[1], real_args[2])
+ event_time_models.save(model_path, model_scope, model_key)
+ else:
+ model_scope, model_key, model_id = normalize_model_identity("switch", target)
+ event_time_models.save(model_path, model_scope, model_key)
return {
"status": "done",
"et": str(time.time() - st),
"path": str(model_path),
+ "target": model_id,
"model_version": str(event_time_model_version),
}
def load_event_time_model_command(args):
- global event_time_model_version
+ global event_time_model_version, event_time_model
st = time.time()
real_args = _real_command_args(args)
@@ -629,25 +900,50 @@ def load_event_time_model_command(args):
"error": f"model path does not exist: {model_path}",
}
- event_time_model.load(model_path)
+ target = real_args[1] if len(real_args) >= 2 else "all"
+
+ if target in ("", "all", "*"):
+ event_time_models.load(model_path)
+ model_id = "all"
+ elif len(real_args) >= 3:
+ model_scope, model_key, model_id = normalize_model_identity(real_args[1], real_args[2])
+ event_time_models.load(model_path, model_scope, model_key)
+ else:
+ model_scope, model_key, model_id = normalize_model_identity("switch", target)
+ event_time_models.load(model_path, model_scope, model_key)
+
+ event_time_model = event_time_models.get()
event_time_model_version += 1
return {
"status": "done",
"et": str(time.time() - st),
"path": str(model_path),
+ "target": model_id,
"model_version": str(event_time_model_version),
}
def event_time_model_status_command(args):
st = time.time()
+ real_args = _real_command_args(args)
+
+ if not real_args or real_args[0] in ("", "all", "*"):
+ model_scope = "all"
+ model_key = ""
+ model_id = "all"
+ elif len(real_args) >= 2:
+ model_scope, model_key, model_id = normalize_model_identity(real_args[0], real_args[1])
+ else:
+ model_scope, model_key, model_id = normalize_model_identity("switch", real_args[0])
+
ret = {
"status": "done",
"et": str(time.time() - st),
+ "target": model_id,
"model_version": str(event_time_model_version),
}
- ret.update(event_time_model.status())
+ ret.update(event_time_models.status(model_scope, model_key))
return ret
@@ -663,15 +959,24 @@ def launch_event_time_inferencing(args, bindata):
requested_count = 1
payload = bindata.decode("utf-8", errors="replace").strip()
- predictions = event_time_model.predict_from_text(
+ rows = list(iter_event_time_rows_from_payload(payload) or [])
+
+ if rows:
+ model_scope, model_key, model_id = event_time_model_identity_from_row(rows[0])
+ else:
+ model_scope, model_key, model_id = normalize_model_identity()
+
+ model = event_time_models.get(model_scope, model_key)
+ predictions = model.predict_from_text(
payload,
requested_count=max(1, requested_count),
)
predictions_str = " ".join(str(float(x)) for x in predictions)
director_debug(
- f"[event-time inference] requested_count={requested_count} "
- f"payload_bytes={len(payload)} predictions={predictions_str}"
+ f"[event-time inference] model={model_id} requested_count={requested_count} "
+ f"payload_bytes={len(payload)} trained={int(model.trained)} "
+ f"rows={len(model.rows)} predictions={predictions_str}"
)
return ("done", time.time() - st, predictions_str)
@@ -695,6 +1000,7 @@ def _real_command_args(args):
return out
+
def train_iteration_time_model_command(args):
global iteration_model_version
@@ -875,7 +1181,6 @@ def load_iteration_records_csv_command(args):
"loaded_clients": str(len(loaded_clients)),
}
-
def load_iteration_time_model_command(args):
global iteration_model_version
@@ -916,6 +1221,7 @@ def load_iteration_time_model_command(args):
}
+
def iteration_time_model_status_command(args):
st = time.time()
real_args = _real_command_args(args)
@@ -960,8 +1266,6 @@ def iteration_time_model_status_command(args):
"clients": ";".join(per_client),
}
-
-# Backwards-compatible wrapper for the old command name.
def launch_surrogate_inferencing(args, bindata):
return launch_iteration_time_inferencing(args, bindata)
@@ -991,7 +1295,7 @@ def director_request_command(msg, bindata):
family = str(msg.get("surrogate_family", "iteration-time")).strip()
operation = str(msg.get("operation", "")).strip()
backend = str(msg.get("surrogate_backend", "")).strip()
- args = _real_command_args(msg.get("args", []))
+ args = msg.get("args", [])
operation_aliases = {
"status": "model-status",