Spaces:
Running
Running
File size: 5,692 Bytes
422c1f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
# model.py
import torch
import torch.nn as nn
from transformers import DistilBertTokenizer, DistilBertModel
from torchvision.models import efficientnet_b0
from config import AUTHENTICITY_CLASSES, CATEGORIES
class AuctionAuthenticityModel(nn.Module):
def __init__(self, num_classes=None, device='cpu'):
# If num_classes not specified, use config
if num_classes is None:
num_classes = len(AUTHENTICITY_CLASSES)
# Category classes (separate head)
num_categories = len(CATEGORIES)
super().__init__()
self.device = device
# Vision
self.vision_model = efficientnet_b0(pretrained=True)
self.vision_model.classifier = nn.Identity()
vision_out_dim = 1280
# Text
self.text_model = DistilBertModel.from_pretrained(
'distilbert-base-multilingual-cased'
)
text_out_dim = 768
self.tokenizer = DistilBertTokenizer.from_pretrained(
'distilbert-base-multilingual-cased'
)
# Fusion encoder (shared) -> then two heads (authenticity + category)
hidden_dim = 256
self.fusion_encoder = nn.Sequential(
nn.Linear(vision_out_dim + text_out_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(hidden_dim, 128),
nn.ReLU(),
nn.Dropout(0.2),
)
# Heads
self.auth_head = nn.Linear(128, num_classes)
self.cat_head = nn.Linear(128, num_categories)
# store sizes for reference
self.num_classes = num_classes
self.num_categories = num_categories
def forward(self, images, texts):
vision_features = self.vision_model(images)
tokens = self.tokenizer(
texts, padding=True, truncation=True, max_length=512, return_tensors='pt'
).to(self.device)
text_outputs = self.text_model(**tokens)
text_features = text_outputs.last_hidden_state[:, 0, :]
combined = torch.cat([vision_features, text_features], dim=1)
shared = self.fusion_encoder(combined)
auth_logits = self.auth_head(shared)
cat_logits = self.cat_head(shared)
# probabilities
auth_probs = torch.softmax(auth_logits, dim=1)
cat_probs = torch.softmax(cat_logits, dim=1)
return {
'auth_logits': auth_logits,
'auth_probs': auth_probs,
'cat_logits': cat_logits,
'cat_probs': cat_probs,
}
def compute_loss(self, outputs, auth_labels=None, cat_labels=None, auth_weight=1.0, cat_weight=1.0):
"""Compute combined loss for two heads. Labels should be LongTensors on same device.
Returns combined scalar loss and a dict with individual losses.
"""
losses = {}
loss = 0.0
criterion = nn.CrossEntropyLoss()
if auth_labels is not None:
l_auth = criterion(outputs['auth_logits'], auth_labels)
losses['auth_loss'] = l_auth
loss = loss + auth_weight * l_auth
if cat_labels is not None:
# Allow sentinel -1 for unknown/uncertain categories and ignore them
if cat_labels.dim() == 1:
mask = cat_labels >= 0
else:
mask = (cat_labels.squeeze(-1) >= 0)
if mask.sum().item() > 0:
selected_logits = outputs['cat_logits'][mask]
selected_labels = cat_labels[mask]
l_cat = criterion(selected_logits, selected_labels)
losses['cat_loss'] = l_cat
loss = loss + cat_weight * l_cat
else:
# No valid category labels in batch
losses['cat_loss'] = torch.tensor(0.0, device=self.device)
return loss, losses
def count_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
if __name__ == '__main__':
print("Testowanie modelu...")
device = torch.device('cpu')
model = AuctionAuthenticityModel(device=device).to(device)
print(f"✓ Model stworzony")
print(f" - Parametrów: {model.count_parameters():,}")
# Dummy test
dummy_img = torch.randn(2, 3, 224, 224).to(device)
dummy_texts = ["Silver spoon antique", "Polish silverware 19th century"]
with torch.no_grad():
output = model(dummy_img, dummy_texts)
# Print shapes
print("✓ Forward pass:")
print(f" - auth_logits: {output['auth_logits'].shape}")
print(f" - auth_probs: {output['auth_probs'].shape}")
print(f" - cat_logits: {output['cat_logits'].shape}")
print(f" - cat_probs: {output['cat_probs'].shape}")
# Show predicted labels and top probabilities
auth_pred = torch.argmax(output['auth_probs'], dim=1)
cat_pred = torch.argmax(output['cat_probs'], dim=1)
for i in range(output['auth_probs'].shape[0]):
a_idx = int(auth_pred[i].item())
a_prob = float(output['auth_probs'][i, a_idx].item())
c_idx = int(cat_pred[i].item())
c_prob = float(output['cat_probs'][i, c_idx].item())
a_name = AUTHENTICITY_CLASSES.get(a_idx, str(a_idx))
c_name = CATEGORIES.get(c_idx, str(c_idx))
print(f"\nSample {i}:")
print(f" - Authenticity: {a_name} ({a_prob:.3f})")
print(f" - Category: {c_name} ({c_prob:.3f})")
# Estimate model size
print(f"\n📊 Rozmiar modelu:")
torch.save(model.state_dict(), 'temp_model.pt')
import os
size_mb = os.path.getsize('temp_model.pt') / (1024*1024)
print(f" - {size_mb:.1f} MB")
os.remove('temp_model.pt') |