import re
import pickle
import os

from gui.service_logger import setup_logging

basedir = os.path.dirname(__file__)
user_appdata = os.path.join(os.getenv('APPDATA'), 'vitasis', 'tb-listener')

# create user App data folder if not exists
os.makedirs(user_appdata, exist_ok=True)
os.makedirs(os.path.join(user_appdata, 'templates'), exist_ok=True)

# setup logger
logger = setup_logging("corrector", user_appdata)

def slovenian_number(n):
    enote = ["", "ena", "dva", "tri", "štiri", "pet", "šest", "sedem", "osem", "devet"]
    enote_masculine = ["", "en", "dva", "tri", "štiri", "pet", "šest", "sedem", "osem", "devet"]
    posebni = {
        0: ["nič", "nula"],
        10: "deset", 11: "enajst", 12: "dvanajst", 13: "trinajst", 14: "štirinajst",
        15: "petnajst", 16: "šestnajst", 17: "sedemnajst", 18: "osemnajst", 19: "devetnajst",
        20: "dvajset", 30: "trideset", 40: "štirideset", 50: "petdeset",
        60: "šestdeset", 70: "sedemdeset", 80: "osemdeset", 90: "devetdeset",
        100: "sto", 1000: "tisoč"
    }

    if n in posebni:
        return posebni[n]

    if 1 <= n < 10:
        return enote[n]
    if 10 < n < 20:
        return posebni[n]

    if 21 <= n < 100:
        des = n // 10 * 10
        en = n % 10
        if en == 1:
            return "enain" + posebni[des]
        else:
            return enote[en] + "in" + posebni[des]

    if 100 < n < 1000:
        stotice = n // 100
        ostanek = n % 100
        if stotice == 1:
            prefix = "sto"
        elif stotice == 2:
            prefix = "dvesto"
        elif stotice == 3:
            prefix = "tristo"
        elif stotice == 4:
            prefix = "štiristo"
        elif stotice == 5:
            prefix = "petsto"
        elif stotice == 6:
            prefix = "šeststo"
        elif stotice == 7:
            prefix = "sedemsto"
        elif stotice == 8:
            prefix = "osemsto"
        elif stotice == 9:
            prefix = "devetsto"
        else:
            prefix = ""

        if ostanek == 0:
            return prefix
        else:
            return prefix + " " + slovenian_number(ostanek)

    return str(n)

word_numbers = dict(
    nula=0, nič=0, ena=1, dva=2, tri=3, štiri=4, pet=5, šest=6, sedem=7, osem=8, devet=9, deset=10, 
    enajst=11, dvanajst=12, trinajst=13, štirinajst=14, petnajst=15, šestnajst=16, sedemnajst=17, osemnajst=18, devetnajst=19, dvajset=20,
    enaindvajset=21, dvaindvajset=22, trindvajset=23, štiriindvajset=24, petindvajset=25, šestindvajset=26, sedemindvajset=27, osemindvajset=28, devetindvajset=29, trideset=30, 
    enaintrideset=31, dvaintrideset=33, trintrideset=33, štiriintrideset=34, petintrideset=35, šestintrideset=36, sedemintrideset=37, osemintrideset=38, devetintrideset=39, štirideset=40,
    enainštirideset=41, dvainštirideset=42, trinštirideset=43, štiriinštirideset=44, petinštirideset=45, šestinštirideset=46, sedeminštirideset=47, oseminštirideset=48, devetinštirideset=49, petdeset=50,
    enainpetdeset=51, dvainpetdeset=52, trinpetdeset=53, štiriinpetdeset=54, petinpetdeset=55, šestinpetdeset=56, sedeminpetdeset=57, oseminpetdeset=58, devetinpetdeset=59, šestdeset=60,
    enainšestdeset=61, dvainšestdeset=62, trinšestdeset=63, štiriinšestdeset=64, petinšestdeset=65, šestinšestdeset=66, sedeminšestdeset=67, oseminšestdeset=68, devetinšestdeset=69, sedemdeset=60,
    enainsedemdeset=71, dvainsedemdeset=72, trinsedemdeset=73, štiriinsedemdeset=74, petinsedemdeset=75, šestinsedemdeset=76, sedeminsedemdeset=77, oseminsedemdeset=78, devetinsedemdeset=79, osemdeset=60,
    enainosemdeset=81, dvainosemdeset=82, trinosemdeset=83, štiriinosemdeset=84, petinosemdeset=85, šestinosemdeset=86, sedeminosemdeset=87, oseminosemdeset=88, devetinosemdeset=89, devetdeset=60,
    enaindevetdeset=91, dvaindevetdeset=92, trindevetdeset=93, štiriindevetdeset=94, petindevetdeset=95, šestindevetdeset=96, sedemindevetdeset=97, osemindevetdeset=98, devetindevetdeset=99, sto=100, 
    tisoč=1000
)

