Coverage for sparkle/tools/configspace.py: 61%

241 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-29 10:17 +0000

1"""Extensions of the ConfigSpace lib.""" 

2 

3from __future__ import annotations 

4from typing_extensions import override, Iterable 

5import ast 

6 

7import numpy as np 

8from typing import Any 

9from ConfigSpace import ConfigurationSpace 

10from ConfigSpace.hyperparameters import Hyperparameter 

11from ConfigSpace.types import Array, Mask, f64 

12 

13from ConfigSpace.conditions import ( 

14 Condition, 

15 AndConjunction, 

16 OrConjunction, 

17 EqualsCondition, 

18 GreaterThanCondition, 

19 InCondition, 

20 LessThanCondition, 

21 NotEqualsCondition, 

22) 

23 

24from ConfigSpace.forbidden import ( 

25 ForbiddenGreaterThanRelation, 

26 ForbiddenLessThanRelation, 

27 ForbiddenClause, 

28 ForbiddenConjunction, 

29 ForbiddenRelation, 

30 ForbiddenInClause, 

31 ForbiddenEqualsClause, 

32 ForbiddenAndConjunction, 

33) 

34 

35_SENTINEL = object() 

36 

37 

38def expression_to_configspace( 

39 expression: str | ast.Module, 

40 configspace: ConfigurationSpace, 

41 target_parameter: Hyperparameter = None, 

42) -> ForbiddenClause | Condition: 

43 """Convert a logic expression to ConfigSpace expression. 

44 

45 Args: 

46 expression: The expression to convert. 

47 configspace: The ConfigSpace to use. 

48 target_parameter: For conditions, will parse the expression as a condition 

49 underwhich the parameter will be active. 

50 """ 

51 if isinstance(expression, str): 

52 try: 

53 expression = ast.parse(expression) 

54 except Exception as e: 

55 raise ValueError(f"Could not parse expression: '{expression}', {e}") 

56 if isinstance(expression, ast.Module): 

57 expression = expression.body[0] 

58 return recursive_conversion( 

59 expression, configspace, target_parameter=target_parameter 

60 ) 

61 

62 

63def recursive_conversion( 

64 item: ast.mod, 

65 configspace: ConfigurationSpace, 

66 target_parameter: Hyperparameter = None, 

67) -> ForbiddenClause | Condition: 

68 """Recursively parse the AST tree to a ConfigSpace expression. 

69 

70 Args: 

71 item: The item to parse. 

72 configspace: The ConfigSpace to use. 

73 target_parameter: For conditions, will parse the expression as a condition 

74 underwhich the parameter will be active. 

75 

76 Returns: 

77 A ConfigSpace expression 

78 """ 

79 if isinstance(item, list): 

80 if len(item) > 1: 

81 raise ValueError(f"Can not parse list of elements: {item}.") 

82 item = item[0] 

83 if isinstance(item, ast.Expr): 

84 return recursive_conversion(item.value, configspace, target_parameter) 

85 if isinstance(item, ast.Name): # Convert to hyperparameter 

86 hp = configspace.get(item.id) 

87 return hp if hp is not None else item.id 

88 if isinstance(item, ast.Constant): 

89 return item.value 

90 if ( 

91 isinstance(item, ast.Tuple) 

92 or isinstance(item, ast.Set) 

93 or isinstance(item, ast.List) 

94 ): 

95 values = [] 

96 for v in item.elts: 

97 if isinstance(v, ast.Constant): 

98 values.append(v.value) 

99 elif isinstance(v, ast.Name): # Check if its a parameter 

100 if v.id in list(configspace.values()): 

101 raise ValueError( 

102 f"Only constants allowed in tuples. Found: {item.elts}" 

103 ) 

104 values.append(v.id) # String value was interpreted as parameter 

105 return values 

106 if isinstance(item, ast.BinOp): 

107 raise NotImplementedError("Binary operations not supported by ConfigSpace.") 

108 if isinstance(item, ast.BoolOp): 

109 values = [ 

110 recursive_conversion(v, configspace, target_parameter) for v in item.values 

111 ] 

112 if isinstance(item.op, ast.Or): 

113 if target_parameter: 

114 return OrConjunction(*values) 

115 return ForbiddenOrConjunction(*values) 

116 elif isinstance(item.op, ast.And): 

117 if target_parameter: 

118 return AndConjunction(*values) 

119 return ForbiddenAndConjunction(*values) 

120 else: 

121 raise ValueError(f"Unknown boolean operator: {item.op}") 

