document_redaction_vlm / tools /word_segmenter.py
seanpedrickcase's picture
Sync: Corrected input image creation location so that output redaction pdfs have coordinates correctly placed
e417155
import os
from typing import Dict, List, Tuple
import cv2
import numpy as np
from tools.config import OUTPUT_FOLDER, SAVE_WORD_SEGMENTER_OUTPUT_IMAGES
# Adaptive thresholding parameters
BLOCK_SIZE_FACTOR = 1.5 # Multiplier for adaptive threshold block size
C_VALUE = 2 # Constant subtracted from mean in adaptive thresholding
# Word segmentation search parameters
INITIAL_KERNEL_WIDTH_FACTOR = 0.0 # Starting kernel width factor for Stage 2 search
INITIAL_VALLEY_THRESHOLD_FACTOR = (
0.0 # Starting valley threshold factor for Stage 1 search
)
MAIN_VALLEY_THRESHOLD_FACTOR = (
0.15 # Primary valley threshold factor for word separation
)
MIN_SPACE_FACTOR = 0.2 # Minimum space width relative to character width
MATCH_TOLERANCE = 0 # Tolerance for word count matching
# Noise removal parameters
MIN_AREA_THRESHOLD = 6 # Minimum component area to be considered valid text
DEFAULT_TRIM_PERCENTAGE = (
0.2 # Percentage to trim from top/bottom for vertical cropping
)
# Skew detection parameters
MIN_SKEW_THRESHOLD = 0.5 # Ignore angles smaller than this (likely noise)
MAX_SKEW_THRESHOLD = 15.0 # Angles larger than this are extreme and likely errors
def _sanitize_filename(filename: str, max_length: int = 100) -> str:
"""
Sanitizes a string to be used as a valid filename.
Removes or replaces invalid characters for Windows/Linux file systems.
Args:
filename: The string to sanitize
max_length: Maximum length of the sanitized filename
Returns:
A sanitized string safe for use in file names
"""
if not filename:
return "unnamed"
# Replace spaces with underscores
sanitized = filename.replace(" ", "_")
# Remove or replace invalid characters for Windows/Linux
# Invalid: < > : " / \ | ? *
invalid_chars = '<>:"/\\|?*'
for char in invalid_chars:
sanitized = sanitized.replace(char, "_")
# Remove control characters
sanitized = "".join(
char for char in sanitized if ord(char) >= 32 or char in "\n\r\t"
)
# Remove leading/trailing dots and spaces (Windows doesn't allow these)
sanitized = sanitized.strip(". ")
# Replace multiple consecutive underscores with a single one
while "__" in sanitized:
sanitized = sanitized.replace("__", "_")
# Truncate if too long
if len(sanitized) > max_length:
sanitized = sanitized[:max_length]
# Ensure it's not empty after sanitization
if not sanitized:
sanitized = "unnamed"
return sanitized
class AdaptiveSegmenter:
"""
Line to word segmentation pipeline. It features:
1. Adaptive Thresholding.
2. Targeted Noise Removal using Connected Component Analysis.
3. The robust two-stage adaptive search (Valley -> Kernel).
4. CCA for final pixel-perfect refinement.
"""
def __init__(self, output_folder: str = OUTPUT_FOLDER):
self.output_folder = output_folder
self.fallback_segmenter = HybridWordSegmenter()
def _correct_orientation(
self, gray_image: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
"""
Detects and corrects 90-degree orientation issues.
"""
h, w = gray_image.shape
center = (w // 2, h // 2)
block_size = 21
if h < block_size:
block_size = h if h % 2 != 0 else h - 1
if block_size > 3:
binary = cv2.adaptiveThreshold(
gray_image,
255,
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV,
block_size,
4,
)
else:
_, binary = cv2.threshold(
gray_image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU
)
opening_kernel = np.ones((2, 2), np.uint8)
binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, opening_kernel)
coords = np.column_stack(np.where(binary > 0))
if len(coords) < 50:
M_orient = cv2.getRotationMatrix2D(center, 0, 1.0)
return gray_image, M_orient
ymin, xmin = coords.min(axis=0)
ymax, xmax = coords.max(axis=0)
box_height = ymax - ymin
box_width = xmax - xmin
orientation_angle = 0.0
if box_height > box_width:
orientation_angle = 90.0
else:
M_orient = cv2.getRotationMatrix2D(center, 0, 1.0)
return gray_image, M_orient
M_orient = cv2.getRotationMatrix2D(center, orientation_angle, 1.0)
new_w, new_h = h, w
M_orient[0, 2] += (new_w - w) / 2
M_orient[1, 2] += (new_h - h) / 2
oriented_gray = cv2.warpAffine(
gray_image,
M_orient,
(new_w, new_h),
flags=cv2.INTER_CUBIC,
borderMode=cv2.BORDER_REPLICATE,
)
return oriented_gray, M_orient
def _deskew_image(self, gray_image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Detects skew using a robust method that normalizes minAreaRect.
"""
h, w = gray_image.shape
block_size = 21
if h < block_size:
block_size = h if h % 2 != 0 else h - 1
if block_size > 3:
binary = cv2.adaptiveThreshold(
gray_image,
255,
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV,
block_size,
4,
)
else:
_, binary = cv2.threshold(
gray_image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU
)
opening_kernel = np.ones((2, 2), np.uint8)
binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, opening_kernel)
coords = np.column_stack(np.where(binary > 0))
if len(coords) < 50:
M = cv2.getRotationMatrix2D((w // 2, h // 2), 0, 1.0)
return gray_image, M
rect = cv2.minAreaRect(coords[:, ::-1])
rect_width, rect_height = rect[1]
angle = rect[2]
if rect_width < rect_height:
rect_width, rect_height = rect_height, rect_width
angle += 90
if angle > 45:
angle -= 90
elif angle < -45:
angle += 90
correction_angle = angle
if abs(correction_angle) < MIN_SKEW_THRESHOLD:
correction_angle = 0.0
elif abs(correction_angle) > MAX_SKEW_THRESHOLD:
correction_angle = 0.0
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, correction_angle, 1.0)
deskewed_gray = cv2.warpAffine(
gray_image,
M,
(w, h),
flags=cv2.INTER_CUBIC,
borderMode=cv2.BORDER_REPLICATE,
)
return deskewed_gray, M
def _get_boxes_from_profile(
self,
binary_image: np.ndarray,
stable_avg_char_width: float,
min_space_factor: float,
valley_threshold_factor: float,
) -> List:
"""
Extracts word bounding boxes from vertical projection profile.
"""
img_h, img_w = binary_image.shape
vertical_projection = np.sum(binary_image, axis=0)
peaks = vertical_projection[vertical_projection > 0]
if len(peaks) == 0:
return []
avg_peak_height = np.mean(peaks)
valley_threshold = int(avg_peak_height * valley_threshold_factor)
min_space_width = int(stable_avg_char_width * min_space_factor)
patched_projection = vertical_projection.copy()
in_gap = False
gap_start = 0
for x, col_sum in enumerate(patched_projection):
if col_sum <= valley_threshold and not in_gap:
in_gap = True
gap_start = x
elif col_sum > valley_threshold and in_gap:
in_gap = False
if (x - gap_start) < min_space_width:
patched_projection[gap_start:x] = int(avg_peak_height)
unlabeled_boxes = []
in_word = False
start_x = 0
for x, col_sum in enumerate(patched_projection):
if col_sum > valley_threshold and not in_word:
start_x = x
in_word = True
elif col_sum <= valley_threshold and in_word:
# [NOTE] Returns full height stripe
unlabeled_boxes.append((start_x, 0, x - start_x, img_h))
in_word = False
if in_word:
unlabeled_boxes.append((start_x, 0, img_w - start_x, img_h))
return unlabeled_boxes
def _enforce_logical_constraints(
self, output: Dict[str, List], image_width: int, image_height: int
) -> Dict[str, List]:
"""
Enforces geometric sanity checks with 2D awareness.
"""
if not output or not output["text"]:
return output
num_items = len(output["text"])
boxes = []
for i in range(num_items):
boxes.append(
{
"text": output["text"][i],
"left": int(output["left"][i]),
"top": int(output["top"][i]),
"width": int(output["width"][i]),
"height": int(output["height"][i]),
"conf": output["conf"][i],
}
)
valid_boxes = []
for box in boxes:
x0 = max(0, box["left"])
y0 = max(0, box["top"])
x1 = min(image_width, box["left"] + box["width"])
y1 = min(image_height, box["top"] + box["height"])
w = x1 - x0
h = y1 - y0
if w > 0 and h > 0:
box["left"] = x0
box["top"] = y0
box["width"] = w
box["height"] = h
valid_boxes.append(box)
boxes = valid_boxes
is_vertical = image_height > (image_width * 1.2)
if is_vertical:
boxes.sort(key=lambda b: (b["top"], b["left"]))
else:
boxes.sort(key=lambda b: (b["left"], -b["width"]))
final_pass_boxes = []
if boxes:
keep_indices = [True] * len(boxes)
for i in range(len(boxes)):
for j in range(len(boxes)):
if i == j:
continue
b1 = boxes[i]
b2 = boxes[j]
x_nested = (b1["left"] >= b2["left"] - 2) and (
b1["left"] + b1["width"] <= b2["left"] + b2["width"] + 2
)
y_nested = (b1["top"] >= b2["top"] - 2) and (
b1["top"] + b1["height"] <= b2["top"] + b2["height"] + 2
)
if x_nested and y_nested:
if b1["text"] == b2["text"]:
if b1["width"] * b1["height"] <= b2["width"] * b2["height"]:
keep_indices[i] = False
for i, keep in enumerate(keep_indices):
if keep:
final_pass_boxes.append(boxes[i])
boxes = final_pass_boxes
if is_vertical:
boxes.sort(key=lambda b: (b["top"], b["left"]))
else:
boxes.sort(key=lambda b: (b["left"], -b["width"]))
for i in range(len(boxes)):
for j in range(i + 1, len(boxes)):
b1 = boxes[i]
b2 = boxes[j]
x_overlap = min(
b1["left"] + b1["width"], b2["left"] + b2["width"]
) - max(b1["left"], b2["left"])
y_overlap = min(
b1["top"] + b1["height"], b2["top"] + b2["height"]
) - max(b1["top"], b2["top"])
if x_overlap > 0 and y_overlap > 0:
if is_vertical:
if b1["top"] < b2["top"]:
new_h = max(1, b2["top"] - b1["top"])
b1["height"] = new_h
else:
if b1["left"] < b2["left"]:
b1_right = b1["left"] + b1["width"]
b2_right = b2["left"] + b2["width"]
left_slice_width = max(0, b2["left"] - b1["left"])
right_slice_width = max(0, b1_right - b2_right)
if (
b1_right > b2_right
and right_slice_width > left_slice_width
):
b1["left"] = b2_right
b1["width"] = right_slice_width
else:
b1["width"] = max(1, left_slice_width)
cleaned_output = {
k: [] for k in ["text", "left", "top", "width", "height", "conf"]
}
if is_vertical:
boxes.sort(key=lambda b: (b["top"], b["left"]))
else:
boxes.sort(key=lambda b: (b["left"], -b["width"]))
for box in boxes:
for key in cleaned_output.keys():
cleaned_output[key].append(box[key])
return cleaned_output
def _is_geometry_valid(
self,
boxes: List[Tuple[int, int, int, int]],
words: List[str],
expected_height: float = 0,
) -> bool:
"""
Validates if the detected boxes are physically plausible.
[FIX] Improved robustness for punctuation and mixed-case text.
"""
if len(boxes) != len(words):
return False
baseline = expected_height
# Use median only if provided expected height is unreliable
if baseline < 5:
heights = [b[3] for b in boxes]
if heights:
baseline = np.median(heights)
if baseline < 5:
return True
for i, box in enumerate(boxes):
word = words[i]
# [FIX] Check for punctuation/symbols. They are allowed to be small.
# If word is just punctuation, skip geometry checks
is_punctuation = not any(c.isalnum() for c in word)
if is_punctuation:
continue
# Standard checks for alphanumeric words
num_chars = len(word)
if num_chars < 1:
continue
width = box[2]
height = box[3]
# [FIX] Only reject height if it's REALLY small compared to baseline
# A period might be small, but we skipped that check above.
# This check ensures a real word like "The" isn't 2 pixels tall.
if height < (baseline * 0.20):
return False
avg_char_width = width / num_chars
min_expected = baseline * 0.20
# Only reject if it fails BOTH absolute (4px) and relative checks
if avg_char_width < min_expected and avg_char_width < 4:
# Exception: If the word is 1 char long (e.g. "I", "l", "1"), allow it to be skinny.
if num_chars == 1 and avg_char_width >= 2:
continue
return False
return True
def segment(
self,
line_data: Dict[str, List],
line_image: np.ndarray,
min_space_factor=MIN_SPACE_FACTOR,
match_tolerance=MATCH_TOLERANCE,
image_name: str = None,
) -> Tuple[Dict[str, List], bool]:
if (
line_image is None
or not isinstance(line_image, np.ndarray)
or line_image.size == 0
):
return ({}, False)
# Allow grayscale (2 dims) or color (3 dims)
if len(line_image.shape) < 2:
return ({}, False)
if not line_data or not line_data.get("text") or len(line_data["text"]) == 0:
return ({}, False)
line_text = line_data["text"][0]
words = line_text.split()
# Early return if 1 or fewer words
if len(words) <= 1:
img_h, img_w = line_image.shape[:2]
one_word_result = self.fallback_segmenter.convert_line_to_word_level(
line_data, img_w, img_h
)
return (one_word_result, False)
line_number = line_data["line"][0]
safe_image_name = _sanitize_filename(image_name or "image", max_length=50)
safe_line_number = _sanitize_filename(str(line_number), max_length=10)
safe_shortened_line_text = _sanitize_filename(line_text, max_length=10)
if SAVE_WORD_SEGMENTER_OUTPUT_IMAGES:
os.makedirs(self.output_folder, exist_ok=True)
output_path = f"{self.output_folder}/word_segmentation/{safe_image_name}_{safe_line_number}_{safe_shortened_line_text}_original.png"
os.makedirs(f"{self.output_folder}/word_segmentation", exist_ok=True)
cv2.imwrite(output_path, line_image)
if len(line_image.shape) == 3:
gray = cv2.cvtColor(line_image, cv2.COLOR_BGR2GRAY)
else:
gray = line_image.copy()
# ========================================================================
# IMAGE PREPROCESSING (Deskew / Rotate)
# ========================================================================
oriented_gray, M_orient = self._correct_orientation(gray)
deskewed_gray, M_skew = self._deskew_image(oriented_gray)
# Combine matrices: M_total = M_skew * M_orient
M_orient_3x3 = np.vstack([M_orient, [0, 0, 1]])
M_skew_3x3 = np.vstack([M_skew, [0, 0, 1]])
M_total_3x3 = M_skew_3x3 @ M_orient_3x3
M = M_total_3x3[0:2, :] # Extract 2x3 affine matrix
# Apply transformation to the original color image
h, w = deskewed_gray.shape
deskewed_line_image = cv2.warpAffine(
line_image,
M,
(w, h),
flags=cv2.INTER_CUBIC,
borderMode=cv2.BORDER_REPLICATE,
)
# [FIX] Create Local Line Data that matches the deskewed/rotated image dimensions.
# This prevents the fallback segmenter from using vertical dimensions on a horizontal image.
local_line_data = {
"text": line_data["text"],
"conf": line_data["conf"],
"left": [0], # Local coordinate system starts at 0
"top": [0],
"width": [w], # Use the ROTATED width
"height": [h], # Use the ROTATED height
"line": line_data.get("line", [0]),
}
if SAVE_WORD_SEGMENTER_OUTPUT_IMAGES:
os.makedirs(self.output_folder, exist_ok=True)
output_path = f"{self.output_folder}/word_segmentation/{safe_image_name}_{safe_line_number}_{safe_shortened_line_text}_deskewed.png"
cv2.imwrite(output_path, deskewed_line_image)
# ========================================================================
# MAIN SEGMENTATION PIPELINE
# ========================================================================
approx_char_count = len(line_data["text"][0].replace(" ", ""))
if approx_char_count == 0:
return {}, False
img_h, img_w = deskewed_gray.shape
estimated_char_height = img_h * 0.6
avg_char_width_approx = img_w / approx_char_count
block_size = int(avg_char_width_approx * BLOCK_SIZE_FACTOR)
if block_size % 2 == 0:
block_size += 1
if block_size < 3:
block_size = 3
# --- Binarization ---
binary_adaptive = cv2.adaptiveThreshold(
deskewed_gray,
255,
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV,
block_size,
C_VALUE,
)
otsu_thresh_val, _ = cv2.threshold(
deskewed_gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU
)
strict_thresh_val = otsu_thresh_val * 0.75
_, binary_strict = cv2.threshold(
deskewed_gray, strict_thresh_val, 255, cv2.THRESH_BINARY_INV
)
binary = cv2.bitwise_and(binary_adaptive, binary_strict)
if SAVE_WORD_SEGMENTER_OUTPUT_IMAGES:
output_path = f"{self.output_folder}/word_segmentation/{safe_image_name}_{safe_line_number}_{safe_shortened_line_text}_binary.png"
cv2.imwrite(output_path, binary)
# --- Morphological Closing ---
morph_width = max(3, int(avg_char_width_approx * 0.40))
morph_height = max(2, int(avg_char_width_approx * 0.1))
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (morph_width, morph_height))
closed_binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel, iterations=1)
# --- Noise Removal ---
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
closed_binary, 8, cv2.CV_32S
)
clean_binary = np.zeros_like(binary)
force_fallback = False
significant_labels = 0
if num_labels > 1:
# Only count components with area > 3 pixels
significant_labels = np.sum(stats[1:, cv2.CC_STAT_AREA] > 3)
if approx_char_count > 0 and significant_labels > (approx_char_count * 12):
force_fallback = True
if num_labels > 1:
areas = stats[1:, cv2.CC_STAT_AREA]
if len(areas) == 0:
clean_binary = binary
areas = np.array([0])
else:
p1 = np.percentile(areas, 1)
img_h, img_w = binary.shape
estimated_char_height = img_h * 0.7
estimated_min_letter_area = max(
2, int(estimated_char_height * 0.2 * estimated_char_height * 0.15)
)
area_threshold = max(
MIN_AREA_THRESHOLD, min(p1, estimated_min_letter_area)
)
# Gap detection logic...
sorted_areas = np.sort(areas)
area_diffs = np.diff(sorted_areas)
if len(sorted_areas) > 10 and len(area_diffs) > 0:
jump_threshold = np.percentile(area_diffs, 95)
significant_jump_thresh = max(10, jump_threshold * 3)
jump_indices = np.where(area_diffs > significant_jump_thresh)[0]
if len(jump_indices) > 0:
gap_idx = jump_indices[0]
area_before_gap = sorted_areas[gap_idx]
final_threshold = max(area_before_gap + 1, area_threshold)
final_threshold = min(final_threshold, 15)
area_threshold = final_threshold
for i in range(1, num_labels):
if stats[i, cv2.CC_STAT_AREA] >= area_threshold:
clean_binary[labels == i] = 255
else:
clean_binary = binary
# --- Vertical Cropping ---
horizontal_projection = np.sum(clean_binary, axis=1)
y_start = 0
non_zero_rows = np.where(horizontal_projection > 0)[0]
if len(non_zero_rows) > 0:
p_top = int(np.percentile(non_zero_rows, 5))
p_bottom = int(np.percentile(non_zero_rows, 95))
core_height = p_bottom - p_top
trim_pixels = int(core_height * 0.1)
y_start = max(0, p_top + trim_pixels)
y_end = min(clean_binary.shape[0], p_bottom - trim_pixels)
if y_end - y_start < 5:
y_start = p_top
y_end = p_bottom
analysis_image = clean_binary[y_start:y_end, :]
else:
analysis_image = clean_binary
if SAVE_WORD_SEGMENTER_OUTPUT_IMAGES:
output_path = f"{self.output_folder}/word_segmentation/{safe_image_name}_{safe_line_number}_{safe_shortened_line_text}_clean_binary.png"
cv2.imwrite(output_path, analysis_image)
# --- Adaptive Search ---
best_boxes = None
successful_binary_image = None
if not force_fallback:
words = line_data["text"][0].split()
target = len(words)
backup_boxes_s1 = None
# STAGE 1
for v_factor in np.arange(INITIAL_VALLEY_THRESHOLD_FACTOR, 0.60, 0.02):
curr_boxes = self._get_boxes_from_profile(
analysis_image, avg_char_width_approx, min_space_factor, v_factor
)
diff = abs(target - len(curr_boxes))
is_geom_valid = self._is_geometry_valid(
curr_boxes, words, estimated_char_height
)
if diff == 0:
if is_geom_valid:
best_boxes = curr_boxes
successful_binary_image = analysis_image
break
else:
if backup_boxes_s1 is None:
backup_boxes_s1 = curr_boxes
if diff == 1 and backup_boxes_s1 is None and is_geom_valid:
backup_boxes_s1 = curr_boxes
# STAGE 2 (if needed)
if best_boxes is None:
backup_boxes_s2 = None
for k_factor in np.arange(INITIAL_KERNEL_WIDTH_FACTOR, 0.5, 0.02):
k_w = max(1, int(avg_char_width_approx * k_factor))
s2_bin = cv2.morphologyEx(
clean_binary, cv2.MORPH_CLOSE, np.ones((1, k_w), np.uint8)
)
s2_img = (
s2_bin[y_start:y_end, :] if len(non_zero_rows) > 0 else s2_bin
)
if s2_img is None or s2_img.size == 0:
continue
curr_boxes = self._get_boxes_from_profile(
s2_img,
avg_char_width_approx,
min_space_factor,
MAIN_VALLEY_THRESHOLD_FACTOR,
)
diff = abs(target - len(curr_boxes))
is_geom_valid = self._is_geometry_valid(
curr_boxes, words, estimated_char_height
)
if diff == 0 and is_geom_valid:
best_boxes = curr_boxes
successful_binary_image = s2_bin
break
if diff == 1 and backup_boxes_s2 is None and is_geom_valid:
backup_boxes_s2 = curr_boxes
if best_boxes is None:
if backup_boxes_s1 is not None:
best_boxes = backup_boxes_s1
successful_binary_image = analysis_image
elif backup_boxes_s2 is not None:
best_boxes = backup_boxes_s2
successful_binary_image = clean_binary
final_output = None
used_fallback = False
if best_boxes is None:
# --- FALLBACK WITH ROTATED DATA ---
used_fallback = True
# [FIX] Use local_line_data (rotated dims) instead of line_data (original dims)
final_output = self.fallback_segmenter.refine_words_bidirectional(
local_line_data, deskewed_line_image
)
else:
# --- CCA Refinement ---
unlabeled_boxes = best_boxes
if successful_binary_image is analysis_image:
cca_source_image = clean_binary
else:
cca_source_image = successful_binary_image
num_labels, _, stats, _ = cv2.connectedComponentsWithStats(
cca_source_image, 8, cv2.CV_32S
)
cca_img_h, cca_img_w = cca_source_image.shape[:2]
component_assignments = {}
num_proc = min(len(words), len(unlabeled_boxes))
min_valid_component_area = estimated_char_height * 2
for j in range(1, num_labels):
comp_x = stats[j, cv2.CC_STAT_LEFT]
comp_w = stats[j, cv2.CC_STAT_WIDTH]
comp_area = stats[j, cv2.CC_STAT_AREA]
comp_r = comp_x + comp_w
comp_center_x = comp_x + comp_w / 2
comp_y = stats[j, cv2.CC_STAT_TOP]
comp_h = stats[j, cv2.CC_STAT_HEIGHT]
comp_center_y = comp_y + comp_h / 2
if comp_center_y < cca_img_h * 0.1 or comp_center_y > cca_img_h * 0.9:
continue
if comp_area < min_valid_component_area:
continue
best_box_idx = None
max_overlap = 0
best_center_distance = float("inf")
component_center_in_box = False
num_to_process = min(len(words), len(unlabeled_boxes))
# Assign components to boxes...
for i in range(
num_to_process
): # Note: ensure num_to_process is defined
box_x, box_y, box_w, box_h = unlabeled_boxes[i]
box_r = box_x + box_w
box_center_x = box_x + box_w / 2
if comp_w > box_w * 1.5:
continue
if comp_x < box_r and box_x < comp_r:
overlap_start = max(comp_x, box_x)
overlap_end = min(comp_r, box_r)
overlap = overlap_end - overlap_start
if overlap > 0:
center_in_box = box_x <= comp_center_x < box_r
center_distance = abs(comp_center_x - box_center_x)
if center_in_box:
if not component_center_in_box or overlap > max_overlap:
component_center_in_box = True
best_center_distance = center_distance
max_overlap = overlap
best_box_idx = i
elif not component_center_in_box:
if center_distance < best_center_distance or (
center_distance == best_center_distance
and overlap > max_overlap
):
best_center_distance = center_distance
max_overlap = overlap
best_box_idx = i
if best_box_idx is not None:
component_assignments[j] = best_box_idx
refined_boxes_list = []
for i in range(num_proc):
word_label = words[i]
components_in_box = [
stats[j] for j, b in component_assignments.items() if b == i
]
use_original_box = False
if not components_in_box:
use_original_box = True
else:
min_x = min(c[cv2.CC_STAT_LEFT] for c in components_in_box)
min_y = min(c[cv2.CC_STAT_TOP] for c in components_in_box)
max_r = max(
c[cv2.CC_STAT_LEFT] + c[cv2.CC_STAT_WIDTH]
for c in components_in_box
)
max_b = max(
c[cv2.CC_STAT_TOP] + c[cv2.CC_STAT_HEIGHT]
for c in components_in_box
)
cca_h = max(1, max_b - min_y)
if cca_h < (estimated_char_height * 0.35):
use_original_box = True
if use_original_box:
box_x, box_y, box_w, box_h = unlabeled_boxes[i]
adjusted_box_y = y_start + box_y
refined_boxes_list.append(
{
"text": word_label,
"left": box_x,
"top": adjusted_box_y,
"width": box_w,
"height": box_h,
"conf": line_data["conf"][0],
}
)
else:
refined_boxes_list.append(
{
"text": word_label,
"left": min_x,
"top": min_y,
"width": max(1, max_r - min_x),
"height": cca_h,
"conf": line_data["conf"][0],
}
)
# Check validity
cca_check_list = [
(b["left"], b["top"], b["width"], b["height"])
for b in refined_boxes_list
]
if not self._is_geometry_valid(
cca_check_list, words, estimated_char_height
):
if abs(len(refined_boxes_list) - len(words)) > 1:
best_boxes = None # Trigger fallback
else:
final_output = {
k: []
for k in ["text", "left", "top", "width", "height", "conf"]
}
for box in refined_boxes_list:
for key in final_output.keys():
final_output[key].append(box[key])
else:
final_output = {
k: [] for k in ["text", "left", "top", "width", "height", "conf"]
}
for box in refined_boxes_list:
for key in final_output.keys():
final_output[key].append(box[key])
# --- REPEAT FALLBACK IF VALIDATION FAILED ---
if best_boxes is None and not used_fallback:
used_fallback = True
# [FIX] Use local_line_data here too
final_output = self.fallback_segmenter.refine_words_bidirectional(
local_line_data, deskewed_line_image
)
# ========================================================================
# COORDINATE TRANSFORMATION (Map back to Original)
# ========================================================================
M_inv = cv2.invertAffineTransform(M)
remapped_boxes_list = []
for i in range(len(final_output["text"])):
left, top = final_output["left"][i], final_output["top"][i]
width, height = final_output["width"][i], final_output["height"][i]
# Map the 4 corners
corners = np.array(
[
[left, top],
[left + width, top],
[left + width, top + height],
[left, top + height],
],
dtype="float32",
)
corners_expanded = np.expand_dims(corners, axis=1)
original_corners = cv2.transform(corners_expanded, M_inv)
squeezed_corners = original_corners.squeeze(axis=1)
# Get axis aligned bounding box in original space
min_x = int(np.min(squeezed_corners[:, 0]))
max_x = int(np.max(squeezed_corners[:, 0]))
min_y = int(np.min(squeezed_corners[:, 1]))
max_y = int(np.max(squeezed_corners[:, 1]))
remapped_boxes_list.append(
{
"text": final_output["text"][i],
"left": min_x,
"top": min_y,
"width": max_x - min_x,
"height": max_y - min_y,
"conf": final_output["conf"][i],
}
)
remapped_output = {k: [] for k in final_output.keys()}
for box in remapped_boxes_list:
for key in remapped_output.keys():
remapped_output[key].append(box[key])
img_h, img_w = line_image.shape[:2]
remapped_output = self._enforce_logical_constraints(
remapped_output, img_w, img_h
)
# ========================================================================
# FINAL SAFETY NET
# ========================================================================
words = line_data["text"][0].split()
target_count = len(words)
current_count = len(remapped_output["text"])
has_collapsed_boxes = any(w < 3 for w in remapped_output["width"])
if current_count > 0:
total_text_len = sum(len(t) for t in remapped_output["text"])
total_box_width = sum(remapped_output["width"])
avg_width_pixels = total_box_width / max(1, total_text_len)
else:
avg_width_pixels = 0
is_suspiciously_thin = avg_width_pixels < 4
if current_count != target_count or is_suspiciously_thin or has_collapsed_boxes:
used_fallback = True
# [FIX] Do NOT use original line_image/line_data here.
# Use the local_line_data + deskewed_line_image pipeline,
# then transform back using M_inv (same as above).
# 1. Run fallback on rotated data
temp_local_output = self.fallback_segmenter.refine_words_bidirectional(
local_line_data, deskewed_line_image
)
# 2. If bidirectional failed to split correctly, use purely mathematical split on rotated data
if len(temp_local_output["text"]) != target_count:
h, w = deskewed_line_image.shape[:2]
temp_local_output = self.fallback_segmenter.convert_line_to_word_level(
local_line_data, w, h
)
# 3. Transform the result back to original coordinates (M_inv)
# (Repeating the transformation logic for the safety net result)
remapped_boxes_list = []
for i in range(len(temp_local_output["text"])):
left, top = temp_local_output["left"][i], temp_local_output["top"][i]
width, height = (
temp_local_output["width"][i],
temp_local_output["height"][i],
)
corners = np.array(
[
[left, top],
[left + width, top],
[left + width, top + height],
[left, top + height],
],
dtype="float32",
)
corners_expanded = np.expand_dims(corners, axis=1)
original_corners = cv2.transform(corners_expanded, M_inv)
squeezed_corners = original_corners.squeeze(axis=1)
min_x = int(np.min(squeezed_corners[:, 0]))
max_x = int(np.max(squeezed_corners[:, 0]))
min_y = int(np.min(squeezed_corners[:, 1]))
max_y = int(np.max(squeezed_corners[:, 1]))
remapped_boxes_list.append(
{
"text": temp_local_output["text"][i],
"left": min_x,
"top": min_y,
"width": max_x - min_x,
"height": max_y - min_y,
"conf": temp_local_output["conf"][i],
}
)
remapped_output = {k: [] for k in temp_local_output.keys()}
for box in remapped_boxes_list:
for key in remapped_output.keys():
remapped_output[key].append(box[key])
if SAVE_WORD_SEGMENTER_OUTPUT_IMAGES:
output_path = f"{self.output_folder}/word_segmentation/{safe_image_name}_{safe_shortened_line_text}_final_boxes.png"
os.makedirs(f"{self.output_folder}/word_segmentation", exist_ok=True)
output_image_vis = line_image.copy()
for i in range(len(remapped_output["text"])):
x, y, w, h = (
int(remapped_output["left"][i]),
int(remapped_output["top"][i]),
int(remapped_output["width"][i]),
int(remapped_output["height"][i]),
)
cv2.rectangle(output_image_vis, (x, y), (x + w, y + h), (0, 255, 0), 2)
cv2.imwrite(output_path, output_image_vis)
return remapped_output, used_fallback
class HybridWordSegmenter:
"""
Implements a two-step approach for word segmentation:
1. Proportional estimation based on text.
2. Image-based refinement with a "Bounded Scan" to prevent
over-correction.
"""
def convert_line_to_word_level(
self, line_data: Dict[str, List], image_width: int, image_height: int
) -> Dict[str, List]:
"""
Step 1: Converts line-level OCR results to word-level by using a
robust proportional estimation method.
Guarantees output box count equals input word count.
"""
output = {
"text": list(),
"left": list(),
"top": list(),
"width": list(),
"height": list(),
"conf": list(),
}
if not line_data or not line_data.get("text"):
return output
i = 0 # Assuming a single line
line_text = line_data["text"][i]
line_left = float(line_data["left"][i])
line_top = float(line_data["top"][i])
line_width = float(line_data["width"][i])
line_height = float(line_data["height"][i])
line_conf = line_data["conf"][i]
if not line_text.strip():
return output
words = line_text.split()
if not words:
return output
num_chars = len("".join(words))
num_spaces = len(words) - 1
if num_chars == 0:
return output
if (num_chars * 2 + num_spaces) > 0:
char_space_ratio = 2.0
estimated_space_width = line_width / (
num_chars * char_space_ratio + num_spaces
)
avg_char_width = estimated_space_width * char_space_ratio
else:
avg_char_width = line_width / (num_chars if num_chars > 0 else 1)
estimated_space_width = avg_char_width
# [SAFETY CHECK] Ensure we never estimate a character width of ~0
avg_char_width = max(3.0, avg_char_width)
min_word_width = max(5.0, avg_char_width * 0.5)
current_left = line_left
for word in words:
raw_word_width = len(word) * avg_char_width
# Force the box to have a legible size
word_width = max(min_word_width, raw_word_width)
clamped_left = max(0, min(current_left, image_width))
# We do NOT clamp the width against image_width here because that
# causes the "0 width" bug if current_left is at the edge.
# It is better to have a box go off-screen than be 0-width.
output["text"].append(word)
output["left"].append(clamped_left)
output["top"].append(line_top)
output["width"].append(word_width)
output["height"].append(line_height)
output["conf"].append(line_conf)
current_left += word_width + estimated_space_width
return output
def _run_single_pass(
self,
initial_boxes: List[Dict],
vertical_projection: np.ndarray,
max_scan_distance: int,
img_w: int,
direction: str = "ltr",
) -> List[Dict]:
"""
Helper function to run one pass of refinement.
IMPROVED: Uses local minima detection for cursive script where
perfect zero-gaps (white space) might not exist.
"""
refined_boxes = [box.copy() for box in initial_boxes]
if direction == "ltr":
last_corrected_right_edge = 0
indices = range(len(refined_boxes))
else: # rtl
next_corrected_left_edge = img_w
indices = range(len(refined_boxes) - 1, -1, -1)
for i in indices:
box = refined_boxes[i]
left = int(box["left"])
right = int(box["left"] + box["width"])
left = max(0, min(left, img_w - 1))
right = max(0, min(right, img_w - 1))
new_left, new_right = left, right
# --- Boundary search with improved gap detection ---
# Priority 1: True gap (zero projection)
# Priority 2: Valley with lowest ink density (thinnest connection)
if direction == "ltr" or direction == "both": # Scan right logic
if right < img_w:
scan_limit = min(img_w, right + max_scan_distance)
search_range = range(right, scan_limit)
best_x = right
min_density = float("inf")
found_zero = False
# Look for the best cut in the window
for x in search_range:
density = vertical_projection[x]
if density == 0:
new_right = x
found_zero = True
break
if density < min_density:
min_density = density
best_x = x
if not found_zero:
# No clear gap found, cut at thinnest point (minimum density)
new_right = best_x
if direction == "rtl" or direction == "both": # Scan left logic
if left > 0:
scan_limit = max(0, left - max_scan_distance)
search_range = range(left, scan_limit, -1)
best_x = left
min_density = float("inf")
found_zero = False
for x in search_range:
density = vertical_projection[x]
if density == 0:
new_left = x
found_zero = True
break
if density < min_density:
min_density = density
best_x = x
if not found_zero:
new_left = best_x
# --- Directional de-overlapping (strict stitching) ---
if direction == "ltr":
if new_left < last_corrected_right_edge:
new_left = last_corrected_right_edge
# Ensure valid width
if new_right <= new_left:
new_right = new_left + 1
last_corrected_right_edge = new_right
else: # rtl
if new_right > next_corrected_left_edge:
new_right = next_corrected_left_edge
# Ensure valid width
if new_left >= new_right:
new_left = new_right - 1
next_corrected_left_edge = new_left
box["left"] = new_left
box["width"] = max(1, new_right - new_left)
return refined_boxes
def refine_words_bidirectional(
self,
line_data: Dict[str, List],
line_image: np.ndarray,
) -> Dict[str, List]:
"""
Refines boxes using a more robust bidirectional scan and averaging.
Includes ADAPTIVE NOISE REMOVAL to filter specks based on font size.
"""
if line_image is None:
return line_data
# Early return if 1 or fewer words
if line_data and line_data.get("text"):
words = line_data["text"][0].split()
if len(words) <= 1:
img_h, img_w = line_image.shape[:2]
return self.convert_line_to_word_level(line_data, img_w, img_h)
# --- PRE-PROCESSING: Stricter Binarization ---
gray = cv2.cvtColor(line_image, cv2.COLOR_BGR2GRAY)
# 1. Calculate standard Otsu threshold first
otsu_thresh_val, _ = cv2.threshold(
gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU
)
# 2. Apply "Strictness Factor" to remove dark noise
# 0.75 means "Only keep pixels that are in the darkest 75% of what Otsu thought was foreground"
# This effectively filters out light-gray noise shadows.
strict_thresh_val = otsu_thresh_val * 0.75
_, binary = cv2.threshold(gray, strict_thresh_val, 255, cv2.THRESH_BINARY_INV)
img_h, img_w = binary.shape
# [NEW STEP 1] Morphological Opening
# Physically erodes small protrusions and dust (2x2 pixels or smaller)
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
binary_clean = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
# [NEW STEP 2] Adaptive Component Filtering
# Instead of hardcoded pixels, we filter relative to the line's text size.
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
binary_clean, 8, cv2.CV_32S
)
# Get heights of all components (excluding background)
heights = stats[1:, cv2.CC_STAT_HEIGHT]
if len(heights) > 0:
# Calculate Median Height of "significant" parts (ignore tiny noise for the median calculation)
# We assume valid text is at least 20% of the image height
significant_heights = heights[heights > img_h * 0.2]
if len(significant_heights) > 0:
median_h = np.median(significant_heights)
else:
median_h = np.median(heights)
# Define Thresholds based on Text Size
# 1. Main Threshold: Keep parts taller than 30% of median letter height
min_height_thresh = median_h * 0.30
clean_binary = np.zeros_like(binary)
for i in range(1, num_labels):
h = stats[i, cv2.CC_STAT_HEIGHT]
w = stats[i, cv2.CC_STAT_WIDTH]
area = stats[i, cv2.CC_STAT_AREA]
# Logic: Keep the component IF:
# A. It is tall enough to be a letter part (h > threshold)
# B. OR it is a "Dot" (Period / i-dot):
# - Height is small (< threshold)
# - Width is ALSO small (roughly square, prevents flat dash/scratch noise)
# - Area is reasonable (> 2px)
is_tall_enough = h > min_height_thresh
is_dot = (
(h <= min_height_thresh) and (w <= min_height_thresh) and (area > 2)
)
if is_tall_enough or is_dot:
clean_binary[labels == i] = 255
# Use the adaptively cleaned image for projection
vertical_projection = np.sum(clean_binary, axis=0)
else:
# Fallback if no components found (unlikely)
vertical_projection = np.sum(binary, axis=0)
# --- Rest of logic remains the same ---
char_blobs = []
in_blob = False
blob_start = 0
for x, col_sum in enumerate(vertical_projection):
if col_sum > 0 and not in_blob:
blob_start = x
in_blob = True
elif col_sum == 0 and in_blob:
char_blobs.append((blob_start, x))
in_blob = False
if in_blob:
char_blobs.append((blob_start, img_w))
if not char_blobs:
return self.convert_line_to_word_level(line_data, img_w, img_h)
# [PREVIOUS FIX] Bounded Scan Distance
total_chars = len("".join(words))
if total_chars > 0:
geom_avg_char_width = img_w / total_chars
else:
geom_avg_char_width = 10
blob_avg_char_width = np.mean([end - start for start, end in char_blobs])
safe_avg_char_width = min(blob_avg_char_width, geom_avg_char_width * 1.5)
max_scan_distance = int(safe_avg_char_width * 2.0)
# [PREVIOUS FIX] Safety Floor
min_safe_box_width = max(4, int(safe_avg_char_width * 0.5))
estimated_data = self.convert_line_to_word_level(line_data, img_w, img_h)
if not estimated_data["text"]:
return estimated_data
initial_boxes = []
for i in range(len(estimated_data["text"])):
initial_boxes.append(
{
"text": estimated_data["text"][i],
"left": estimated_data["left"][i],
"top": estimated_data["top"][i],
"width": estimated_data["width"][i],
"height": estimated_data["height"][i],
"conf": estimated_data["conf"][i],
}
)
# --- STEP 1 & 2: Perform bidirectional refinement passes ---
ltr_boxes = self._run_single_pass(
initial_boxes, vertical_projection, max_scan_distance, img_w, "ltr"
)
rtl_boxes = self._run_single_pass(
initial_boxes, vertical_projection, max_scan_distance, img_w, "rtl"
)
# --- STEP 3: Combine results using best edge from each pass ---
combined_boxes = [box.copy() for box in initial_boxes]
for i in range(len(combined_boxes)):
final_left = ltr_boxes[i]["left"]
rtl_right = rtl_boxes[i]["left"] + rtl_boxes[i]["width"]
combined_boxes[i]["left"] = final_left
combined_boxes[i]["width"] = max(min_safe_box_width, rtl_right - final_left)
# --- STEP 4: Contiguous stitching to eliminate gaps ---
for i in range(len(combined_boxes) - 1):
if combined_boxes[i + 1]["left"] <= combined_boxes[i]["left"]:
combined_boxes[i + 1]["left"] = (
combined_boxes[i]["left"] + min_safe_box_width
)
for i in range(len(combined_boxes) - 1):
curr = combined_boxes[i]
nxt = combined_boxes[i + 1]
gap_width = nxt["left"] - curr["left"]
curr["width"] = max(min_safe_box_width, gap_width)
# Convert back to output dict
final_output = {k: [] for k in estimated_data.keys()}
for box in combined_boxes:
if box["width"] >= min_safe_box_width:
for key in final_output.keys():
final_output[key].append(box[key])
return final_output