rtr46's picture
Update app.py
fc42767 verified
import gradio as gr
import onnxruntime as ort
import numpy as np
import cv2
from huggingface_hub import hf_hub_download # <-- IMPORT THE DOWNLOADER
# --- 1. GLOBAL SETUP: DOWNLOAD AND LOAD MODELS AT STARTUP ---
# This is the recommended way to use models in a Space.
try:
print("Downloading and loading ONNX models from the Hub...")
# Define your model repository ID
MODEL_REPO = "rtr46/meiki.text.detect.v0"
# hf_hub_download will download the file and cache it, returning the local path.
tiny_model_path = hf_hub_download(repo_id=MODEL_REPO, filename="meiki.text.detect.tiny.v0.onnx")
small_model_path = hf_hub_download(repo_id=MODEL_REPO, filename="meiki.text.detect.small.v0.onnx")
# Use CPUExecutionProvider for broad compatibility
providers = ['CPUExecutionProvider']
ort_session_tiny = ort.InferenceSession(tiny_model_path, providers=providers)
ort_session_small = ort.InferenceSession(small_model_path, providers=providers)
print("Models loaded successfully.")
except Exception as e:
print(f"Error loading models: {e}")
# If models fail to load, the app will not work.
ort_session_tiny = None
ort_session_small = None
# --- 2. HELPER FUNCTION: PREPROCESSING ---
# (This section remains exactly the same)
def resize_and_pad(image: np.ndarray, size: int, is_color: bool):
""" Resizes and pads an image, works for both grayscale and color. """
if is_color:
h, w, _ = image.shape
else:
h, w = image.shape
ratio = min(size / w, size / h)
new_w, new_h = int(w * ratio), int(h * ratio)
resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
if is_color:
padded_image = np.zeros((size, size, 3), dtype=np.uint8)
else:
padded_image = np.zeros((size, size), dtype=np.uint8)
pad_w, pad_h = (size - new_w) // 2, (size - new_h) // 2
padded_image[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = resized_image
return padded_image, ratio, pad_w, pad_h
# --- 3. CORE INFERENCE FUNCTION ---
# (This section remains exactly the same)
def detect_text(model_name, input_image, confidence_threshold):
"""
Performs text detection on the input image using the selected model.
"""
if ort_session_tiny is None or ort_session_small is None:
raise gr.Error("Models are not loaded. Please check the console logs for errors.")
if model_name == "tiny":
session = ort_session_tiny
model_size = 320
is_color = False
else: # "small"
session = ort_session_small
model_size = 640
is_color = True
output_image = input_image.copy()
if is_color:
image_for_model = input_image
else:
image_for_model = cv2.cvtColor(input_image, cv2.COLOR_BGR2GRAY)
padded_image, ratio, pad_w, pad_h = resize_and_pad(image_for_model, model_size, is_color)
img_normalized = padded_image.astype(np.float32) / 255.0
if is_color:
img_transposed = np.transpose(img_normalized, (2, 0, 1))
input_tensor = np.expand_dims(img_transposed, axis=0)
else:
input_tensor = np.expand_dims(np.expand_dims(img_normalized, axis=0), axis=0)
sizes_tensor = np.array([[model_size, model_size]], dtype=np.int64)
input_names = [inp.name for inp in session.get_inputs()]
inputs = {input_names[0]: input_tensor, input_names[1]: sizes_tensor}
outputs = session.run(None, inputs)
if model_name == "tiny":
boxes = outputs[0]
scores = [1.0] * len(boxes)
else:
_, boxes, scores = outputs
boxes, scores = boxes[0], scores[0]
box_count = 0
for box, score in zip(boxes, scores):
if score < confidence_threshold:
continue
box_count += 1
x_min, y_min, x_max, y_max = box
final_x_min = int((x_min - pad_w) / ratio)
final_y_min = int((y_min - pad_h) / ratio)
final_x_max = int((x_max - pad_w) / ratio)
final_y_max = int((y_max - pad_h) / ratio)
color = (0, 255, 0) if model_name == "small" else (0, 0, 255)
cv2.rectangle(output_image, (final_x_min, final_y_min), (final_x_max, final_y_max), color, 2)
print(f"Processed with '{model_name}' model. Found {box_count} boxes with confidence > {confidence_threshold}.")
return output_image
# --- 4. GRADIO INTERFACE ---
# (This section remains exactly the same)
with gr.Blocks() as demo:
gr.Markdown("# meiki text detect v0")
gr.Markdown(
"upload an image and choose a model to detect horizontal and vertical text lines. "
"the **small** model is more accurate, especially for images with many text lines like manga, while the **tiny** model is much faster."
)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="numpy", label="upload image")
model_name = gr.Radio(
["tiny", "small"], label="choose model", value="small"
)
confidence_threshold = gr.Slider(
minimum=0.1, maximum=1.0, value=0.4, step=0.1, label="confidence threshold"
)
detect_button = gr.Button("detect text", variant="primary")
with gr.Column():
output_image = gr.Image(type="numpy", label="result")
detect_button.click(
fn=detect_text,
inputs=[model_name, input_image, confidence_threshold],
outputs=output_image
)
# --- 5. LAUNCH THE APP ---
demo.launch()