import torch, torch.nn
import tqdm

class Transformer(torch.nn.Module):
    def __init__(self, d_model = 512, d_ff = 2048, n_layers = 6, d_input = None, d_output = None, is_causal = False, dtype=float, device='cpu'):
        super().__init__()
        self.encoding = d_input and torch.nn.Linear(d_input, d_model, dtype=dtype, device=device)
        self.decoding = d_output and torch.nn.Linear(d_model, d_output, dtype=dtype, device=device)
        if n_layers > 1:
            self.recoding = torch.nn.Linear(d_model*2, d_model, dtype=dtype, device=device)
        self.n_layers = n_layers
        self.d_input = d_input or d_model
        self.d_output = d_output or d_model
        self.is_causal = is_causal
            
        self.transformer = torch.nn.Transformer(
            d_model = d_model, 
            nhead = 1,
            num_decoder_layers = 1,#n_layers,
            dim_feedforward = d_ff,
            dropout = 0,#.1,
            activation = torch.nn.functional.silu,
            custom_encoder = lambda input, *params, **kwparams: input,
            layer_norm_eps = 1e-5,
            batch_first = True,
            norm_first = True, # some interest in removing norm
            bias = True,
            device = device,
            dtype = dtype,
        )

        self.raw_size = sum([p.view(-1).shape[-1] for p in self.parameters()])
        self.raw_dims = max([len(p.shape) for p in self.parameters()])
        self.encoded_size = [self.raw_size, 2 + self.raw_dims]
        #self.raw_names = list({n for n, p in self.named_parameters()})
        #self.raw_names.sort()
        #self.raw_names = { name: idx for idx, name in enumerate(self.raw_names) }
        #self.raw_depth = max([n.count('.')+1 for n, p in self.named_parameters()])
        ##self.raw_count = sum([1 for p in self.parameters()])
        #self.raw_names = list({name for n, p in self.named_parameters() for name in n.split('.')})
        #self.raw_names.sort()
        #self.raw_types = list({str(type(m)) for m in self.modules()})
        #self.raw_types.sort()
    def forward(self, inp):
        if self.encoding:
            inp = self.encoding(inp)
        out = self.transformer(
            src=torch.empty(list(inp.shape[:-2])+[0,inp.shape[-1]], dtype=inp.dtype, device=inp.device),
            tgt=inp,
            tgt_is_causal=self.is_causal,
        )
        if self.n_layers > 1:
            out = out[:,None,...]
            for layer in range(self.n_layers - 1):
                out = torch.cat([
                    out,
                    self.transformer(
                        #src=self.recoding(out[:,-1,...]),
                        #tgt=inp,
                        src=torch.empty(list(inp.shape[:-2])+[0,inp.shape[-1]], dtype=inp.dtype, device=inp.device),
                        #tgt=torch.cat([self.recoding(out[:,-1,...]),inp],dim=-2),
                        #tgt=self.recoding(out[:,-1,...]),
                        tgt=self.recoding(torch.cat([inp,out[:,-1,...]], dim=-1)),
                        tgt_is_causal=self.is_causal,
                    )[:,None,-out.shape[-2]:,:]
                ], dim=1)
        if self.decoding:
            out = self.decoding(out)
        return out
    CONST_DATA = 1
    CONST_ENCODED = 2
    #def forward_encoded(self, inp, model, encoded_0 = None):
    #    encoded_0 = encoded_0 or model.to_encoded()
    #    inputs_n = inputs.shape[0] * inputs.shape[1]
    #    context = torch.cat([
    #        torch.full([inputs_n, 1], CONST_DATA, device=inputs.device), # labels
    #        torch.arange(end=inputs.shape[0], device=inputs.device)[:,None,None].expand([inputs.shape[0], inputs.shape[1], 1]).reshape(inputs_n, 1), # batch ids
    #        inputs.view([inputs_n, inputs.shape[-1]]), # data
    #        torch.zeros([inputs_n, t_0.shape[-1] - inputs.shape[-1] - 1], device=inputs.device), # padding
    #    ])
    #    out = self(inp)
    #    model.from_encoded(out)
    def from_raw(self, raw):
        idx = 0
        with torch.no_grad():
            for p in self.parameters():
                p = p.view(-1)
                size = p.shape[-1]
                p[:] = raw[idx:idx+size]
                idx += size
    def to_raw(self, out=None):
        idx = 0
        for p in self.parameters():
            if out is None:
                out = torch.empty([self.raw_size], dtype=p.dtype, device=p.device)
            p = p.view(-1)
            size = p.shape[-1]
            out[idx:idx+size] = p
            idx += size
        return out
    def to_encoded(self, out=None):
        idx = 0
        ct = 0
        names = {}
        types = {}
        for name, param in self.named_parameters():
            if out is None:
                #out = torch.empty([self.raw_size, 1 + self.raw_dims], dtype=p.dtype, device=p.device)
                out = torch.empty(self.encoded_size, dtype=param.dtype, device=param.device)
                    ##### # noting some [harmed, like starchy confusion, common sadly] space, so taking space/distance [didn't take enough]
            flattened = param.view(-1)
            size = flattened.shape[0]
            dims = len(param.shape)
            #dims = list(param.shape)
            out[idx:idx+size,0] = flattened
            out[idx:idx+size,1] = ct
            # might need to swizzle this one
            out[idx:idx+size,2:-dims] = 0
            out[idx:idx+size,-dims:] = torch.stack(torch.meshgrid([torch.arange(dim) for dim in param.shape], indexing='ij'), dim=-1).view(-1,dims)
            ct += 1
            idx += size
        return out
    def from_encoded(self, encoded):
        return self.from_raw(encoded[:,0])

    def train_data(self, trainer, inputs, outputs, context = None, accuracy = 0.5):
        return trainer.train(model=self, inputs=inputs, outputs=outputs, context=context, accuracy=accuracy)

    #def train_encoded(self, trainer, model, inputs, outputs, context = None, accuracy = 0.5):
    #    return trainer.train(model


