Coverage for src / sparkle / CLI / compute_features.py: 88%

95 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 15:31 +0000

1#!/usr/bin/env python3 

2"""Sparkle command to compute features for instances.""" 

3 

4from __future__ import annotations 

5import sys 

6import argparse 

7 

8from runrunner.base import Run, Runner 

9 

10from sparkle.selector import Extractor 

11from sparkle.platform.settings_objects import Settings 

12from sparkle.structures import FeatureDataFrame 

13from sparkle.instance import Instance_Set, InstanceSet 

14 

15 

16from sparkle.CLI.help import global_variables as gv 

17from sparkle.CLI.help import logging as sl 

18from sparkle.CLI.help import argparse_custom as ac 

19from sparkle.CLI.initialise import check_for_initialise 

20from sparkle.CLI.help.nicknames import resolve_object_name, resolve_instance_name 

21 

22 

23def parser_function() -> argparse.ArgumentParser: 

24 """Define the command line arguments.""" 

25 parser = argparse.ArgumentParser( 

26 description="Sparkle command to Compute features " 

27 "for instances using added extractors " 

28 "and instances." 

29 ) 

30 parser.add_argument( 

31 *ac.InstanceSetPathsArgument.names, **ac.InstanceSetPathsArgument.kwargs 

32 ) 

33 parser.add_argument(*ac.ExtractorsArgument.names, **ac.ExtractorsArgument.kwargs) 

34 parser.add_argument( 

35 *ac.RecomputeFeaturesArgument.names, **ac.RecomputeFeaturesArgument.kwargs 

36 ) 

37 # Settings arguments 

38 parser.add_argument(*ac.SettingsFileArgument.names, **ac.SettingsFileArgument.kwargs) 

39 parser.add_argument(*Settings.OPTION_run_on.args, **Settings.OPTION_run_on.kwargs) 

40 return parser 

41 

42 

43def compute_features( 

44 feature_data: FeatureDataFrame, 

45 recompute: bool, 

46 run_on: Runner = Runner.SLURM, 

47) -> list[Run]: 

48 """Compute features for all instance and feature extractor combinations. 

49 

50 A RunRunner run is submitted for the computation of the features. 

51 The results are then stored in the csv file specified by feature_data_csv_path. 

52 

53 Args: 

54 feature_data: Feature Data Frame to use 

55 recompute: Specifies if features should be recomputed. 

56 run_on: Runner 

57 On which computer or cluster environment to run the solvers. 

58 Available: Runner.LOCAL, Runner.SLURM. Default: Runner.SLURM 

59 

60 Returns: 

61 The Slurm job or Local job 

62 """ 

63 if recompute: 

64 feature_data.reset_dataframe() 

65 jobs = feature_data.remaining_jobs() 

66 

67 # Lookup all instances to resolve the instance paths later 

68 instances: list[InstanceSet] = [] 

69 for instance_dir in gv.settings().DEFAULT_instance_dir.iterdir(): 

70 if instance_dir.is_dir(): 

71 instances.append(Instance_Set(instance_dir)) 

72 

73 # If there are no jobs, stop 

74 if not jobs: 

75 print( 

76 "No feature computation jobs to run; stopping execution! To recompute " 

77 "feature values use the --recompute flag." 

78 ) 

79 return 

80 cutoff = gv.settings().extractor_cutoff_time 

81 instance_paths = set() 

82 grouped_job_list: dict[str, dict[str, list[str]]] = {} 

83 

84 # Group the jobs by extractor/feature group 

85 for instance_name, extractor_name, feature_group in jobs: 

86 if extractor_name not in grouped_job_list: 

87 grouped_job_list[extractor_name] = {} 

88 if feature_group not in grouped_job_list[extractor_name]: 

89 grouped_job_list[extractor_name][feature_group] = [] 

90 instance_path = resolve_instance_name(str(instance_name), instances) 

91 grouped_job_list[extractor_name][feature_group].append(instance_path) 

92 

93 sbatch_options = gv.settings().sbatch_settings 

94 slurm_prepend = gv.settings().slurm_job_prepend 

95 srun_options = ["-N1", "-n1"] + sbatch_options 

96 runs = [] 

97 for extractor_name, feature_groups in grouped_job_list.items(): 

98 extractor_path = gv.settings().DEFAULT_extractor_dir / extractor_name 

99 extractor = Extractor(extractor_path) 

100 for feature_group, instance_paths in feature_groups.items(): 

101 run = extractor.run_cli( 

102 instance_paths, 

103 feature_data, 

104 cutoff, 

105 feature_group if extractor.groupwise_computation else None, 

106 run_on, 

107 sbatch_options, 

108 srun_options, 

109 gv.settings().slurm_jobs_in_parallel, 

110 slurm_prepend, 

111 log_dir=sl.caller_log_dir, 

112 ) 

113 runs.append(run) 

114 return runs 

115 

116 

117def main(argv: list[str]) -> None: 

118 """Main function of the compute features command.""" 

119 # Define command line arguments 

120 parser = parser_function() 

121 

122 # Process command line arguments 

123 args = parser.parse_args(argv) 

124 settings = gv.settings(args) 

125 run_on = settings.run_on 

126 

127 # Log command call 

128 sl.log_command(sys.argv, settings.random_state) 

129 check_for_initialise() 

130 

131 # Check if there are any feature extractors registered 

132 if not any([p.is_dir() for p in gv.settings().DEFAULT_extractor_dir.iterdir()]): 

133 print( 

134 "No feature extractors present! Add feature extractors to Sparkle " 

135 "by using the add_feature_extractor command." 

136 ) 

137 sys.exit() 

138 

139 # Load feature data 

140 feature_data = FeatureDataFrame(settings.DEFAULT_feature_data_path) 

141 

142 # Filter instances or extractors 

143 if args.instance_path: 

144 instances = [] 

145 for instance_arg in args.instance_path: 

146 instance: InstanceSet = resolve_object_name( 

147 instance_arg, 

148 gv.instance_set_nickname_mapping, 

149 settings.DEFAULT_instance_dir, 

150 Instance_Set, 

151 ) 

152 if instance is None: 

153 raise ValueError( 

154 f"Argument Error! Could not resolve instance: '{instance_arg}'" 

155 ) 

156 for i in instance.instance_names: 

157 instances.append(i) 

158 

159 for instance in feature_data.instances: 

160 if instance not in instances: 

161 feature_data.remove_instances(instance) 

162 if feature_data.num_instances == 0: 

163 raise ValueError("Argument Error! No instances left after filtering.") 

164 if args.extractors: 

165 extractors = [] 

166 for extractor in args.extractors: 

167 extractor: Extractor = resolve_object_name( 

168 extractor, 

169 nickname_dict=gv.extractor_nickname_mapping, 

170 target_dir=settings.DEFAULT_extractor_dir, 

171 class_name=Extractor, 

172 ) 

173 if extractor is None: 

174 raise ValueError( 

175 f"Argument Error! Could not resolve extractor: '{extractor}'" 

176 ) 

177 extractors.append(extractor.name) 

178 for extractor in feature_data.extractors: 

179 if extractor not in extractors: 

180 feature_data.remove_extractor(extractor) 

181 if feature_data.num_extractors == 0: 

182 raise ValueError( 

183 "Argument Error! No feature extractors left after filtering." 

184 ) 

185 

186 # Start compute features 

187 print("Start computing features ...") 

188 compute_features(feature_data, args.recompute, run_on) 

189 

190 # Write used settings to file 

191 gv.settings().write_used_settings() 

192 sys.exit(0) 

193 

194 

195if __name__ == "__main__": 

196 main(sys.argv[1:])