Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import re | |
| from typing import List, Tuple | |
| # Lazy imports for heavy deps so the Space boots faster | |
| from functools import lru_cache | |
| def _lazy_imports(): | |
| global datasets, pipeline, WordCloud, plt | |
| import matplotlib.pyplot as plt # noqa: F401 | |
| from datasets import load_dataset # noqa: F401 | |
| from transformers import pipeline as hf_pipeline # noqa: F401 | |
| try: | |
| from wordcloud import WordCloud # noqa: F401 | |
| except Exception: | |
| WordCloud = None | |
| return locals() | |
| # ---------------------------- | |
| # Helpers | |
| # ---------------------------- | |
| TARIFF_KEYWORDS_DEFAULT = [ | |
| "tariff", "tariffs", "import tax", "trade war", "section 301", "section301", | |
| "customs duty", "custom duties", "duties", "anti-dumping", "countervailing", | |
| "steel tariff", "aluminum tariff", "aluminium tariff", "US tariff", "U.S. tariff", | |
| "tariff policy", "retaliatory tariff", "tariff hike", "tariff cut" | |
| ] | |
| KEYWORD_PATTERN_CACHE = {} | |
| def compile_keyword_pattern(keywords: List[str]) -> re.Pattern: | |
| key = "\u0001".join(sorted([k.strip().lower() for k in keywords if k.strip()])) | |
| if key in KEYWORD_PATTERN_CACHE: | |
| return KEYWORD_PATTERN_CACHE[key] | |
| escaped = [re.escape(k) for k in keywords if k.strip()] | |
| pattern = re.compile(r"(" + r"|".join(escaped) + r")", flags=re.IGNORECASE) | |
| KEYWORD_PATTERN_CACHE[key] = pattern | |
| return pattern | |
| def normalize_text(s: str) -> str: | |
| s = re.sub(r"https?://\S+", " ", s) # drop urls | |
| s = re.sub(r"@[A-Za-z0-9_]+", " ", s) # drop @mentions | |
| s = re.sub(r"#[A-Za-z0-9_]+", " ", s) # drop hashtags (we'll match keywords separately) | |
| s = re.sub(r"\s+", " ", s).strip() | |
| return s | |
| def load_sentiment_pipeline(model_name: str = "cardiffnlp/twitter-roberta-base-sentiment-latest"): | |
| _ = _lazy_imports() | |
| from transformers import pipeline as hf_pipeline | |
| pipe = hf_pipeline( | |
| task="sentiment-analysis", | |
| model=model_name, | |
| tokenizer=model_name, | |
| truncation=True, | |
| max_length=256, | |
| return_all_scores=False, | |
| device=-1, | |
| ) | |
| return pipe | |
| def load_hf_dataset(name: str): | |
| _ = _lazy_imports() | |
| from datasets import load_dataset | |
| if name == "sentiment140": | |
| # 1.6M tweets; we'll stream and sample later | |
| ds = load_dataset("sentiment140", trust_remote_code=True) | |
| # columns: ['sentiment','ids','date','query','user','text'] | |
| return ds | |
| elif name == "tweet_eval": | |
| # We'll use the sentiment subset | |
| ds = load_dataset("tweet_eval", "sentiment") | |
| # columns: ['text','label'] where label in {0:negative,1:neutral,2:positive} | |
| return ds | |
| else: | |
| raise ValueError("Unsupported dataset: " + name) | |
| def filter_and_sample(df: pd.DataFrame, keywords: List[str], sample_size: int, random_state: int = 42) -> pd.DataFrame: | |
| pat = compile_keyword_pattern(keywords) | |
| mask = df['text'].str.contains(pat, na=False) | |
| subset = df.loc[mask].copy() | |
| if subset.empty: | |
| return subset | |
| if sample_size > 0 and len(subset) > sample_size: | |
| subset = subset.sample(n=sample_size, random_state=random_state) | |
| return subset | |
| def run_inference(texts: List[str], batch_size: int = 64) -> List[dict]: | |
| pipe = load_sentiment_pipeline() | |
| results = [] | |
| for i in range(0, len(texts), batch_size): | |
| batch = texts[i:i+batch_size] | |
| out = pipe(batch) | |
| # normalize labels to {positive, neutral, negative} | |
| for o in out: | |
| lab = o.get('label', '').lower() | |
| if 'pos' in lab: | |
| label = 'positive' | |
| elif 'neg' in lab: | |
| label = 'negative' | |
| else: | |
| label = 'neutral' | |
| results.append({'label': label, 'score': float(o.get('score', 0.0))}) | |
| return results | |
| def make_bar_plot(counts: pd.Series): | |
| import matplotlib.pyplot as plt | |
| fig = plt.figure(figsize=(5, 3.2), dpi=140) | |
| ax = fig.gca() | |
| counts = counts.reindex(['negative', 'neutral', 'positive']).fillna(0) | |
| ax.bar(counts.index, counts.values) | |
| total = int(counts.sum()) | |
| ax.set_title(f"Sentiment distribution (n={total})") | |
| ax.set_xlabel("Sentiment") | |
| ax.set_ylabel("# Tweets") | |
| fig.tight_layout() | |
| return fig | |
| def make_wordcloud(texts: List[str]): | |
| # Optional; will return None if wordcloud isn't available | |
| try: | |
| from wordcloud import WordCloud | |
| except Exception: | |
| return None | |
| joined = " ".join(texts) | |
| wc = WordCloud(width=800, height=320, background_color="white").generate(joined) | |
| import matplotlib.pyplot as plt | |
| fig = plt.figure(figsize=(8, 3.6), dpi=120) | |
| plt.imshow(wc) | |
| plt.axis("off") | |
| fig.tight_layout() | |
| return fig | |
| # ---------------------------- | |
| # Core pipeline | |
| # ---------------------------- | |
| def analyze(dataset_choice: str, | |
| keywords_csv: str, | |
| max_rows: int, | |
| include_wordcloud: bool) -> Tuple[str, "matplotlib.figure.Figure", "matplotlib.figure.Figure", pd.DataFrame]: | |
| """Return (summary_markdown, bar_fig, wordcloud_fig|None, table_df)""" | |
| ds = load_hf_dataset(dataset_choice) | |
| # Convert to pandas | |
| if dataset_choice == "sentiment140": | |
| # concatenate a manageable slice from train/test (to keep runtime reasonable) | |
| train = ds.get('train') | |
| test = ds.get('test') | |
| frames = [] | |
| for split in [train, test]: | |
| if split is None: | |
| continue | |
| # Take a small random slice to keep Space responsive | |
| n = len(split) | |
| take = min(n, 150_000) # cap | |
| frames.append(split.shuffle(seed=42).select(range(take)).to_pandas()[['text', 'date']]) | |
| df = pd.concat(frames, ignore_index=True) | |
| else: | |
| # tweet_eval sentiment | |
| frames = [] | |
| for name in ['train', 'validation', 'test']: | |
| if name in ds: | |
| frames.append(ds[name].to_pandas()[['text']]) | |
| df = pd.concat(frames, ignore_index=True) | |
| if 'date' not in df.columns: | |
| df['date'] = np.nan | |
| # Clean | |
| df['text'] = df['text'].astype(str).apply(normalize_text) | |
| # Keywords | |
| keywords = [k.strip() for k in (keywords_csv or "").split(',') if k.strip()] or TARIFF_KEYWORDS_DEFAULT | |
| # Filter + sample | |
| subset = filter_and_sample(df, keywords, sample_size=max_rows) | |
| if subset.empty: | |
| return ( | |
| "### No matches found\nTry broadening keywords or increasing the sample size.", | |
| make_bar_plot(pd.Series(dtype=int)), | |
| None, | |
| pd.DataFrame(columns=['text','pred_label','pred_score','date']) | |
| ) | |
| # Inference | |
| preds = run_inference(subset['text'].tolist()) | |
| pred_df = pd.DataFrame(preds) | |
| subset = subset.reset_index(drop=True).copy() | |
| subset['pred_label'] = pred_df['label'] | |
| subset['pred_score'] = pred_df['score'] | |
| # Metrics | |
| counts = subset['pred_label'].value_counts() | |
| total = int(counts.sum()) | |
| pct = (counts / max(total, 1) * 100).round(1) | |
| # Summary text | |
| sentiment_line = ( | |
| f"**Negative:** {int(counts.get('negative', 0))} ({pct.get('negative', 0.0)}%) | " | |
| f"**Neutral:** {int(counts.get('neutral', 0))} ({pct.get('neutral', 0.0)}%) | " | |
| f"**Positive:** {int(counts.get('positive', 0))} ({pct.get('positive', 0.0)}%)" | |
| ) | |
| summary = ( | |
| "## Tariff Tweet Sentiment — Snapshot\n" | |
| f"Dataset: **{dataset_choice}** | Sampled tweets: **{total}**\n\n" | |
| f"Keyword filter: `{', '.join(keywords)}`\n\n" | |
| + sentiment_line + | |
| "\n\nTip: Neutral can be high when tweets are mostly informative (news/links) or ambiguous." | |
| ) | |
| # Plots | |
| bar_fig = make_bar_plot(counts) | |
| wc_fig = make_wordcloud(subset['text'].tolist()) if include_wordcloud else None | |
| # Output table (limit rows for UI responsiveness) | |
| out_df = subset[['text','pred_label','pred_score','date']] | |
| return summary, bar_fig, wc_fig, out_df | |
| # ---------------------------- | |
| # Gradio UI | |
| # ---------------------------- | |
| with gr.Blocks(title="Tariff Tweet Sentiment (No Twitter API)") as demo: | |
| gr.Markdown( | |
| """ | |
| # Tariff Tweet Sentiment | |
| Analyze how people talk about **U.S. tariff policy** using public Twitter corpora (no API key required). | |
| **How it works** | |
| - Choose a public dataset (e.g., `sentiment140` or `tweet_eval/sentiment`). | |
| - Filter tweets by keywords like *tariff*, *trade war*, *Section 301*, etc. | |
| - Run a Twitter-optimized sentiment model. | |
| - View distribution, word cloud, and the matching tweets. | |
| *Note:* Public corpora may skew older or topical; results are a **snapshot**, not a live feed. | |
| """ | |
| ) | |
| with gr.Row(): | |
| dataset_choice = gr.Dropdown( | |
| choices=["sentiment140", "tweet_eval"], | |
| value="sentiment140", | |
| label="Dataset" | |
| ) | |
| max_rows = gr.Slider(100, 5000, value=1500, step=50, label="Max tweets to analyze (after keyword filter)") | |
| keywords_csv = gr.Textbox(value=", ".join(TARIFF_KEYWORDS_DEFAULT), label="Keywords (comma‑separated)") | |
| include_wordcloud = gr.Checkbox(value=True, label="Include word cloud (optional)") | |
| run_btn = gr.Button("Run Analysis", variant="primary") | |
| summary_md = gr.Markdown() | |
| bar_plot = gr.Plot(label="Sentiment distribution") | |
| wc_plot = gr.Plot(label="Word cloud (optional)") | |
| table = gr.Dataframe(headers=["text","pred_label","pred_score","date"], wrap=True, interactive=False) | |
| csv = gr.File(label="Download CSV of results", visible=True) | |
| def _go(dataset_choice, keywords_csv, max_rows, include_wordcloud): | |
| summary, bar_fig, wc_fig, df = analyze(dataset_choice, keywords_csv, int(max_rows), bool(include_wordcloud)) | |
| # Save CSV | |
| out_path = "tariff_tweets_sentiment.csv" | |
| df.to_csv(out_path, index=False) | |
| return summary, bar_fig, wc_fig, df, out_path | |
| run_btn.click(_go, [dataset_choice, keywords_csv, max_rows, include_wordcloud], [summary_md, bar_plot, wc_plot, table, csv]) | |
| if __name__ == "__main__": | |
| demo.launch() | |