import torch from src.models.decoder import TransformerDecoder from src.models.encoder import TransformerEncoder from src.models.heads import ClassificationHead, LMHead, TokenClassificationHead from src.models.multitask import MultiTaskModel def test_multitask_encoder_classification_forward_and_loss(): torch.manual_seed(0) vocab_size = 30 d_model = 32 num_layers = 2 num_heads = 4 d_ff = 64 batch_size = 3 seq_len = 8 num_labels = 5 enc = TransformerEncoder( vocab_size=vocab_size, d_model=d_model, num_layers=num_layers, num_heads=num_heads, d_ff=d_ff, dropout=0.0, max_len=seq_len, pad_token_id=0, ) mt = MultiTaskModel(encoder=enc) head = ClassificationHead(d_model=d_model, num_labels=num_labels, pooler="mean", dropout=0.0) mt.add_head("sentiment", head) input_ids = torch.randint(1, vocab_size, (batch_size, seq_len), dtype=torch.long) labels = torch.randint(0, num_labels, (batch_size,), dtype=torch.long) logits = mt.forward("sentiment", {"input_ids": input_ids}) assert logits.shape == (batch_size, num_labels) loss, logits2 = mt.forward( "sentiment", {"input_ids": input_ids, "labels": labels}, return_loss=True ) assert loss.item() >= 0 # grads loss.backward() grads = [p.grad for p in mt.parameters() if p.requires_grad] assert any(g is not None for g in grads) def test_multitask_seq2seq_lm_forward_and_loss(): torch.manual_seed(1) vocab_size = 40 d_model = 32 num_layers = 2 num_heads = 4 d_ff = 64 batch_size = 2 src_len = 7 tgt_len = 6 enc = TransformerEncoder( vocab_size=vocab_size, d_model=d_model, num_layers=num_layers, num_heads=num_heads, d_ff=d_ff, dropout=0.0, max_len=src_len, pad_token_id=0, ) dec = TransformerDecoder( vocab_size=vocab_size, d_model=d_model, num_layers=num_layers, num_heads=num_heads, d_ff=d_ff, dropout=0.0, max_len=tgt_len, pad_token_id=0, ) mt = MultiTaskModel(encoder=enc, decoder=dec) lm_head = LMHead(d_model=d_model, vocab_size=vocab_size, tie_embedding=None) mt.add_head("summarize", lm_head) src_ids = torch.randint(1, vocab_size, (batch_size, src_len), dtype=torch.long) # for training: provide decoder inputs (typically shifted right) and labels tgt_ids = torch.randint(1, vocab_size, (batch_size, tgt_len), dtype=torch.long) labels = tgt_ids.clone() logits = mt.forward("summarize", {"src_ids": src_ids, "tgt_ids": tgt_ids}) assert logits.shape == (batch_size, tgt_len, vocab_size) loss, logits2 = mt.forward( "summarize", {"src_ids": src_ids, "tgt_ids": tgt_ids, "labels": labels}, return_loss=True ) assert loss.item() >= 0 loss.backward() grads = [p.grad for p in mt.parameters() if p.requires_grad] assert any(g is not None for g in grads) def test_token_classification_forward_and_loss(): torch.manual_seed(2) vocab_size = 20 d_model = 24 num_layers = 2 num_heads = 4 d_ff = 64 batch_size = 2 seq_len = 5 num_labels = 7 enc = TransformerEncoder( vocab_size=vocab_size, d_model=d_model, num_layers=num_layers, num_heads=num_heads, d_ff=d_ff, dropout=0.0, max_len=seq_len, pad_token_id=0, ) mt = MultiTaskModel(encoder=enc) head = TokenClassificationHead(d_model=d_model, num_labels=num_labels, dropout=0.0) mt.add_head("ner", head) input_ids = torch.randint(1, vocab_size, (batch_size, seq_len), dtype=torch.long) labels = torch.randint(0, num_labels, (batch_size, seq_len), dtype=torch.long) logits = mt.forward("ner", {"input_ids": input_ids}) assert logits.shape == (batch_size, seq_len, num_labels) loss, logits2 = mt.forward("ner", {"input_ids": input_ids, "labels": labels}, return_loss=True) assert loss.item() >= 0 loss.backward() grads = [p.grad for p in mt.parameters() if p.requires_grad] assert any(g is not None for g in grads)