122 if isinstance(item, ast.Compare): 

123 if len(item.ops) > 1: 

124 raise ValueError(f"Only single comparisons allowed. Found: {item.ops}") 

125 left = recursive_conversion(item.left, configspace, target_parameter) 

126 right = recursive_conversion(item.comparators, configspace, target_parameter) 

127 operator = item.ops[0] 

128 if isinstance(left, Hyperparameter): # Convert to HP type 

129 if isinstance(right, Iterable) and not isinstance(right, str): 

130 right = [type(left.default_value)(v) for v in right] 

131 if len(right) == 1 and not isinstance(operator, ast.In): 

132 right = right[0] 

133 elif isinstance(right, int): 

134 right = type(left.default_value)(right) 

135 

136 if isinstance(operator, ast.Lt): 

137 if target_parameter: 

138 return LessThanCondition(target_parameter, left, right) 

139 return ForbiddenLessThanRelation(left=left, right=right) 

140 if isinstance(operator, ast.LtE): 

141 if target_parameter: 

142 raise ValueError("LessThanEquals not supported for conditions.") 

143 return ForbiddenLessThanEqualsRelation(left=left, right=right) 

144 if isinstance(operator, ast.Gt): 

145 if target_parameter: 

146 return GreaterThanCondition(target_parameter, left, right) 

147 return ForbiddenGreaterThanRelation(left=left, right=right) 

148 if isinstance(operator, ast.GtE): 

149 if target_parameter: 

150 raise ValueError("GreaterThanEquals not supported for conditions.") 

151 return ForbiddenGreaterThanEqualsRelation(left=left, right=right) 

152 if isinstance(operator, ast.Eq): 

153 if target_parameter: 

154 return EqualsCondition(target_parameter, left, right) 

155 return ForbiddenEqualsClause(hyperparameter=left, value=right) 

156 if isinstance(operator, ast.In): 

157 if target_parameter: 

158 return InCondition(target_parameter, left, right) 

159 return ForbiddenInClause(hyperparameter=left, values=right) 

160 if isinstance(operator, ast.NotEq): 

161 if target_parameter: 

162 return NotEqualsCondition(target_parameter, left, right) 

163 raise ValueError("NotEq operator not supported for ForbiddenClauses.") 

164 # The following classes do not (yet?) exist in configspace 

165 if isinstance(operator, ast.NotIn): 

166 raise ValueError("NotIn operator not supported for ForbiddenClauses.") 

167 if isinstance(operator, ast.Is): 

168 raise NotImplementedError("Is operator not supported.") 

169 if isinstance(operator, ast.IsNot): 

170 raise NotImplementedError("IsNot operator not supported.") 

171 raise ValueError(f"Unsupported type: {item}") 

172 

173 

174class ForbiddenLessThanEqualsRelation(ForbiddenLessThanRelation): 

175 """A ForbiddenLessThanEquals relation between two hyperparameters.""" 

176 

177 _RELATION_STR = "LESSEQUAL" 

178 

179 def __repr__(self: ForbiddenLessThanEqualsRelation) -> str: 

180 """Return a string representation of the ForbiddenLessThanEqualsRelation.""" 

181 return f"Forbidden: {self.left.name} <= {self.right.name}" 

182 

183 @override 

184 def is_forbidden_value( 

185 self: ForbiddenLessThanEqualsRelation, values: dict[str, Any] 

186 ) -> bool: 

187 """Check if the value is forbidden.""" 

188 # Relation is always evaluated against actual value and not vector rep 

189 left = values.get(self.left.name, _SENTINEL) 

190 if left is _SENTINEL: 

191 return False 

192 

193 right = values.get(self.right.name, _SENTINEL) 

194 if right is _SENTINEL: 

195 return False 

196 

197 return left <= right # type: ignore 

198 

199 @override 

200 def is_forbidden_vector( 

201 self: ForbiddenLessThanEqualsRelation, vector: Array[f64] 

202 ) -> bool: 

203 """Check if the vector is forbidden.""" 

204 # Relation is always evaluated against actual value and not vector rep 

205 left: f64 = vector[self.vector_ids[0]] # type: ignore 

206 right: f64 = vector[self.vector_ids[1]] # type: ignore 

207 if np.isnan(left) or np.isnan(right): 

208 return False 

209 return self.left.to_value(left) <= self.right.to_value(right) # type: ignore 

210 

211 @override 

212 def is_forbidden_vector_array( 

213 self: ForbiddenLessThanEqualsRelation, arr: Array[f64] 

214 ) -> Mask: 

