Source code for scwidgets.code._widget_code_input

import ast
import copy
import inspect
import re
import sys
import textwrap
import traceback
import types
import warnings
from functools import wraps
from typing import Any, List, Optional, Tuple, Union

from widget_code_input import WidgetCodeInput
from widget_code_input.utils import (
    CodeValidationError,
    format_syntax_error_msg,
    is_valid_variable_name,
)

from ..check import Check


[docs] class CodeInput(WidgetCodeInput): """ Small wrapper around WidgetCodeInput that controls the output :param function: A Python function to be parse automatically. Note that the parsing may alter the original formatting or lose certain syntactical nuances. If this behavior is undesired, provide the function explicitly using other parameters. :param function_name: The name of the function :param function_paramaters: The parameters as a continuous string as specified in the signature of the function. e.g for `foo(x, y = 5)` it should be `"x, y = 5"` :param docstring: The docstring of the function :param function_body: The function definition without indentation :param builtins: A dictionary containing variable names and values that are added to the globals __builtins__ and thus available on initialization """ valid_code_themes = ["nord", "solarizedLight", "basicLight"] def __init__( self, function: Optional[types.FunctionType] = None, function_name: Optional[str] = None, function_parameters: Optional[str] = None, docstring: Optional[str] = None, function_body: Optional[str] = None, builtins: Optional[dict[str, Any]] = None, code_theme: str = "basicLight", ): if function is not None: function_name = ( function.__name__ if function_name is None else function_name ) function_parameters = ( self.get_function_parameters(function) if function_parameters is None else function_parameters ) docstring = self.get_docstring(function) if docstring is None else docstring function_body = ( self.get_function_body(function) if function_body is None else function_body ) # default parameters from WidgetCodeInput if function_name is None: raise ValueError("function_name must be given if no function is given.") function_parameters = "" if function_parameters is None else function_parameters function_body = "" if function_body is None else function_body self._builtins = {} if builtins is None else builtins super().__init__( function_name, function_parameters, docstring, function_body, code_theme ) # this list is retrieved from # https://github.com/osscar-org/widget-code-input/blob/eb10ca0baee65dd3bf62c9ec5d9cb2f152932ff5/js/widget.js#L249-L253 if code_theme not in CodeInput.valid_code_themes: raise ValueError( f"Given code_theme {code_theme!r} invalid. Please use one of " f"the values {CodeInput.valid_code_themes}" ) @property def unwrapped_function(self) -> types.FunctionType: """ Returns the compiled function object. This can be assigned to a variable and then called, for instance: func = widget.wrapped_function # This can raise a SyntaxError retval = func(parameters) :raise SyntaxError: if the function code has syntax errors (or if the function name is not a valid identifier) """ # we shallow copy the builtins to be able to overwrite it # if self.builtins changes globals_dict = { "__builtins__": copy.copy(globals()["__builtins__"]), "__name__": "__main__", "__doc__": None, "__package__": None, } globals_dict["__builtins__"].update(self._builtins) if not is_valid_variable_name(self.function_name): raise SyntaxError("Invalid function name '{}'".format(self.function_name)) # Optionally one could do a ast.parse here already, to check syntax # before execution try: exec( compile(self.full_function_code, __name__, "exec", dont_inherit=True), globals_dict, ) except SyntaxError as exc: raise CodeValidationError( format_syntax_error_msg(exc), orig_exc=exc ) from exc return globals_dict[self.function_name] def __call__(self, *args, **kwargs) -> Check.FunOutParamsT: """Calls the wrapped function""" return self.function(*args, **kwargs)
[docs] def compatible_with_signature(self, parameters: List[str]) -> str: """ This function checks if the arguments are compatible with the function signature and returns an explanatory message if this is not the case. """ if "**" in self.function_parameters: # function has keyword arguments so it is compatible return "" for parameter_name in inspect.signature(self.function).parameters.keys(): if not (parameter_name in parameters): return ( f"The input parameter {parameter_name} is not compatible with " "the function code." ) return ""
@property def function_parameters_name(self) -> List[str]: """ Returns the names of the function parameters """ return self.function_parameters.replace(",", "").split(" ")
[docs] @staticmethod def get_docstring(function: types.FunctionType) -> Union[str, None]: """ Returns the docstring of a function, if it exists, without leading or trailing whitespace or triple quotes. """ docstring = function.__doc__ return ( None if docstring is None else textwrap.dedent(docstring).strip('"""') # noqa: B005 )
@staticmethod def _get_function_source_and_def( function: types.FunctionType, ) -> Tuple[str, ast.FunctionDef]: function_source = inspect.getsource(function) function_source = textwrap.dedent(function_source) module = ast.parse(function_source) if len(module.body) != 1: raise ValueError( f"Expected code with one function definition but found {module.body}" ) function_definition = module.body[0] if not isinstance(function_definition, ast.FunctionDef): raise ValueError( f"While parsing code found {module.body[0]}" " but only ast.FunctionDef is supported." ) return function_source, function_definition
[docs] @staticmethod def get_function_parameters(function: types.FunctionType) -> str: """ Returns the parameters of a function as a continuous string, e.g for `foo(x, y = 5)` it would return `"x, y = 5"` """ function_parameters = [] function_source, function_definition = CodeInput._get_function_source_and_def( function ) idx_start_defaults = len(function_definition.args.args) - len( function_definition.args.defaults ) for i, arg in enumerate(function_definition.args.args): function_parameter = ast.get_source_segment(function_source, arg) # Following PEP 8 in formatting if arg.annotation: annotation = function_parameter = ast.get_source_segment( function_source, arg.annotation ) function_parameter = f"{arg.arg}: {annotation}" else: function_parameter = f"{arg.arg}" if i >= idx_start_defaults: default_val = ast.get_source_segment( function_source, function_definition.args.defaults[i - idx_start_defaults], ) # Following PEP 8 in formatting if arg.annotation: function_parameter = f"{function_parameter} = {default_val}" else: function_parameter = f"{function_parameter}={default_val}" function_parameters.append(function_parameter) if function_definition.args.kwarg is not None: function_parameters.append(f"**{function_definition.args.kwarg.arg}") return ", ".join(function_parameters)
[docs] @staticmethod def get_function_body(function: types.FunctionType) -> str: """ Extracts the body of the given function, removing the signature, docstrings, and adjusting indentation appropriately. """ source_lines, _ = inspect.getsourcelines(function) found_def = False def_index = 0 for i, line in enumerate(source_lines): if "def" in line: found_def = True def_index = i break if not (found_def): raise ValueError( "Did not find any def definition. Only functions with a " "definition are supported" ) # Remove function definition line = re.sub(r"^\s*def\s+[^\(]*\(.*\)(.*?):\n?", "", line) source_lines[def_index] = line # Remove any potential wrappers source_lines = source_lines[i:] source = "".join(source_lines) # Remove docstrings source = re.sub( r"((.*?)\'\'\'(.*?)\'\'\'.*?[;\n]|(.*?)\"\"\"(.*?)\"\"\"(.*?)[;\n])", "", source, flags=re.DOTALL, ) # Adjust indentation lines = source.split("\n") if lines: leading_indent = len(lines[0]) - len(lines[0].lstrip()) source = "\n".join( line[leading_indent:] if line.strip() else "" for line in lines ) return source.strip()
@property def function(self) -> types.FunctionType: """ Returns the compiled function object wrapped by an try-catch block raising a `CodeValidationError`. This can be assigned to a variable and then called, for instance: func = widget.function # This can raise a CodeValidationError retval = func(parameters) :raise CodeValidationError: if the function code has syntax errors (or if the function name is not a valid identifier) """ def catch_exceptions(func): @wraps(func) def wrapper(*args, **kwargs): """Wrap and check exceptions to return a longer and clearer exception.""" try: return func(*args, **kwargs) except Exception as exc: err_msg = format_generic_error_msg(exc, code_widget=self) raise CodeValidationError(err_msg, orig_exc=exc) from exc return wrapper return catch_exceptions(self.unwrapped_function) @property def builtins(self) -> dict[str, Any]: return self._builtins @builtins.setter def builtins(self, value: dict[str, Any]): self._builtins = value
# Temporary fix until https://github.com/osscar-org/widget-code-input/pull/26 # is merged def format_generic_error_msg(exc, code_widget): """ Returns a string reproducing the traceback of a typical error. This includes line numbers, as well as neighboring lines. It will require also the code_widget instance, to get the actual source code. :note: this must be called from within the exception, as it will get the current traceback state. :param exc: The exception that is being processed. :param code_widget: the instance of the code widget with the code that raised the exception. """ error_class, _, tb = sys.exc_info() frame_summaries = traceback.extract_tb(tb) # The correct frame summary corresponding to widget_code_intput is not # always at the end therefore we loop through all of them wci_frame_summary = None for frame_summary in frame_summaries: if frame_summary.filename == "widget_code_input": wci_frame_summary = frame_summary if wci_frame_summary is None: warnings.warn( "Could not find traceback frame corresponding to " "widget_code_input, we output whole error message.", stacklevel=2, ) return exc line_number = wci_frame_summary[1] code_lines = code_widget.full_function_code.splitlines() err_msg = f"{error_class.__name__} in code input: {str(exc)}\n" if line_number > 2: err_msg += f" {line_number - 2:4d} {code_lines[line_number - 3]}\n" if line_number > 1: err_msg += f" {line_number - 1:4d} {code_lines[line_number - 2]}\n" err_msg += f"---> {line_number:4d} {code_lines[line_number - 1]}\n" if line_number < len(code_lines): err_msg += f" {line_number + 1:4d} {code_lines[line_number]}\n" if line_number < len(code_lines) - 1: err_msg += f" {line_number + 2:4d} {code_lines[line_number + 1]}\n" return err_msg