| | from transformers import PreTrainedModel, PretrainedConfig |
| | from sentence_transformers import SentenceTransformer |
| | import torch |
| | import torch.nn as nn |
| | import numpy as np |
| |
|
| |
|
| | class ZeroShotEmbeddingConfig(PretrainedConfig): |
| | model_type = "embedding-head" |
| |
|
| | def __init__(self, input_size=768, hidden_size=2048, output_size=128, base_embedding_model='all-mpnet-base-v2', **kwargs): |
| | self.input_size = input_size |
| | self.hidden_size = hidden_size |
| | self.output_size = output_size |
| | self.base_embedding_model = base_embedding_model |
| | super().__init__(**kwargs) |
| |
|
| |
|
| | class ZeroShotEmbedding(PreTrainedModel): |
| | config_class = ZeroShotEmbeddingConfig |
| |
|
| | def __init__(self, config): |
| | super(ZeroShotEmbedding, self).__init__(config) |
| |
|
| | input_size = config.input_size |
| | hidden_size = config.hidden_size |
| | output_size = config.output_size |
| |
|
| | self.input_size = input_size |
| | self.hidden_size = hidden_size |
| | self.output_size = output_size |
| | |
| | self.fc1 = nn.Linear(input_size * 2, hidden_size) |
| | self.fc2 = nn.Linear(hidden_size, output_size) |
| | self.gelu = nn.GELU() |
| |
|
| | def forward(self, prompt_embedding, text_a_embedding, text_b_embedding=None, labels=None, **kwargs): |
| | |
| | |
| | |
| |
|
| | |
| | |
| | x = torch.cat((text_a_embedding, prompt_embedding), dim=1) |
| | if text_b_embedding is not None: |
| | |
| | |
| | x2 = torch.cat((text_b_embedding, prompt_embedding), dim=1) |
| |
|
| | |
| | x = self.fc1(x) |
| | x = self.gelu(x) |
| | x = self.fc2(x) |
| | x = nn.functional.normalize(x, p=2, dim=1) |
| | if text_b_embedding is not None: |
| | x2 = self.fc1(x2) |
| | x2 = self.gelu(x2) |
| | x2 = self.fc2(x2) |
| | x2 = nn.functional.normalize(x2, p=2, dim=1) |
| | |
| | dot_product = torch.bmm(x.unsqueeze(1), x2.unsqueeze(2)).squeeze() |
| | if labels is not None: |
| | |
| | loss = torch.mean((dot_product - labels) ** 2) |
| | return loss, dot_product |
| | return dot_product |
| | return x |
| |
|
| |
|
| | class ZeroShotEmbeddingForClustering(PreTrainedModel): |
| | config_class = ZeroShotEmbeddingConfig |
| |
|
| | def __init__(self, config): |
| | super(ZeroShotEmbeddingForClustering, self).__init__(config) |
| | self.base_embedding_model = SentenceTransformer( |
| | config.base_embedding_model) |
| | self.head_model = ZeroShotEmbedding(config) |
| |
|
| | def forward(self, texts, prompt, **kwargs): |
| | text_embeddings = self.base_embedding_model.encode(texts) |
| | prompt_embedding = self.base_embedding_model.encode(prompt) |
| | prompt_embeddings = np.tile(prompt_embedding, (len(texts), 1)) |
| | text_embeddings = torch.tensor(text_embeddings) |
| | prompt_embeddings = torch.tensor(prompt_embeddings) |
| | prompted_embeddings = self.head_model( |
| | prompt_embeddings, text_embeddings) |
| | similarity = torch.mm(prompted_embeddings, |
| | prompted_embeddings.transpose(0, 1)) |
| | return similarity |
| |
|
| | @classmethod |
| | def from_pretrained_base(cls, pretrained_model_name_or_path): |
| | head_model = ZeroShotEmbedding.from_pretrained( |
| | pretrained_model_name_or_path) |
| | model = cls(head_model.config) |
| | cls.head_model = head_model |
| | return model |
| |
|
| |
|
| | ZeroShotEmbeddingConfig.register_for_auto_class() |
| | ZeroShotEmbedding.register_for_auto_class("AutoModel") |
| | ZeroShotEmbeddingForClustering.register_for_auto_class("AutoModel") |
| |
|