215 """Check if the vector array is forbidden.""" 

216 left = arr[self.vector_ids[0]] 

217 right = arr[self.vector_ids[1]] 

218 valid = ~(np.isnan(left) | np.isnan(right)) 

219 out = np.zeros_like(valid) 

220 out[valid] = self.left.to_value(left[valid]) <= self.right.to_value(right[valid]) 

221 return out 

222 

223 

224class ForbiddenGreaterThanEqualsRelation(ForbiddenGreaterThanRelation): 

225 """A ForbiddenGreaterThanEquals relation between two hyperparameters.""" 

226 

227 _RELATION_STR = "GREATEREQUAL" 

228 

229 def __repr__(self: ForbiddenGreaterThanEqualsRelation) -> str: 

230 """Return a string representation of the ForbiddenGreaterThanEqualsRelation.""" 

231 return f"Forbidden: {self.left.name} >= {self.right.name}" 

232 

233 @override 

234 def is_forbidden_value( 

235 self: ForbiddenGreaterThanEqualsRelation, values: dict[str, Any] 

236 ) -> bool: 

237 """Check if the value is forbidden.""" 

238 left = values.get(self.left.name, _SENTINEL) 

239 if left is _SENTINEL: 

240 return False 

241 

242 right = values.get(self.right.name, _SENTINEL) 

243 if right is _SENTINEL: 

244 return False 

245 

246 return left >= right # type: ignore 

247 

248 @override 

249 def is_forbidden_vector( 

250 self: ForbiddenGreaterThanEqualsRelation, vector: Array[f64] 

251 ) -> bool: 

252 """Check if the vector is forbidden.""" 

253 # Relation is always evaluated against actual value and not vector rep 

254 left: f64 = vector[self.vector_ids[0]] # type: ignore 

255 right: f64 = vector[self.vector_ids[1]] # type: ignore 

256 if np.isnan(left) or np.isnan(right): 

257 return False 

258 return self.left.to_value(left) >= self.right.to_value(right) # type: ignore 

259 

260 @override 

261 def is_forbidden_vector_array( 

262 self: ForbiddenGreaterThanEqualsRelation, arr: Array[f64] 

263 ) -> Mask: 

264 """Check if the vector array is forbidden.""" 

265 left = arr[self.vector_ids[0]] 

266 right = arr[self.vector_ids[1]] 

267 valid = ~(np.isnan(left) | np.isnan(right)) 

268 out = np.zeros_like(valid) 

269 out[valid] = self.left.to_value(left[valid]) >= self.right.to_value(right[valid]) 

270 return out 

271 

272 

273class ForbiddenGreaterThanClause(ForbiddenEqualsClause): 

274 """A ForbiddenGreaterThanClause. 

275 

276 It forbids a value from the value range of a hyperparameter to be 

277 *greater than* `value`. 

278 

279 Forbids the value of the hyperparameter *a* to be greater than 2 

280 

281 Args: 

282 hyperparameter: Methods on which a restriction will be made 

283 value: forbidden value 

284 """ 

285 

286 def __repr__(self: ForbiddenGreaterThanClause) -> str: 

287 """Return a string representation of the ForbiddenGreaterThanClause.""" 

288 return f"Forbidden: {self.hyperparameter.name} > {self.value!r}" 

289 

290 @override 

291 def is_forbidden_value( 

292 self: ForbiddenGreaterThanClause, values: dict[str, Any] 

293 ) -> bool: 

294 """Check if the value is forbidden.""" 

295 return ( # type: ignore 

296 values.get(self.hyperparameter.name, _SENTINEL) > self.value 

297 ) 

298 

299 @override 

300 def is_forbidden_vector( 

301 self: ForbiddenGreaterThanClause, vector: Array[f64] 

302 ) -> bool: 

303 """Check if the vector is forbidden.""" 

304 return vector[self.vector_id] > self.vector_value # type: ignore 

305 

306 @override 

307 def is_forbidden_vector_array( 

308 self: ForbiddenGreaterThanClause, arr: Array[f64] 

309 ) -> Mask: 

310 """Check if the vector array is forbidden.""" 

311 return np.greater(arr[self.vector_id], self.vector_value, dtype=np.bool_) 

312 

313 @override 

314 def to_dict(self: ForbiddenGreaterThanClause) -> dict[str, Any]: 

315 """Convert the ForbiddenGreaterThanClause to a dictionary.""" 

316 return { 

317 "name": self.hyperparameter.name, 

318 "type": "GREATER", 

319 "value": self.value, 

320 } 

