import jax, torch, math
import memory_efficient_attention

def moskomule_attention(queries, keys, values, query_chunk_size, key_chunk_size):
    queries = queries.permute(0, 2, 1, 3)
    keys = keys.permute(0, 2, 1, 3)
    values = values.permute(0, 2, 1, 3)
    queries /= math.sqrt(keys.shape[-1])
    return memory_efficient_attention.efficient_attention(queries, keys, values, chunk_size=key_chunk_size, checkpointing = True, out_of_place = False)

def aminrezaei_attention(queries, keys, values, query_chunk_size, key_chunk_size):
    return memory_efficient_attention.efficient_dot_product_attention_pt(queries, keys, values, None, None, query_chunk_size = query_chunk_size, key_chunk_size = key_chunk_size)

def attention(queries, keys, values, query_chunk_size, key_chunk_size):
    if hasattr(memory_efficient_attention, 'efficient_attention'):
        return moskomule_attention(queries, keys, values, query_chunk_size, key_chunk_size)
    elif hasattr(memory_efficient_attention, 'efficient_dot_product_attention_pt'):
        return aminrezaei_attention(queries, keys, values, query_chunk_size, key_chunk_size)

if __name__ == '__main__':
    queries, keys, values = torch.from_numpy(jax.random.normal(jax.random.PRNGKey(0), (3, 1, 64, 8, 16)).to_py())
    out = attention(queries, keys, values, query_chunk_size=4, key_chunk_size=4)
    print(out)
