enc_hiddens_proj = self.att_projection(enc_hiddens) # shape (b, src_len, h)
Y = self.model_embeddings.target(target_padded) # shape shape (tgt_len, b, e)
for i in torch.split(Y, 1, dim=0): # torch.split(tensor, split_size_or_section, dim=0), Y_t.shape (1, b, e)
Y_t = i.squeeze(0) # Y_t.shape (b, e), o_prev.shape (b, h), squeeze explicitly at location
Ybar_t = torch.cat((Y_t, o_prev), dim=1) # Ybar_t.shape(b, e+h), torch.cat(tensors, dim=0, out=None)
dec_state, o_t, e_t = self.step(Ybar_t, dec_state, enc_hiddens, enc_hiddens_proj, enc_masks) # dec_state, combined_output, e_t
combined_outputs.append(o_t) # shape (b, h)
o_prev = o_t
combined_outputs = torch.stack(combined_outputs, dim=0) # (tgt_len, b, h)