Source code for scwidgets.check._asserts

import functools
from collections import abc
from typing import Iterable, Union

import numpy as np

from ._check import AssertResult, Check

AssertFunctionOutputT = Union[str, AssertResult]


[docs] def assert_equal( output_parameters: Check.FunOutParamsT, output_references: Check.FunOutParamsT, parameters_to_check: Union[Iterable[int], str] = "all", ) -> AssertResult: """ Check if output_parameters are equal to output_references using simple Python equality check. """ assert len(output_parameters) == len( output_references ), "output_parameters and output_references have to have the same length" parameter_indices: Iterable[int] if isinstance(parameters_to_check, str): if parameters_to_check == "all": parameter_indices = range(len(output_parameters)) else: raise ValueError( f'Got parameters_to_check="{parameters_to_check}" but only "all" ' "is accepted as string" ) elif isinstance(parameters_to_check, abc.Iterable): parameter_indices = parameters_to_check # type: ignore[assignment] else: raise TypeError( "Only str and Iterable are accepted for parameters_to_check, " f"but got type {type(parameters_to_check)}." ) failed_parameter_indices = [] failed_parameter_values = [] messages = [] for i in parameter_indices: if not output_parameters[i] == output_references[i]: message = ( f"Expected {output_references[i]} " f"but got {output_parameters[i]}." ) failed_parameter_indices.append(i) failed_parameter_values.append(output_parameters[i]) messages.append(message) return AssertResult( assert_name="assert_equal", parameter_indices=failed_parameter_indices, parameter_values=failed_parameter_values, messages=messages, )
[docs] def assert_shape( output_parameters: Check.FunOutParamsT, output_references: Check.FunOutParamsT, parameters_to_check: Union[Iterable[int], str] = "auto", ) -> AssertResult: """ Check that the shape of output parameters matches the reference. """ assert len(output_parameters) == len( output_references ), "output_parameters and output_references have to have the same length" parameter_indices: Iterable[int] if isinstance(parameters_to_check, str): if parameters_to_check == "auto": parameter_indices = [] for i in range(len(output_references)): if hasattr(output_references[i], "shape"): parameter_indices.append(i) elif parameters_to_check == "all": parameter_indices = range(len(output_parameters)) else: raise ValueError( f'Got parameters_to_check="{parameters_to_check}" but only "all" ' ' and "auto" are accepted as string' ) elif isinstance(parameters_to_check, abc.Iterable): parameter_indices = parameters_to_check # type: ignore[assignment] else: raise TypeError( "Only str and Iterable are accepted for parameters_to_check, " f"but got type {type(parameters_to_check)}." ) failed_parameter_indices = [] failed_parameter_values = [] messages = [] for i in parameter_indices: if output_parameters[i].shape != output_references[i].shape: message = ( f"Expected shape {output_references[i].shape} " f"but got {output_parameters[i].shape}." ) failed_parameter_indices.append(i) failed_parameter_values.append(output_parameters[i]) messages.append(message) return AssertResult( assert_name="assert_shape", parameter_indices=failed_parameter_indices, parameter_values=failed_parameter_values, messages=messages, )
[docs] def assert_numpy_allclose( output_parameters: Check.FunOutParamsT, output_references: Check.FunOutParamsT, parameters_to_check: Union[Iterable[int], str] = "auto", rtol=1e-05, atol=1e-08, equal_nan=False, ) -> AssertResult: """ Check if output_parameters are numerically close to output_references using numpy.allclose(). """ assert len(output_parameters) == len( output_references ), "output_parameters and output_references have to have the same length" parameter_indices: Iterable[int] if isinstance(parameters_to_check, str): if parameters_to_check == "auto": parameter_indices = [] for i in range(len(output_references)): try: np.allclose(output_references[i], output_references[i]) parameter_indices.append(i) except Exception: pass elif parameters_to_check == "all": parameter_indices = range(len(output_parameters)) else: raise ValueError( f'Got parameters_to_check="{parameters_to_check}" but only "all" ' ' and "auto" are accepted as string' ) elif isinstance(parameters_to_check, abc.Iterable): parameter_indices = parameters_to_check # type: ignore[assignment] else: raise TypeError( "Only str and Iterable are accepted for parameters_to_check, " f"but got type {type(parameters_to_check)}." ) failed_parameter_indices = [] failed_parameter_values = [] messages = [] for i in parameter_indices: is_allclose = np.allclose( output_parameters[i], output_references[i], atol=atol, rtol=rtol, equal_nan=equal_nan, ) if not (is_allclose): output_parameters_i_arr = np.asarray(output_parameters[i]) output_references_i_arr = np.asarray(output_references[i]) diff = np.abs(output_parameters_i_arr - output_references_i_arr) abs_diff = np.sum(diff) rel_diff_dividend = np.max( np.vstack( ( np.abs(output_parameters_i_arr), np.abs(output_references_i_arr), ) ), axis=0, ) # when both are zero the diff is also zero, so we set it to 1 # so no division by zero error is raised rel_diff_dividend[rel_diff_dividend == 0.0] = 1.0 rel_diff = np.sum(diff / rel_diff_dividend) message = ( f"Output is not close to reference absolute difference " f"is {abs_diff}, relative difference is {rel_diff}." ) failed_parameter_indices.append(i) failed_parameter_values.append(output_parameters[i]) messages.append(message) return AssertResult( assert_name="assert_numpy_allclose", parameter_indices=failed_parameter_indices, parameter_values=failed_parameter_values, messages=messages, )
[docs] def assert_type( output_parameters: Check.FunOutParamsT, output_references: Check.FunOutParamsT, parameters_to_check: Union[Iterable[int], str] = "all", ) -> AssertResult: """ Check that output parameters have the correct type. """ assert len(output_parameters) == len( output_references ), "output_parameters and output_references have to have the same length" parameter_indices: Iterable[int] if isinstance(parameters_to_check, str): if parameters_to_check == "all": parameter_indices = range(len(output_parameters)) else: raise ValueError( f'Got parameters_to_check="{parameters_to_check}" but only "all" ' "is accepted as string" ) elif isinstance(parameters_to_check, abc.Iterable): parameter_indices = parameters_to_check # type: ignore[assignment] else: raise TypeError( "Only str and Iterable are accepted for parameters_to_check, " f"but got type {type(parameters_to_check)}." ) failed_parameter_indices = [] failed_parameter_values = [] messages = [] for i in parameter_indices: if not (isinstance(output_parameters[i], type(output_references[i]))): message = ( f"Expected type {type(output_references[i])} " f"but got {type(output_parameters[i])}." ) failed_parameter_indices.append(i) failed_parameter_values.append(output_parameters[i]) messages.append(message) return AssertResult( assert_name="assert_type", parameter_indices=failed_parameter_indices, parameter_values=failed_parameter_values, messages=messages, )
[docs] def assert_numpy_sub_dtype( output_parameters: Union[Check.FunOutParamsT, tuple[Check.FingerprintT]], numpy_type: Union[np.dtype, type], parameters_to_check: Union[Iterable[int], str] = "all", ) -> AssertResult: """ Check that output parameters have the correct numpy sub-dtype. """ if parameters_to_check == "all": parameter_indices = range(len(output_parameters)) elif isinstance(parameters_to_check, abc.Iterable): parameter_indices = parameters_to_check # type: ignore[assignment] else: raise TypeError( "Only str and Iterable are accepted for parameters_to_check, " f"but got type {type(parameters_to_check)}." ) failed_parameter_indices = [] failed_parameter_values = [] messages = [] for i in parameter_indices: if not (isinstance(output_parameters[i], np.ndarray)): failed_parameter_indices.append(i) failed_parameter_values.append(output_parameters[i]) message = ( f"Output expected to be numpy array " f"but got {type(output_parameters[i])}." ) messages.append(message) if not (np.issubdtype(output_parameters[i].dtype, numpy_type)): if isinstance(numpy_type, np.dtype): type_name = numpy_type.type.__name__ else: type_name = numpy_type.__name__ failed_parameter_indices.append(i) failed_parameter_values.append(output_parameters[i]) message = ( f"Output expected to be sub dtype " f"numpy.{type_name} but got " f"numpy.{output_parameters[i].dtype.type.__name__}." ) messages.append(message) if isinstance(numpy_type, np.dtype): type_name = numpy_type.type.__name__ else: type_name = numpy_type.__name__ return AssertResult( assert_name=f"assert_numpy_{type_name}_sub_dtype", parameter_indices=failed_parameter_indices, parameter_values=failed_parameter_values, messages=messages, )
assert_numpy_floating_sub_dtype = functools.partial( assert_numpy_sub_dtype, numpy_type=np.floating )