| | from functools import lru_cache |
| |
|
| | import torch |
| | from loguru import logger |
| | from sentence_transformers import SentenceTransformer |
| | from transformers import AutoTokenizer, AutoModel |
| |
|
| | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
| |
|
| | list_models = [ |
| | 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2', |
| | 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', |
| | 'sentence-transformers/all-mpnet-base-v2', |
| | 'sentence-transformers/all-MiniLM-L12-v2', |
| | 'cyclone/simcse-chinese-roberta-wwm-ext', |
| | 'bert-base-chinese', |
| | 'IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese', |
| | ] |
| |
|
| |
|
| | class SBert: |
| | def __init__(self, path): |
| | logger.info(f'Start loading {self.__class__} from {path} ...') |
| | self.model = SentenceTransformer(path, device=DEVICE) |
| | logger.info(f'Load {self.__class__} from {path} ...') |
| |
|
| | @lru_cache(maxsize=10000) |
| | def __call__(self, x) -> torch.Tensor: |
| | y = self.model.encode(x, convert_to_tensor=True) |
| | return y |
| |
|
| |
|
| | class ModelWithPooling: |
| | def __init__(self, path): |
| | logger.info(f'Start loading {self.__class__} from {path} ...') |
| | self.tokenizer = AutoTokenizer.from_pretrained(path) |
| | self.model = AutoModel.from_pretrained(path) |
| | logger.info(f'Load {self.__class__} from {path} ...') |
| |
|
| | @lru_cache(maxsize=100) |
| | @torch.no_grad() |
| | def __call__(self, text: str, pooling='mean'): |
| | inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt") |
| | outputs = self.model(**inputs, output_hidden_states=True) |
| |
|
| | if pooling == 'cls': |
| | o = outputs.last_hidden_state[:, 0] |
| |
|
| | elif pooling == 'pooler': |
| | o = outputs.pooler_output |
| |
|
| | elif pooling in ['mean', 'last-avg']: |
| | last = outputs.last_hidden_state.transpose(1, 2) |
| | o = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) |
| |
|
| | elif pooling == 'first-last-avg': |
| | first = outputs.hidden_states[1].transpose(1, 2) |
| | last = outputs.hidden_states[-1].transpose(1, 2) |
| | first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1) |
| | last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) |
| | avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1) |
| | o = torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1) |
| |
|
| | else: |
| | raise Exception(f'Unknown pooling {pooling}') |
| |
|
| | o = o.squeeze(0) |
| | return o |
| |
|
| |
|
| | def test_sbert(): |
| | m = SBert('bert-base-chinese') |
| | o = m('hello') |
| | print(o.size()) |
| | assert o.size() == (768,) |
| |
|
| |
|
| | def test_hf_model(): |
| | m = ModelWithPooling('IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese') |
| | o = m('hello', pooling='cls') |
| | print(o.size()) |
| | assert o.size() == (768,) |
| |
|