File size: 5,692 Bytes
422c1f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# model.py
import torch
import torch.nn as nn
from transformers import DistilBertTokenizer, DistilBertModel
from torchvision.models import efficientnet_b0
from config import AUTHENTICITY_CLASSES, CATEGORIES

class AuctionAuthenticityModel(nn.Module):
    def __init__(self, num_classes=None, device='cpu'):
        # If num_classes not specified, use config
        if num_classes is None:
            num_classes = len(AUTHENTICITY_CLASSES)
        # Category classes (separate head)
        num_categories = len(CATEGORIES)
        super().__init__()
        self.device = device
        
        # Vision
        self.vision_model = efficientnet_b0(pretrained=True)
        self.vision_model.classifier = nn.Identity()
        vision_out_dim = 1280
        
        # Text
        self.text_model = DistilBertModel.from_pretrained(
            'distilbert-base-multilingual-cased'
        )
        text_out_dim = 768
        
        self.tokenizer = DistilBertTokenizer.from_pretrained(
            'distilbert-base-multilingual-cased'
        )
        
        # Fusion encoder (shared) -> then two heads (authenticity + category)
        hidden_dim = 256
        self.fusion_encoder = nn.Sequential(
            nn.Linear(vision_out_dim + text_out_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
        )

        # Heads
        self.auth_head = nn.Linear(128, num_classes)
        self.cat_head = nn.Linear(128, num_categories)

        # store sizes for reference
        self.num_classes = num_classes
        self.num_categories = num_categories
    
    def forward(self, images, texts):
        vision_features = self.vision_model(images)
        tokens = self.tokenizer(
            texts, padding=True, truncation=True, max_length=512, return_tensors='pt'
        ).to(self.device)
        text_outputs = self.text_model(**tokens)
        text_features = text_outputs.last_hidden_state[:, 0, :]

        combined = torch.cat([vision_features, text_features], dim=1)
        shared = self.fusion_encoder(combined)

        auth_logits = self.auth_head(shared)
        cat_logits = self.cat_head(shared)

        # probabilities
        auth_probs = torch.softmax(auth_logits, dim=1)
        cat_probs = torch.softmax(cat_logits, dim=1)

        return {
            'auth_logits': auth_logits,
            'auth_probs': auth_probs,
            'cat_logits': cat_logits,
            'cat_probs': cat_probs,
        }

    def compute_loss(self, outputs, auth_labels=None, cat_labels=None, auth_weight=1.0, cat_weight=1.0):
        """Compute combined loss for two heads. Labels should be LongTensors on same device.

        Returns combined scalar loss and a dict with individual losses.
        """
        losses = {}
        loss = 0.0
        criterion = nn.CrossEntropyLoss()

        if auth_labels is not None:
            l_auth = criterion(outputs['auth_logits'], auth_labels)
            losses['auth_loss'] = l_auth
            loss = loss + auth_weight * l_auth

        if cat_labels is not None:
            # Allow sentinel -1 for unknown/uncertain categories and ignore them
            if cat_labels.dim() == 1:
                mask = cat_labels >= 0
            else:
                mask = (cat_labels.squeeze(-1) >= 0)

            if mask.sum().item() > 0:
                selected_logits = outputs['cat_logits'][mask]
                selected_labels = cat_labels[mask]
                l_cat = criterion(selected_logits, selected_labels)
                losses['cat_loss'] = l_cat
                loss = loss + cat_weight * l_cat
            else:
                # No valid category labels in batch
                losses['cat_loss'] = torch.tensor(0.0, device=self.device)

        return loss, losses
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


if __name__ == '__main__':
    print("Testowanie modelu...")
    
    device = torch.device('cpu')
    model = AuctionAuthenticityModel(device=device).to(device)
    
    print(f"✓ Model stworzony")
    print(f"  - Parametrów: {model.count_parameters():,}")
    
    # Dummy test
    dummy_img = torch.randn(2, 3, 224, 224).to(device)
    dummy_texts = ["Silver spoon antique", "Polish silverware 19th century"]
    
    with torch.no_grad():
        output = model(dummy_img, dummy_texts)

    # Print shapes
    print("✓ Forward pass:")
    print(f"  - auth_logits: {output['auth_logits'].shape}")
    print(f"  - auth_probs: {output['auth_probs'].shape}")
    print(f"  - cat_logits: {output['cat_logits'].shape}")
    print(f"  - cat_probs: {output['cat_probs'].shape}")

    # Show predicted labels and top probabilities
    auth_pred = torch.argmax(output['auth_probs'], dim=1)
    cat_pred = torch.argmax(output['cat_probs'], dim=1)

    for i in range(output['auth_probs'].shape[0]):
        a_idx = int(auth_pred[i].item())
        a_prob = float(output['auth_probs'][i, a_idx].item())
        c_idx = int(cat_pred[i].item())
        c_prob = float(output['cat_probs'][i, c_idx].item())
        a_name = AUTHENTICITY_CLASSES.get(a_idx, str(a_idx))
        c_name = CATEGORIES.get(c_idx, str(c_idx))
        print(f"\nSample {i}:")
        print(f"  - Authenticity: {a_name} ({a_prob:.3f})")
        print(f"  - Category: {c_name} ({c_prob:.3f})")
    
    # Estimate model size
    print(f"\n📊 Rozmiar modelu:")
    torch.save(model.state_dict(), 'temp_model.pt')
    import os
    size_mb = os.path.getsize('temp_model.pt') / (1024*1024)
    print(f"  - {size_mb:.1f} MB")
    os.remove('temp_model.pt')