Coverage for sparkle/CLI/compute_features.py: 92%

72 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-29 10:17 +0000

1#!/usr/bin/env python3 

2"""Sparkle command to compute features for instances.""" 

3 

4from __future__ import annotations 

5import sys 

6import argparse 

7from pathlib import Path 

8 

9from runrunner.base import Run, Runner 

10 

11from sparkle.selector import Extractor 

12from sparkle.platform.settings_objects import Settings 

13from sparkle.structures import FeatureDataFrame 

14from sparkle.instance import Instance_Set, InstanceSet 

15 

16 

17from sparkle.CLI.help import global_variables as gv 

18from sparkle.CLI.help import logging as sl 

19from sparkle.CLI.help import argparse_custom as ac 

20from sparkle.CLI.initialise import check_for_initialise 

21from sparkle.CLI.help.nicknames import resolve_instance_name 

22 

23 

24def parser_function() -> argparse.ArgumentParser: 

25 """Define the command line arguments.""" 

26 parser = argparse.ArgumentParser( 

27 description="Sparkle command to Compute features " 

28 "for instances using added extractors " 

29 "and instances." 

30 ) 

31 parser.add_argument( 

32 *ac.RecomputeFeaturesArgument.names, **ac.RecomputeFeaturesArgument.kwargs 

33 ) 

34 # Settings arguments 

35 parser.add_argument(*ac.SettingsFileArgument.names, **ac.SettingsFileArgument.kwargs) 

36 parser.add_argument(*Settings.OPTION_run_on.args, **Settings.OPTION_run_on.kwargs) 

37 return parser 

38 

39 

40def compute_features( 

41 feature_data: Path | FeatureDataFrame, 

42 recompute: bool, 

43 run_on: Runner = Runner.SLURM, 

44) -> list[Run]: 

45 """Compute features for all instance and feature extractor combinations. 

46 

47 A RunRunner run is submitted for the computation of the features. 

48 The results are then stored in the csv file specified by feature_data_csv_path. 

49 

50 Args: 

51 feature_data: Feature Data Frame to use, or path to read it from. 

52 recompute: Specifies if features should be recomputed. 

53 run_on: Runner 

54 On which computer or cluster environment to run the solvers. 

55 Available: Runner.LOCAL, Runner.SLURM. Default: Runner.SLURM 

56 

57 Returns: 

58 The Slurm job or Local job 

59 """ 

60 if isinstance(feature_data, Path): 

61 feature_data = FeatureDataFrame(feature_data) 

62 if recompute: 

63 feature_data.reset_dataframe() 

64 jobs = feature_data.remaining_jobs() 

65 

66 # Lookup all instances to resolve the instance paths later 

67 instances: list[InstanceSet] = [] 

68 for instance_dir in gv.settings().DEFAULT_instance_dir.iterdir(): 

69 if instance_dir.is_dir(): 

70 instances.append(Instance_Set(instance_dir)) 

71 

72 # If there are no jobs, stop 

73 if not jobs: 

74 print( 

75 "No feature computation jobs to run; stopping execution! To recompute " 

76 "feature values use the --recompute flag." 

77 ) 

78 return None 

79 cutoff = gv.settings().extractor_cutoff_time 

80 cmd_list = [] 

81 instance_paths = set() 

82 grouped_job_list: dict[str, dict[str, list[str]]] = {} 

83 

84 # Group the jobs by extractor/feature group 

85 for instance_name, extractor_name, feature_group in jobs: 

86 if extractor_name not in grouped_job_list: 

87 grouped_job_list[extractor_name] = {} 

88 if feature_group not in grouped_job_list[extractor_name]: 

89 grouped_job_list[extractor_name][feature_group] = [] 

90 instance_path = resolve_instance_name(str(instance_name), instances) 

91 grouped_job_list[extractor_name][feature_group].append(instance_path) 

92 

93 parallel_jobs = min(len(cmd_list), gv.settings().slurm_jobs_in_parallel) 

94 sbatch_options = gv.settings().sbatch_settings 

95 slurm_prepend = gv.settings().slurm_job_prepend 

96 srun_options = ["-N1", "-n1"] + sbatch_options 

97 runs = [] 

98 for extractor_name, feature_groups in grouped_job_list.items(): 

99 extractor_path = gv.settings().DEFAULT_extractor_dir / extractor_name 

100 extractor = Extractor(extractor_path) 

101 for feature_group, instance_paths in feature_groups.items(): 

102 run = extractor.run_cli( 

103 instance_paths, 

104 feature_data, 

105 cutoff, 

106 feature_group if extractor.groupwise_computation else None, 

107 run_on, 

108 sbatch_options, 

109 srun_options, 

110 parallel_jobs, 

111 slurm_prepend, 

112 log_dir=sl.caller_log_dir, 

113 ) 

114 runs.append(run) 

115 return runs 

116 

117 

118def main(argv: list[str]) -> None: 

119 """Main function of the compute features command.""" 

120 # Define command line arguments 

121 parser = parser_function() 

122 

123 # Process command line arguments 

124 args = parser.parse_args(argv) 

125 settings = gv.settings(args) 

126 run_on = settings.run_on 

127 

128 # Log command call 

129 sl.log_command(sys.argv, settings.random_state) 

130 check_for_initialise() 

131 

132 # Check if there are any feature extractors registered 

133 if not any([p.is_dir() for p in gv.settings().DEFAULT_extractor_dir.iterdir()]): 

134 print( 

135 "No feature extractors present! Add feature extractors to Sparkle " 

136 "by using the add_feature_extractor command." 

137 ) 

138 sys.exit() 

139 

140 # Start compute features 

141 print("Start computing features ...") 

142 compute_features(settings.DEFAULT_feature_data_path, args.recompute, run_on) 

143 

144 # Write used settings to file 

145 gv.settings().write_used_settings() 

146 sys.exit(0) 

147 

148 

149if __name__ == "__main__": 

150 main(sys.argv[1:])