Update app.py
Browse files
app.py
CHANGED
|
@@ -112,9 +112,10 @@ class RAGSystem:
|
|
| 112 |
return ""
|
| 113 |
|
| 114 |
class ImageAnalyzer:
|
| 115 |
-
def __init__(self):
|
| 116 |
-
self.device = "cpu"
|
| 117 |
self.defect_classes = ["spalling", "structural_cracks", "surface_deterioration"]
|
|
|
|
| 118 |
self._model = None
|
| 119 |
self._feature_extractor = None
|
| 120 |
|
|
@@ -127,22 +128,60 @@ class ImageAnalyzer:
|
|
| 127 |
@property
|
| 128 |
def feature_extractor(self):
|
| 129 |
if self._feature_extractor is None:
|
| 130 |
-
self._feature_extractor =
|
| 131 |
return self._feature_extractor
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
def _load_model(self):
|
| 134 |
try:
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
|
|
|
| 141 |
with torch.no_grad():
|
| 142 |
-
model
|
| 143 |
-
in_features=model.classifier.in_features
|
| 144 |
-
|
| 145 |
-
)
|
|
|
|
|
|
|
|
|
|
| 146 |
return model
|
| 147 |
except Exception as e:
|
| 148 |
logger.error(f"Model initialization error: {e}")
|
|
@@ -150,7 +189,7 @@ class ImageAnalyzer:
|
|
| 150 |
|
| 151 |
def preprocess_image(self, image_bytes):
|
| 152 |
"""Preprocess image for model input"""
|
| 153 |
-
return _cached_preprocess_image(image_bytes)
|
| 154 |
|
| 155 |
def analyze_image(self, image):
|
| 156 |
"""Analyze image for defects"""
|
|
@@ -187,20 +226,26 @@ class ImageAnalyzer:
|
|
| 187 |
return None
|
| 188 |
|
| 189 |
@st.cache_data
|
| 190 |
-
def _cached_preprocess_image(image_bytes):
|
| 191 |
"""Cached version of image preprocessing"""
|
| 192 |
try:
|
| 193 |
image = Image.open(image_bytes)
|
| 194 |
if image.mode != 'RGB':
|
| 195 |
image = image.convert('RGB')
|
| 196 |
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
image = image.resize((width, height), Image.Resampling.LANCZOS)
|
| 199 |
return image
|
| 200 |
except Exception as e:
|
| 201 |
logger.error(f"Image preprocessing error: {e}")
|
| 202 |
return None
|
| 203 |
-
|
|
|
|
| 204 |
def get_groq_response(query: str, context: str) -> str:
|
| 205 |
"""Get response from Groq LLM with caching"""
|
| 206 |
try:
|
|
|
|
| 112 |
return ""
|
| 113 |
|
| 114 |
class ImageAnalyzer:
|
| 115 |
+
def __init__(self, model_name="microsoft/swin-base-patch4-window7-224-in22k"):
|
| 116 |
+
self.device = "cpu"
|
| 117 |
self.defect_classes = ["spalling", "structural_cracks", "surface_deterioration"]
|
| 118 |
+
self.model_name = model_name
|
| 119 |
self._model = None
|
| 120 |
self._feature_extractor = None
|
| 121 |
|
|
|
|
| 128 |
@property
|
| 129 |
def feature_extractor(self):
|
| 130 |
if self._feature_extractor is None:
|
| 131 |
+
self._feature_extractor = self._load_feature_extractor()
|
| 132 |
return self._feature_extractor
|
| 133 |
|
| 134 |
+
def _load_feature_extractor(self):
|
| 135 |
+
"""Load the appropriate feature extractor based on model type"""
|
| 136 |
+
try:
|
| 137 |
+
if "swin" in self.model_name:
|
| 138 |
+
from transformers import AutoFeatureExtractor
|
| 139 |
+
return AutoFeatureExtractor.from_pretrained(self.model_name)
|
| 140 |
+
elif "convnext" in self.model_name:
|
| 141 |
+
from transformers import ConvNextFeatureExtractor
|
| 142 |
+
return ConvNextFeatureExtractor.from_pretrained(self.model_name)
|
| 143 |
+
else:
|
| 144 |
+
from transformers import ViTFeatureExtractor
|
| 145 |
+
return ViTFeatureExtractor.from_pretrained(self.model_name)
|
| 146 |
+
except Exception as e:
|
| 147 |
+
logger.error(f"Feature extractor initialization error: {e}")
|
| 148 |
+
return None
|
| 149 |
+
|
| 150 |
def _load_model(self):
|
| 151 |
try:
|
| 152 |
+
if "swin" in self.model_name:
|
| 153 |
+
from transformers import SwinForImageClassification
|
| 154 |
+
model = SwinForImageClassification.from_pretrained(
|
| 155 |
+
self.model_name,
|
| 156 |
+
num_labels=len(self.defect_classes),
|
| 157 |
+
ignore_mismatched_sizes=True
|
| 158 |
+
)
|
| 159 |
+
elif "convnext" in self.model_name:
|
| 160 |
+
from transformers import ConvNextForImageClassification
|
| 161 |
+
model = ConvNextForImageClassification.from_pretrained(
|
| 162 |
+
self.model_name,
|
| 163 |
+
num_labels=len(self.defect_classes),
|
| 164 |
+
ignore_mismatched_sizes=True
|
| 165 |
+
)
|
| 166 |
+
else:
|
| 167 |
+
from transformers import ViTForImageClassification
|
| 168 |
+
model = ViTForImageClassification.from_pretrained(
|
| 169 |
+
self.model_name,
|
| 170 |
+
num_labels=len(self.defect_classes),
|
| 171 |
+
ignore_mismatched_sizes=True
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
model = model.to(self.device)
|
| 175 |
|
| 176 |
+
# Reinitialize the classifier layer
|
| 177 |
with torch.no_grad():
|
| 178 |
+
if hasattr(model, 'classifier'):
|
| 179 |
+
in_features = model.classifier.in_features
|
| 180 |
+
model.classifier = torch.nn.Linear(in_features, len(self.defect_classes))
|
| 181 |
+
elif hasattr(model, 'head'):
|
| 182 |
+
in_features = model.head.in_features
|
| 183 |
+
model.head = torch.nn.Linear(in_features, len(self.defect_classes))
|
| 184 |
+
|
| 185 |
return model
|
| 186 |
except Exception as e:
|
| 187 |
logger.error(f"Model initialization error: {e}")
|
|
|
|
| 189 |
|
| 190 |
def preprocess_image(self, image_bytes):
|
| 191 |
"""Preprocess image for model input"""
|
| 192 |
+
return _cached_preprocess_image(image_bytes, self.model_name)
|
| 193 |
|
| 194 |
def analyze_image(self, image):
|
| 195 |
"""Analyze image for defects"""
|
|
|
|
| 226 |
return None
|
| 227 |
|
| 228 |
@st.cache_data
|
| 229 |
+
def _cached_preprocess_image(image_bytes, model_name):
|
| 230 |
"""Cached version of image preprocessing"""
|
| 231 |
try:
|
| 232 |
image = Image.open(image_bytes)
|
| 233 |
if image.mode != 'RGB':
|
| 234 |
image = image.convert('RGB')
|
| 235 |
|
| 236 |
+
# Adjust size based on model requirements
|
| 237 |
+
if "convnext" in model_name:
|
| 238 |
+
width, height = 384, 384
|
| 239 |
+
else:
|
| 240 |
+
width, height = 224, 224
|
| 241 |
+
|
| 242 |
image = image.resize((width, height), Image.Resampling.LANCZOS)
|
| 243 |
return image
|
| 244 |
except Exception as e:
|
| 245 |
logger.error(f"Image preprocessing error: {e}")
|
| 246 |
return None
|
| 247 |
+
|
| 248 |
+
@st.cache_data
|
| 249 |
def get_groq_response(query: str, context: str) -> str:
|
| 250 |
"""Get response from Groq LLM with caching"""
|
| 251 |
try:
|