Update linear_mapping.py
Browse files- linear_mapping.py +9 -4
linear_mapping.py
CHANGED
|
@@ -2,6 +2,7 @@ from config import LinearMappingConfig
|
|
| 2 |
from transformers import (
|
| 3 |
GPT2TokenizerFast, GPT2LMHeadModel, AutoModel,
|
| 4 |
CLIPVisionModel, AutoProcessor, BatchEncoding,
|
|
|
|
| 5 |
)
|
| 6 |
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModelOutput
|
| 7 |
import torch
|
|
@@ -104,9 +105,11 @@ class ImagePrefix(nn.Module):
|
|
| 104 |
|
| 105 |
def __init__(self, config: LinearMappingConfig):
|
| 106 |
super().__init__()
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
|
| 110 |
|
| 111 |
if config.freeze_image_model:
|
| 112 |
for param in self.encoder.parameters():
|
|
@@ -128,7 +131,9 @@ class LinearMapping(nn.Module):
|
|
| 128 |
def __init__(self, config: LinearMappingConfig):
|
| 129 |
super().__init__()
|
| 130 |
self.image_prefix = ImagePrefix(config)
|
| 131 |
-
self.language_model = GPT2LMHeadModel.from_pretrained(config.text_model)
|
|
|
|
|
|
|
| 132 |
self.processor = LinearMappingProcessor(config)
|
| 133 |
self.tokenizer = self.processor.tokenizer
|
| 134 |
self.image_processor = self.processor.image_processor
|
|
|
|
| 2 |
from transformers import (
|
| 3 |
GPT2TokenizerFast, GPT2LMHeadModel, AutoModel,
|
| 4 |
CLIPVisionModel, AutoProcessor, BatchEncoding,
|
| 5 |
+
AutoConfig, CLIPVisionConfig
|
| 6 |
)
|
| 7 |
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModelOutput
|
| 8 |
import torch
|
|
|
|
| 105 |
|
| 106 |
def __init__(self, config: LinearMappingConfig):
|
| 107 |
super().__init__()
|
| 108 |
+
clip_config = CLIPVisionConfig.from_pretrained(config.image_model)
|
| 109 |
+
|
| 110 |
+
self.encoder = CLIPVisionModel(clip_config)
|
| 111 |
+
if config.image_from_pretrained:
|
| 112 |
+
self.encoder = self.encoder.from_pretrained(config.image_model)
|
| 113 |
|
| 114 |
if config.freeze_image_model:
|
| 115 |
for param in self.encoder.parameters():
|
|
|
|
| 131 |
def __init__(self, config: LinearMappingConfig):
|
| 132 |
super().__init__()
|
| 133 |
self.image_prefix = ImagePrefix(config)
|
| 134 |
+
self.language_model = GPT2LMHeadModel(AutoConfig.from_pretrained(config.text_model))
|
| 135 |
+
if config.text_from_pretrained:
|
| 136 |
+
self.language_model = self.language_model.from_pretrained(config.text_model)
|
| 137 |
self.processor = LinearMappingProcessor(config)
|
| 138 |
self.tokenizer = self.processor.tokenizer
|
| 139 |
self.image_processor = self.processor.image_processor
|