# postpones evaluation of annotations
# see https://stackoverflow.com/a/33533514
from __future__ import annotations
import functools
import inspect
import re
import sys
import types
from copy import deepcopy
from platform import python_version
from types import TracebackType
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
import IPython.core.ultratb
from .._utils import Formatter
ExecutionInfo = Tuple[
Union[None, type], # BaseException type
Union[None, BaseException],
Union[None, TracebackType],
]
[docs]
class Check:
"""
A check verifies the correctness of a function for a set of inputs parameters using
a list of univariate and bivariate asserts with the option to obscure the reference
outputs.
:param function_to_check:
The function must that accepts each input parameters in :params
input_parameters:
:param inputs_parameters:
A dict or a list of dictionaries each containing the argument name and its value
as (key, value) pair that is used as input for the function
:param function_to_check:
:param outputs_references:
A list or a list of lists each containing the expected output of the function
:param function_to_check: of :param function_to_check: for the inputs in the
:param input_parameters:
:param asserts:
A list of assert functions. An assert function can the output parameters of
:param function_to_check: to run assert. If output references has been set it
can take additional output references to compare with. If a fingerprint is given
then the fingerprints are compared while assert functions with a single argument
are always applied on the output parameters.
:param fingerprint:
A one-way function that takes as input the output parameters of function :param
function_to_check: and obscures the :param output_references:.
:param suppress_fingerprint_asserts:
Specifies if the assert messages that use the fingerprint function output for
tests are surpressed. The message might be confusing to a student as the output
is converted by the fingerprint function.
:param stop_on_assert_error_raised:
Specifies if running the asserts is stopped as soon as an error is raised in an
assert. If a lot of asserts are specified the printing of a lot of error
tracebacks might make debugging harder.
"""
FunInParamT = TypeVar("FunInParamT", bound=Any)
FunOutParamsT = Tuple[Any, ...]
FingerprintT = TypeVar("FingerprintT", bound=Any)
AssertFunT = Union[
Callable[[FunOutParamsT, FunOutParamsT], str],
Callable[[FingerprintT, FingerprintT], str],
Callable[[FunOutParamsT], str],
Callable[[], str],
]
def __init__(
self,
function_to_check: Callable[..., FunOutParamsT],
asserts: Union[List[AssertFunT], AssertFunT],
inputs_parameters: Optional[
Union[List[Dict[str, FunInParamT]], Dict[str, FunInParamT]]
] = None,
outputs_references: Optional[Union[List[tuple], tuple]] = None,
fingerprint: Optional[
Callable[[Check.FunOutParamsT], Check.FingerprintT]
] = None,
suppress_fingerprint_asserts: bool = True,
stop_on_assert_error_raised: bool = True,
):
self._function_to_check = function_to_check
self._asserts = []
self._nullvariate_asserts: List[Callable[[], str]] = []
self._univariate_asserts: List[Callable[[tuple], str]] = []
self._bivariate_asserts = []
if not (isinstance(asserts, list)):
asserts = [asserts]
for i, assert_f in enumerate(asserts):
nb_positional_arguments = len(
[
parameters
for parameters in inspect.signature(assert_f).parameters.values()
if parameters.default is inspect._empty
]
)
self._asserts.append(assert_f)
if nb_positional_arguments == 0:
self._nullvariate_asserts.append(assert_f) # type: ignore[arg-type]
elif nb_positional_arguments == 1:
if inputs_parameters is None:
raise ValueError(
"For functions taking two input arguments we need "
"inputs_parameters."
)
# type checker cannot infer type change
self._univariate_asserts.append(assert_f) # type: ignore[arg-type]
elif nb_positional_arguments == 2:
if inputs_parameters is None or outputs_references is None:
raise ValueError(
"For functions taking two input arguments we need "
"inputs_parameters and outputs_references."
)
self._bivariate_asserts.append(assert_f)
else:
raise ValueError(
f"Only assert function with 1 or 2 positional arguments are allowed"
f"but assert function {i} has {nb_positional_arguments} positional"
f"arguments"
)
# We cannot verify if the number of input argumets match because they can be
# hidden in **kwargs
if isinstance(inputs_parameters, dict):
inputs_parameters = [inputs_parameters]
if inputs_parameters is not None and outputs_references is not None:
if isinstance(outputs_references, tuple):
outputs_references = [outputs_references]
assert len(inputs_parameters) == len(outputs_references), (
"Number of inputs_parameters and outputs_references are mismatching: "
"len inputs parameters != len outputs parameters "
f"[{len(inputs_parameters)} != {len(outputs_references)}]."
)
self._inputs_parameters = [] if inputs_parameters is None else inputs_parameters
self._outputs_references = (
[] if outputs_references is None else outputs_references
)
self._fingerprint = fingerprint
self._suppress_fingerprint_asserts = suppress_fingerprint_asserts
self._stop_on_assert_error_raised = stop_on_assert_error_raised
@property
def function_to_check(self) -> Callable[..., FunOutParamsT]:
return self._function_to_check
@function_to_check.setter
def function_to_check(self, function_to_check: Callable[..., FunOutParamsT]):
self._function_to_check = function_to_check
@property
def fingerprint(self):
return deepcopy(self._fingerprint)
@property
def asserts(self):
return deepcopy(self._asserts)
@property
def nullvariate_asserts(self):
return deepcopy(self._nullvariate_asserts)
@property
def univariate_asserts(self):
return deepcopy(self._univariate_asserts)
@property
def bivariate_asserts(self):
return deepcopy(self._bivariate_asserts)
@property
def inputs_parameters(self):
return deepcopy(self._inputs_parameters)
@property
def outputs_references(self):
return deepcopy(self._outputs_references)
@property
def nb_conducted_asserts(self):
return len(self._asserts) * len(self._inputs_parameters) + len(
self._nullvariate_asserts
)
def compute_outputs(self):
outputs = []
for input_parameters in self._inputs_parameters:
output = self._function_to_check(**input_parameters)
if not (isinstance(output, tuple)):
output = (output,)
if self._fingerprint is not None:
output = self._fingerprint(*output)
if not (isinstance(output, tuple)):
output = (output,)
outputs.append(output)
return outputs
def compute_and_set_references(self):
self._outputs_references = self.compute_outputs()
[docs]
def check_function(self) -> CheckResult:
"""
Returns for each input (first depth list) the result message for each assert
(second depth list). If a result message is empty, the assert was successful,
otherwise it contains information about the failure.
"""
if len(self._bivariate_asserts) > 0:
if self._outputs_references is None:
raise ValueError(
"outputs_references are None but asserts exist that require "
"outputs_references (second positional argument)"
)
assert len(self._inputs_parameters) == len(self._outputs_references), (
"Number of inputs and reference outputs "
"are mismatching: len inputs parameters != len outputs parameters "
f"[{len(self._inputs_parameters)} != {len(self._outputs_references)}]."
)
check_result = CheckResult()
for assert_f in self._nullvariate_asserts:
try:
assert_result = assert_f()
check_result.append(assert_result, assert_f, {})
except Exception:
excution_info = sys.exc_info()
check_result.append(excution_info, assert_f, {})
if self._stop_on_assert_error_raised:
return check_result
for i, input_parameters in enumerate(self._inputs_parameters):
output = self._function_to_check(**input_parameters)
if not (isinstance(output, tuple)):
output = (output,)
for uni_assert_f in self._univariate_asserts:
try:
assert_result = uni_assert_f(output)
check_result.append(assert_result, uni_assert_f, input_parameters)
except Exception:
excution_info = sys.exc_info()
check_result.append(excution_info, uni_assert_f, input_parameters)
if self._stop_on_assert_error_raised:
return check_result
if self._fingerprint is not None:
try:
output = self._fingerprint(*output)
except ( # we do not raise here since it is passed to widget output # noqa B040
Exception
) as exception:
if python_version() >= "3.11":
exception.add_note(
"An error was raised in fingerprint function, "
" most likely because your output type is wrong."
)
excution_info = sys.exc_info()
check_result.append(
excution_info, self._fingerprint, input_parameters
)
return check_result
if not (isinstance(output, tuple)):
output = (output,)
for assert_f in self._bivariate_asserts: # type: ignore[assignment]
assert len(output) == len(
self._outputs_references[i] # type: ignore[index]
), (
"Number of output parameters and reference output parameters "
"are mismatching: "
"len output parameters != len outputs references "
f"[{len(output)} != "
f"{len(self._outputs_references[i])}]." # type: ignore[index]
)
try:
assert_result = assert_f(
output,
self._outputs_references[i], # type: ignore[index, call-arg]
)
except Exception:
excution_info = sys.exc_info()
check_result.append(excution_info, assert_f, input_parameters)
if self._stop_on_assert_error_raised:
return check_result
check_result.append(
assert_result,
assert_f,
input_parameters,
self._suppress_fingerprint_asserts
and self._fingerprint is not None,
)
return check_result
class CheckResult:
def __init__(self):
self._assert_results = []
self._assert_names = []
self._inputs_parameters = []
self._suppress_assert_messages = []
def append(
self,
assert_result: Union[str, AssertResult, ExecutionInfo],
assert_f: Optional[Check.AssertFunT] = None,
input_parameters: Optional[dict] = None,
suppress_assert_message: Optional[bool] = False,
):
self._assert_results.append(assert_result)
if isinstance(assert_result, AssertResult):
self._assert_names.append(assert_result.assert_name)
else:
self._assert_names.append(self._get_name_from_assert(assert_f))
self._inputs_parameters.append(input_parameters)
self._suppress_assert_messages.append(suppress_assert_message)
@property
def successful(self):
return (
len(
[
result
for result in self._assert_results
if (isinstance(result, str) and result != "")
or (isinstance(result, AssertResult) and not (result.successful))
]
)
== 0
)
def message(self) -> str:
messages = []
for i, result in enumerate(self._assert_results):
if (isinstance(result, str) and result == "") or (
isinstance(result, AssertResult) and result.successful
):
message = Formatter.color_assert_success(
f"{self._assert_names[i]} passed",
)
if len(self._inputs_parameters[i]) > 0:
message += Formatter.color_assert_success(" for input\n")
input_parameters_message = "\n".join(
[
f" {Formatter.color_assert_info(param_name)}: "
f"{param_value!r}"
for param_name, param_value in self._inputs_parameters[
i
].items()
]
)
if input_parameters_message != "":
message += input_parameters_message
else:
message = Formatter.color_assert_failed(
f"{self._assert_names[i]} failed",
)
if len(self._inputs_parameters[i]) > 0 or not (
self._suppress_assert_messages[i]
):
message += Formatter.color_assert_failed(" for\n")
input_parameters_message = "\n".join(
[
f" {Formatter.color_assert_info(param_name)}: "
f"{param_value!r}"
for param_name, param_value in self._inputs_parameters[
i
].items()
]
)
assert_message = ""
if input_parameters_message != "":
assert_message += input_parameters_message
assert_result = ""
if isinstance(result, tuple) and len(result) == 3:
# Execution info
tb = IPython.core.ultratb.VerboseTB()
assert_result = tb.text(*result)
elif not (self._suppress_assert_messages[i]):
if hasattr(result, "message"):
assert_result = f"{result.message()}"
else:
assert_result = f"{Formatter.color_assert_failed(result)}"
if assert_result != "":
assert_message += "\n" + assert_result
if assert_message != "":
# adds "| " to the beginning of each line
assert_message = re.sub(
r"(^)",
r"\1" + f"{Formatter.color_assert_failed('|')} ",
assert_message,
flags=re.M,
)
message += f"{assert_message}"
messages.append(message)
return "\n".join(messages)
def _get_name_from_assert(self, assert_f: Any) -> str:
if isinstance(assert_f, types.FunctionType):
return assert_f.__name__
elif isinstance(assert_f, functools.partial):
return assert_f.func.__name__
else:
return str(assert_f)
@property
def assert_results(self):
return deepcopy(self._assert_results)
@property
def assert_names(self):
return deepcopy(self._assert_names)
@property
def inputs_parameters(self):
return deepcopy(self._inputs_parameters)
[docs]
class AssertResult:
"""
:param assert_name:
TODO
:param parameter_indices:
TODO
TODO...
"""
def __init__(
self,
assert_name: str,
parameter_indices: Union[int, List[int]],
parameter_values: Union[Any, List[Any]],
messages: Union[str, List[str]],
):
self._assert_name = assert_name
# we do not include parameter_values in the check because it can be a list
# by type definition
if isinstance(parameter_indices, list) or isinstance(messages, list):
if (
not (isinstance(parameter_indices, list))
or not (isinstance(parameter_values, list))
or not (isinstance(messages, list))
):
raise ValueError(
"If one of the inputs parameter_indices, paramater_values or "
"messages is a list, then all must be lists of the same size."
)
elif len(parameter_indices) != len(parameter_values) or len(
parameter_indices
) != len(messages):
raise ValueError(
"If one of the inputs parameter_indices, paramater_values or "
"messages is a list, then all must be lists of the same size, "
"but got len(parameter_indices), len(parameter_values), "
f"len(messages) [{len(parameter_indices)}, "
f"{len(parameter_values)}, {len(messages)}]"
)
if not (isinstance(parameter_indices, list)):
parameter_indices = [parameter_indices]
self._parameter_indices = parameter_indices
if not (isinstance(parameter_values, list)):
parameter_values = [parameter_values]
self._parameter_values = parameter_values
if not (isinstance(messages, list)):
messages = [messages]
self._messages = messages
def message(self) -> str:
message = ""
for i in range(len(self._parameter_indices)):
message += (
Formatter.color_assert_info(f"> output {self._parameter_indices[i]}: ")
+ f"{self._parameter_values[i]}\n"
+ Formatter.color_assert_failed(self._messages[i])
)
return message
@property
def assert_name(self) -> str:
return self._assert_name
@property
def successful(self):
return len(self._parameter_indices) == 0