Coverage for sparkle/CLI/compute_features.py: 90%

82 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-05 14:48 +0000

1#!/usr/bin/env python3 

2"""Sparkle command to compute features for instances.""" 

3from __future__ import annotations 

4import sys 

5import argparse 

6from pathlib import Path 

7 

8import runrunner as rrr 

9from runrunner.base import Runner, Status, Run 

10 

11from sparkle.solver import Extractor 

12from sparkle.CLI.help import global_variables as gv 

13from sparkle.CLI.help import logging as sl 

14from sparkle.platform.settings_objects import SettingState 

15from sparkle.CLI.help import argparse_custom as ac 

16from sparkle.platform import COMMAND_DEPENDENCIES, CommandName 

17from sparkle.CLI.initialise import check_for_initialise 

18from sparkle.structures import FeatureDataFrame 

19 

20 

21def parser_function() -> argparse.ArgumentParser: 

22 """Define the command line arguments.""" 

23 parser = argparse.ArgumentParser(description="Sparkle command to Compute features " 

24 "for instances using added extractors " 

25 "and instances.") 

26 parser.add_argument(*ac.RecomputeFeaturesArgument.names, 

27 **ac.RecomputeFeaturesArgument.kwargs) 

28 parser.add_argument(*ac.SettingsFileArgument.names, 

29 **ac.SettingsFileArgument.kwargs) 

30 parser.add_argument(*ac.RunOnArgument.names, 

31 **ac.RunOnArgument.kwargs) 

32 

33 return parser 

34 

35 

36def compute_features( 

37 feature_data: Path | FeatureDataFrame, 

38 recompute: bool, 

39 run_on: Runner = Runner.SLURM) -> Run: 

40 """Compute features for all instance and feature extractor combinations. 

41 

42 A RunRunner run is submitted for the computation of the features. 

43 The results are then stored in the csv file specified by feature_data_csv_path. 

44 

45 Args: 

46 feature_data: Feature Data Frame to use, or path to read it from. 

47 recompute: Specifies if features should be recomputed. 

48 run_on: Runner 

49 On which computer or cluster environment to run the solvers. 

50 Available: Runner.LOCAL, Runner.SLURM. Default: Runner.SLURM 

51 

52 Returns: 

53 The Slurm job or Local job 

54 """ 

55 if isinstance(feature_data, Path): 

56 feature_data = FeatureDataFrame(feature_data) 

57 if recompute: 

58 feature_data.reset_dataframe() 

59 jobs = feature_data.remaining_jobs() 

60 

61 # If there are no jobs, stop 

62 if not jobs: 

63 print("No feature computation jobs to run; stopping execution! To recompute " 

64 "feature values use the --recompute flag.") 

65 return None 

66 cutoff = gv.settings().get_general_extractor_cutoff_time() 

67 cmd_list = [] 

68 extractors = {} 

69 instance_paths = set() 

70 features_core = Path(__file__).parent.resolve() / "core" / "compute_features.py" 

71 # We create a job for each instance/extractor combination 

72 for instance_path, extractor_name, feature_group in jobs: 

73 extractor_path = gv.settings().DEFAULT_extractor_dir / extractor_name 

74 instance_paths.add(instance_path) 

75 cmd = (f"{features_core} " 

76 f"--instance {instance_path} " 

77 f"--extractor {extractor_path} " 

78 f"--feature-csv {feature_data.csv_filepath} " 

79 f"--cutoff {cutoff} " 

80 f"--log-dir {sl.caller_log_dir}") 

81 if extractor_name in extractors: 

82 extractor = extractors[extractor_name] 

83 else: 

84 extractor = Extractor(extractor_path) 

85 extractors[extractor_name] = extractor 

86 if extractor.groupwise_computation: 

87 # Extractor job can be parallelised, thus creating i * e * g jobs 

88 cmd_list.append(cmd + f" --feature-group {feature_group}") 

89 else: 

90 cmd_list.append(cmd) 

91 

92 print(f"The number of compute jobs: {len(cmd_list)}") 

93 

94 parallel_jobs = min(len(cmd_list), gv.settings().get_number_of_jobs_in_parallel()) 

95 sbatch_options = gv.settings().get_slurm_extra_options(as_args=True) 

96 srun_options = ["-N1", "-n1"] + sbatch_options 

97 run = rrr.add_to_queue( 

98 runner=run_on, 

99 cmd=cmd_list, 

100 name=f"Compute Features: {len(extractors)} Extractors on " 

101 f"{len(instance_paths)} instances", 

102 parallel_jobs=parallel_jobs, 

103 base_dir=sl.caller_log_dir, 

104 sbatch_options=sbatch_options, 

105 srun_options=srun_options) 

106 

107 if run_on == Runner.SLURM: 

108 print(f"Running the extractors through Slurm with Job IDs: {run.run_id}") 

109 elif run_on == Runner.LOCAL: 

110 print("Waiting for the local calculations to finish.") 

111 run.wait() 

112 for job in run.jobs: 

113 jobs_done = sum(j.status == Status.COMPLETED for j in run.jobs) 

114 print(f"Executing Progress: {jobs_done} out of {len(run.jobs)}") 

115 if jobs_done == len(run.jobs): 

116 break 

117 job.wait() 

118 print("Computing features done!") 

119 

120 return run 

121 

122 

123def main(argv: list[str]) -> None: 

124 """Main function of the compute features command.""" 

125 # Log command call 

126 sl.log_command(sys.argv) 

127 

128 # Define command line arguments 

129 parser = parser_function() 

130 

131 # Process command line arguments 

132 args = parser.parse_args(argv) 

133 

134 check_for_initialise(COMMAND_DEPENDENCIES[CommandName.COMPUTE_FEATURES]) 

135 

136 if ac.set_by_user(args, "settings_file"): 

137 gv.settings().read_settings_ini( 

138 args.settings_file, SettingState.CMD_LINE 

139 ) # Do first, so other command line options can override settings from the file 

140 if args.run_on is not None: 

141 gv.settings().set_run_on( 

142 args.run_on.value, SettingState.CMD_LINE) 

143 run_on = gv.settings().get_run_on() 

144 

145 # Check if there are any feature extractors registered 

146 if not any([p.is_dir() for p in gv.settings().DEFAULT_extractor_dir.iterdir()]): 

147 print("No feature extractors present! Add feature extractors to Sparkle " 

148 "by using the add_feature_extractor command.") 

149 sys.exit() 

150 

151 # Start compute features 

152 print("Start computing features ...") 

153 compute_features(gv.settings().DEFAULT_feature_data_path, args.recompute, run_on) 

154 

155 # Write used settings to file 

156 gv.settings().write_used_settings() 

157 sys.exit(0) 

158 

159 

160if __name__ == "__main__": 

161 main(sys.argv[1:])