Spaces:
Running
Running
File size: 4,083 Bytes
590a604 ee1a8a3 1fbc47b 7f15aed 1fbc47b 7f15aed 1fbc47b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
"""
Inference script for the LexiMind multitask model.
Command-line interface for running summarization, emotion detection, and topic
classification on arbitrary text inputs.
Author: Oliver Perrin
Date: December 2025
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import List, cast
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from src.data.tokenization import TokenizerConfig
from src.inference import EmotionPrediction, TopicPrediction, create_inference_pipeline
def _load_texts(positional: List[str], file_path: Path | None) -> List[str]:
texts = [text for text in positional if text]
if file_path is not None:
if not file_path.exists():
raise FileNotFoundError(file_path)
with file_path.open("r", encoding="utf-8") as handle:
texts.extend([line.strip() for line in handle if line.strip()])
if not texts:
raise ValueError("No input texts provided. Pass text arguments or use --file.")
return texts
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run LexiMind multitask inference.")
parser.add_argument("text", nargs="*", help="Input text(s) to analyse.")
parser.add_argument("--file", type=Path, help="Path to a file containing one text per line.")
parser.add_argument(
"--checkpoint",
type=Path,
default=Path("checkpoints/best.pt"),
help="Path to the model checkpoint produced during training.",
)
parser.add_argument(
"--labels",
type=Path,
default=Path("artifacts/labels.json"),
help="JSON file containing emotion/topic label vocabularies.",
)
parser.add_argument(
"--tokenizer",
type=Path,
default=None,
help="Optional path to a tokenizer directory exported during training.",
)
parser.add_argument(
"--model-config",
type=Path,
default=Path("configs/model/base.yaml"),
help="Model architecture config used to rebuild the transformer stack.",
)
parser.add_argument("--device", default="cpu", help="Device to run inference on (cpu or cuda).")
parser.add_argument(
"--summary-max-length",
type=int,
default=None,
help="Optional maximum length for generated summaries.",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
texts = _load_texts(args.text, args.file)
tokenizer_config = None
if args.tokenizer is not None:
tokenizer_config = TokenizerConfig(pretrained_model_name=str(args.tokenizer))
else:
local_dir = Path("artifacts/hf_tokenizer")
if local_dir.exists():
tokenizer_config = TokenizerConfig(pretrained_model_name=str(local_dir))
pipeline, _ = create_inference_pipeline(
checkpoint_path=args.checkpoint,
labels_path=args.labels,
tokenizer_config=tokenizer_config,
model_config_path=args.model_config,
device=args.device,
summary_max_length=args.summary_max_length,
)
results = pipeline.batch_predict(texts)
summaries = cast(List[str], results["summaries"])
emotion_preds = cast(List[EmotionPrediction], results["emotion"])
topic_preds = cast(List[TopicPrediction], results["topic"])
packaged = []
for idx, text in enumerate(texts):
emotion = emotion_preds[idx]
topic = topic_preds[idx]
packaged.append(
{
"text": text,
"summary": summaries[idx],
"emotion": {
"labels": emotion.labels,
"scores": emotion.scores,
},
"topic": {
"label": topic.label,
"confidence": topic.confidence,
},
}
)
print(json.dumps(packaged, indent=2, ensure_ascii=False))
if __name__ == "__main__":
main()
|