Source code for rsopt.configuration.parameters

from numpy import ndarray
import numpy as np
from pykern.pkcollections import PKDict

_EXTERNAL_PARAMETER_CATEGORIES = ('min', 'max', 'start')
_OPTIONAL_PARAMETER_CATEGORIES = ('samples', )


def _validate_parameter(name, min, max, start):
    assert min < max, f"Parameter {name} invalid: min > max"
    assert min <= start <= max, f"Parameter {name} invalid: start is not between [min,max]"


[docs]def read_parameter_array(obj): """ Read an array of N parameters with rows organized by either (name, min, max start) or (name, min, max, start, samples) :param input: :return: """ for i, row in enumerate(obj): if len(row) == 5: yield row[0], row.tolist()[1:] elif len(row) == 4: yield row[0], row.tolist()[1:] + (None,) else: raise IndexError("Input parameters are not length 4 or 5")
[docs]def read_parameter_dict(obj): for name, values in obj.items(): output = [] for key in _EXTERNAL_PARAMETER_CATEGORIES: output.append(values[key]) for key in _OPTIONAL_PARAMETER_CATEGORIES: output.append(values.get(key, None)) yield name, output
PARAMETER_READERS = { ndarray: read_parameter_array, dict: read_parameter_dict, PKDict: read_parameter_dict }
[docs]class Parameters: def __init__(self): self.parameters = {} self._NAMES = [] self._LOWER_BOUND = 'lb' self._UPPER_BOUND = 'ub' self._START = 'start' self._SAMPLES = 'samples' self.fields = (self._LOWER_BOUND, self._UPPER_BOUND, self._START, self._SAMPLES)
[docs] def parse(self, name, values): if name in self._NAMES: raise KeyError(f'Parameter {name} is defined multiple times') _validate_parameter(name, *values[:3]) self._NAMES.append(name) self.parameters[name] = {} for field, value in zip(self.fields, values): self.parameters[name][field] = value
[docs] def get_parameter_names(self): return self._NAMES
[docs] def get_lower_bound(self): return np.array([self.parameters[name][self._LOWER_BOUND] for name in self._NAMES])
[docs] def get_upper_bound(self): return np.array([self.parameters[name][self._UPPER_BOUND] for name in self._NAMES])
[docs] def get_start(self): return np.array([self.parameters[name][self._START] for name in self._NAMES])
[docs] def get_samples(self): samples = [self.parameters[name][self._SAMPLES] for name in self._NAMES] # Because samples is not required there are no prior validations vals = set(samples) if len(vals) == 1 and None in vals: # samples were not set for any parameters return samples elif None in vals: # samples were set for some parameters and not others assert ValueError("Not all parameters had samples field set") else: # samples were properly set for all parameters return samples