321 

322 

323class ForbiddenGreaterEqualsClause(ForbiddenEqualsClause): 

324 """A ForbiddenGreaterEqualsClause. 

325 

326 It forbids a value from the value range of a hyperparameter to be 

327 *greater or equal to* `value`. 

328 

329 Forbids the value of the hyperparameter *a* to be greater or equal to 2 

330 

331 Args: 

332 hyperparameter: Methods on which a restriction will be made 

333 value: forbidden value 

334 """ 

335 

336 def __repr__(self: ForbiddenGreaterEqualsClause) -> str: 

337 """Return a string representation of the ForbiddenGreaterEqualsClause.""" 

338 return f"Forbidden: {self.hyperparameter.name} >= {self.value!r}" 

339 

340 @override 

341 def is_forbidden_value( 

342 self: ForbiddenGreaterEqualsClause, values: dict[str, Any] 

343 ) -> bool: 

344 """Check if the value is forbidden.""" 

345 return ( # type: ignore 

346 values.get(self.hyperparameter.name, _SENTINEL) >= self.value 

347 ) 

348 

349 @override 

350 def is_forbidden_vector( 

351 self: ForbiddenGreaterEqualsClause, vector: Array[f64] 

352 ) -> bool: 

353 """Check if the vector is forbidden.""" 

354 return vector[self.vector_id] >= self.vector_value # type: ignore 

355 

356 @override 

357 def is_forbidden_vector_array( 

358 self: ForbiddenGreaterEqualsClause, arr: Array[f64] 

359 ) -> Mask: 

360 """Check if the vector array is forbidden.""" 

361 return np.greater_equal(arr[self.vector_id], self.vector_value, dtype=np.bool_) 

362 

363 @override 

364 def to_dict(self: ForbiddenGreaterEqualsClause) -> dict[str, Any]: 

365 """Convert the ForbiddenGreaterEqualsClause to a dictionary.""" 

366 return { 

367 "name": self.hyperparameter.name, 

368 "type": "GREATEREQUAL", 

369 "value": self.value, 

370 } 

371 

372 

373class ForbiddenLessThanClause(ForbiddenEqualsClause): 

374 """A ForbiddenLessThanClause. 

375 

376 It forbids a value from the value range of a hyperparameter to be 

377 *less than* `value`. 

378 

379 Args: 

380 hyperparameter: Methods on which a restriction will be made 

381 value: forbidden value 

382 """ 

383 

384 def __repr__(self: ForbiddenLessThanClause) -> str: 

385 """Return a string representation of the ForbiddenLessThanClause.""" 

386 return f"Forbidden: {self.hyperparameter.name} < {self.value!r}" 

387 

388 @override 

389 def is_forbidden_value( 

390 self: ForbiddenLessThanClause, values: dict[str, Any] 

391 ) -> bool: 

392 """Check if the value is forbidden.""" 

393 return ( # type: ignore 

394 values.get(self.hyperparameter.name, _SENTINEL) < self.value 

395 ) 

396 

397 @override 

398 def is_forbidden_vector(self: ForbiddenLessThanClause, vector: Array[f64]) -> bool: 

399 """Check if the vector is forbidden.""" 

400 return vector[self.vector_id] < self.vector_value # type: ignore 

401 

402 @override 

403 def is_forbidden_vector_array( 

404 self: ForbiddenLessThanClause, arr: Array[f64] 

405 ) -> Mask: 

406 """Check if the vector array is forbidden.""" 

407 return np.less(arr[self.vector_id], self.vector_value, dtype=np.bool_) 

408 

409 @override 

410 def to_dict(self: ForbiddenLessThanClause) -> dict[str, Any]: 

411 """Convert the ForbiddenLessThanClause to a dictionary.""" 

412 return { 

413 "name": self.hyperparameter.name, 

414 "type": "LESS", 

415 "value": self.value, 

416 } 

417 

418 

419class ForbiddenLessEqualsClause(ForbiddenEqualsClause): 

420 """A ForbiddenLessEqualsClause. 

421 

422 It forbids a value from the value range of a hyperparameter to be 

423 *less or equal to* `value`. 

424 

425 Args: 

426 hyperparameter: Methods on which a restriction will be made 

427 value: forbidden value 

428 """ 

429 

430 def __repr__(self: ForbiddenLessEqualsClause) -> str: 

431 """Return a string representation of the ForbiddenLessEqualsClause.""" 

432 return f"Forbidden: {self.hyperparameter.name} <= {self.value!r}" 

