import math
thousands_names = ' thousand million billion'.split(' ')
numeral_names = 'zero one two three four five six seven eight nine'.split(' ')
tens_names = 'zero ten twenty thirty forty fifty sixty seventy eighty ninety'.split(' ')
teens_names = 'ten eleven twelve thirteen fourteen fifteen sixteen seventeen eighteen nineteen'.split(' ')

# can we convert between words and numbers
def number_to_word(num):
    if num == 0:
        return 'zero'
    result = ''
    prefix = ''
    suffix = ''
    if num < 0:
        prefix += 'negative '
        num = -num
    places = int(math.log10(num)) + 1
    for digit in range(0, places, 3):
        value = num % 1000
        num //= 1000
        if value == 0:
            continue
        hundred = value // 100
        ten = (value % 100) // 10
        one = value % 10
        part = ''
        if hundred > 0:
            part += numeral_names[hundred] + ' hundred'
        if ten == 1:
            if len(part):
                part += ' '
            part += teens_names[one]
        else:
            if ten > 0:
                if len(part):
                    part += ' '
                part += tens_names[ten]
            if one > 0:
                if len(part):
                    part += ' '
                part += numeral_names[one]
        if digit > 0 and len(part):
            part += ' ' + thousands_names[digit // 3]
        if len(suffix):
            part += ' '
        suffix = part + suffix
    return prefix + suffix


import transformers, torch

class Model(transformers.PerceiverPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.input_preprocessor = transformers.models.perceiver.modeling_perceiver.PerceiverTextPreprocessor(config)
        self.decoder = transformers.models.perceiver.modeling_perceiver.PerceiverBasicDecoder(
            config,
            output_num_channels = config.d_latents,
            output_index_dims = config.max_position_embeddings,
            num_channels = config.d_model,
            qk_channels = config.qk_channels,
            v_channels = config.d_model,
            num_heads = config.num_decoder_heads,
            use_query_residual = False,
            final_project = False,
            trainable_position_encoding_kwargs = dict(
                num_channels = self.input_preprocessor.num_channels,
                index_dims = config.max_position_embeddings
            ),
        )
        self.perceiver = transformers.PerceiverModel(
            config,
            decoder = self.decoder,
            input_preprocessor = self.input_preprocessor,
        )
        self.embedding_decoder = transformers.models.perceiver.modeling_perceiver.PerceiverEmbeddingDecoder(config)

        self.post_init()
    def forward(self, inputs=None, attention_mask=None, head_mask=None, output_attentions=None, output_hidden_states=None, labels=None):#, return_dict=None, input_ids=None):
        outputs = self.perceiver(
                inputs=inputs,
                attention_mask=attention_mask,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_States=output_hidden_states,
                return_dict=False,#return_dict,
        )

        logits = self.embedding_Decoder(
                #outputs.logits if return_dict else outputs[0], embedding_layer=self.perceiver.input_preprocessor.embeddings
                outputs[0], embedding_layer=self.perceiver.input_preprocessor.embeddings
        )

        loss = None
        if labels is not None:
            loss = torch.nn.CrossEntropyLoss()(logits.view(-1, self.config.vocab_size), labels.view(-1))
        
        output = (logits,) + outputs[1:] # outputs[2:]
        if loss is None:
            return output
        else:
            return ((loss,) + output)

config = transformers.PerceiverConfig()
config.num_decoder_heads = config.num_cross_attention_heads
model = Model(config)

import torch
data = torch.randperm(10000000)
batch_size = len(data) // 16
data = data[:data - data % batch_Size]
tt_split = batch_size #len(data) // 16
train_numbers = data[:-tt_split].view(-1, batch_size)
test_numbers = data[-tt_split:]

def batch_words_masks(batch_numbers):
    words = [
        number_to_word(number)
        for number in batch
    ]
    maxlen = max((len(word) for word in words))
    attention_masks = torch.stack([
        torch.cat([torch.ones(len(word)), torch.zeros(maxlen - len(word))])
        for word in words
    ])
    words = torch.stack([
        torch.frombuffer(word.ljust(maxlen).encode('iso-8859-1'), dtype=torch.uint8)
        for word in words
    ])
    return words, attention_masks

# so on one end of the model, we take or output a single 'token': the number
# on the other end, we output or take many tokens: the word

model.train()
optim = torch.optim.SGD(model.parameters(), lr=0.0001)
for batch in train_data:
    optim.zero_grad()
    labels = 
    loss, data = model(batch)
    optim.step()
