"""
Use this module to compare TESTSET transcripts created with different models
The script expects two result folders to be provided on the input, each associated with a model to be in the comparision
The subfolder structure of the result folders must be the same.

Date: December 2021
Author: VITASIS
"""

import os

_CONFIG = {
    "folderA": "/Users/markob/Google Drive/FRI/LPT Projects/TESTSET/_results/json/NEMO_Conformer-BPE-20211206-KENLM",
    "folderB": "/Users/markob/Google Drive/FRI/LPT Projects/TESTSET/_results/json/NEMO_Conformer-BPE-20211206",
    "exclude_files": ["SC01.txt"]
}

def compare(fileA, fileB):
    inp1 = open(fileA, "r", encoding="utf-8")
    inp2 = open(fileB, "r", encoding="utf-8")
    aligned_data = needle(str(inp1.readlines()).strip("\n").split(" "), str(inp2.readlines()).strip("\n").split(" "))    
    inp1.close()
    inp2.close()
    return aligned_data

def print_differences(aligned_data):
    for i in range(0, len(aligned_data[2])):
        if aligned_data[2][i] != aligned_data[3][i]:
            print(f"pos: {i}, fileA: {aligned_data[2][i]}, fileB: {aligned_data[3][i]}")

# These are parameters for the aligner
match_award      = 20
mismatch_penalty = -5
gap_penalty      = -5 # both for opening and extanding

def zeros(shape):
    retval = []
    for x in range(shape[0]):
        retval.append([])
        for y in range(shape[1]):
            retval[-1].append(0)
    return retval

def match_score(alpha, beta):
    if alpha == beta:
        return match_award
    elif alpha == '-' or beta == '-':
        return gap_penalty
    else:
        return mismatch_penalty

def finalize(align1, align2):
    align1.reverse()    #reverse sequence 1
    align2.reverse()    #reverse sequence 2
    
    i,j = 0,0
    
    #calcuate identity, score and aligned sequeces

    found = 0
    score = 0
    identity = 0
    for i in range(0,len(align1)):
        # if two AAs are the same, then output the letter
        if align1[i] == align2[i]:                

            identity = identity + 1
            score += match_score(align1[i], align2[i])
    
        # if they are not identical and none of them is gap
        elif align1[i] != align2[i] and align1[i] != '-' and align2[i] != '-': 
            score += match_score(align1[i], align2[i])
            found = 0
    
        #if one of them is a gap, output a space
        elif align1[i] == '-' or align2[i] == '-':          
            score += gap_penalty
    
    identity = float(identity) / len(align1)
    
    return identity, score, align1, align2

def needle(seq1, seq2):
    m, n = len(seq1), len(seq2)  # length of two sequences
    
    # Generate DP table and traceback path pointer matrix
    score = zeros((m+1, n+1))      # the DP table
   
    # Calculate DP table
    for i in range(0, m + 1):
        score[i][0] = gap_penalty * i
    for j in range(0, n + 1):
        score[0][j] = gap_penalty * j
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            match = score[i - 1][j - 1] + match_score(seq1[i-1], seq2[j-1])
            delete = score[i - 1][j] + gap_penalty
            insert = score[i][j - 1] + gap_penalty
            score[i][j] = max(match, delete, insert)

    # Traceback and compute the alignment 
    align1, align2 = [], []
    i,j = m,n # start from the bottom right cell
    while i > 0 and j > 0: # end toching the top or the left edge
        score_current = score[i][j]
        score_diagonal = score[i-1][j-1]
        score_up = score[i][j-1]
        score_left = score[i-1][j]

        if score_current == score_diagonal + match_score(seq1[i-1], seq2[j-1]):
            align1.append(seq1[i-1])
            align2.append(seq2[j-1])
            i -= 1
            j -= 1
        elif score_current == score_left + gap_penalty:
            align1.append(seq1[i-1])
            align2.append('-')
            i -= 1
        elif score_current == score_up + gap_penalty:
            align1.append('-')
            align2.append(seq2[j-1])
            j -= 1

    # Finish tracing up to the top left cell
    while i > 0:
        align1.append(seq1[i-1])
        align2.append('-')
        i -= 1
    while j > 0:
        align1.append('-')
        align2.append(seq2[j-1])
        j -= 1

    return finalize(align1, align2)

def water(seq1, seq2):
    m, n = len(seq1), len(seq2)  # length of two sequences
    
    # Generate DP table and traceback path pointer matrix
    score = zeros((m+1, n+1))      # the DP table
    pointer = zeros((m+1, n+1))    # to store the traceback path
    
    max_score = 0        # initial maximum score in DP table
    # Calculate DP table and mark pointers
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            score_diagonal = score[i-1][j-1] + match_score(seq1[i-1], seq2[j-1])
            score_up = score[i][j-1] + gap_penalty
            score_left = score[i-1][j] + gap_penalty
            score[i][j] = max(0,score_left, score_up, score_diagonal)
            if score[i][j] == 0:
                pointer[i][j] = 0 # 0 means end of the path
            if score[i][j] == score_left:
                pointer[i][j] = 1 # 1 means trace up
            if score[i][j] == score_up:
                pointer[i][j] = 2 # 2 means trace left
            if score[i][j] == score_diagonal:
                pointer[i][j] = 3 # 3 means trace diagonal
            if score[i][j] >= max_score:
                max_i = i
                max_j = j
                max_score = score[i][j];
    
    align1, align2 = [], []    # initial sequences
    
    i,j = max_i,max_j    # indices of path starting point
    
    #traceback, follow pointers
    while pointer[i][j] != 0:
        if pointer[i][j] == 3:
            align1.append(seq1[i-1])
            align2.append(seq2[j-1])
            i -= 1
            j -= 1
        elif pointer[i][j] == 2:
            align1.append('-')
            align2.append(seq2[j-1])
            j -= 1
        elif pointer[i][j] == 1:
            align1.append(seq1[i-1])
            align2.append('-')
            i -= 1

    return finalize(align1, align2)

if __name__ == "__main__":

    for root, dirs, files in os.walk(_CONFIG["folderA"]):
        for file in files:
            if file[0] not in ('.', '~') and file.split(".")[-1] == "txt" and file not in _CONFIG["exclude_files"]:
                if file[0:2] != "DZ": continue
                destf = os.path.join(root, file).replace(_CONFIG["folderA"], _CONFIG["folderB"])
                if os.path.isfile(destf):
                    print(f"Comparing file {file}")
                    result = compare(os.path.join(root, file), destf)
                    print_differences(result)
                else:
                    print(f"File {file} not found in the destination folder! Skipping.")
                


        #aligned_data = needle(text.strip("\n").split(" "), pred_text.strip("\n").split(" "))

