|
|
""" |
|
|
HuggingFace Sentiment Provider - AI-powered text analysis |
|
|
|
|
|
Provides: |
|
|
- Sentiment analysis using transformer models |
|
|
- Text summarization |
|
|
- Named entity recognition |
|
|
- Zero-shot classification |
|
|
|
|
|
Uses HuggingFace Inference API for model inference. |
|
|
API Documentation: https://huggingface.co/docs/api-inference/ |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
from typing import Any, Dict, List, Optional |
|
|
|
|
|
import os |
|
|
|
|
|
from .base import BaseProvider, create_success_response, create_error_response |
|
|
|
|
|
|
|
|
class HFSentimentProvider(BaseProvider): |
|
|
"""HuggingFace Inference API provider for AI-powered analysis""" |
|
|
|
|
|
DEFAULT_API_KEY = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") or "" |
|
|
|
|
|
|
|
|
MODELS = { |
|
|
"sentiment": "distilbert-base-uncased-finetuned-sst-2-english", |
|
|
"sentiment_financial": "mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis", |
|
|
"summarization": "sshleifer/distilbart-cnn-12-6", |
|
|
"ner": "dslim/bert-base-NER", |
|
|
"classification": "facebook/bart-large-mnli", |
|
|
"text_generation": "gpt2" |
|
|
} |
|
|
|
|
|
def __init__(self, api_key: Optional[str] = None): |
|
|
super().__init__( |
|
|
name="huggingface", |
|
|
base_url="https://router.huggingface.co/hf-inference/models", |
|
|
api_key=api_key if api_key is not None else self.DEFAULT_API_KEY, |
|
|
timeout=15.0, |
|
|
cache_ttl=60.0 |
|
|
) |
|
|
|
|
|
def _get_default_headers(self) -> Dict[str, str]: |
|
|
"""Get headers with HuggingFace authorization""" |
|
|
return { |
|
|
"Accept": "application/json", |
|
|
"Content-Type": "application/json", |
|
|
"Authorization": f"Bearer {self.api_key}" |
|
|
} |
|
|
|
|
|
async def analyze_sentiment( |
|
|
self, |
|
|
text: str, |
|
|
model: Optional[str] = None, |
|
|
use_financial_model: bool = False |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Analyze sentiment of text using HuggingFace models. |
|
|
|
|
|
Args: |
|
|
text: Text to analyze |
|
|
model: Custom model to use (optional) |
|
|
use_financial_model: Use FinBERT for financial text |
|
|
|
|
|
Returns: |
|
|
Standardized response with sentiment analysis |
|
|
""" |
|
|
if not text or len(text.strip()) < 3: |
|
|
return create_error_response( |
|
|
self.name, |
|
|
"Invalid text", |
|
|
"Text must be at least 3 characters" |
|
|
) |
|
|
|
|
|
|
|
|
text = text[:1000] |
|
|
|
|
|
|
|
|
if model: |
|
|
model_id = model |
|
|
elif use_financial_model: |
|
|
model_id = self.MODELS["sentiment_financial"] |
|
|
else: |
|
|
model_id = self.MODELS["sentiment"] |
|
|
|
|
|
|
|
|
endpoint = f"{model_id}" |
|
|
|
|
|
response = await self.post(endpoint, json_data={"inputs": text}) |
|
|
|
|
|
if not response.get("success"): |
|
|
return response |
|
|
|
|
|
data = response.get("data", []) |
|
|
|
|
|
|
|
|
if isinstance(data, dict) and data.get("error"): |
|
|
error_msg = data.get("error", "Model error") |
|
|
if "loading" in error_msg.lower(): |
|
|
return create_error_response( |
|
|
self.name, |
|
|
"Model is loading", |
|
|
"Please retry in a few seconds" |
|
|
) |
|
|
return create_error_response(self.name, error_msg) |
|
|
|
|
|
|
|
|
results = self._parse_sentiment_results(data, model_id) |
|
|
|
|
|
return create_success_response( |
|
|
self.name, |
|
|
{ |
|
|
"text": text[:100] + "..." if len(text) > 100 else text, |
|
|
"model": model_id, |
|
|
"sentiment": results |
|
|
} |
|
|
) |
|
|
|
|
|
def _parse_sentiment_results(self, data: Any, model_id: str) -> Dict[str, Any]: |
|
|
"""Parse sentiment results from different model formats""" |
|
|
if not data: |
|
|
return {"label": "unknown", "score": 0.0} |
|
|
|
|
|
|
|
|
if isinstance(data, list) and len(data) > 0: |
|
|
if isinstance(data[0], list): |
|
|
data = data[0] |
|
|
|
|
|
|
|
|
best = max(data, key=lambda x: x.get("score", 0)) |
|
|
|
|
|
|
|
|
label = best.get("label", "unknown").lower() |
|
|
score = best.get("score", 0.0) |
|
|
|
|
|
|
|
|
label_map = { |
|
|
"label_0": "negative", |
|
|
"label_1": "neutral", |
|
|
"label_2": "positive", |
|
|
"negative": "negative", |
|
|
"neutral": "neutral", |
|
|
"positive": "positive", |
|
|
"pos": "positive", |
|
|
"neg": "negative", |
|
|
"neu": "neutral" |
|
|
} |
|
|
|
|
|
normalized_label = label_map.get(label, label) |
|
|
|
|
|
return { |
|
|
"label": normalized_label, |
|
|
"score": round(score, 4), |
|
|
"allScores": [ |
|
|
{"label": item.get("label"), "score": round(item.get("score", 0), 4)} |
|
|
for item in data |
|
|
] |
|
|
} |
|
|
|
|
|
return {"label": "unknown", "score": 0.0} |
|
|
|
|
|
async def summarize_text( |
|
|
self, |
|
|
text: str, |
|
|
max_length: int = 150, |
|
|
min_length: int = 30, |
|
|
model: Optional[str] = None |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Summarize text using HuggingFace summarization model. |
|
|
|
|
|
Args: |
|
|
text: Text to summarize |
|
|
max_length: Maximum summary length |
|
|
min_length: Minimum summary length |
|
|
model: Custom model to use |
|
|
""" |
|
|
if not text or len(text.strip()) < 50: |
|
|
return create_error_response( |
|
|
self.name, |
|
|
"Text too short", |
|
|
"Text must be at least 50 characters for summarization" |
|
|
) |
|
|
|
|
|
|
|
|
text = text[:3000] |
|
|
|
|
|
model_id = model or self.MODELS["summarization"] |
|
|
|
|
|
payload = { |
|
|
"inputs": text, |
|
|
"parameters": { |
|
|
"max_length": max_length, |
|
|
"min_length": min_length, |
|
|
"do_sample": False |
|
|
} |
|
|
} |
|
|
|
|
|
response = await self.post(model_id, json_data=payload) |
|
|
|
|
|
if not response.get("success"): |
|
|
return response |
|
|
|
|
|
data = response.get("data", []) |
|
|
|
|
|
|
|
|
if isinstance(data, dict) and data.get("error"): |
|
|
error_msg = data.get("error", "Model error") |
|
|
if "loading" in error_msg.lower(): |
|
|
return create_error_response( |
|
|
self.name, |
|
|
"Model is loading", |
|
|
"Please retry in a few seconds" |
|
|
) |
|
|
return create_error_response(self.name, error_msg) |
|
|
|
|
|
|
|
|
summary = "" |
|
|
if isinstance(data, list) and len(data) > 0: |
|
|
summary = data[0].get("summary_text", "") |
|
|
elif isinstance(data, dict): |
|
|
summary = data.get("summary_text", "") |
|
|
|
|
|
return create_success_response( |
|
|
self.name, |
|
|
{ |
|
|
"originalLength": len(text), |
|
|
"summaryLength": len(summary), |
|
|
"model": model_id, |
|
|
"summary": summary |
|
|
} |
|
|
) |
|
|
|
|
|
async def extract_entities( |
|
|
self, |
|
|
text: str, |
|
|
model: Optional[str] = None |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Extract named entities from text. |
|
|
|
|
|
Args: |
|
|
text: Text to analyze |
|
|
model: Custom NER model to use |
|
|
""" |
|
|
if not text or len(text.strip()) < 3: |
|
|
return create_error_response( |
|
|
self.name, |
|
|
"Invalid text", |
|
|
"Text must be at least 3 characters" |
|
|
) |
|
|
|
|
|
text = text[:1000] |
|
|
model_id = model or self.MODELS["ner"] |
|
|
|
|
|
response = await self.post(model_id, json_data={"inputs": text}) |
|
|
|
|
|
if not response.get("success"): |
|
|
return response |
|
|
|
|
|
data = response.get("data", []) |
|
|
|
|
|
if isinstance(data, dict) and data.get("error"): |
|
|
error_msg = data.get("error", "Model error") |
|
|
if "loading" in error_msg.lower(): |
|
|
return create_error_response( |
|
|
self.name, |
|
|
"Model is loading", |
|
|
"Please retry in a few seconds" |
|
|
) |
|
|
return create_error_response(self.name, error_msg) |
|
|
|
|
|
|
|
|
entities = [] |
|
|
if isinstance(data, list): |
|
|
for entity in data: |
|
|
entities.append({ |
|
|
"word": entity.get("word"), |
|
|
"entity": entity.get("entity_group") or entity.get("entity"), |
|
|
"score": round(entity.get("score", 0), 4), |
|
|
"start": entity.get("start"), |
|
|
"end": entity.get("end") |
|
|
}) |
|
|
|
|
|
return create_success_response( |
|
|
self.name, |
|
|
{ |
|
|
"text": text[:100] + "..." if len(text) > 100 else text, |
|
|
"model": model_id, |
|
|
"entities": entities, |
|
|
"count": len(entities) |
|
|
} |
|
|
) |
|
|
|
|
|
async def classify_text( |
|
|
self, |
|
|
text: str, |
|
|
candidate_labels: List[str], |
|
|
model: Optional[str] = None |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Zero-shot text classification. |
|
|
|
|
|
Args: |
|
|
text: Text to classify |
|
|
candidate_labels: List of possible labels |
|
|
model: Custom classification model |
|
|
""" |
|
|
if not text or len(text.strip()) < 3: |
|
|
return create_error_response( |
|
|
self.name, |
|
|
"Invalid text", |
|
|
"Text must be at least 3 characters" |
|
|
) |
|
|
|
|
|
if not candidate_labels or len(candidate_labels) < 2: |
|
|
return create_error_response( |
|
|
self.name, |
|
|
"Invalid labels", |
|
|
"At least 2 candidate labels required" |
|
|
) |
|
|
|
|
|
text = text[:500] |
|
|
model_id = model or self.MODELS["classification"] |
|
|
|
|
|
payload = { |
|
|
"inputs": text, |
|
|
"parameters": { |
|
|
"candidate_labels": candidate_labels[:10] |
|
|
} |
|
|
} |
|
|
|
|
|
response = await self.post(model_id, json_data=payload) |
|
|
|
|
|
if not response.get("success"): |
|
|
return response |
|
|
|
|
|
data = response.get("data", {}) |
|
|
|
|
|
if isinstance(data, dict) and data.get("error"): |
|
|
error_msg = data.get("error", "Model error") |
|
|
if "loading" in error_msg.lower(): |
|
|
return create_error_response( |
|
|
self.name, |
|
|
"Model is loading", |
|
|
"Please retry in a few seconds" |
|
|
) |
|
|
return create_error_response(self.name, error_msg) |
|
|
|
|
|
|
|
|
labels = data.get("labels", []) |
|
|
scores = data.get("scores", []) |
|
|
|
|
|
classifications = [] |
|
|
for label, score in zip(labels, scores): |
|
|
classifications.append({ |
|
|
"label": label, |
|
|
"score": round(score, 4) |
|
|
}) |
|
|
|
|
|
return create_success_response( |
|
|
self.name, |
|
|
{ |
|
|
"text": text[:100] + "..." if len(text) > 100 else text, |
|
|
"model": model_id, |
|
|
"classifications": classifications, |
|
|
"bestLabel": labels[0] if labels else None, |
|
|
"bestScore": round(scores[0], 4) if scores else 0.0 |
|
|
} |
|
|
) |
|
|
|
|
|
async def get_available_models(self) -> Dict[str, Any]: |
|
|
"""Get list of available models for each task""" |
|
|
return create_success_response( |
|
|
self.name, |
|
|
{ |
|
|
"models": self.MODELS, |
|
|
"tasks": list(self.MODELS.keys()) |
|
|
} |
|
|
) |
|
|
|