Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| from copy import deepcopy | |
| from typing import TYPE_CHECKING, Optional, Union | |
| import fastapi | |
| from gradio_client.documentation import document, set_documentation_group | |
| from gradio import utils | |
| from gradio.data_classes import PredictBody | |
| from gradio.exceptions import Error | |
| from gradio.helpers import EventData | |
| if TYPE_CHECKING: | |
| from gradio.routes import App | |
| set_documentation_group("routes") | |
| class Obj: | |
| """ | |
| Using a class to convert dictionaries into objects. Used by the `Request` class. | |
| Credit: https://www.geeksforgeeks.org/convert-nested-python-dictionary-to-object/ | |
| """ | |
| def __init__(self, dict_): | |
| self.__dict__.update(dict_) | |
| for key, value in dict_.items(): | |
| if isinstance(value, (dict, list)): | |
| value = Obj(value) | |
| setattr(self, key, value) | |
| def __getitem__(self, item): | |
| return self.__dict__[item] | |
| def __setitem__(self, item, value): | |
| self.__dict__[item] = value | |
| def __iter__(self): | |
| for key, value in self.__dict__.items(): | |
| if isinstance(value, Obj): | |
| yield (key, dict(value)) | |
| else: | |
| yield (key, value) | |
| def __contains__(self, item) -> bool: | |
| if item in self.__dict__: | |
| return True | |
| for value in self.__dict__.values(): | |
| if isinstance(value, Obj) and item in value: | |
| return True | |
| return False | |
| def keys(self): | |
| return self.__dict__.keys() | |
| def values(self): | |
| return self.__dict__.values() | |
| def items(self): | |
| return self.__dict__.items() | |
| def __str__(self) -> str: | |
| return str(self.__dict__) | |
| def __repr__(self) -> str: | |
| return str(self.__dict__) | |
| class Request: | |
| """ | |
| A Gradio request object that can be used to access the request headers, cookies, | |
| query parameters and other information about the request from within the prediction | |
| function. The class is a thin wrapper around the fastapi.Request class. Attributes | |
| of this class include: `headers`, `client`, `query_params`, and `path_params`. If | |
| auth is enabled, the `username` attribute can be used to get the logged in user. | |
| Example: | |
| import gradio as gr | |
| def echo(name, request: gr.Request): | |
| print("Request headers dictionary:", request.headers) | |
| print("IP address:", request.client.host) | |
| return name | |
| io = gr.Interface(echo, "textbox", "textbox").launch() | |
| """ | |
| def __init__( | |
| self, | |
| request: fastapi.Request | None = None, | |
| username: str | None = None, | |
| **kwargs, | |
| ): | |
| """ | |
| Can be instantiated with either a fastapi.Request or by manually passing in | |
| attributes (needed for websocket-based queueing). | |
| Parameters: | |
| request: A fastapi.Request | |
| """ | |
| self.request = request | |
| self.username = username | |
| self.kwargs: dict = kwargs | |
| def dict_to_obj(self, d): | |
| if isinstance(d, dict): | |
| return json.loads(json.dumps(d), object_hook=Obj) | |
| else: | |
| return d | |
| def __getattr__(self, name): | |
| if self.request: | |
| return self.dict_to_obj(getattr(self.request, name)) | |
| else: | |
| try: | |
| obj = self.kwargs[name] | |
| except KeyError as ke: | |
| raise AttributeError( | |
| f"'Request' object has no attribute '{name}'" | |
| ) from ke | |
| return self.dict_to_obj(obj) | |
| class FnIndexInferError(Exception): | |
| pass | |
| def infer_fn_index(app: App, api_name: str, body: PredictBody) -> int: | |
| if body.fn_index is None: | |
| for i, fn in enumerate(app.get_blocks().dependencies): | |
| if fn["api_name"] == api_name: | |
| return i | |
| raise FnIndexInferError(f"Could not infer fn_index for api_name {api_name}.") | |
| else: | |
| return body.fn_index | |
| def compile_gr_request( | |
| app: App, | |
| body: PredictBody, | |
| fn_index_inferred: int, | |
| username: Optional[str], | |
| request: Optional[fastapi.Request], | |
| ): | |
| # If this fn_index cancels jobs, then the only input we need is the | |
| # current session hash | |
| if app.get_blocks().dependencies[fn_index_inferred]["cancels"]: | |
| body.data = [body.session_hash] | |
| if body.request: | |
| if body.batched: | |
| gr_request = [Request(username=username, **req) for req in body.request] | |
| else: | |
| assert isinstance(body.request, dict) | |
| gr_request = Request(username=username, **body.request) | |
| else: | |
| if request is None: | |
| raise ValueError("request must be provided if body.request is None") | |
| gr_request = Request(username=username, request=request) | |
| return gr_request | |
| def restore_session_state(app: App, body: PredictBody): | |
| fn_index = body.fn_index | |
| session_hash = getattr(body, "session_hash", None) | |
| if session_hash is not None: | |
| if session_hash not in app.state_holder: | |
| app.state_holder[session_hash] = { | |
| _id: deepcopy(getattr(block, "value", None)) | |
| for _id, block in app.get_blocks().blocks.items() | |
| if getattr(block, "stateful", False) | |
| } | |
| session_state = app.state_holder[session_hash] | |
| # The should_reset set keeps track of the fn_indices | |
| # that have been cancelled. When a job is cancelled, | |
| # the /reset route will mark the jobs as having been reset. | |
| # That way if the cancel job finishes BEFORE the job being cancelled | |
| # the job being cancelled will not overwrite the state of the iterator. | |
| if fn_index in app.iterators_to_reset[session_hash]: | |
| iterators = {} | |
| app.iterators_to_reset[session_hash].remove(fn_index) | |
| else: | |
| iterators = app.iterators[session_hash] | |
| else: | |
| session_state = {} | |
| iterators = {} | |
| return session_state, iterators | |
| async def call_process_api( | |
| app: App, | |
| body: PredictBody, | |
| gr_request: Union[Request, list[Request]], | |
| fn_index_inferred, | |
| ): | |
| session_state, iterators = restore_session_state(app=app, body=body) | |
| dependency = app.get_blocks().dependencies[fn_index_inferred] | |
| target = dependency["targets"][0] if len(dependency["targets"]) else None | |
| event_data = EventData( | |
| app.get_blocks().blocks.get(target) if target else None, | |
| body.event_data, | |
| ) | |
| event_id = getattr(body, "event_id", None) | |
| fn_index = body.fn_index | |
| session_hash = getattr(body, "session_hash", None) | |
| inputs = body.data | |
| batch_in_single_out = not body.batched and dependency["batch"] | |
| if batch_in_single_out: | |
| inputs = [inputs] | |
| try: | |
| with utils.MatplotlibBackendMananger(): | |
| output = await app.get_blocks().process_api( | |
| fn_index=fn_index_inferred, | |
| inputs=inputs, | |
| request=gr_request, | |
| state=session_state, | |
| iterators=iterators, | |
| session_hash=session_hash, | |
| event_id=event_id, | |
| event_data=event_data, | |
| ) | |
| iterator = output.pop("iterator", None) | |
| if hasattr(body, "session_hash"): | |
| app.iterators[body.session_hash][fn_index] = iterator | |
| if isinstance(output, Error): | |
| raise output | |
| except BaseException: | |
| iterator = iterators.get(fn_index, None) | |
| if iterator is not None: # close off any streams that are still open | |
| run_id = id(iterator) | |
| pending_streams: dict[int, list] = ( | |
| app.get_blocks().pending_streams[session_hash].get(run_id, {}) | |
| ) | |
| for stream in pending_streams.values(): | |
| stream.append(None) | |
| raise | |
| if batch_in_single_out: | |
| output["data"] = output["data"][0] | |
| return output | |