# Extend for all missing values
for i in range(0, 1001):
    word = slovenian_number(i)
    if isinstance(word, list):
        for w in word:
            word_numbers.setdefault(w, i)
    else:
        word_numbers.setdefault(word, i)

number_words = dict()
for word, number in word_numbers.items():
    number_words[str(number)] = word

word_symbols = {
    "minus":"-", "plus":"+", "krat":"×", "deljeno":"/", "je":"=", "enako":"=", "je enako":"=", "v oklepaju":"(", "zaklepaj":")", "oklepaj":"(", "zaviti oklepaj": "{", "zaviti zaklepaj":"}",
    "oglati oklepaj": "{", "zaviti zaklepaj":"}"
}

def _nvl(t, rv):
    if t:
        return t
    else:
        return rv
    

class TextCorrector():
    def __init__(
            self,
            left_binded_symbols: str = ".,!?:;)/$€°",
            right_binded_symbols: str = "(",
            named_entities_pickle_file = None,
            translations = None,
        ):
        self.left_binded_symbols = left_binded_symbols
        self.right_binded_symbols = right_binded_symbols
        self.named_entity_file = named_entities_pickle_file
        self.named_entities = dict()
        self.spp_codes = self.read_spp_codes()
        if self.named_entity_file:
            self.named_entities = self.load_named_entities()
        # REGEX_PATTERNS --------------------------------------------------------------------------------------------------------------------------------
        self.corrector_regex_patterns = dict(
            multiple_spaces  = dict(ptr=re.compile(r'[ ]{2,}'), res=' '),
            extra_space_before = dict(ptr=re.compile(rf'(\s)([{self.left_binded_symbols}])'), res=r"\g<2>"),
            extra_space_after = dict(ptr=re.compile(rf'([{self.right_binded_symbols}])(\s)'), res=r"\g<1>"),
            lr_binded_symbol = dict(ptr=re.compile(r'(.)(<->|<\+>|<\*>|<=>|<–>)(.)'), func=self.remove_leading_and_trailing_space),
        )
        # NEW LINE NEW PARAGRAPH MISMATCH
        self.ptrn_new_line_mismatch = re.compile(r'(<nl>|<np>)(\snova|\snov)*(\svrsta|\sodstavek)(\s|$|<)', flags=re.IGNORECASE)
        # SYMBOLS
        self.ptrn_leftside_binded_symbols = re.compile(r'(.|^)(<)([\%\)°$€])(>)(.|$)')
        self.ptrn_binded_symbols = re.compile(r'(.|^)(<)([–])(>)(.|$)')
        # SENTENCE RESTART
        self.ptrn_new_sentence = re.compile(r'([\.\?\!])(\s)([a-zščćžđ])')
        # NUMBERS
        self.ptrn_numbers_to_translate1 = re.compile(r'(\d+)(\s)(tisoč|sto|deset)(\s|\,\.\!\?)', flags=re.IGNORECASE)
        self.ptrn_numbers_to_translate2 = re.compile(r'(tisoč|sto|deset|nič|nula|dva|tri|štiri|pet|šest|sedem|osem|devet)(\s)(\d+)', flags=re.IGNORECASE)
        self.ptrn_leading_zero = re.compile(r'(^|\s)(nič)(\s)(\d+)', flags=re.IGNORECASE)
        self.ptrn_numbers_to_join = re.compile(r'(^|\s)(\d+)(\s)(\d+)')
        self.ptrn_word2symbol = re.compile(r'\b(?:' + '|'.join(re.escape(k) for k in sorted(word_symbols, key=len, reverse=True)) + r')\b')
        # TRANSLATIONS
        self.translations = translations if translations else dict()
        self.ptrn_translation_sources = re.compile(rf'(?:(?<=^)|(?<=\s)|(?<=>))({"|".join(self.translations.keys())})(?=\s|$|<)', flags=re.IGNORECASE)
        # COMPOSITE COMMANDS - there are only two composites at the moment: DEL n and INS n. The insertion is handled as a modifier since after INS command
        # there must always be a template name, while in the DEL case the DEL command is already a standalone command
        self.ptrn_composite_DEL = re.compile(r'(<delw>)(\s)*(\d+|ena|dva|tri|štiri|pet|šest|sedem|osem|devet|deset)(.*)', flags=re.IGNORECASE)
        # ------------------------------------------------------------------------------------------------------------------------------------------------

    def read_spp_codes(self):
        """
        Read SPP codes from a file
        """
        spp = dict()
        if os.path.isfile(os.path.join(basedir, "spp_codes.psv")):
            with open(os.path.join(basedir, "spp_codes.psv"), "r", encoding="utf-8") as fspp:
                for line in fspp.readlines():
                    fields = line.split("|")
                    spp[fields[1]]=fields[2].strip()
                    # add a variation without dot notation: Z10.0 --> Z100
                    if fields[1].find('.') != -1:
                        spp[fields[1].replace(".", "")]=fields[2].strip()
            logger.info(f"{len(spp)} SPP codes and their variations imported from {os.path.join(basedir, 'spp_codes.psv')}")
            return spp
        else:
            logger.error(f"File gui/spp_codes.psv is missing! SPP codes will not be detected")
            return None

    def load_named_entities(self):
        """
        Load named entity data from a pickle file
        """
        infile = open(self.named_entity_file,'rb')
        ne = pickle.load(infile)
        infile.close()
        print(f"{len(ne)} named entities loaded for text corrector!")
        return ne

    def is_named_entity(self, t, treshold=0.75):
        """
        Returns True if t is a named entity
        Args:
        - t: token
        - named_entities: dict of named entities. Each token has values for L (lowercased), U (uppercased), C (capital) and M (mixed). 
            Each value tells how many occurancies of the token t appeared in upper, lower, capital and mixed case.
        - treshold: if in more than 'treshold' cases token t appeared as lowercase token, then return False
        """
        def remove_trailing_symbols(t):
            return "".join([c for c in t if c.isalpha()])
        
        t = t.lower()
        t = remove_trailing_symbols(t)

        if t in self.named_entities:
            if (self.named_entities[t]['C'] + self.named_entities[t]['U'] + self.named_entities[t]['M']) == 0 or \
                self.named_entities[t]['L'] / (self.named_entities[t]['C'] + self.named_entities[t]['U'] + self.named_entities[t]['M']) > treshold:
                return False
            else:
                return True
        
        return False        
        
    def uppercase_match(self, match):
        """
        Returns uppercase letter for a match where lowercased token is after .?!
        """
        return match.group(1) + match.group(2).upper()
    
    def remove_leading_and_trailing_space(self, match):
        """
        Removes space before/after symbols that are left and right binded
        """
        left_part = "" if match.group(1) == " " else match.group(1)
        right_part = "" if match.group(3) == " " else match.group(3)
        return left_part + match.group(2)[1:-1] + right_part
    
    def lowercase_match(self, match):
        """
        Returns lowercase letter for a match where uppercased token is after ,:; and the token is not named entity
        """
        if self.is_named_entity(match.group(2)):
            return match.group(0)
        return match.group(1) + match.group(2).lower()

    def correct_typos(self, t):
        """
        Iterates over all patterns and makes substitutions where matches found
        """
        for _, item in self.corrector_regex_patterns.items():
            if 'func' in item:
                t = item['ptr'].sub(item['func'], t)
            else:
                t = re.sub(item['ptr'], item['res'], t)
        return t

    def correct_commands_mismatch(self, t):
        """
        Sometimes commands are identified partially resulting to text like this: '<nl> vrsta' or '<np> odstavek'
        These mistakes are handled here
        TODO: these are issues that BE must elimite!!!
        """
        while True:
            res = re.search(self.ptrn_new_line_mismatch, t)
            if res:
                t = re.sub(self.ptrn_new_line_mismatch, r'\g<1>\g<4>', t)
            else:
                return t

    def start_with_capital(self, t):
        """
        If a new sentence is detected within an ASR response, make sure it starts with a capital letter
        """
        def repl(m):
            return f'{m.group(1)}{m.group(2)}{m.group(3).upper()}'

        t = re.sub(self.ptrn_new_sentence, repl, t)
        return t

    def handle_translations(self, t, translations):
        """
        This will apply mappings for all source, target pairs from the settings file. Translations are initialized as a main
        window variable.
        """
        def applay_translation(match):
            return translations.get(match.group(0).lower(), match.group(0))

        # first make corrections that are case sensitive
        res = t.replace('<UC>', '<ucc>').replace('<LC>', '<lcc>')
        # rest of the mappings
        res = re.sub(self.ptrn_translation_sources, applay_translation, t)
        return res

    def handle_symbols(self, t):
        """
        If symbols like +, *, ° etc. come as <>, remove parenthesis and hadle correct spacing
        """
        def binded_symbols_correction(match):
            left_space = '' if match.group(1)==' ' else match.group(1)
            right_space = '' if match.group(5)==' ' else match.group(5)
            return left_space + match.group(3) + right_space

        def liftside_binded_symbols_correction(match):
            left_space = '' if match.group(1)==' ' else match.group(1)
            return left_space + match.group(3) + match.group(5)
        
        res = re.sub(self.ptrn_binded_symbols, binded_symbols_correction, t)
        res = re.sub(self.ptrn_leftside_binded_symbols, liftside_binded_symbols_correction, res)
        return res

    def handle_composite_commands(self, t):
        """
        Composite commands are detected and translated to a specific command identifier. 
        """
        # check for composite DEL
        res = re.sub(self.ptrn_composite_DEL, r'<deln_\g<3>_\g<4>', t)
        if res != t:
            cmd, arg, rest = res.split('_')
            arg = arg.strip().lower()
            res_corrected = f"{cmd}_{word_numbers.get(arg, arg)}>{rest}"
            return res_corrected
        else:
            return t

    def handle_numbers(self, final):
        """
        Handle number replacements such as 1900 70 --> 1970 and similar
        RULE 1: if two numbers where one written with letters other with digits, like 3 tisoč --> 3000
        RULE 2: if leading zero with letters, like 'nič 1' or 'nič 2'
        RULE 3: 
            - if one number after another without a punctuation, like: 20 15 --> change by joining to 2015
            - if first number is hundred or thousand, like: 100 15 --> change by sum to 115
            - TODO: if first number is 2, 3, 4, 5, 6, 7, 8, 9 and second is 100 or 1000 --> change by multiplication, e.g. 3 1000 --> 3000
        """
        translations = []

        def convert_number(t):
            return word_numbers.get(t.lower(), t)

        # RULE 1 ------------------------------------------------------
        old_final = final
        res = re.search(self.ptrn_numbers_to_translate2, final)
        while res:
            num = f"{convert_number(res.group(1))} {res.group(3)}"
            if num==res:
                break
            final = re.sub(self.ptrn_numbers_to_translate2, num, final) 
            res = re.search(self.ptrn_numbers_to_translate2, final)
        if old_final != final:
            translations.append(f"R1: {old_final} --> {final}")

        old_final = final
        res = re.search(self.ptrn_numbers_to_translate1, final)
        while res:
            num = f"{res.group(1)} {convert_number(res.group(3))}"
            if num==res:
                break
            final = re.sub(self.ptrn_numbers_to_translate1, num, final) 
            res = re.search(self.ptrn_numbers_to_translate1, final)
            print(num, final, res)
        if old_final != final:
            translations.append(f"R1: {old_final} --> {final}")

        # RULE 2 ------------------------------------------------------
        old_final = final
        res = re.search(self.ptrn_leading_zero, final)
        while res:
            final = re.sub(self.ptrn_leading_zero, f"0{res.group(4)}", final) 
            res = re.search(self.ptrn_leading_zero, final)
        if old_final != final:
            translations.append(f"R2: {old_final} --> {final}")

        # RULE 3 ------------------------------------------------------
        old_final = final
        res = re.search(self.ptrn_numbers_to_join, final)
        while res:
            x1 = int(res.group(2))
            x2 = int(res.group(4))
            if x1!=0 and (x1 % 1000 == 0 or x1 % 100 ==0) and len(str(x2))<len(str(x1)):
                num = _nvl(res.group(1), "") + str(x1+x2)
            else:
                num = _nvl(res.group(1), "") + res.group(2)+res.group(4)
            final = final[0:res.span()[0]] + num + final[res.span()[1]:]
            res = re.search(self.ptrn_numbers_to_join, final)
        if old_final != final:
            translations.append(f"R3: {old_final} --> {final}")

        return final, translations

    def handle_word2digit(self, final):
        """
        Iterates over all tokens and replace words to digits
        """
        return " ".join([str(word_numbers.get(token.lower(), token)) for token in final.split()])
    
    def handle_digit2word(self, final):
        """
        Iterates over all tokens and replace digits with words
        """
        return " ".join([str(number_words.get(token, token)) for token in final.split()])
    
    def handle_word2symbols(self, final):
        """
        Converts words to symbols
        """
        return self.ptrn_word2symbol.sub(lambda m: word_symbols[m.group(0)], final)

    def handle_spp_codes(self, final):
        """
        Returns SSP code description for the provided SPP code.
        I expect the code is dictated as a standalone final.
        """
        if not self.spp_codes:
            return final
        formated_spp_code = final.strip().upper().replace(" ", "")
        print(f"FINAL: {final}, FORMATED: {formated_spp_code}")
        spp_code_desc = self.spp_codes.get(formated_spp_code)
        if spp_code_desc:
            return f"{formated_spp_code}: {spp_code_desc}"
        else:
            return None

    def run_tests(self, tid=None):
        """
        This script is used for testing functionality of the TextCorrector class.
        Add tests for each functionality.
        """
        text = [
            "To je primer . napačnega teksta!",
            "Še tole. z malo !",
            "Ali   tole , kako?",
            "( danes smo bili na pregledu v UKC).",
            "Primer stavka, Ko se začne z veliko po vejici",
            "Še to, UKC je bolnica",
            "Ne damo: Apgar",
            "Neverjetno je bilo; Marko"
        ]

        ne = ['marko', 'aspirin', 'maja', 'december', 'klinični']

        print("Running text correction tests")
        print("==================================================")
        if tid:
            print(f"TEST {tid}: {text[tid]} ==> {self.correct_typos(text[tid])}")
        else:
            for c, t in enumerate(text):
                print(f"TEST {c}: {t} ==> {self.correct_typos(t)}")

        print("\nRunning named entity tests")
        print("==================================================")
        for c, t in enumerate(ne):
            print(f"TEST {c}: {t} ==> {self.is_named_entity(t)}")
        

if __name__ == "__main__":
    # run tests
    print(os.path.dirname(__file__))
    tc = TextCorrector(
        named_entities_pickle_file=os.path.join(basedir, 'resources', 'named_entities_ver20210621.pckl')
        )
    tc.run_tests()