433 

434 @override 

435 def is_forbidden_value( 

436 self: ForbiddenLessEqualsClause, values: dict[str, Any] 

437 ) -> bool: 

438 """Check if the value is forbidden.""" 

439 return ( # type: ignore 

440 values.get(self.hyperparameter.name, _SENTINEL) <= self.value 

441 ) 

442 

443 @override 

444 def is_forbidden_vector(self: ForbiddenLessEqualsClause, vector: Array[f64]) -> bool: 

445 """Check if the vector is forbidden.""" 

446 return vector[self.vector_id] <= self.vector_value # type: ignore 

447 

448 @override 

449 def is_forbidden_vector_array( 

450 self: ForbiddenLessEqualsClause, arr: Array[f64] 

451 ) -> Mask: 

452 """Check if the vector array is forbidden.""" 

453 return np.greater_equal(arr[self.vector_id], self.vector_value, dtype=np.bool_) 

454 

455 @override 

456 def to_dict(self: ForbiddenLessEqualsClause) -> dict[str, Any]: 

457 """Convert the ForbiddenLessEqualsClause to a dictionary.""" 

458 return { 

459 "name": self.hyperparameter.name, 

460 "type": "LESSEQUAL", 

461 "value": self.value, 

462 } 

463 

464 

465class ForbiddenOrConjunction(ForbiddenConjunction): 

466 """A ForbiddenOrConjunction. 

467 

468 The ForbiddenOrConjunction combines forbidden-clauses, which allows to 

469 build powerful constraints. 

470 

471 ```python exec="true", source="material-block" result="python" 

472 from ConfigSpace import ( 

473 ConfigurationSpace, 

474 ForbiddenEqualsClause, 

475 ForbiddenInClause, 

476 ) 

477 from sparkle.tools.configspace import ForbiddenOrConjunction 

478 

479 cs = ConfigurationSpace({"a": [1, 2, 3], "b": [2, 5, 6]}) 

480 forbidden_clause_a = ForbiddenEqualsClause(cs["a"], 2) 

481 forbidden_clause_b = ForbiddenInClause(cs["b"], [2]) 

482 

483 forbidden_clause = ForbiddenOrConjunction(forbidden_clause_a, forbidden_clause_b) 

484 

485 cs.add(forbidden_clause) 

486 print(cs) 

487 ``` 

488 

489 Args: 

490 *args: forbidden clauses, which should be combined 

491 """ 

492 

493 components: tuple[ForbiddenClause | ForbiddenConjunction | ForbiddenRelation, ...] 

494 """Components of the conjunction.""" 

495 

496 dlcs: tuple[ForbiddenClause | ForbiddenRelation, ...] 

497 """Descendant literal clauses of the conjunction. 

498 

499 These are the base forbidden clauses/relations that are part of conjunctions. 

500 

501 !!! note 

502 

503 This will only store a unique set of the descendant clauses, no duplicates. 

504 """ 

505 

506 def __repr__(self: ForbiddenOrConjunction) -> str: 

507 """Return a string representation of the ForbiddenOrConjunction.""" 

508 return "(" + " || ".join([str(c) for c in self.components]) + ")" 

509 

510 @override 

511 def is_forbidden_value(self: ForbiddenOrConjunction, values: dict[str, Any]) -> bool: 

512 """Check if the value is forbidden.""" 

513 return any( 

514 [forbidden.is_forbidden_value(values) for forbidden in self.components] 

515 ) 

516 

517 @override 

518 def is_forbidden_vector(self: ForbiddenOrConjunction, vector: Array[f64]) -> bool: 

519 """Check if the vector is forbidden.""" 

520 return any( 

521 forbidden.is_forbidden_vector(vector) for forbidden in self.components 

522 ) 

523 

524 @override 

525 def is_forbidden_vector_array(self: ForbiddenOrConjunction, arr: Array[f64]) -> Mask: 

526 """Check if the vector array is forbidden.""" 

527 forbidden_mask: Mask = np.zeros(shape=arr.shape[1], dtype=np.bool_) 

528 for forbidden in self.components: 

529 forbidden_mask |= forbidden.is_forbidden_vector_array(arr) 

530 

531 return forbidden_mask 

532 

533 @override 

534 def to_dict(self: ForbiddenOrConjunction) -> dict[str, Any]: 

535 """Convert the ForbiddenOrConjunction to a dictionary.""" 

536 return { 

537 "type": "OR", 

538 "clauses": [component.to_dict() for component in self.components], 

539 }