import math
from whisperx.conjunctions import get_conjunctions, get_comma

def normal_round(n):
    if n - math.floor(n) < 0.5:
        return math.floor(n)
    return math.ceil(n)


def format_timestamp(seconds: float, is_vtt: bool = False):

    assert seconds >= 0, "non-negative timestamp expected"
    milliseconds = round(seconds * 1000.0)

    hours = milliseconds // 3_600_000
    milliseconds -= hours * 3_600_000

    minutes = milliseconds // 60_000
    milliseconds -= minutes * 60_000

    seconds = milliseconds // 1_000
    milliseconds -= seconds * 1_000

    separator = '.' if is_vtt else ','
    
    hours_marker = f"{hours:02d}:"
    return (
        f"{hours_marker}{minutes:02d}:{seconds:02d}{separator}{milliseconds:03d}"
    )



class SubtitlesProcessor:
    def __init__(self, segments, lang, max_line_length = 45, min_char_length_splitter = 30, is_vtt = False):
        self.comma = get_comma(lang)
        self.conjunctions = set(get_conjunctions(lang))
        self.segments = segments
        self.lang = lang
        self.max_line_length = max_line_length
        self.min_char_length_splitter = min_char_length_splitter
        self.is_vtt = is_vtt
        complex_script_languages = ['th', 'lo', 'my', 'km', 'am', 'ko', 'ja', 'zh', 'ti', 'ta', 'te', 'kn', 'ml', 'hi', 'ne', 'mr', 'ar', 'fa', 'ur', 'ka']
        if self.lang in complex_script_languages:
            self.max_line_length = 30
            self.min_char_length_splitter = 20

    def estimate_timestamp_for_word(self, words, i, next_segment_start_time=None):
        k = 0.25
        has_prev_end = i > 0 and 'end' in words[i - 1]
        has_next_start = i < len(words) - 1 and 'start' in words[i + 1]

        if has_prev_end:
            words[i]['start'] = words[i - 1]['end']
            if has_next_start:
                words[i]['end'] = words[i + 1]['start']
            else:
                if next_segment_start_time:
                    words[i]['end'] = next_segment_start_time if next_segment_start_time - words[i - 1]['end'] <= 1 else next_segment_start_time - 0.5
                else:
                    words[i]['end'] = words[i]['start'] + len(words[i]['word']) * k

        elif has_next_start:
            words[i]['start'] = words[i + 1]['start'] - len(words[i]['word']) * k
            words[i]['end'] = words[i + 1]['start']

        else:
            if next_segment_start_time:
                words[i]['start'] = next_segment_start_time - 1
                words[i]['end'] = next_segment_start_time - 0.5
            else:
                words[i]['start'] = 0
                words[i]['end'] = 0



    def process_segments(self, advanced_splitting=True):
        subtitles = []
        for i, segment in enumerate(self.segments):
            next_segment_start_time = self.segments[i + 1]['start'] if i + 1 < len(self.segments) else None
            
            if advanced_splitting:

                split_points = self.determine_advanced_split_points(segment, next_segment_start_time)
                subtitles.extend(self.generate_subtitles_from_split_points(segment, split_points, next_segment_start_time))
            else:
                words = segment['words']
                for i, word in enumerate(words):
                    if 'start' not in word or 'end' not in word:
                        self.estimate_timestamp_for_word(words, i, next_segment_start_time)

                subtitles.append({
                'start': segment['start'],
                'end': segment['end'],
                'text': segment['text']
            })

        return subtitles

    def determine_advanced_split_points(self, segment, next_segment_start_time=None):
        split_points = []
        last_split_point = 0
        char_count = 0

        words = segment.get('words', segment['text'].split())
        add_space = 0 if self.lang in ['zh', 'ja'] else 1

        total_char_count = sum(len(word['word']) if isinstance(word, dict) else len(word) + add_space for word in words)
        char_count_after = total_char_count

        for i, word in enumerate(words):
            word_text = word['word'] if isinstance(word, dict) else word
            word_length = len(word_text) + add_space
            char_count += word_length
            char_count_after -= word_length

            char_count_before = char_count - word_length

            if isinstance(word, dict) and ('start' not in word or 'end' not in word):
                self.estimate_timestamp_for_word(words, i, next_segment_start_time)

            if char_count >= self.max_line_length:
                midpoint = normal_round((last_split_point + i) / 2)
                if char_count_before >= self.min_char_length_splitter:
                    split_points.append(midpoint)
                    last_split_point = midpoint + 1
                    char_count = sum(len(words[j]['word']) if isinstance(words[j], dict) else len(words[j]) + add_space for j in range(last_split_point, i + 1))

            elif word_text.endswith(self.comma) and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter:
                split_points.append(i)
                last_split_point = i + 1
                char_count = 0

            elif word_text.lower() in self.conjunctions and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter:
                split_points.append(i - 1)
                last_split_point = i
                char_count = word_length

        return split_points

    
    def generate_subtitles_from_split_points(self, segment, split_points, next_start_time=None):
        subtitles = []
        
        words = segment.get('words', segment['text'].split())
        total_word_count = len(words)
        total_time = segment['end'] - segment['start']
        elapsed_time = segment['start']
        prefix = ' ' if self.lang not in ['zh', 'ja'] else ''
        start_idx = 0
        for split_point in split_points:

            fragment_words = words[start_idx:split_point + 1]
            current_word_count = len(fragment_words)
            

            if isinstance(fragment_words[0], dict):
                start_time = fragment_words[0]['start']
                end_time = fragment_words[-1]['end']
                next_start_time_for_word = words[split_point + 1]['start'] if split_point + 1 < len(words) else None
                if next_start_time_for_word and (next_start_time_for_word - end_time) <= 0.8:
                    end_time = next_start_time_for_word
            else:
                fragment = prefix.join(fragment_words).strip()
                current_duration = (current_word_count / total_word_count) * total_time
                start_time = elapsed_time
                end_time = elapsed_time + current_duration
                elapsed_time += current_duration


            subtitles.append({
                'start': start_time,
                'end': end_time,
                'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words)
            })
            
            start_idx = split_point + 1

        # Handle the last fragment
        if start_idx < len(words):
            fragment_words = words[start_idx:]
            current_word_count = len(fragment_words)
            
            if isinstance(fragment_words[0], dict):
                start_time = fragment_words[0]['start']
                end_time = fragment_words[-1]['end']
            else:
                fragment = prefix.join(fragment_words).strip()
                current_duration = (current_word_count / total_word_count) * total_time
                start_time = elapsed_time
                end_time = elapsed_time + current_duration

            if next_start_time and (next_start_time - end_time) <= 0.8:
                end_time = next_start_time

            subtitles.append({
                'start': start_time,
                'end': end_time if end_time is not None else segment['end'],
                'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words)
            })
            
        return subtitles
    


    def save(self, filename="subtitles.srt", advanced_splitting=True):
        
        subtitles = self.process_segments(advanced_splitting)

        def write_subtitle(file, idx, start_time, end_time, text):

            file.write(f"{idx}\n")
            file.write(f"{start_time} --> {end_time}\n")
            file.write(text + "\n\n")

        with open(filename, 'w', encoding='utf-8') as file:
            if self.is_vtt:
                file.write("WEBVTT\n\n")
            
            if advanced_splitting:
                for idx, subtitle in enumerate(subtitles, 1):
                    start_time = format_timestamp(subtitle['start'], self.is_vtt)
                    end_time = format_timestamp(subtitle['end'], self.is_vtt)
                    text = subtitle['text'].strip()
                    write_subtitle(file, idx, start_time, end_time, text)

        return len(subtitles)