| | import torch |
| | import torch.utils.data |
| |
|
| | from seq2struct.models import abstract_preproc |
| | from seq2struct.utils import registry |
| |
|
| | class ZippedDataset(torch.utils.data.Dataset): |
| | def __init__(self, *components): |
| | assert len(components) >= 1 |
| | lengths = [len(c) for c in components] |
| | assert all( |
| | lengths[0] == other for other in lengths[1:]), "Lengths don't match: {}".format(lengths) |
| | self.components = components |
| | |
| | def __getitem__(self, idx): |
| | return tuple(c[idx] for c in self.components) |
| | |
| | def __len__(self): |
| | return len(self.components[0]) |
| |
|
| |
|
| | @registry.register('model', 'EncDec') |
| | class EncDecModel(torch.nn.Module): |
| | class Preproc(abstract_preproc.AbstractPreproc): |
| | def __init__( |
| | self, |
| | encoder, |
| | decoder, |
| | encoder_preproc, |
| | decoder_preproc): |
| | super().__init__() |
| |
|
| | self.enc_preproc = registry.lookup('encoder', encoder['name']).Preproc(**encoder_preproc) |
| | self.dec_preproc = registry.lookup('decoder', decoder['name']).Preproc(**decoder_preproc) |
| | |
| | def validate_item(self, item, section): |
| | enc_result, enc_info = self.enc_preproc.validate_item(item, section) |
| | dec_result, dec_info = self.dec_preproc.validate_item(item, section) |
| | |
| | return enc_result and dec_result, (enc_info, dec_info) |
| | |
| | def add_item(self, item, section, validation_info): |
| | enc_info, dec_info = validation_info |
| | self.enc_preproc.add_item(item, section, enc_info) |
| | self.dec_preproc.add_item(item, section, dec_info) |
| | |
| | def clear_items(self): |
| | self.enc_preproc.clear_items() |
| | self.dec_preproc.clear_items() |
| |
|
| | def save(self): |
| | self.enc_preproc.save() |
| | self.dec_preproc.save() |
| | |
| | def load(self): |
| | self.enc_preproc.load() |
| | self.dec_preproc.load() |
| | |
| | def dataset(self, section): |
| | return ZippedDataset(self.enc_preproc.dataset(section), self.dec_preproc.dataset(section)) |
| | |
| | def __init__(self, preproc, device, encoder, decoder): |
| | super().__init__() |
| | self.preproc = preproc |
| | self.encoder = registry.construct( |
| | 'encoder', encoder, device=device, preproc=preproc.enc_preproc) |
| | self.decoder = registry.construct( |
| | 'decoder', decoder, device=device, preproc=preproc.dec_preproc) |
| | self.decoder.visualize_flag = False |
| | |
| | if getattr(self.encoder, 'batched'): |
| | self.compute_loss = self._compute_loss_enc_batched |
| | else: |
| | self.compute_loss = self._compute_loss_unbatched |
| |
|
| | def _compute_loss_enc_batched(self, batch, debug=False): |
| | losses = [] |
| | d = [enc_input for enc_input, dec_output in batch] |
| | enc_states = self.encoder(d) |
| |
|
| | for enc_state, (enc_input, dec_output) in zip(enc_states, batch): |
| | loss = self.decoder.compute_loss(enc_input, dec_output, enc_state, debug) |
| | losses.append(loss) |
| | if debug: |
| | return losses |
| | else: |
| | return torch.mean(torch.stack(losses, dim=0), dim=0) |
| |
|
| | def _compute_loss_enc_batched2(self, batch, debug=False): |
| | losses = [] |
| | for enc_input, dec_output in batch: |
| | enc_state, = self.encoder([enc_input]) |
| | loss = self.decoder.compute_loss(enc_input, dec_output, enc_state, debug) |
| | losses.append(loss) |
| | if debug: |
| | return losses |
| | else: |
| | return torch.mean(torch.stack(losses, dim=0), dim=0) |
| |
|
| | def _compute_loss_unbatched(self, batch, debug=False): |
| | losses = [] |
| | for enc_input, dec_output in batch: |
| | enc_state = self.encoder(enc_input) |
| | loss = self.decoder.compute_loss(enc_input, dec_output, enc_state, debug) |
| | losses.append(loss) |
| | if debug: |
| | return losses |
| | else: |
| | return torch.mean(torch.stack(losses, dim=0), dim=0) |
| |
|
| | def eval_on_batch(self, batch): |
| | mean_loss = self.compute_loss(batch).item() |
| | batch_size = len(batch) |
| | result = {'loss': mean_loss * batch_size, 'total': batch_size} |
| | return result |
| |
|
| | def begin_inference(self, orig_item, preproc_item): |
| | |
| | |
| | |
| | |
| | |
| |
|
| | enc_input, _ = preproc_item |
| | if self.decoder.visualize_flag: |
| | print('question:') |
| | print(enc_input['question']) |
| | print('columns:') |
| | print(enc_input['columns']) |
| | print('tables:') |
| | print(enc_input['tables']) |
| | if getattr(self.encoder, 'batched'): |
| | enc_state, = self.encoder([enc_input]) |
| | else: |
| | enc_state = self.encoder(enc_input) |
| | return self.decoder.begin_inference(enc_state, orig_item) |
| |
|