Coverage for src/sparkle/CLI/compute_features.py: 91%
70 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 14:11 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 14:11 +0000
1#!/usr/bin/env python3
2"""Sparkle command to compute features for instances."""
4from __future__ import annotations
5import sys
6import argparse
7from pathlib import Path
9from runrunner.base import Run, Runner
11from sparkle.selector import Extractor
12from sparkle.platform.settings_objects import Settings
13from sparkle.structures import FeatureDataFrame
14from sparkle.instance import Instance_Set, InstanceSet
17from sparkle.CLI.help import global_variables as gv
18from sparkle.CLI.help import logging as sl
19from sparkle.CLI.help import argparse_custom as ac
20from sparkle.CLI.initialise import check_for_initialise
21from sparkle.CLI.help.nicknames import resolve_instance_name
24def parser_function() -> argparse.ArgumentParser:
25 """Define the command line arguments."""
26 parser = argparse.ArgumentParser(
27 description="Sparkle command to Compute features "
28 "for instances using added extractors "
29 "and instances."
30 )
31 parser.add_argument(
32 *ac.RecomputeFeaturesArgument.names, **ac.RecomputeFeaturesArgument.kwargs
33 )
34 # Settings arguments
35 parser.add_argument(*ac.SettingsFileArgument.names, **ac.SettingsFileArgument.kwargs)
36 parser.add_argument(*Settings.OPTION_run_on.args, **Settings.OPTION_run_on.kwargs)
37 return parser
40def compute_features(
41 feature_data: Path | FeatureDataFrame,
42 recompute: bool,
43 run_on: Runner = Runner.SLURM,
44) -> list[Run]:
45 """Compute features for all instance and feature extractor combinations.
47 A RunRunner run is submitted for the computation of the features.
48 The results are then stored in the csv file specified by feature_data_csv_path.
50 Args:
51 feature_data: Feature Data Frame to use, or path to read it from.
52 recompute: Specifies if features should be recomputed.
53 run_on: Runner
54 On which computer or cluster environment to run the solvers.
55 Available: Runner.LOCAL, Runner.SLURM. Default: Runner.SLURM
57 Returns:
58 The Slurm job or Local job
59 """
60 if isinstance(feature_data, Path):
61 feature_data = FeatureDataFrame(feature_data)
62 if recompute:
63 feature_data.reset_dataframe()
64 jobs = feature_data.remaining_jobs()
66 # Lookup all instances to resolve the instance paths later
67 instances: list[InstanceSet] = []
68 for instance_dir in gv.settings().DEFAULT_instance_dir.iterdir():
69 if instance_dir.is_dir():
70 instances.append(Instance_Set(instance_dir))
72 # If there are no jobs, stop
73 if not jobs:
74 print(
75 "No feature computation jobs to run; stopping execution! To recompute "
76 "feature values use the --recompute flag."
77 )
78 return None
79 cutoff = gv.settings().extractor_cutoff_time
80 instance_paths = set()
81 grouped_job_list: dict[str, dict[str, list[str]]] = {}
83 # Group the jobs by extractor/feature group
84 for instance_name, extractor_name, feature_group in jobs:
85 if extractor_name not in grouped_job_list:
86 grouped_job_list[extractor_name] = {}
87 if feature_group not in grouped_job_list[extractor_name]:
88 grouped_job_list[extractor_name][feature_group] = []
89 instance_path = resolve_instance_name(str(instance_name), instances)
90 grouped_job_list[extractor_name][feature_group].append(instance_path)
92 sbatch_options = gv.settings().sbatch_settings
93 slurm_prepend = gv.settings().slurm_job_prepend
94 srun_options = ["-N1", "-n1"] + sbatch_options
95 runs = []
96 for extractor_name, feature_groups in grouped_job_list.items():
97 extractor_path = gv.settings().DEFAULT_extractor_dir / extractor_name
98 extractor = Extractor(extractor_path)
99 for feature_group, instance_paths in feature_groups.items():
100 run = extractor.run_cli(
101 instance_paths,
102 feature_data,
103 cutoff,
104 feature_group if extractor.groupwise_computation else None,
105 run_on,
106 sbatch_options,
107 srun_options,
108 gv.settings().slurm_jobs_in_parallel,
109 slurm_prepend,
110 log_dir=sl.caller_log_dir,
111 )
112 runs.append(run)
113 return runs
116def main(argv: list[str]) -> None:
117 """Main function of the compute features command."""
118 # Define command line arguments
119 parser = parser_function()
121 # Process command line arguments
122 args = parser.parse_args(argv)
123 settings = gv.settings(args)
124 run_on = settings.run_on
126 # Log command call
127 sl.log_command(sys.argv, settings.random_state)
128 check_for_initialise()
130 # Check if there are any feature extractors registered
131 if not any([p.is_dir() for p in gv.settings().DEFAULT_extractor_dir.iterdir()]):
132 print(
133 "No feature extractors present! Add feature extractors to Sparkle "
134 "by using the add_feature_extractor command."
135 )
136 sys.exit()
138 # Start compute features
139 print("Start computing features ...")
140 compute_features(settings.DEFAULT_feature_data_path, args.recompute, run_on)
142 # Write used settings to file
143 gv.settings().write_used_settings()
144 sys.exit(0)
147if __name__ == "__main__":
148 main(sys.argv[1:])