# - make it single-pass (so the loss is of the output's forward)
# - make it save state

# if we abstract the training approach into a class, this can then be passed to a function to train on e.g. data or result from encoded
# the function can then be first class, with the training approach a parameter, which might simplify the work some around concept of interest
#### cognition maybe simpler if train_rect put into Transformer class?
class RectTrainer:
    # calling it rectangular training when all data trained equally
        # i think it would be faster to not do it fully rectangular
    def __init__(self, optim = torch.optim.SGD, lr=1e-3, loss_fn = torch.nn.functional.mse_loss, layer_loss_fn = lambda loss, idx, total: loss ** idx):
        self.optim = optim
        #self.optim_kwparams = optim_kwparams
        self.lr = lr
        self.loss_fn = loss_fn
        self.layer_loss_fn = layer_loss_fn
    def train(self, model, inputs, outputs, context = None, accuracy = 0.5):
        lr = self.lr
        optim = self.optim(model.parameters(), lr=lr)#**self.optim_kwparams)
        last_ls = None
        # we might add acceleration once something basic works. oh!
        if context is not None:
            ctx_inputs = torch.cat([context,inputs],dim=-2)
            output_offset = context.shape[-2]
        else:
            ctx_inputs = inputs
            output_offset = 0
        with tqdm.tqdm(total=0.5, unit='ls') as pbar:
            while (last_ls is None or last_ls > accuracy) and lr:
                attempt = model(ctx_inputs)
                if len(attempt.shape) <=3:
                    ls = loss(attempt[:,output_offset:,:], outputs)
                    ls_item = ls.item()
                else:
                    # earlier layers are included in the loss so that depth can be changed to
                    # exchange computation time for output accuracy.
                    ls = [
                        self.loss_fn(attempt[:,idx,output_offset:,:], outputs)
                        for idx in range(attempt.shape[1])
                    ]
                    ls_item = ls[-1].item()
                    ls = torch.stack([self.layer_loss_fn(ls[idx], idx, attempt.shape[1]) for idx in range(attempt.shape[1])]).sum()
                if last_ls is not None and last_ls < ls_item: #last_ls - ls < ls * lr: # this is likely wrong, maybe change to last_ls < ls
                    lr /= 2
                    optim.param_groups[0]['lr'] = lr
                    pbar.display()
                else:
                    ls.backward()
                    optim.step()
                    optim.zero_grad()
                    last_ls = ls_item
                pbar.desc = f'lr={lr}, acc={last_ls}'
                if ls_item + accuracy > pbar.total:
                    pbar.total = ls_item + accuracy
                pbar.update(pbar.total - ls_item + accuracy - pbar.n)

#class TransformerTransformer(Transformer):
    #def __init__(self, d_model = 512, d_ff = 2048, n_layers = 6, d_input = None, d_output = None, dtype=float, device='cpu'):

if __name__ == '__main__':

    trainer = RectTrainer(lr=0.0001)
    t = Transformer(d_model=8, d_ff=16, n_layers=2, d_input=2, d_output=1, dtype=torch.float32, device=0)
    t_0 = t.to_encoded()
    t2 = Transformer(d_model=8, d_ff=16, n_layers=2, d_input=1+t.encoded_size[-1], d_output=1, dtype=torch.float32, device=0)

    inputs = torch.rand([3,16,2],device=0)
    inputs[:,:,1] = torch.arange(end=inputs.shape[-2], device=0)
    outputs = torch.rand([3,16,1],device=0)
    t.train_data(trainer, inputs, outputs, accuracy=0.1)
    #train_rect(
    #    model = t,
    #    inputs = inputs,
    #    outputs = outputs,
    #    accuracy = 0.1,
    #    lr = 0.0001,
    #) # move to next step. it works well enough. put one inside another.
    t_f = t.to_raw()



    CONST_DATA = 1
    CONST_ENCODED = 2

    # inputs are a sequence of vectors of properties, many of these sequences collected into a batch
    # dims are seq_idx, prop_idx
    inputs_n = inputs.shape[0] * inputs.shape[1]
    inputs2_data = torch.cat([
        torch.full([inputs_n, 1], CONST_DATA, device=inputs.device), # labels
        torch.arange(end=inputs.shape[0], device=inputs.device)[:,None,None].expand([inputs.shape[0], inputs.shape[1], 1]).reshape(inputs_n, 1), # batch ids
        inputs.view([inputs_n, inputs.shape[-1]]), # data
        torch.zeros([inputs_n, t_0.shape[-1] - inputs.shape[-1] - 1], device=inputs.device), # padding
    ], dim=-1)[None,...]
    inputs2_encoded = torch.cat([
        torch.full([t_0.shape[0], 1], CONST_ENCODED, device=inputs.device), # labels
        t_0.detach(), # data
    ], dim=-1)[None,...]
    outputs2 = t_f.detach()[None,:,None]
    t2.train_data(trainer, context=inputs2_data, inputs=inputs2_encoded, outputs=outputs2, accuracy=0.1)
    #train_rect(
    #    model = t2,
    #    context = inputs2_data,
    #    inputs = inputs2_encoded,
    #    outputs = outputs2,
    #    accuracy = 0.1,
    #    lr=0.0001,
    #)
    # next we want to try doing away with the first training step and train the second model directly on its data.
    # that will be in a new file.
    # this could be important because the loss is bound directly to the output. it may help ensure an approach is stabilized.
