Coverage for src / sparkle / tools / configspace.py: 100%
243 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 15:31 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 15:31 +0000
1"""Extensions of the ConfigSpace lib."""
3from __future__ import annotations
4from typing_extensions import override, Iterable
5import ast
7import numpy as np
8from typing import Any
9from ConfigSpace import ConfigurationSpace
10from ConfigSpace.hyperparameters import Hyperparameter
11from ConfigSpace.types import Array, Mask, f64
13from ConfigSpace.conditions import (
14 Condition,
15 AndConjunction,
16 OrConjunction,
17 EqualsCondition,
18 GreaterThanCondition,
19 InCondition,
20 LessThanCondition,
21 NotEqualsCondition,
22)
24from ConfigSpace.forbidden import (
25 ForbiddenGreaterThanRelation,
26 ForbiddenLessThanRelation,
27 ForbiddenClause,
28 ForbiddenConjunction,
29 ForbiddenRelation,
30 ForbiddenInClause,
31 ForbiddenEqualsClause,
32 ForbiddenAndConjunction,
33)
35_SENTINEL = object()
38def expression_to_configspace(
39 expression: str | ast.Module,
40 configspace: ConfigurationSpace,
41 target_parameter: Hyperparameter = None,
42) -> ForbiddenClause | Condition:
43 """Convert a logic expression to ConfigSpace expression.
45 Args:
46 expression: The expression to convert.
47 configspace: The ConfigSpace to use.
48 target_parameter: For conditions, will parse the expression as a condition
49 underwhich the parameter will be active.
50 """
51 if isinstance(expression, str):
52 try:
53 expression = ast.parse(expression)
54 except Exception as e:
55 raise ValueError(f"Could not parse expression: '{expression}', {e}")
56 if isinstance(expression, ast.Module):
57 expression = expression.body[0]
58 return recursive_conversion(
59 expression, configspace, target_parameter=target_parameter
60 )
63def recursive_conversion(
64 item: ast.mod,
65 configspace: ConfigurationSpace,
66 target_parameter: Hyperparameter = None,
67) -> ForbiddenClause | Condition:
68 """Recursively parse the AST tree to a ConfigSpace expression.
70 Args:
71 item: The item to parse.
72 configspace: The ConfigSpace to use.
73 target_parameter: For conditions, will parse the expression as a condition
74 underwhich the parameter will be active.
76 Returns:
77 A ConfigSpace expression
78 """
79 if isinstance(item, list):
80 if len(item) > 1:
81 raise ValueError(f"Can not parse list of elements: {item}.")
82 item = item[0]
83 if isinstance(item, ast.Expr):
84 return recursive_conversion(item.value, configspace, target_parameter)
85 if isinstance(item, ast.Name): # Convert to hyperparameter
86 hp = configspace.get(item.id)
87 return hp if hp is not None else item.id
88 if isinstance(item, ast.Constant):
89 return item.value
90 if (
91 isinstance(item, ast.Tuple)
92 or isinstance(item, ast.Set)
93 or isinstance(item, ast.List)
94 ):
95 values = []
96 for v in item.elts:
97 if isinstance(v, ast.Constant):
98 values.append(v.value)
99 elif isinstance(v, ast.Name): # Check if its a parameter
100 if v.id in list(configspace.values()):
101 raise ValueError(
102 f"Only constants allowed in tuples. Found: {item.elts}"
103 )
104 values.append(v.id) # String value was interpreted as parameter
105 return values
106 if isinstance(item, ast.BinOp):
107 raise NotImplementedError("Binary operations not supported by ConfigSpace.")
108 if isinstance(item, ast.BoolOp):
109 values = [
110 recursive_conversion(v, configspace, target_parameter) for v in item.values
111 ]
112 if isinstance(item.op, ast.Or):
113 if target_parameter:
114 return OrConjunction(*values)
115 return ForbiddenOrConjunction(*values)
116 elif isinstance(item.op, ast.And):
117 if target_parameter:
118 return AndConjunction(*values)
119 return ForbiddenAndConjunction(*values)
120 else:
121 raise ValueError(f"Unknown boolean operator: {item.op}")
122 if isinstance(item, ast.Compare):
123 if len(item.ops) > 1:
124 raise ValueError(f"Only single comparisons allowed. Found: {item.ops}")
125 left = recursive_conversion(item.left, configspace, target_parameter)
126 right = recursive_conversion(item.comparators, configspace, target_parameter)
127 operator = item.ops[0]
128 if isinstance(left, Hyperparameter): # Convert to HP type
129 # Handle special case for 'in' operator with single value
130 # "text" -> ["text"], 5 -> [5]
131 # We want to ensure that the right side is a list for 'in' operator
132 # So that we can check membership correctly:
133 # Not: hp in "hp" but: hp in ["hp", "hp2", ...]
134 if isinstance(operator, ast.In) and (
135 not isinstance(right, Iterable) or isinstance(right, str)
136 ):
137 right = [right]
138 if isinstance(right, Iterable) and not isinstance(right, str):
139 right = [type(left.default_value)(v) for v in right]
140 if len(right) == 1 and not isinstance(operator, ast.In):
141 right = right[0]
142 elif isinstance(right, int):
143 right = type(left.default_value)(right)
145 if isinstance(operator, ast.Lt):
146 if target_parameter:
147 return LessThanCondition(target_parameter, left, right)
148 return ForbiddenLessThanRelation(left=left, right=right)
149 if isinstance(operator, ast.LtE):
150 if target_parameter:
151 raise ValueError("LessThanEquals not supported for conditions.")
152 return ForbiddenLessThanEqualsRelation(left=left, right=right)
153 if isinstance(operator, ast.Gt):
154 if target_parameter:
155 return GreaterThanCondition(target_parameter, left, right)
156 return ForbiddenGreaterThanRelation(left=left, right=right)
157 if isinstance(operator, ast.GtE):
158 if target_parameter:
159 raise ValueError("GreaterThanEquals not supported for conditions.")
160 return ForbiddenGreaterThanEqualsRelation(left=left, right=right)
161 if isinstance(operator, ast.Eq):
162 if target_parameter:
163 return EqualsCondition(target_parameter, left, right)
164 return ForbiddenEqualsClause(hyperparameter=left, value=right)
165 if isinstance(operator, ast.In):
166 if target_parameter:
167 return InCondition(target_parameter, left, right)
168 return ForbiddenInClause(hyperparameter=left, values=right)
169 if isinstance(operator, ast.NotEq):
170 if target_parameter:
171 return NotEqualsCondition(target_parameter, left, right)
172 raise ValueError("NotEq operator not supported for ForbiddenClauses.")
173 # The following classes do not (yet?) exist in configspace
174 if isinstance(operator, ast.NotIn):
175 raise ValueError("NotIn operator not supported for ForbiddenClauses.")
176 if isinstance(operator, ast.Is):
177 raise NotImplementedError("Is operator not supported.")
178 if isinstance(operator, ast.IsNot):
179 raise NotImplementedError("IsNot operator not supported.")
180 raise ValueError(f"Unsupported type: {item}")
183class ForbiddenLessThanEqualsRelation(ForbiddenLessThanRelation):
184 """A ForbiddenLessThanEquals relation between two hyperparameters."""
186 _RELATION_STR = "LESSEQUAL"
188 def __repr__(self: ForbiddenLessThanEqualsRelation) -> str:
189 """Return a string representation of the ForbiddenLessThanEqualsRelation."""
190 return f"Forbidden: {self.left.name} <= {self.right.name}"
192 @override
193 def is_forbidden_value(
194 self: ForbiddenLessThanEqualsRelation, values: dict[str, Any]
195 ) -> bool:
196 """Check if the value is forbidden."""
197 # Relation is always evaluated against actual value and not vector rep
198 left = values.get(self.left.name, _SENTINEL)
199 if left is _SENTINEL:
200 return False
202 right = values.get(self.right.name, _SENTINEL)
203 if right is _SENTINEL:
204 return False
206 return left <= right # type: ignore
208 @override
209 def is_forbidden_vector(
210 self: ForbiddenLessThanEqualsRelation, vector: Array[f64]
211 ) -> bool:
212 """Check if the vector is forbidden."""
213 # Relation is always evaluated against actual value and not vector rep
214 left: f64 = vector[self.vector_ids[0]] # type: ignore
215 right: f64 = vector[self.vector_ids[1]] # type: ignore
216 if np.isnan(left) or np.isnan(right):
217 return False
218 return self.left.to_value(left) <= self.right.to_value(right) # type: ignore
220 @override
221 def is_forbidden_vector_array(
222 self: ForbiddenLessThanEqualsRelation, arr: Array[f64]
223 ) -> Mask:
224 """Check if the vector array is forbidden."""
225 left = arr[self.vector_ids[0]]
226 right = arr[self.vector_ids[1]]
227 valid = ~(np.isnan(left) | np.isnan(right))
228 out = np.zeros_like(valid)
229 out[valid] = self.left.to_value(left[valid]) <= self.right.to_value(right[valid])
230 return out
233class ForbiddenGreaterThanEqualsRelation(ForbiddenGreaterThanRelation):
234 """A ForbiddenGreaterThanEquals relation between two hyperparameters."""
236 _RELATION_STR = "GREATEREQUAL"
238 def __repr__(self: ForbiddenGreaterThanEqualsRelation) -> str:
239 """Return a string representation of the ForbiddenGreaterThanEqualsRelation."""
240 return f"Forbidden: {self.left.name} >= {self.right.name}"
242 @override
243 def is_forbidden_value(
244 self: ForbiddenGreaterThanEqualsRelation, values: dict[str, Any]
245 ) -> bool:
246 """Check if the value is forbidden."""
247 left = values.get(self.left.name, _SENTINEL)
248 if left is _SENTINEL:
249 return False
251 right = values.get(self.right.name, _SENTINEL)
252 if right is _SENTINEL:
253 return False
255 return left >= right # type: ignore
257 @override
258 def is_forbidden_vector(
259 self: ForbiddenGreaterThanEqualsRelation, vector: Array[f64]
260 ) -> bool:
261 """Check if the vector is forbidden."""
262 # Relation is always evaluated against actual value and not vector rep
263 left: f64 = vector[self.vector_ids[0]] # type: ignore
264 right: f64 = vector[self.vector_ids[1]] # type: ignore
265 if np.isnan(left) or np.isnan(right):
266 return False
267 return self.left.to_value(left) >= self.right.to_value(right) # type: ignore
269 @override
270 def is_forbidden_vector_array(
271 self: ForbiddenGreaterThanEqualsRelation, arr: Array[f64]
272 ) -> Mask:
273 """Check if the vector array is forbidden."""
274 left = arr[self.vector_ids[0]]
275 right = arr[self.vector_ids[1]]
276 valid = ~(np.isnan(left) | np.isnan(right))
277 out = np.zeros_like(valid)
278 out[valid] = self.left.to_value(left[valid]) >= self.right.to_value(right[valid])
279 return out
282class ForbiddenGreaterThanClause(ForbiddenEqualsClause):
283 """A ForbiddenGreaterThanClause.
285 It forbids a value from the value range of a hyperparameter to be
286 *greater than* `value`.
288 Forbids the value of the hyperparameter *a* to be greater than 2
290 Args:
291 hyperparameter: Methods on which a restriction will be made
292 value: forbidden value
293 """
295 def __repr__(self: ForbiddenGreaterThanClause) -> str:
296 """Return a string representation of the ForbiddenGreaterThanClause."""
297 return f"Forbidden: {self.hyperparameter.name} > {self.value!r}"
299 @override
300 def is_forbidden_value(
301 self: ForbiddenGreaterThanClause, values: dict[str, Any]
302 ) -> bool:
303 """Check if the value is forbidden."""
304 return ( # type: ignore
305 values.get(self.hyperparameter.name, _SENTINEL) > self.value
306 )
308 @override
309 def is_forbidden_vector(
310 self: ForbiddenGreaterThanClause, vector: Array[f64]
311 ) -> bool:
312 """Check if the vector is forbidden."""
313 return vector[self.vector_id] > self.vector_value # type: ignore
315 @override
316 def is_forbidden_vector_array(
317 self: ForbiddenGreaterThanClause, arr: Array[f64]
318 ) -> Mask:
319 """Check if the vector array is forbidden."""
320 return np.greater(arr[self.vector_id], self.vector_value, dtype=np.bool_)
322 @override
323 def to_dict(self: ForbiddenGreaterThanClause) -> dict[str, Any]:
324 """Convert the ForbiddenGreaterThanClause to a dictionary."""
325 return {
326 "name": self.hyperparameter.name,
327 "type": "GREATER",
328 "value": self.value,
329 }
332class ForbiddenGreaterEqualsClause(ForbiddenEqualsClause):
333 """A ForbiddenGreaterEqualsClause.
335 It forbids a value from the value range of a hyperparameter to be
336 *greater or equal to* `value`.
338 Forbids the value of the hyperparameter *a* to be greater or equal to 2
340 Args:
341 hyperparameter: Methods on which a restriction will be made
342 value: forbidden value
343 """
345 def __repr__(self: ForbiddenGreaterEqualsClause) -> str:
346 """Return a string representation of the ForbiddenGreaterEqualsClause."""
347 return f"Forbidden: {self.hyperparameter.name} >= {self.value!r}"
349 @override
350 def is_forbidden_value(
351 self: ForbiddenGreaterEqualsClause, values: dict[str, Any]
352 ) -> bool:
353 """Check if the value is forbidden."""
354 return ( # type: ignore
355 values.get(self.hyperparameter.name, _SENTINEL) >= self.value
356 )
358 @override
359 def is_forbidden_vector(
360 self: ForbiddenGreaterEqualsClause, vector: Array[f64]
361 ) -> bool:
362 """Check if the vector is forbidden."""
363 return vector[self.vector_id] >= self.vector_value # type: ignore
365 @override
366 def is_forbidden_vector_array(
367 self: ForbiddenGreaterEqualsClause, arr: Array[f64]
368 ) -> Mask:
369 """Check if the vector array is forbidden."""
370 return np.greater_equal(arr[self.vector_id], self.vector_value, dtype=np.bool_)
372 @override
373 def to_dict(self: ForbiddenGreaterEqualsClause) -> dict[str, Any]:
374 """Convert the ForbiddenGreaterEqualsClause to a dictionary."""
375 return {
376 "name": self.hyperparameter.name,
377 "type": "GREATEREQUAL",
378 "value": self.value,
379 }
382class ForbiddenLessThanClause(ForbiddenEqualsClause):
383 """A ForbiddenLessThanClause.
385 It forbids a value from the value range of a hyperparameter to be
386 *less than* `value`.
388 Args:
389 hyperparameter: Methods on which a restriction will be made
390 value: forbidden value
391 """
393 def __repr__(self: ForbiddenLessThanClause) -> str:
394 """Return a string representation of the ForbiddenLessThanClause."""
395 return f"Forbidden: {self.hyperparameter.name} < {self.value!r}"
397 @override
398 def is_forbidden_value(
399 self: ForbiddenLessThanClause, values: dict[str, Any]
400 ) -> bool:
401 """Check if the value is forbidden."""
402 return ( # type: ignore
403 values.get(self.hyperparameter.name, _SENTINEL) < self.value
404 )
406 @override
407 def is_forbidden_vector(self: ForbiddenLessThanClause, vector: Array[f64]) -> bool:
408 """Check if the vector is forbidden."""
409 return vector[self.vector_id] < self.vector_value # type: ignore
411 @override
412 def is_forbidden_vector_array(
413 self: ForbiddenLessThanClause, arr: Array[f64]
414 ) -> Mask:
415 """Check if the vector array is forbidden."""
416 return np.less(arr[self.vector_id], self.vector_value, dtype=np.bool_)
418 @override
419 def to_dict(self: ForbiddenLessThanClause) -> dict[str, Any]:
420 """Convert the ForbiddenLessThanClause to a dictionary."""
421 return {
422 "name": self.hyperparameter.name,
423 "type": "LESS",
424 "value": self.value,
425 }
428class ForbiddenLessEqualsClause(ForbiddenEqualsClause):
429 """A ForbiddenLessEqualsClause.
431 It forbids a value from the value range of a hyperparameter to be
432 *less or equal to* `value`.
434 Args:
435 hyperparameter: Methods on which a restriction will be made
436 value: forbidden value
437 """
439 def __repr__(self: ForbiddenLessEqualsClause) -> str:
440 """Return a string representation of the ForbiddenLessEqualsClause."""
441 return f"Forbidden: {self.hyperparameter.name} <= {self.value!r}"
443 @override
444 def is_forbidden_value(
445 self: ForbiddenLessEqualsClause, values: dict[str, Any]
446 ) -> bool:
447 """Check if the value is forbidden."""
448 return ( # type: ignore
449 values.get(self.hyperparameter.name, _SENTINEL) <= self.value
450 )
452 @override
453 def is_forbidden_vector(self: ForbiddenLessEqualsClause, vector: Array[f64]) -> bool:
454 """Check if the vector is forbidden."""
455 return vector[self.vector_id] <= self.vector_value # type: ignore
457 @override
458 def is_forbidden_vector_array(
459 self: ForbiddenLessEqualsClause, arr: Array[f64]
460 ) -> Mask:
461 """Check if the vector array is forbidden."""
462 return np.greater_equal(arr[self.vector_id], self.vector_value, dtype=np.bool_)
464 @override
465 def to_dict(self: ForbiddenLessEqualsClause) -> dict[str, Any]:
466 """Convert the ForbiddenLessEqualsClause to a dictionary."""
467 return {
468 "name": self.hyperparameter.name,
469 "type": "LESSEQUAL",
470 "value": self.value,
471 }
474class ForbiddenOrConjunction(ForbiddenConjunction):
475 """A ForbiddenOrConjunction.
477 The ForbiddenOrConjunction combines forbidden-clauses, which allows to
478 build powerful constraints.
480 ```python exec="true", source="material-block" result="python"
481 from ConfigSpace import (
482 ConfigurationSpace,
483 ForbiddenEqualsClause,
484 ForbiddenInClause,
485 )
486 from sparkle.tools.configspace import ForbiddenOrConjunction
488 cs = ConfigurationSpace({"a": [1, 2, 3], "b": [2, 5, 6]})
489 forbidden_clause_a = ForbiddenEqualsClause(cs["a"], 2)
490 forbidden_clause_b = ForbiddenInClause(cs["b"], [2])
492 forbidden_clause = ForbiddenOrConjunction(forbidden_clause_a, forbidden_clause_b)
494 cs.add(forbidden_clause)
495 print(cs)
496 ```
498 Args:
499 *args: forbidden clauses, which should be combined
500 """
502 components: tuple[ForbiddenClause | ForbiddenConjunction | ForbiddenRelation, ...]
503 """Components of the conjunction."""
505 dlcs: tuple[ForbiddenClause | ForbiddenRelation, ...]
506 """Descendant literal clauses of the conjunction.
508 These are the base forbidden clauses/relations that are part of conjunctions.
510 !!! note
512 This will only store a unique set of the descendant clauses, no duplicates.
513 """
515 def __repr__(self: ForbiddenOrConjunction) -> str:
516 """Return a string representation of the ForbiddenOrConjunction."""
517 return "(" + " || ".join([str(c) for c in self.components]) + ")"
519 @override
520 def is_forbidden_value(self: ForbiddenOrConjunction, values: dict[str, Any]) -> bool:
521 """Check if the value is forbidden."""
522 return any(
523 [forbidden.is_forbidden_value(values) for forbidden in self.components]
524 )
526 @override
527 def is_forbidden_vector(self: ForbiddenOrConjunction, vector: Array[f64]) -> bool:
528 """Check if the vector is forbidden."""
529 return any(
530 forbidden.is_forbidden_vector(vector) for forbidden in self.components
531 )
533 @override
534 def is_forbidden_vector_array(self: ForbiddenOrConjunction, arr: Array[f64]) -> Mask:
535 """Check if the vector array is forbidden."""
536 forbidden_mask: Mask = np.zeros(shape=arr.shape[1], dtype=np.bool_)
537 for forbidden in self.components:
538 forbidden_mask |= forbidden.is_forbidden_vector_array(arr)
540 return forbidden_mask
542 @override
543 def to_dict(self: ForbiddenOrConjunction) -> dict[str, Any]:
544 """Convert the ForbiddenOrConjunction to a dictionary."""
545 return {
546 "type": "OR",
547 "clauses": [component.to_dict() for component in self.components],
548 }