rajendrr's picture
Update app.py
5d2a5bf verified
import gradio as gr
import pandas as pd
import numpy as np
import re
from typing import List, Tuple
# Lazy imports for heavy deps so the Space boots faster
from functools import lru_cache
def _lazy_imports():
global datasets, pipeline, WordCloud, plt
import matplotlib.pyplot as plt # noqa: F401
from datasets import load_dataset # noqa: F401
from transformers import pipeline as hf_pipeline # noqa: F401
try:
from wordcloud import WordCloud # noqa: F401
except Exception:
WordCloud = None
return locals()
# ----------------------------
# Helpers
# ----------------------------
TARIFF_KEYWORDS_DEFAULT = [
"tariff", "tariffs", "import tax", "trade war", "section 301", "section301",
"customs duty", "custom duties", "duties", "anti-dumping", "countervailing",
"steel tariff", "aluminum tariff", "aluminium tariff", "US tariff", "U.S. tariff",
"tariff policy", "retaliatory tariff", "tariff hike", "tariff cut"
]
KEYWORD_PATTERN_CACHE = {}
def compile_keyword_pattern(keywords: List[str]) -> re.Pattern:
key = "\u0001".join(sorted([k.strip().lower() for k in keywords if k.strip()]))
if key in KEYWORD_PATTERN_CACHE:
return KEYWORD_PATTERN_CACHE[key]
escaped = [re.escape(k) for k in keywords if k.strip()]
pattern = re.compile(r"(" + r"|".join(escaped) + r")", flags=re.IGNORECASE)
KEYWORD_PATTERN_CACHE[key] = pattern
return pattern
def normalize_text(s: str) -> str:
s = re.sub(r"https?://\S+", " ", s) # drop urls
s = re.sub(r"@[A-Za-z0-9_]+", " ", s) # drop @mentions
s = re.sub(r"#[A-Za-z0-9_]+", " ", s) # drop hashtags (we'll match keywords separately)
s = re.sub(r"\s+", " ", s).strip()
return s
@lru_cache(maxsize=2)
def load_sentiment_pipeline(model_name: str = "cardiffnlp/twitter-roberta-base-sentiment-latest"):
_ = _lazy_imports()
from transformers import pipeline as hf_pipeline
pipe = hf_pipeline(
task="sentiment-analysis",
model=model_name,
tokenizer=model_name,
truncation=True,
max_length=256,
return_all_scores=False,
device=-1,
)
return pipe
@lru_cache(maxsize=2)
def load_hf_dataset(name: str):
_ = _lazy_imports()
from datasets import load_dataset
if name == "sentiment140":
# 1.6M tweets; we'll stream and sample later
ds = load_dataset("sentiment140", trust_remote_code=True)
# columns: ['sentiment','ids','date','query','user','text']
return ds
elif name == "tweet_eval":
# We'll use the sentiment subset
ds = load_dataset("tweet_eval", "sentiment")
# columns: ['text','label'] where label in {0:negative,1:neutral,2:positive}
return ds
else:
raise ValueError("Unsupported dataset: " + name)
def filter_and_sample(df: pd.DataFrame, keywords: List[str], sample_size: int, random_state: int = 42) -> pd.DataFrame:
pat = compile_keyword_pattern(keywords)
mask = df['text'].str.contains(pat, na=False)
subset = df.loc[mask].copy()
if subset.empty:
return subset
if sample_size > 0 and len(subset) > sample_size:
subset = subset.sample(n=sample_size, random_state=random_state)
return subset
def run_inference(texts: List[str], batch_size: int = 64) -> List[dict]:
pipe = load_sentiment_pipeline()
results = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
out = pipe(batch)
# normalize labels to {positive, neutral, negative}
for o in out:
lab = o.get('label', '').lower()
if 'pos' in lab:
label = 'positive'
elif 'neg' in lab:
label = 'negative'
else:
label = 'neutral'
results.append({'label': label, 'score': float(o.get('score', 0.0))})
return results
def make_bar_plot(counts: pd.Series):
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(5, 3.2), dpi=140)
ax = fig.gca()
counts = counts.reindex(['negative', 'neutral', 'positive']).fillna(0)
ax.bar(counts.index, counts.values)
total = int(counts.sum())
ax.set_title(f"Sentiment distribution (n={total})")
ax.set_xlabel("Sentiment")
ax.set_ylabel("# Tweets")
fig.tight_layout()
return fig
def make_wordcloud(texts: List[str]):
# Optional; will return None if wordcloud isn't available
try:
from wordcloud import WordCloud
except Exception:
return None
joined = " ".join(texts)
wc = WordCloud(width=800, height=320, background_color="white").generate(joined)
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(8, 3.6), dpi=120)
plt.imshow(wc)
plt.axis("off")
fig.tight_layout()
return fig
# ----------------------------
# Core pipeline
# ----------------------------
def analyze(dataset_choice: str,
keywords_csv: str,
max_rows: int,
include_wordcloud: bool) -> Tuple[str, "matplotlib.figure.Figure", "matplotlib.figure.Figure", pd.DataFrame]:
"""Return (summary_markdown, bar_fig, wordcloud_fig|None, table_df)"""
ds = load_hf_dataset(dataset_choice)
# Convert to pandas
if dataset_choice == "sentiment140":
# concatenate a manageable slice from train/test (to keep runtime reasonable)
train = ds.get('train')
test = ds.get('test')
frames = []
for split in [train, test]:
if split is None:
continue
# Take a small random slice to keep Space responsive
n = len(split)
take = min(n, 150_000) # cap
frames.append(split.shuffle(seed=42).select(range(take)).to_pandas()[['text', 'date']])
df = pd.concat(frames, ignore_index=True)
else:
# tweet_eval sentiment
frames = []
for name in ['train', 'validation', 'test']:
if name in ds:
frames.append(ds[name].to_pandas()[['text']])
df = pd.concat(frames, ignore_index=True)
if 'date' not in df.columns:
df['date'] = np.nan
# Clean
df['text'] = df['text'].astype(str).apply(normalize_text)
# Keywords
keywords = [k.strip() for k in (keywords_csv or "").split(',') if k.strip()] or TARIFF_KEYWORDS_DEFAULT
# Filter + sample
subset = filter_and_sample(df, keywords, sample_size=max_rows)
if subset.empty:
return (
"### No matches found\nTry broadening keywords or increasing the sample size.",
make_bar_plot(pd.Series(dtype=int)),
None,
pd.DataFrame(columns=['text','pred_label','pred_score','date'])
)
# Inference
preds = run_inference(subset['text'].tolist())
pred_df = pd.DataFrame(preds)
subset = subset.reset_index(drop=True).copy()
subset['pred_label'] = pred_df['label']
subset['pred_score'] = pred_df['score']
# Metrics
counts = subset['pred_label'].value_counts()
total = int(counts.sum())
pct = (counts / max(total, 1) * 100).round(1)
# Summary text
sentiment_line = (
f"**Negative:** {int(counts.get('negative', 0))} ({pct.get('negative', 0.0)}%) | "
f"**Neutral:** {int(counts.get('neutral', 0))} ({pct.get('neutral', 0.0)}%) | "
f"**Positive:** {int(counts.get('positive', 0))} ({pct.get('positive', 0.0)}%)"
)
summary = (
"## Tariff Tweet Sentiment — Snapshot\n"
f"Dataset: **{dataset_choice}** | Sampled tweets: **{total}**\n\n"
f"Keyword filter: `{', '.join(keywords)}`\n\n"
+ sentiment_line +
"\n\nTip: Neutral can be high when tweets are mostly informative (news/links) or ambiguous."
)
# Plots
bar_fig = make_bar_plot(counts)
wc_fig = make_wordcloud(subset['text'].tolist()) if include_wordcloud else None
# Output table (limit rows for UI responsiveness)
out_df = subset[['text','pred_label','pred_score','date']]
return summary, bar_fig, wc_fig, out_df
# ----------------------------
# Gradio UI
# ----------------------------
with gr.Blocks(title="Tariff Tweet Sentiment (No Twitter API)") as demo:
gr.Markdown(
"""
# Tariff Tweet Sentiment
Analyze how people talk about **U.S. tariff policy** using public Twitter corpora (no API key required).
**How it works**
- Choose a public dataset (e.g., `sentiment140` or `tweet_eval/sentiment`).
- Filter tweets by keywords like *tariff*, *trade war*, *Section 301*, etc.
- Run a Twitter-optimized sentiment model.
- View distribution, word cloud, and the matching tweets.
*Note:* Public corpora may skew older or topical; results are a **snapshot**, not a live feed.
"""
)
with gr.Row():
dataset_choice = gr.Dropdown(
choices=["sentiment140", "tweet_eval"],
value="sentiment140",
label="Dataset"
)
max_rows = gr.Slider(100, 5000, value=1500, step=50, label="Max tweets to analyze (after keyword filter)")
keywords_csv = gr.Textbox(value=", ".join(TARIFF_KEYWORDS_DEFAULT), label="Keywords (comma‑separated)")
include_wordcloud = gr.Checkbox(value=True, label="Include word cloud (optional)")
run_btn = gr.Button("Run Analysis", variant="primary")
summary_md = gr.Markdown()
bar_plot = gr.Plot(label="Sentiment distribution")
wc_plot = gr.Plot(label="Word cloud (optional)")
table = gr.Dataframe(headers=["text","pred_label","pred_score","date"], wrap=True, interactive=False)
csv = gr.File(label="Download CSV of results", visible=True)
def _go(dataset_choice, keywords_csv, max_rows, include_wordcloud):
summary, bar_fig, wc_fig, df = analyze(dataset_choice, keywords_csv, int(max_rows), bool(include_wordcloud))
# Save CSV
out_path = "tariff_tweets_sentiment.csv"
df.to_csv(out_path, index=False)
return summary, bar_fig, wc_fig, df, out_path
run_btn.click(_go, [dataset_choice, keywords_csv, max_rows, include_wordcloud], [summary_md, bar_plot, wc_plot, table, csv])
if __name__ == "__main__":
demo.launch()