Coverage for sparkle/tools/configspace.py: 61%
241 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-29 10:17 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-29 10:17 +0000
1"""Extensions of the ConfigSpace lib."""
3from __future__ import annotations
4from typing_extensions import override, Iterable
5import ast
7import numpy as np
8from typing import Any
9from ConfigSpace import ConfigurationSpace
10from ConfigSpace.hyperparameters import Hyperparameter
11from ConfigSpace.types import Array, Mask, f64
13from ConfigSpace.conditions import (
14 Condition,
15 AndConjunction,
16 OrConjunction,
17 EqualsCondition,
18 GreaterThanCondition,
19 InCondition,
20 LessThanCondition,
21 NotEqualsCondition,
22)
24from ConfigSpace.forbidden import (
25 ForbiddenGreaterThanRelation,
26 ForbiddenLessThanRelation,
27 ForbiddenClause,
28 ForbiddenConjunction,
29 ForbiddenRelation,
30 ForbiddenInClause,
31 ForbiddenEqualsClause,
32 ForbiddenAndConjunction,
33)
35_SENTINEL = object()
38def expression_to_configspace(
39 expression: str | ast.Module,
40 configspace: ConfigurationSpace,
41 target_parameter: Hyperparameter = None,
42) -> ForbiddenClause | Condition:
43 """Convert a logic expression to ConfigSpace expression.
45 Args:
46 expression: The expression to convert.
47 configspace: The ConfigSpace to use.
48 target_parameter: For conditions, will parse the expression as a condition
49 underwhich the parameter will be active.
50 """
51 if isinstance(expression, str):
52 try:
53 expression = ast.parse(expression)
54 except Exception as e:
55 raise ValueError(f"Could not parse expression: '{expression}', {e}")
56 if isinstance(expression, ast.Module):
57 expression = expression.body[0]
58 return recursive_conversion(
59 expression, configspace, target_parameter=target_parameter
60 )
63def recursive_conversion(
64 item: ast.mod,
65 configspace: ConfigurationSpace,
66 target_parameter: Hyperparameter = None,
67) -> ForbiddenClause | Condition:
68 """Recursively parse the AST tree to a ConfigSpace expression.
70 Args:
71 item: The item to parse.
72 configspace: The ConfigSpace to use.
73 target_parameter: For conditions, will parse the expression as a condition
74 underwhich the parameter will be active.
76 Returns:
77 A ConfigSpace expression
78 """
79 if isinstance(item, list):
80 if len(item) > 1:
81 raise ValueError(f"Can not parse list of elements: {item}.")
82 item = item[0]
83 if isinstance(item, ast.Expr):
84 return recursive_conversion(item.value, configspace, target_parameter)
85 if isinstance(item, ast.Name): # Convert to hyperparameter
86 hp = configspace.get(item.id)
87 return hp if hp is not None else item.id
88 if isinstance(item, ast.Constant):
89 return item.value
90 if (
91 isinstance(item, ast.Tuple)
92 or isinstance(item, ast.Set)
93 or isinstance(item, ast.List)
94 ):
95 values = []
96 for v in item.elts:
97 if isinstance(v, ast.Constant):
98 values.append(v.value)
99 elif isinstance(v, ast.Name): # Check if its a parameter
100 if v.id in list(configspace.values()):
101 raise ValueError(
102 f"Only constants allowed in tuples. Found: {item.elts}"
103 )
104 values.append(v.id) # String value was interpreted as parameter
105 return values
106 if isinstance(item, ast.BinOp):
107 raise NotImplementedError("Binary operations not supported by ConfigSpace.")
108 if isinstance(item, ast.BoolOp):
109 values = [
110 recursive_conversion(v, configspace, target_parameter) for v in item.values
111 ]
112 if isinstance(item.op, ast.Or):
113 if target_parameter:
114 return OrConjunction(*values)
115 return ForbiddenOrConjunction(*values)
116 elif isinstance(item.op, ast.And):
117 if target_parameter:
118 return AndConjunction(*values)
119 return ForbiddenAndConjunction(*values)
120 else:
121 raise ValueError(f"Unknown boolean operator: {item.op}")
122 if isinstance(item, ast.Compare):
123 if len(item.ops) > 1:
124 raise ValueError(f"Only single comparisons allowed. Found: {item.ops}")
125 left = recursive_conversion(item.left, configspace, target_parameter)
126 right = recursive_conversion(item.comparators, configspace, target_parameter)
127 operator = item.ops[0]
128 if isinstance(left, Hyperparameter): # Convert to HP type
129 if isinstance(right, Iterable) and not isinstance(right, str):
130 right = [type(left.default_value)(v) for v in right]
131 if len(right) == 1 and not isinstance(operator, ast.In):
132 right = right[0]
133 elif isinstance(right, int):
134 right = type(left.default_value)(right)
136 if isinstance(operator, ast.Lt):
137 if target_parameter:
138 return LessThanCondition(target_parameter, left, right)
139 return ForbiddenLessThanRelation(left=left, right=right)
140 if isinstance(operator, ast.LtE):
141 if target_parameter:
142 raise ValueError("LessThanEquals not supported for conditions.")
143 return ForbiddenLessThanEqualsRelation(left=left, right=right)
144 if isinstance(operator, ast.Gt):
145 if target_parameter:
146 return GreaterThanCondition(target_parameter, left, right)
147 return ForbiddenGreaterThanRelation(left=left, right=right)
148 if isinstance(operator, ast.GtE):
149 if target_parameter:
150 raise ValueError("GreaterThanEquals not supported for conditions.")
151 return ForbiddenGreaterThanEqualsRelation(left=left, right=right)
152 if isinstance(operator, ast.Eq):
153 if target_parameter:
154 return EqualsCondition(target_parameter, left, right)
155 return ForbiddenEqualsClause(hyperparameter=left, value=right)
156 if isinstance(operator, ast.In):
157 if target_parameter:
158 return InCondition(target_parameter, left, right)
159 return ForbiddenInClause(hyperparameter=left, values=right)
160 if isinstance(operator, ast.NotEq):
161 if target_parameter:
162 return NotEqualsCondition(target_parameter, left, right)
163 raise ValueError("NotEq operator not supported for ForbiddenClauses.")
164 # The following classes do not (yet?) exist in configspace
165 if isinstance(operator, ast.NotIn):
166 raise ValueError("NotIn operator not supported for ForbiddenClauses.")
167 if isinstance(operator, ast.Is):
168 raise NotImplementedError("Is operator not supported.")
169 if isinstance(operator, ast.IsNot):
170 raise NotImplementedError("IsNot operator not supported.")
171 raise ValueError(f"Unsupported type: {item}")
174class ForbiddenLessThanEqualsRelation(ForbiddenLessThanRelation):
175 """A ForbiddenLessThanEquals relation between two hyperparameters."""
177 _RELATION_STR = "LESSEQUAL"
179 def __repr__(self: ForbiddenLessThanEqualsRelation) -> str:
180 """Return a string representation of the ForbiddenLessThanEqualsRelation."""
181 return f"Forbidden: {self.left.name} <= {self.right.name}"
183 @override
184 def is_forbidden_value(
185 self: ForbiddenLessThanEqualsRelation, values: dict[str, Any]
186 ) -> bool:
187 """Check if the value is forbidden."""
188 # Relation is always evaluated against actual value and not vector rep
189 left = values.get(self.left.name, _SENTINEL)
190 if left is _SENTINEL:
191 return False
193 right = values.get(self.right.name, _SENTINEL)
194 if right is _SENTINEL:
195 return False
197 return left <= right # type: ignore
199 @override
200 def is_forbidden_vector(
201 self: ForbiddenLessThanEqualsRelation, vector: Array[f64]
202 ) -> bool:
203 """Check if the vector is forbidden."""
204 # Relation is always evaluated against actual value and not vector rep
205 left: f64 = vector[self.vector_ids[0]] # type: ignore
206 right: f64 = vector[self.vector_ids[1]] # type: ignore
207 if np.isnan(left) or np.isnan(right):
208 return False
209 return self.left.to_value(left) <= self.right.to_value(right) # type: ignore
211 @override
212 def is_forbidden_vector_array(
213 self: ForbiddenLessThanEqualsRelation, arr: Array[f64]
214 ) -> Mask:
215 """Check if the vector array is forbidden."""
216 left = arr[self.vector_ids[0]]
217 right = arr[self.vector_ids[1]]
218 valid = ~(np.isnan(left) | np.isnan(right))
219 out = np.zeros_like(valid)
220 out[valid] = self.left.to_value(left[valid]) <= self.right.to_value(right[valid])
221 return out
224class ForbiddenGreaterThanEqualsRelation(ForbiddenGreaterThanRelation):
225 """A ForbiddenGreaterThanEquals relation between two hyperparameters."""
227 _RELATION_STR = "GREATEREQUAL"
229 def __repr__(self: ForbiddenGreaterThanEqualsRelation) -> str:
230 """Return a string representation of the ForbiddenGreaterThanEqualsRelation."""
231 return f"Forbidden: {self.left.name} >= {self.right.name}"
233 @override
234 def is_forbidden_value(
235 self: ForbiddenGreaterThanEqualsRelation, values: dict[str, Any]
236 ) -> bool:
237 """Check if the value is forbidden."""
238 left = values.get(self.left.name, _SENTINEL)
239 if left is _SENTINEL:
240 return False
242 right = values.get(self.right.name, _SENTINEL)
243 if right is _SENTINEL:
244 return False
246 return left >= right # type: ignore
248 @override
249 def is_forbidden_vector(
250 self: ForbiddenGreaterThanEqualsRelation, vector: Array[f64]
251 ) -> bool:
252 """Check if the vector is forbidden."""
253 # Relation is always evaluated against actual value and not vector rep
254 left: f64 = vector[self.vector_ids[0]] # type: ignore
255 right: f64 = vector[self.vector_ids[1]] # type: ignore
256 if np.isnan(left) or np.isnan(right):
257 return False
258 return self.left.to_value(left) >= self.right.to_value(right) # type: ignore
260 @override
261 def is_forbidden_vector_array(
262 self: ForbiddenGreaterThanEqualsRelation, arr: Array[f64]
263 ) -> Mask:
264 """Check if the vector array is forbidden."""
265 left = arr[self.vector_ids[0]]
266 right = arr[self.vector_ids[1]]
267 valid = ~(np.isnan(left) | np.isnan(right))
268 out = np.zeros_like(valid)
269 out[valid] = self.left.to_value(left[valid]) >= self.right.to_value(right[valid])
270 return out
273class ForbiddenGreaterThanClause(ForbiddenEqualsClause):
274 """A ForbiddenGreaterThanClause.
276 It forbids a value from the value range of a hyperparameter to be
277 *greater than* `value`.
279 Forbids the value of the hyperparameter *a* to be greater than 2
281 Args:
282 hyperparameter: Methods on which a restriction will be made
283 value: forbidden value
284 """
286 def __repr__(self: ForbiddenGreaterThanClause) -> str:
287 """Return a string representation of the ForbiddenGreaterThanClause."""
288 return f"Forbidden: {self.hyperparameter.name} > {self.value!r}"
290 @override
291 def is_forbidden_value(
292 self: ForbiddenGreaterThanClause, values: dict[str, Any]
293 ) -> bool:
294 """Check if the value is forbidden."""
295 return ( # type: ignore
296 values.get(self.hyperparameter.name, _SENTINEL) > self.value
297 )
299 @override
300 def is_forbidden_vector(
301 self: ForbiddenGreaterThanClause, vector: Array[f64]
302 ) -> bool:
303 """Check if the vector is forbidden."""
304 return vector[self.vector_id] > self.vector_value # type: ignore
306 @override
307 def is_forbidden_vector_array(
308 self: ForbiddenGreaterThanClause, arr: Array[f64]
309 ) -> Mask:
310 """Check if the vector array is forbidden."""
311 return np.greater(arr[self.vector_id], self.vector_value, dtype=np.bool_)
313 @override
314 def to_dict(self: ForbiddenGreaterThanClause) -> dict[str, Any]:
315 """Convert the ForbiddenGreaterThanClause to a dictionary."""
316 return {
317 "name": self.hyperparameter.name,
318 "type": "GREATER",
319 "value": self.value,
320 }
323class ForbiddenGreaterEqualsClause(ForbiddenEqualsClause):
324 """A ForbiddenGreaterEqualsClause.
326 It forbids a value from the value range of a hyperparameter to be
327 *greater or equal to* `value`.
329 Forbids the value of the hyperparameter *a* to be greater or equal to 2
331 Args:
332 hyperparameter: Methods on which a restriction will be made
333 value: forbidden value
334 """
336 def __repr__(self: ForbiddenGreaterEqualsClause) -> str:
337 """Return a string representation of the ForbiddenGreaterEqualsClause."""
338 return f"Forbidden: {self.hyperparameter.name} >= {self.value!r}"
340 @override
341 def is_forbidden_value(
342 self: ForbiddenGreaterEqualsClause, values: dict[str, Any]
343 ) -> bool:
344 """Check if the value is forbidden."""
345 return ( # type: ignore
346 values.get(self.hyperparameter.name, _SENTINEL) >= self.value
347 )
349 @override
350 def is_forbidden_vector(
351 self: ForbiddenGreaterEqualsClause, vector: Array[f64]
352 ) -> bool:
353 """Check if the vector is forbidden."""
354 return vector[self.vector_id] >= self.vector_value # type: ignore
356 @override
357 def is_forbidden_vector_array(
358 self: ForbiddenGreaterEqualsClause, arr: Array[f64]
359 ) -> Mask:
360 """Check if the vector array is forbidden."""
361 return np.greater_equal(arr[self.vector_id], self.vector_value, dtype=np.bool_)
363 @override
364 def to_dict(self: ForbiddenGreaterEqualsClause) -> dict[str, Any]:
365 """Convert the ForbiddenGreaterEqualsClause to a dictionary."""
366 return {
367 "name": self.hyperparameter.name,
368 "type": "GREATEREQUAL",
369 "value": self.value,
370 }
373class ForbiddenLessThanClause(ForbiddenEqualsClause):
374 """A ForbiddenLessThanClause.
376 It forbids a value from the value range of a hyperparameter to be
377 *less than* `value`.
379 Args:
380 hyperparameter: Methods on which a restriction will be made
381 value: forbidden value
382 """
384 def __repr__(self: ForbiddenLessThanClause) -> str:
385 """Return a string representation of the ForbiddenLessThanClause."""
386 return f"Forbidden: {self.hyperparameter.name} < {self.value!r}"
388 @override
389 def is_forbidden_value(
390 self: ForbiddenLessThanClause, values: dict[str, Any]
391 ) -> bool:
392 """Check if the value is forbidden."""
393 return ( # type: ignore
394 values.get(self.hyperparameter.name, _SENTINEL) < self.value
395 )
397 @override
398 def is_forbidden_vector(self: ForbiddenLessThanClause, vector: Array[f64]) -> bool:
399 """Check if the vector is forbidden."""
400 return vector[self.vector_id] < self.vector_value # type: ignore
402 @override
403 def is_forbidden_vector_array(
404 self: ForbiddenLessThanClause, arr: Array[f64]
405 ) -> Mask:
406 """Check if the vector array is forbidden."""
407 return np.less(arr[self.vector_id], self.vector_value, dtype=np.bool_)
409 @override
410 def to_dict(self: ForbiddenLessThanClause) -> dict[str, Any]:
411 """Convert the ForbiddenLessThanClause to a dictionary."""
412 return {
413 "name": self.hyperparameter.name,
414 "type": "LESS",
415 "value": self.value,
416 }
419class ForbiddenLessEqualsClause(ForbiddenEqualsClause):
420 """A ForbiddenLessEqualsClause.
422 It forbids a value from the value range of a hyperparameter to be
423 *less or equal to* `value`.
425 Args:
426 hyperparameter: Methods on which a restriction will be made
427 value: forbidden value
428 """
430 def __repr__(self: ForbiddenLessEqualsClause) -> str:
431 """Return a string representation of the ForbiddenLessEqualsClause."""
432 return f"Forbidden: {self.hyperparameter.name} <= {self.value!r}"
434 @override
435 def is_forbidden_value(
436 self: ForbiddenLessEqualsClause, values: dict[str, Any]
437 ) -> bool:
438 """Check if the value is forbidden."""
439 return ( # type: ignore
440 values.get(self.hyperparameter.name, _SENTINEL) <= self.value
441 )
443 @override
444 def is_forbidden_vector(self: ForbiddenLessEqualsClause, vector: Array[f64]) -> bool:
445 """Check if the vector is forbidden."""
446 return vector[self.vector_id] <= self.vector_value # type: ignore
448 @override
449 def is_forbidden_vector_array(
450 self: ForbiddenLessEqualsClause, arr: Array[f64]
451 ) -> Mask:
452 """Check if the vector array is forbidden."""
453 return np.greater_equal(arr[self.vector_id], self.vector_value, dtype=np.bool_)
455 @override
456 def to_dict(self: ForbiddenLessEqualsClause) -> dict[str, Any]:
457 """Convert the ForbiddenLessEqualsClause to a dictionary."""
458 return {
459 "name": self.hyperparameter.name,
460 "type": "LESSEQUAL",
461 "value": self.value,
462 }
465class ForbiddenOrConjunction(ForbiddenConjunction):
466 """A ForbiddenOrConjunction.
468 The ForbiddenOrConjunction combines forbidden-clauses, which allows to
469 build powerful constraints.
471 ```python exec="true", source="material-block" result="python"
472 from ConfigSpace import (
473 ConfigurationSpace,
474 ForbiddenEqualsClause,
475 ForbiddenInClause,
476 )
477 from sparkle.tools.configspace import ForbiddenOrConjunction
479 cs = ConfigurationSpace({"a": [1, 2, 3], "b": [2, 5, 6]})
480 forbidden_clause_a = ForbiddenEqualsClause(cs["a"], 2)
481 forbidden_clause_b = ForbiddenInClause(cs["b"], [2])
483 forbidden_clause = ForbiddenOrConjunction(forbidden_clause_a, forbidden_clause_b)
485 cs.add(forbidden_clause)
486 print(cs)
487 ```
489 Args:
490 *args: forbidden clauses, which should be combined
491 """
493 components: tuple[ForbiddenClause | ForbiddenConjunction | ForbiddenRelation, ...]
494 """Components of the conjunction."""
496 dlcs: tuple[ForbiddenClause | ForbiddenRelation, ...]
497 """Descendant literal clauses of the conjunction.
499 These are the base forbidden clauses/relations that are part of conjunctions.
501 !!! note
503 This will only store a unique set of the descendant clauses, no duplicates.
504 """
506 def __repr__(self: ForbiddenOrConjunction) -> str:
507 """Return a string representation of the ForbiddenOrConjunction."""
508 return "(" + " || ".join([str(c) for c in self.components]) + ")"
510 @override
511 def is_forbidden_value(self: ForbiddenOrConjunction, values: dict[str, Any]) -> bool:
512 """Check if the value is forbidden."""
513 return any(
514 [forbidden.is_forbidden_value(values) for forbidden in self.components]
515 )
517 @override
518 def is_forbidden_vector(self: ForbiddenOrConjunction, vector: Array[f64]) -> bool:
519 """Check if the vector is forbidden."""
520 return any(
521 forbidden.is_forbidden_vector(vector) for forbidden in self.components
522 )
524 @override
525 def is_forbidden_vector_array(self: ForbiddenOrConjunction, arr: Array[f64]) -> Mask:
526 """Check if the vector array is forbidden."""
527 forbidden_mask: Mask = np.zeros(shape=arr.shape[1], dtype=np.bool_)
528 for forbidden in self.components:
529 forbidden_mask |= forbidden.is_forbidden_vector_array(arr)
531 return forbidden_mask
533 @override
534 def to_dict(self: ForbiddenOrConjunction) -> dict[str, Any]:
535 """Convert the ForbiddenOrConjunction to a dictionary."""
536 return {
537 "type": "OR",
538 "clauses": [component.to_dict() for component in self.components],
539 }