Spaces:
Sleeping
Sleeping
| """ Handy utility functions. """ | |
| from __future__ import annotations | |
| import asyncio | |
| import copy | |
| import functools | |
| import importlib | |
| import inspect | |
| import json | |
| import json.decoder | |
| import os | |
| import pkgutil | |
| import pprint | |
| import random | |
| import re | |
| import threading | |
| import time | |
| import traceback | |
| import typing | |
| import warnings | |
| from abc import ABC, abstractmethod | |
| from contextlib import contextmanager | |
| from io import BytesIO | |
| from numbers import Number | |
| from pathlib import Path | |
| from types import GeneratorType | |
| from typing import ( | |
| TYPE_CHECKING, | |
| Any, | |
| Callable, | |
| Iterator, | |
| Optional, | |
| TypeVar, | |
| ) | |
| import anyio | |
| import matplotlib | |
| import requests | |
| from gradio_client.serializing import Serializable | |
| from typing_extensions import ParamSpec | |
| import gradio | |
| from gradio.context import Context | |
| from gradio.strings import en | |
| if TYPE_CHECKING: # Only import for type checking (is False at runtime). | |
| from gradio.blocks import Block, BlockContext, Blocks | |
| from gradio.components import Component | |
| from gradio.routes import App | |
| JSON_PATH = os.path.join(os.path.dirname(gradio.__file__), "launches.json") | |
| GRADIO_VERSION = ( | |
| (pkgutil.get_data(__name__, "version.txt") or b"").decode("ascii").strip() | |
| ) | |
| P = ParamSpec("P") | |
| T = TypeVar("T") | |
| def safe_get_lock() -> asyncio.Lock: | |
| """Get asyncio.Lock() without fear of getting an Exception. | |
| Needed because in reload mode we import the Blocks object outside | |
| the main thread. | |
| """ | |
| try: | |
| asyncio.get_event_loop() | |
| return asyncio.Lock() | |
| except RuntimeError: | |
| return None # type: ignore | |
| class BaseReloader(ABC): | |
| def running_app(self) -> App: | |
| pass | |
| def queue_changed(self, demo: Blocks): | |
| return ( | |
| hasattr(self.running_app.blocks, "_queue") and not hasattr(demo, "_queue") | |
| ) or ( | |
| not hasattr(self.running_app.blocks, "_queue") and hasattr(demo, "_queue") | |
| ) | |
| def swap_blocks(self, demo: Blocks): | |
| assert self.running_app.blocks | |
| # Copy over the blocks to get new components and events but | |
| # not a new queue | |
| if hasattr(self.running_app.blocks, "_queue"): | |
| self.running_app.blocks._queue.blocks_dependencies = demo.dependencies | |
| demo._queue = self.running_app.blocks._queue | |
| self.running_app.blocks = demo | |
| class SourceFileReloader(BaseReloader): | |
| def __init__( | |
| self, | |
| app: App, | |
| watch_dirs: list[str], | |
| watch_file: str, | |
| stop_event: threading.Event, | |
| change_event: threading.Event, | |
| demo_name: str = "demo", | |
| ) -> None: | |
| super().__init__() | |
| self.app = app | |
| self.watch_dirs = watch_dirs | |
| self.watch_file = watch_file | |
| self.stop_event = stop_event | |
| self.change_event = change_event | |
| self.demo_name = demo_name | |
| def running_app(self) -> App: | |
| return self.app | |
| def should_watch(self) -> bool: | |
| return not self.stop_event.is_set() | |
| def stop(self) -> None: | |
| self.stop_event.set() | |
| def alert_change(self): | |
| self.change_event.set() | |
| def swap_blocks(self, demo: Blocks): | |
| super().swap_blocks(demo) | |
| self.alert_change() | |
| def watchfn(reloader: SourceFileReloader): | |
| """Watch python files in a given module. | |
| get_changes is taken from uvicorn's default file watcher. | |
| """ | |
| # The thread running watchfn will be the thread reloading | |
| # the app. So we need to modify this thread_data attr here | |
| # so that subsequent calls to reload don't launch the app | |
| from gradio.reload import reload_thread | |
| reload_thread.running_reload = True | |
| def get_changes() -> Path | None: | |
| for file in iter_py_files(): | |
| try: | |
| mtime = file.stat().st_mtime | |
| except OSError: # pragma: nocover | |
| continue | |
| old_time = mtimes.get(file) | |
| if old_time is None: | |
| mtimes[file] = mtime | |
| continue | |
| elif mtime > old_time: | |
| return file | |
| return None | |
| def iter_py_files() -> Iterator[Path]: | |
| for reload_dir in reload_dirs: | |
| for path in list(reload_dir.rglob("*.py")): | |
| yield path.resolve() | |
| module = None | |
| reload_dirs = [Path(dir_) for dir_ in reloader.watch_dirs] | |
| mtimes = {} | |
| while reloader.should_watch(): | |
| import sys | |
| changed = get_changes() | |
| if changed: | |
| print(f"Changes detected in: {changed}") | |
| # To simulate a fresh reload, delete all module references from sys.modules | |
| # for the modules in the package the change came from. | |
| dir_ = next(d for d in reload_dirs if is_in_or_equal(changed, d)) | |
| modules = list(sys.modules) | |
| for k in modules: | |
| v = sys.modules[k] | |
| sourcefile = getattr(v, "__file__", None) | |
| # Do not reload `reload.py` to keep thread data | |
| if ( | |
| sourcefile | |
| and dir_ == Path(inspect.getfile(gradio)).parent | |
| and sourcefile.endswith("reload.py") | |
| ): | |
| continue | |
| if sourcefile and is_in_or_equal(sourcefile, dir_): | |
| del sys.modules[k] | |
| try: | |
| module = importlib.import_module(reloader.watch_file) | |
| module = importlib.reload(module) | |
| except Exception as e: | |
| print( | |
| f"Reloading {reloader.watch_file} failed with the following exception: " | |
| ) | |
| traceback.print_exception(None, value=e, tb=None) | |
| mtimes = {} | |
| continue | |
| demo = getattr(module, reloader.demo_name) | |
| if reloader.queue_changed(demo): | |
| print( | |
| "Reloading failed. The new demo has a queue and the old one doesn't (or vice versa). " | |
| "Please launch your demo again" | |
| ) | |
| else: | |
| reloader.swap_blocks(demo) | |
| mtimes = {} | |
| def colab_check() -> bool: | |
| """ | |
| Check if interface is launching from Google Colab | |
| :return is_colab (bool): True or False | |
| """ | |
| is_colab = False | |
| try: # Check if running interactively using ipython. | |
| from IPython.core.getipython import get_ipython | |
| from_ipynb = get_ipython() | |
| if "google.colab" in str(from_ipynb): | |
| is_colab = True | |
| except (ImportError, NameError): | |
| pass | |
| return is_colab | |
| def kaggle_check() -> bool: | |
| return bool( | |
| os.environ.get("KAGGLE_KERNEL_RUN_TYPE") or os.environ.get("GFOOTBALL_DATA_DIR") | |
| ) | |
| def sagemaker_check() -> bool: | |
| try: | |
| import boto3 # type: ignore | |
| client = boto3.client("sts") | |
| response = client.get_caller_identity() | |
| return "sagemaker" in response["Arn"].lower() | |
| except Exception: | |
| return False | |
| def ipython_check() -> bool: | |
| """ | |
| Check if interface is launching from iPython (not colab) | |
| :return is_ipython (bool): True or False | |
| """ | |
| is_ipython = False | |
| try: # Check if running interactively using ipython. | |
| from IPython.core.getipython import get_ipython | |
| if get_ipython() is not None: | |
| is_ipython = True | |
| except (ImportError, NameError): | |
| pass | |
| return is_ipython | |
| def get_space() -> str | None: | |
| if os.getenv("SYSTEM") == "spaces": | |
| return os.getenv("SPACE_ID") | |
| return None | |
| def is_zero_gpu_space() -> bool: | |
| return os.getenv("SPACES_ZERO_GPU") == "true" | |
| def readme_to_html(article: str) -> str: | |
| try: | |
| response = requests.get(article, timeout=3) | |
| if response.status_code == requests.codes.ok: # pylint: disable=no-member | |
| article = response.text | |
| except requests.exceptions.RequestException: | |
| pass | |
| return article | |
| def show_tip(interface: gradio.Blocks) -> None: | |
| if interface.show_tips and random.random() < 1.5: | |
| tip: str = random.choice(en["TIPS"]) | |
| print(f"Tip: {tip}") | |
| def launch_counter() -> None: | |
| try: | |
| if not os.path.exists(JSON_PATH): | |
| launches = {"launches": 1} | |
| with open(JSON_PATH, "w+") as j: | |
| json.dump(launches, j) | |
| else: | |
| with open(JSON_PATH) as j: | |
| launches = json.load(j) | |
| launches["launches"] += 1 | |
| if launches["launches"] in [25, 50, 150, 500, 1000]: | |
| print(en["BETA_INVITE"]) | |
| with open(JSON_PATH, "w") as j: | |
| j.write(json.dumps(launches)) | |
| except Exception: | |
| pass | |
| def get_default_args(func: Callable) -> list[Any]: | |
| signature = inspect.signature(func) | |
| return [ | |
| v.default if v.default is not inspect.Parameter.empty else None | |
| for v in signature.parameters.values() | |
| ] | |
| def assert_configs_are_equivalent_besides_ids( | |
| config1: dict, config2: dict, root_keys: tuple = ("mode",) | |
| ): | |
| """Allows you to test if two different Blocks configs produce the same demo. | |
| Parameters: | |
| config1 (dict): nested dict with config from the first Blocks instance | |
| config2 (dict): nested dict with config from the second Blocks instance | |
| root_keys (Tuple): an interable consisting of which keys to test for equivalence at | |
| the root level of the config. By default, only "mode" is tested, | |
| so keys like "version" are ignored. | |
| """ | |
| config1 = copy.deepcopy(config1) | |
| config2 = copy.deepcopy(config2) | |
| pp = pprint.PrettyPrinter(indent=2) | |
| for key in root_keys: | |
| assert config1[key] == config2[key], f"Configs have different: {key}" | |
| assert len(config1["components"]) == len( | |
| config2["components"] | |
| ), "# of components are different" | |
| def assert_same_components(config1_id, config2_id): | |
| c1 = list(filter(lambda c: c["id"] == config1_id, config1["components"]))[0] | |
| c2 = list(filter(lambda c: c["id"] == config2_id, config2["components"]))[0] | |
| c1 = copy.deepcopy(c1) | |
| c1.pop("id") | |
| c2 = copy.deepcopy(c2) | |
| c2.pop("id") | |
| assert json.dumps(c1) == json.dumps( | |
| c2 | |
| ), f"{pp.pprint(c1)} does not match {pp.pprint(c2)}" | |
| def same_children_recursive(children1, chidren2): | |
| for child1, child2 in zip(children1, chidren2): | |
| assert_same_components(child1["id"], child2["id"]) | |
| if "children" in child1 or "children" in child2: | |
| same_children_recursive(child1["children"], child2["children"]) | |
| children1 = config1["layout"]["children"] | |
| children2 = config2["layout"]["children"] | |
| same_children_recursive(children1, children2) | |
| for d1, d2 in zip(config1["dependencies"], config2["dependencies"]): | |
| for t1, t2 in zip(d1.pop("targets"), d2.pop("targets")): | |
| assert_same_components(t1, t2) | |
| for i1, i2 in zip(d1.pop("inputs"), d2.pop("inputs")): | |
| assert_same_components(i1, i2) | |
| for o1, o2 in zip(d1.pop("outputs"), d2.pop("outputs")): | |
| assert_same_components(o1, o2) | |
| assert d1 == d2, f"{d1} does not match {d2}" | |
| return True | |
| def format_ner_list(input_string: str, ner_groups: list[dict[str, str | int]]): | |
| if len(ner_groups) == 0: | |
| return [(input_string, None)] | |
| output = [] | |
| end = 0 | |
| prev_end = 0 | |
| for group in ner_groups: | |
| entity, start, end = group["entity_group"], group["start"], group["end"] | |
| output.append((input_string[prev_end:start], None)) | |
| output.append((input_string[start:end], entity)) | |
| prev_end = end | |
| output.append((input_string[end:], None)) | |
| return output | |
| def delete_none(_dict: dict, skip_value: bool = False) -> dict: | |
| """ | |
| Delete keys whose values are None from a dictionary | |
| """ | |
| for key, value in list(_dict.items()): | |
| if skip_value and key == "value": | |
| continue | |
| elif value is None: | |
| del _dict[key] | |
| return _dict | |
| def resolve_singleton(_list: list[Any] | Any) -> Any: | |
| if len(_list) == 1: | |
| return _list[0] | |
| else: | |
| return _list | |
| def component_or_layout_class(cls_name: str) -> type[Component] | type[BlockContext]: | |
| """ | |
| Returns the component, template, or layout class with the given class name, or | |
| raises a ValueError if not found. | |
| Parameters: | |
| cls_name (str): lower-case string class name of a component | |
| Returns: | |
| cls: the component class | |
| """ | |
| import gradio.blocks | |
| import gradio.components | |
| import gradio.layouts | |
| import gradio.templates | |
| components = [ | |
| (name, cls) | |
| for name, cls in gradio.components.__dict__.items() | |
| if isinstance(cls, type) | |
| ] | |
| templates = [ | |
| (name, cls) | |
| for name, cls in gradio.templates.__dict__.items() | |
| if isinstance(cls, type) | |
| ] | |
| layouts = [ | |
| (name, cls) | |
| for name, cls in gradio.layouts.__dict__.items() | |
| if isinstance(cls, type) | |
| ] | |
| for name, cls in components + templates + layouts: | |
| if name.lower() == cls_name.replace("_", "") and ( | |
| issubclass(cls, gradio.components.Component) | |
| or issubclass(cls, gradio.blocks.BlockContext) | |
| ): | |
| return cls | |
| raise ValueError(f"No such component or layout: {cls_name}") | |
| def run_coro_in_background(func: Callable, *args, **kwargs): | |
| """ | |
| Runs coroutines in background. | |
| Warning, be careful to not use this function in other than FastAPI scope, because the event_loop has not started yet. | |
| You can use it in any scope reached by FastAPI app. | |
| correct scope examples: endpoints in routes, Blocks.process_api | |
| incorrect scope examples: Blocks.launch | |
| Use startup_events in routes.py if you need to run a coro in background in Blocks.launch(). | |
| Example: | |
| utils.run_coro_in_background(fn, *args, **kwargs) | |
| Args: | |
| func: | |
| *args: | |
| **kwargs: | |
| Returns: | |
| """ | |
| event_loop = asyncio.get_event_loop() | |
| return event_loop.create_task(func(*args, **kwargs)) | |
| def run_sync_iterator_async(iterator): | |
| """Helper for yielding StopAsyncIteration from sync iterators.""" | |
| try: | |
| return next(iterator) | |
| except StopIteration: | |
| # raise a ValueError here because co-routines can't raise StopIteration themselves | |
| raise StopAsyncIteration() from None | |
| class SyncToAsyncIterator: | |
| """Treat a synchronous iterator as async one.""" | |
| def __init__(self, iterator, limiter) -> None: | |
| self.iterator = iterator | |
| self.limiter = limiter | |
| def __aiter__(self): | |
| return self | |
| async def __anext__(self): | |
| return await anyio.to_thread.run_sync( | |
| run_sync_iterator_async, self.iterator, limiter=self.limiter | |
| ) | |
| async def async_iteration(iterator): | |
| # anext not introduced until 3.10 :( | |
| return await iterator.__anext__() | |
| def set_directory(path: Path | str): | |
| """Context manager that sets the working directory to the given path.""" | |
| origin = Path().absolute() | |
| try: | |
| os.chdir(path) | |
| yield | |
| finally: | |
| os.chdir(origin) | |
| def sanitize_value_for_csv(value: str | Number) -> str | Number: | |
| """ | |
| Sanitizes a value that is being written to a CSV file to prevent CSV injection attacks. | |
| Reference: https://owasp.org/www-community/attacks/CSV_Injection | |
| """ | |
| if isinstance(value, Number): | |
| return value | |
| unsafe_prefixes = ["=", "+", "-", "@", "\t", "\n"] | |
| unsafe_sequences = [",=", ",+", ",-", ",@", ",\t", ",\n"] | |
| if any(value.startswith(prefix) for prefix in unsafe_prefixes) or any( | |
| sequence in value for sequence in unsafe_sequences | |
| ): | |
| value = f"'{value}" | |
| return value | |
| def sanitize_list_for_csv(values: list[Any]) -> list[Any]: | |
| """ | |
| Sanitizes a list of values (or a list of list of values) that is being written to a | |
| CSV file to prevent CSV injection attacks. | |
| """ | |
| sanitized_values = [] | |
| for value in values: | |
| if isinstance(value, list): | |
| sanitized_value = [sanitize_value_for_csv(v) for v in value] | |
| sanitized_values.append(sanitized_value) | |
| else: | |
| sanitized_value = sanitize_value_for_csv(value) | |
| sanitized_values.append(sanitized_value) | |
| return sanitized_values | |
| def append_unique_suffix(name: str, list_of_names: list[str]): | |
| """Appends a numerical suffix to `name` so that it does not appear in `list_of_names`.""" | |
| set_of_names: set[str] = set(list_of_names) # for O(1) lookup | |
| if name not in set_of_names: | |
| return name | |
| else: | |
| suffix_counter = 1 | |
| new_name = f"{name}_{suffix_counter}" | |
| while new_name in set_of_names: | |
| suffix_counter += 1 | |
| new_name = f"{name}_{suffix_counter}" | |
| return new_name | |
| def validate_url(possible_url: str) -> bool: | |
| headers = {"User-Agent": "gradio (https://gradio.app/; [email protected])"} | |
| try: | |
| head_request = requests.head(possible_url, headers=headers) | |
| # some URLs, such as AWS S3 presigned URLs, return a 405 or a 403 for HEAD requests | |
| if head_request.status_code == 405 or head_request.status_code == 403: | |
| return requests.get(possible_url, headers=headers).ok | |
| return head_request.ok | |
| except Exception: | |
| return False | |
| def is_update(val): | |
| return isinstance(val, dict) and "update" in val.get("__type__", "") | |
| def get_continuous_fn(fn: Callable, every: float) -> Callable: | |
| def continuous_fn(*args): | |
| while True: | |
| output = fn(*args) | |
| if isinstance(output, GeneratorType): | |
| yield from output | |
| else: | |
| yield output | |
| time.sleep(every) | |
| return continuous_fn | |
| def function_wrapper( | |
| f, before_fn=None, before_args=None, after_fn=None, after_args=None | |
| ): | |
| before_args = [] if before_args is None else before_args | |
| after_args = [] if after_args is None else after_args | |
| if inspect.isasyncgenfunction(f): | |
| async def asyncgen_wrapper(*args, **kwargs): | |
| if before_fn: | |
| before_fn(*before_args) | |
| async for response in f(*args, **kwargs): | |
| yield response | |
| if after_fn: | |
| after_fn(*after_args) | |
| return asyncgen_wrapper | |
| elif asyncio.iscoroutinefunction(f): | |
| async def async_wrapper(*args, **kwargs): | |
| if before_fn: | |
| before_fn(*before_args) | |
| response = await f(*args, **kwargs) | |
| if after_fn: | |
| after_fn(*after_args) | |
| return response | |
| return async_wrapper | |
| elif inspect.isgeneratorfunction(f): | |
| def gen_wrapper(*args, **kwargs): | |
| if before_fn: | |
| before_fn(*before_args) | |
| yield from f(*args, **kwargs) | |
| if after_fn: | |
| after_fn(*after_args) | |
| return gen_wrapper | |
| else: | |
| def wrapper(*args, **kwargs): | |
| if before_fn: | |
| before_fn(*before_args) | |
| response = f(*args, **kwargs) | |
| if after_fn: | |
| after_fn(*after_args) | |
| return response | |
| return wrapper | |
| def get_function_with_locals(fn: Callable, blocks: Blocks, event_id: str | None): | |
| def before_fn(blocks, event_id): | |
| from gradio.context import thread_data | |
| thread_data.blocks = blocks | |
| thread_data.event_id = event_id | |
| return function_wrapper(fn, before_fn=before_fn, before_args=(blocks, event_id)) | |
| async def cancel_tasks(task_ids: set[str]): | |
| matching_tasks = [ | |
| task for task in asyncio.all_tasks() if task.get_name() in task_ids | |
| ] | |
| for task in matching_tasks: | |
| task.cancel() | |
| await asyncio.gather(*matching_tasks, return_exceptions=True) | |
| def set_task_name(task, session_hash: str, fn_index: int, batch: bool): | |
| if not batch: | |
| task.set_name(f"{session_hash}_{fn_index}") | |
| def get_cancel_function( | |
| dependencies: list[dict[str, Any]] | |
| ) -> tuple[Callable, list[int]]: | |
| fn_to_comp = {} | |
| for dep in dependencies: | |
| if Context.root_block: | |
| fn_index = next( | |
| i for i, d in enumerate(Context.root_block.dependencies) if d == dep | |
| ) | |
| fn_to_comp[fn_index] = [ | |
| Context.root_block.blocks[o] for o in dep["outputs"] | |
| ] | |
| async def cancel(session_hash: str) -> None: | |
| task_ids = {f"{session_hash}_{fn}" for fn in fn_to_comp} | |
| await cancel_tasks(task_ids) | |
| return ( | |
| cancel, | |
| list(fn_to_comp.keys()), | |
| ) | |
| def get_type_hints(fn): | |
| # Importing gradio with the canonical abbreviation. Used in typing._eval_type. | |
| import gradio as gr # noqa: F401 | |
| from gradio import OAuthProfile, Request # noqa: F401 | |
| if inspect.isfunction(fn) or inspect.ismethod(fn): | |
| pass | |
| elif callable(fn): | |
| fn = fn.__call__ | |
| else: | |
| return {} | |
| try: | |
| return typing.get_type_hints(fn) | |
| except TypeError: | |
| # On Python 3.9 or earlier, get_type_hints throws a TypeError if the function | |
| # has a type annotation that include "|". We resort to parsing the signature | |
| # manually using inspect.signature. | |
| type_hints = {} | |
| sig = inspect.signature(fn) | |
| for name, param in sig.parameters.items(): | |
| if param.annotation is inspect.Parameter.empty: | |
| continue | |
| if param.annotation == "gr.OAuthProfile | None": | |
| # Special case: we want to inject the OAuthProfile value even on Python 3.9 | |
| type_hints[name] = Optional[OAuthProfile] | |
| if "|" in str(param.annotation): | |
| continue | |
| # To convert the string annotation to a class, we use the | |
| # internal typing._eval_type function. This is not ideal, but | |
| # it's the only way to do it without eval-ing the string. | |
| # Since the API is internal, it may change in the future. | |
| try: | |
| type_hints[name] = typing._eval_type( # type: ignore | |
| typing.ForwardRef(param.annotation), globals(), locals() | |
| ) | |
| except (NameError, TypeError): | |
| pass | |
| return type_hints | |
| def is_special_typed_parameter(name, parameter_types): | |
| from gradio.helpers import EventData | |
| from gradio.oauth import OAuthProfile | |
| from gradio.routes import Request | |
| """Checks if parameter has a type hint designating it as a gr.Request, gr.EventData or gr.OAuthProfile.""" | |
| hint = parameter_types.get(name) | |
| if not hint: | |
| return False | |
| is_request = hint == Request | |
| is_oauth_arg = hint in (OAuthProfile, Optional[OAuthProfile]) | |
| is_event_data = inspect.isclass(hint) and issubclass(hint, EventData) | |
| return is_request or is_event_data or is_oauth_arg | |
| def check_function_inputs_match(fn: Callable, inputs: list, inputs_as_dict: bool): | |
| """ | |
| Checks if the input component set matches the function | |
| Returns: None if valid, a string error message if mismatch | |
| """ | |
| signature = inspect.signature(fn) | |
| parameter_types = get_type_hints(fn) | |
| min_args = 0 | |
| max_args = 0 | |
| infinity = -1 | |
| for name, param in signature.parameters.items(): | |
| has_default = param.default != param.empty | |
| if param.kind in [param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD]: | |
| if not is_special_typed_parameter(name, parameter_types): | |
| if not has_default: | |
| min_args += 1 | |
| max_args += 1 | |
| elif param.kind == param.VAR_POSITIONAL: | |
| max_args = infinity | |
| elif param.kind == param.KEYWORD_ONLY and not has_default: | |
| return f"Keyword-only args must have default values for function {fn}" | |
| arg_count = 1 if inputs_as_dict else len(inputs) | |
| if min_args == max_args and max_args != arg_count: | |
| warnings.warn( | |
| f"Expected {max_args} arguments for function {fn}, received {arg_count}." | |
| ) | |
| if arg_count < min_args: | |
| warnings.warn( | |
| f"Expected at least {min_args} arguments for function {fn}, received {arg_count}." | |
| ) | |
| if max_args != infinity and arg_count > max_args: | |
| warnings.warn( | |
| f"Expected maximum {max_args} arguments for function {fn}, received {arg_count}." | |
| ) | |
| def concurrency_count_warning(queue: Callable[P, T]) -> Callable[P, T]: | |
| def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: | |
| _self, *positional = args | |
| if is_zero_gpu_space() and ( | |
| len(positional) >= 1 or "concurrency_count" in kwargs | |
| ): | |
| warnings.warn( | |
| "Queue concurrency_count on ZeroGPU Spaces cannot be overridden " | |
| "and is always equal to Block's max_threads. " | |
| "Consider setting max_threads value on the Block instead" | |
| ) | |
| return queue(*args, **kwargs) | |
| return wrapper | |
| class TupleNoPrint(tuple): | |
| # To remove printing function return in notebook | |
| def __repr__(self): | |
| return "" | |
| def __str__(self): | |
| return "" | |
| class MatplotlibBackendMananger: | |
| def __enter__(self): | |
| self._original_backend = matplotlib.get_backend() | |
| matplotlib.use("agg") | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| matplotlib.use(self._original_backend) | |
| def tex2svg(formula, *args): | |
| with MatplotlibBackendMananger(): | |
| import matplotlib.pyplot as plt | |
| fontsize = 20 | |
| dpi = 300 | |
| plt.rc("mathtext", fontset="cm") | |
| fig = plt.figure(figsize=(0.01, 0.01)) | |
| fig.text(0, 0, rf"${formula}$", fontsize=fontsize) | |
| output = BytesIO() | |
| fig.savefig( | |
| output, | |
| dpi=dpi, | |
| transparent=True, | |
| format="svg", | |
| bbox_inches="tight", | |
| pad_inches=0.0, | |
| ) | |
| plt.close(fig) | |
| output.seek(0) | |
| xml_code = output.read().decode("utf-8") | |
| svg_start = xml_code.index("<svg ") | |
| svg_code = xml_code[svg_start:] | |
| svg_code = re.sub(r"<metadata>.*<\/metadata>", "", svg_code, flags=re.DOTALL) | |
| svg_code = re.sub(r' width="[^"]+"', "", svg_code) | |
| height_match = re.search(r'height="([\d.]+)pt"', svg_code) | |
| if height_match: | |
| height = float(height_match.group(1)) | |
| new_height = height / fontsize # conversion from pt to em | |
| svg_code = re.sub( | |
| r'height="[\d.]+pt"', f'height="{new_height}em"', svg_code | |
| ) | |
| copy_code = f"<span style='font-size: 0px'>{formula}</span>" | |
| return f"{copy_code}{svg_code}" | |
| def abspath(path: str | Path) -> Path: | |
| """Returns absolute path of a str or Path path, but does not resolve symlinks.""" | |
| path = Path(path) | |
| if path.is_absolute(): | |
| return path | |
| # recursively check if there is a symlink within the path | |
| is_symlink = path.is_symlink() or any( | |
| parent.is_symlink() for parent in path.parents | |
| ) | |
| if is_symlink or path == path.resolve(): # in case path couldn't be resolved | |
| return Path.cwd() / path | |
| else: | |
| return path.resolve() | |
| def is_in_or_equal(path_1: str | Path, path_2: str | Path): | |
| """ | |
| True if path_1 is a descendant (i.e. located within) path_2 or if the paths are the | |
| same, returns False otherwise. | |
| Parameters: | |
| path_1: str or Path (should be a file) | |
| path_2: str or Path (can be a file or directory) | |
| """ | |
| path_1, path_2 = abspath(path_1), abspath(path_2) | |
| try: | |
| if str(path_1.relative_to(path_2)).startswith(".."): # prevent path traversal | |
| return False | |
| except ValueError: | |
| return False | |
| return True | |
| def get_serializer_name(block: Block) -> str | None: | |
| if not hasattr(block, "serialize"): | |
| return None | |
| def get_class_that_defined_method(meth: Callable): | |
| # Adapted from: https://stackoverflow.com/a/25959545/5209347 | |
| if isinstance(meth, functools.partial): | |
| return get_class_that_defined_method(meth.func) | |
| if inspect.ismethod(meth) or ( | |
| inspect.isbuiltin(meth) | |
| and getattr(meth, "__self__", None) is not None | |
| and getattr(meth.__self__, "__class__", None) | |
| ): | |
| for cls in inspect.getmro(meth.__self__.__class__): | |
| # Find the first serializer defined in gradio_client that | |
| if issubclass(cls, Serializable) and "gradio_client" in cls.__module__: | |
| return cls | |
| if meth.__name__ in cls.__dict__: | |
| return cls | |
| meth = getattr(meth, "__func__", meth) # fallback to __qualname__ parsing | |
| if inspect.isfunction(meth): | |
| cls = getattr( | |
| inspect.getmodule(meth), | |
| meth.__qualname__.split(".<locals>", 1)[0].rsplit(".", 1)[0], | |
| None, | |
| ) | |
| if isinstance(cls, type): | |
| return cls | |
| return getattr(meth, "__objclass__", None) | |
| cls = get_class_that_defined_method(block.serialize) # type: ignore | |
| if cls: | |
| return cls.__name__ | |
| HTML_TAG_RE = re.compile("<.*?>") | |
| def remove_html_tags(raw_html: str | None) -> str: | |
| return re.sub(HTML_TAG_RE, "", raw_html or "") | |
| def find_user_stack_level() -> int: | |
| """ | |
| Find the first stack frame not inside Gradio. | |
| """ | |
| frame = inspect.currentframe() | |
| n = 0 | |
| while frame: | |
| fname = inspect.getfile(frame) | |
| if "/gradio/" not in fname.replace(os.sep, "/"): | |
| break | |
| frame = frame.f_back | |
| n += 1 | |
| return n | |