# Creates unigram LM following KenLM
import math
import shutil, tempfile


def calculate_log_probabilities(word_counts, num_sentences, n_smoothing=0.01):
    """
    Calculate log probabilities for each word in the corpus,
    including a special <unk> token for unknown words.
    """
    total_words = sum(word_counts.values())
    total_words += 2 * num_sentences  # add counts for <s> and </s>
    # Adjust total for <unk>
    total_words_with_unk = total_words + 1  # Adding 1 for <unk>
    total_words_with_unk = total_words_with_unk + total_words_with_unk * n_smoothing

    # Calculate probabilities, adjust for <unk>
    probabilities = {
        word: ((count + n_smoothing) / total_words_with_unk)
        for word, count in word_counts.items()
    }
    probabilities["<unk>"] = 1 / total_words_with_unk
    probabilities["<s>"] = (num_sentences + n_smoothing) / total_words_with_unk
    probabilities["</s>"] = (num_sentences + n_smoothing) / total_words_with_unk

    # Convert to log probabilities
    return {word: math.log10(prob) for word, prob in probabilities.items()}


def maybe_generate_pseudo_bigram_arpa(arpa_fpath):
    with open(arpa_fpath, "r") as file:
        lines = file.readlines()

    # if ngram order >=2 , do not modify
    if any(["2-grams:" in l for l in lines]):
        return

    with open(arpa_fpath, "w") as file:
        for line in lines:
            if line.strip().startswith("ngram 1="):
                file.write(line)
                file.write("ngram 2=1\n")  # Add the new ngram line
                continue

            if line.strip() == "\\end\\":
                file.write("\\2-grams:\n")
                file.write("-9.9999999\t</s> <s>\n\n")

            file.write(line)


def save_log_probabilities(log_probabilities, file_path):
    with open(file_path, "w") as file:
        file.write(f"\data\\")
        file.write(f"\n")
        file.write(f"ngram 1={len(log_probabilities)}\n\n")
        file.write(f"\\1-grams:")
        file.write(f"\n")
        for word, log_prob in log_probabilities.items():
            if word == "<s>":
                log_prob = 0
            file.write(f"{log_prob}\t{word}\n")
        file.write(f"\n")
        file.write(f"\end\\")


def create_unigram_lm(word_counts, num_sentences, file_path, n_smoothing=0.01):
    log_probs = calculate_log_probabilities(word_counts, num_sentences, n_smoothing)
    save_log_probabilities(log_probs, file_path)