import pandas as pd
from glob import glob
from collections import defaultdict
from mutagen.mp3 import MP3
import os

files = sorted(glob('/home/ubuntu/10-speakers-latest-v1/data/*.parquet'))

speaker_dur = defaultdict(float)
speaker_count = defaultdict(int)
speaker_tags = defaultdict(lambda: defaultdict(int))

tmp = '/tmp/dur_check.mp3'
for fi, f in enumerate(files):
    print(f'Shard {fi+1}/15...', flush=True)
    df = pd.read_parquet(f)
    for _, row in df.iterrows():
        spk = row['speaker_id']
        tag = row['tag']
        audio_bytes = row['audio']['bytes']

        with open(tmp, 'wb') as out:
            out.write(audio_bytes)
        try:
            audio = MP3(tmp)
            dur = audio.info.length
            speaker_dur[spk] += dur
            speaker_count[spk] += 1
            speaker_tags[spk][tag] += 1
        except:
            pass

if os.path.exists(tmp):
    os.remove(tmp)

print()
print(f'{"Speaker":<12} {"Samples":>8} {"Hours":>8}  Tags')
print('=' * 90)
total_hrs = 0
for spk in sorted(speaker_dur.keys()):
    hrs = speaker_dur[spk] / 3600
    total_hrs += hrs
    tags_str = ', '.join(f'{t}:{c}' for t, c in sorted(speaker_tags[spk].items(), key=lambda x: -x[1]))
    print(f'{spk:<12} {speaker_count[spk]:>8} {hrs:>8.2f}  {tags_str}')

print('=' * 90)
print(f'{"TOTAL":<12} {sum(speaker_count.values()):>8} {total_hrs:>8.2f}')
