import transformers
from transformers import AutoTokenizer
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
)
from transformers import pipeline, set_seed, LogitsProcessor
from transformers.generation.logits_process import TopPLogitsWarper, TopKLogitsWarper
import torch
from scipy.special import gamma, gammainc, gammaincc, betainc
from scipy.optimize import fminbound
import numpy as np

import os

hf_token = os.getenv('HF_TOKEN')


device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

def hash_tokens(input_ids: torch.LongTensor, key: int):
    seed = key
    salt = 35317
    for i in input_ids:
        seed = (seed * salt + i.item()) % (2 ** 64 - 1)
    return seed

class WatermarkingLogitsProcessor(LogitsProcessor):
    def __init__(self, n, key, messages, window_size, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.batch_size = len(messages)
        self.generators = [ torch.Generator(device=device) for _ in range(self.batch_size) ]

        self.n = n
        self.key = key
        self.window_size = window_size
        if not self.window_size:
            for b in range(self.batch_size):
                self.generators[b].manual_seed(self.key)

        self.messages = messages

class WatermarkingAaronsonLogitsProcessor( WatermarkingLogitsProcessor):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:       
        # get random uniform variables
        B, V = scores.shape

        r = torch.zeros_like(scores)
        for b in range(B):
            if self.window_size:
                window = input_ids[b, -self.window_size:]
                seed = hash_tokens(window, self.key)
                self.generators[b].manual_seed(seed)
            r[b] = torch.rand(self.n, generator=self.generators[b], device=self.generators[b].device).log().roll(-self.messages[b])
        # generate n but keep only V, as we want to keep the pseudo-random sequences in sync with the decoder
        r = r[:,:V]

        # modify law as r^(1/p)
        # Since we want to return logits (logits processor takes and outputs logits),
        # we return log(q), hence torch.log(r) * torch.log(torch.exp(1/p)) = torch.log(r) / p
        return r / scores.exp()

class WatermarkingKirchenbauerLogitsProcessor(WatermarkingLogitsProcessor):
    def __init__(self, *args,
                 gamma = 0.5,
                 delta = 4.0,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.gamma = gamma
        self.delta = delta

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        B, V = scores.shape

        for b in range(B):
            if self.window_size:
                window = input_ids[b, -self.window_size:]
                seed = hash_tokens(window, self.key)
                self.generators[b].manual_seed(seed)
            vocab_permutation = torch.randperm(self.n, generator=self.generators[b], device=self.generators[b].device)
            greenlist = vocab_permutation[:int(self.gamma * self.n)] # gamma * n
            bias = torch.zeros(self.n).to(scores.device)
            bias[greenlist] = self.delta
            bias = bias.roll(-self.messages[b])[:V]
            scores[b] += bias # add bias to greenlist words

        return scores

class Watermarker(object):
    def __init__(self, modelname="facebook/opt-350m", window_size = 0, payload_bits = 0, logits_processor = None, *args, **kwargs):
        self.tokenizer = AutoTokenizer.from_pretrained(modelname, use_auth_token=hf_token)
        self.model = AutoModelForCausalLM.from_pretrained(modelname, use_auth_token=hf_token).to(device)
        self.model.eval()
        self.window_size = window_size

        # preprocessing wrappers
        self.logits_processor = logits_processor or []

        self.payload_bits = payload_bits
        self.V = max(2**payload_bits, self.model.config.vocab_size)
        self.generator = torch.Generator(device=device)


    def embed(self, key=42, messages=[1234], prompt="", max_length=30, method='aaronson'):

        B = len(messages) # batch size
        length = max_length
         
        # compute capacity
        if self.payload_bits:
            assert min([message >= 0 and message < 2**self.payload_bits for message in messages])

        # tokenize prompt
        inputs = self.tokenizer([ prompt ] * B, return_tensors="pt")

        if method == 'aaronson':
            # generate with greedy search
            generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False,
                                                logits_processor = self.logits_processor + [
                                                    WatermarkingAaronsonLogitsProcessor(n=self.V,
                                                                                        key=key,
                                                                                        messages=messages,
                                                                                        window_size = self.window_size)])
        elif method == 'kirchenbauer':
            # use sampling
            generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True,
                                                logits_processor = self.logits_processor + [
                                                    WatermarkingKirchenbauerLogitsProcessor(n=self.V,
                                                                                            key=key,
                                                                                            messages=messages,
                                                                                            window_size = self.window_size)])
        elif method == 'greedy':
            # generate with greedy search
            generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False,
                                                logits_processor = self.logits_processor)
        elif method == 'sampling':
            # generate with greedy search
            generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True,
                                                logits_processor = self.logits_processor)
        else:
           raise Exception('Unknown method %s' % method)
        decoded_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)

        return decoded_texts
    
    def detect(self, attacked_texts, key=42, method='aaronson', gamma=0.5, prompts=None):
        if(prompts==None):
            prompts = [""] * len(attacked_texts)

        generator = self.generator

        #print("attacked_texts = ", attacked_texts)

        cdfs = []
        ms = []

        MAX = 2**self.payload_bits
        
        # tokenize input
        inputs = self.tokenizer(attacked_texts, return_tensors="pt", padding=True, return_attention_mask=True)
                        
        input_ids = inputs["input_ids"].to(self.model.device)
        attention_masks = inputs["attention_mask"].to(self.model.device)

        B,T = input_ids.shape

        if method == 'aaronson_neyman_pearson':
            # compute logits
            outputs = self.model.forward(input_ids, return_dict=True)
            logits = outputs['logits']
            # TODO
            # reapply logits processors to get same distribution
            #for i in range(T):
            #    for processor in self.logits_processor:
            #        logits[:,i] = processor(input_ids[:, :i], logits[:, i])

            probs = logits.softmax(dim=-1)
            ps = torch.gather(probs, 2, input_ids[:,1:,None]).squeeze_(-1)


        seq_len = input_ids.shape[1]
        length = seq_len

        V = self.V
                
        Z = torch.zeros(size=(B, V), dtype=torch.float32, device=device)


        # keep a history of contexts we have already seen,
        # to exclude them from score aggregation and allow
        # correct p-value computation under H0
        history = [set() for _ in range(B)]

        attention_masks_prompts = self.tokenizer(prompts, return_tensors="pt", padding=True, return_attention_mask=True)["attention_mask"]
        prompts_length = torch.sum(attention_masks_prompts, dim=1)
        for b in range(B):
            attention_masks[b, :prompts_length[b]] = 0
            if not self.window_size:
                generator.manual_seed(key)
            # We can go from seq_len - prompt_len, need to change +1 to + prompt_len
            for i in range(seq_len-1):
            
                if self.window_size:
                    window = input_ids[b, max(0, i-self.window_size+1):i+1]
                    #print("window = ", window)
                    seed = hash_tokens(window, key)
                    if seed not in history[b]:
                        generator.manual_seed(seed)
                        history[b].add(seed)
                    else:
                        # ignore the token
                        attention_masks[b, i+1] = 0

                if not attention_masks[b,i+1]:
                    continue

                token = int(input_ids[b,i+1])

                if method in {'aaronson', 'aaronson_simplified', 'aaronson_neyman_pearson'}:
                    R = torch.rand(V, generator = generator, device = generator.device)

                if method == 'aaronson':
                    r = -(1-R).log()
                elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}:
                    r = -R.log()
                elif method == 'kirchenbauer':
                    r = torch.zeros(V, device=device)
                    vocab_permutation = torch.randperm(V, generator = generator, device=generator.device)
                    greenlist = vocab_permutation[:int(gamma * V)]
                    r[greenlist] = 1
                else:
                    raise Exception('Unknown method %s' % method)

                if method in {'aaronson', 'aaronson_simplified', 'kirchenbauer'}:
                    # independent of probs
                    Z[b] += r.roll(-token)
                elif method == 'aaronson_neyman_pearson':
                    # Neyman-Pearson
                    Z[b] += r.roll(-token) * (1/ps[b,i] - 1)

        for b in range(B):
            if method in {'aaronson', 'kirchenbauer'}:
                m = torch.argmax(Z[b,:MAX])
            elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}:
                m = torch.argmin(Z[b,:MAX])

            i = int(m)
            S = Z[b, i].item()
            m = i

            # actual sequence length
            k = torch.sum(attention_masks[b]).item() - 1

            if method == 'aaronson':
                cdf = gammaincc(k, S)
            elif method == 'aaronson_simplified':
                cdf = gammainc(k, S)
            elif method == 'aaronson_neyman_pearson':
                # Chernoff bound
                ratio = ps[b,:k] / (1 - ps[b,:k])
                E = (1/ratio).sum()

                if S > E:
                    cdf = 1.0
                else:
                    # to compute p-value we must solve for c*:
                    # (1/(c* + ps/(1-ps))).sum() = S
                    func = lambda c : (((1 / (c + ratio)).sum() - S)**2).item()
                    c1 = (k / S - torch.min(ratio)).item()
                    print("max = ", c1)
                    c = fminbound(func, 0, c1)
                    print("solved c = ", c)
                    print("solved s = ", ((1/(c + ratio)).sum()).item())
                    # upper bound
                    cdf = torch.exp(torch.sum(-torch.log(1 + c / ratio)) + c * S)
            elif method == 'kirchenbauer':
                cdf = betainc(S, k - S + 1, gamma)

            if cdf > min(1 / MAX, 1e-5):
                cdf = 1 - (1 - cdf)**MAX # true value
            else:
                cdf = cdf * MAX # numerically stable upper bound
            cdfs.append(float(cdf))
            ms.append(m)

        return cdfs, ms