gliner2-multi-v1-onnx / example.py
AleksanderObuchowski's picture
Upload folder using huggingface_hub
4f0c01d verified
#!/usr/bin/env python3
"""
GLiNER2 ONNX inference example.
Requirements:
pip install onnxruntime numpy tokenizers
Usage:
python example.py
"""
import re
import numpy as np
import onnxruntime as ort
from tokenizers import Tokenizer
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
MODEL_PATH = "model.onnx"
TOKENIZER_PATH = "tokenizer.json"
THRESHOLD = 0.5
MAX_WIDTH = 8
SEP_TEXT_ID = 250103
TEXT = "Steve Jobs founded Apple Inc. in Cupertino, California on April 1, 1976."
LABELS = ["person", "company", "city", "date"]
# ---------------------------------------------------------------------------
# Word splitter (WhitespaceTokenSplitter)
# ---------------------------------------------------------------------------
WORD_RE = re.compile(
r"(?:https?://[^\s]+|www\.[^\s]+)"
r"|[a-z0-9._%+-]+@[a-z0-9.-]+\.[a-z]{2,}"
r"|@[a-z0-9_]+"
r"|\w+(?:[-_]\w+)*"
r"|\S",
re.IGNORECASE,
)
def split_words(text: str) -> list[tuple[str, int, int]]:
"""Return list of (word, char_start, char_end). Words are lowercased."""
lowered = text.lower()
return [(m.group(), m.start(), m.end()) for m in WORD_RE.finditer(lowered)]
# ---------------------------------------------------------------------------
# Build model inputs
# ---------------------------------------------------------------------------
def build_inputs(
tokenizer: Tokenizer,
text: str,
labels: list[str],
) -> tuple[dict[str, np.ndarray], list[tuple[str, int, int]]]:
"""Construct all five ONNX input tensors and return (feeds, words)."""
words = split_words(text)
word_strings = [w for w, _, _ in words]
# -- Schema tokens --
# Format: ( [P] entities ( [E] label1 [E] label2 ... ) ) [SEP_TEXT] word1 word2 ...
schema_tokens = ["(", "[P]", "entities", "("]
for label in labels:
schema_tokens.append("[E]")
schema_tokens.extend(label.split())
schema_tokens.append(")")
schema_tokens.append(")")
# Full pre-tokenized sequence: schema + [SEP_TEXT] + words
full_sequence = schema_tokens + ["[SEP_TEXT]"] + word_strings
# Tokenize
encoding = tokenizer.encode(full_sequence, is_pretokenized=True, add_special_tokens=False)
token_ids = encoding.ids
word_ids = encoding.word_ids # maps each token to its word in full_sequence
num_schema_words = len(schema_tokens) + 1 # +1 for [SEP_TEXT]
# -- input_ids and attention_mask --
seq_len = len(token_ids)
input_ids = np.array([token_ids], dtype=np.int64)
attention_mask = np.ones((1, seq_len), dtype=np.int64)
# -- text_positions: first token index for each text word --
text_positions = []
for word_idx in range(len(word_strings)):
full_word_idx = num_schema_words + word_idx
first_token = None
for tok_pos, wid in enumerate(word_ids):
if wid == full_word_idx:
first_token = tok_pos
break
assert first_token is not None, f"Word {word_idx} ('{word_strings[word_idx]}') not found in token mapping"
text_positions.append(first_token)
text_positions = np.array(text_positions, dtype=np.int64)
# -- schema_positions: [P] position, then each [E] position --
schema_positions = []
for i, tok in enumerate(schema_tokens):
if tok == "[P]" or tok == "[E]":
# Find the first token of this schema word
for tok_pos, wid in enumerate(word_ids):
if wid == i:
schema_positions.append(tok_pos)
break
schema_positions = np.array(schema_positions, dtype=np.int64)
# -- span_idx: all (start, end) pairs with end - start <= max_width --
num_words = len(word_strings)
spans = []
for start in range(num_words):
for width in range(1, MAX_WIDTH + 1):
end = start + width
if end <= num_words:
spans.append((start, end - 1))
else:
spans.append((0, 0)) # padding
span_idx = np.array(spans, dtype=np.int64).reshape(1, -1, 2)
feeds = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"text_positions": text_positions,
"schema_positions": schema_positions,
"span_idx": span_idx,
}
return feeds, words
# ---------------------------------------------------------------------------
# Post-processing
# ---------------------------------------------------------------------------
def decode_entities(
span_scores: np.ndarray,
words: list[tuple[str, int, int]],
labels: list[str],
text: str,
threshold: float = THRESHOLD,
) -> list[dict]:
"""Extract entities from span_scores above threshold."""
# span_scores shape: (1, num_fields, num_words, max_width)
scores = span_scores[0] # (num_fields, num_words, max_width)
entities = []
for field_idx, label in enumerate(labels):
num_words = scores.shape[1]
for start in range(num_words):
for width_idx in range(scores.shape[2]):
score = scores[field_idx, start, width_idx]
if score >= threshold:
end = start + width_idx # inclusive word index
if end >= len(words):
continue
char_start = words[start][1]
char_end = words[end][2]
entity_text = text[char_start:char_end]
entities.append({
"label": label,
"text": entity_text,
"start": char_start,
"end": char_end,
"score": float(score),
})
# Sort by position
entities.sort(key=lambda e: (e["start"], e["end"]))
return entities
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
# Load tokenizer and model
tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
session = ort.InferenceSession(MODEL_PATH)
# Build inputs
feeds, words = build_inputs(tokenizer, TEXT, LABELS)
# Run inference
outputs = session.run(None, feeds)
span_scores = outputs[0]
# Extract entities
entities = decode_entities(span_scores, words, LABELS, TEXT)
# Print results
print(f"Text: {TEXT}")
print(f"Labels: {LABELS}")
print(f"Threshold: {THRESHOLD}")
print()
if entities:
for ent in entities:
print(f" {ent['label']:>10} {ent['score']:.3f} [{ent['start']:3d}:{ent['end']:3d}] {ent['text']}")
else:
print(" (no entities found)")
if __name__ == "__main__":
main()