o
    ॵiN                     @   s  d dl Z d dlZd dlmZ d dlZd dlZd dlmZ d dlm	Z	m
Z
 d dlZd dlZd dlZd dlm  mZ d dlZd dlmZ d dlmZ d dlmZmZ d dlmZ d dlmZ d d	lm Z  d d
l!m"Z" d dl#m$Z$ d dl%m&Z& d dl'm(Z( e( Z)e$j*ej+dG dd de"Z,dS )    N)datetime)DictOptional)Trainers)NeRFReconPreprocessor)BlenderDatasetColmapDataset)	NeRFModel)ObjectSegmenter)PSNR)BaseTrainer)TRAINERS)	ModelFile)
get_logger)module_namec                       s   e Zd ZdZ												ddedef fddZdd Zd	d
 Zdd Z	dde	e de
eef fddZdd Zdd ZdddZdd ZdddZ  ZS ) NeRFReconAccTrainera  initialize the acceleration version of nerf reconstruction model for object.

    Args:
        model (str): the model path.
        cfg_file (str): cfg json file
        data_type (str): only support 'blender' or 'colmap'
        use_mask (bool): whether use mask of objects, default True
        max_step (int): max train steps, default 30000
        train_num_rays (int): init number of rays in training, default 256
        num_samples_per_ray (int): sampling numbers for each ray, default 1024
        max_train_num_rays (int): max number of rays in training, default 8192
        test_ray_chunk (int): chunk size for rendering, default 1024
        dynamic_ray_sampling (bool): whether use dynamic ray sampling when training, default True
        max_size (int): max size of (width, height) when training, default 800
        n_test_traj_steps (int): number of testing images, default 120
        log_every_n_steps (int): print log info every n steps, default 1000
        work_dir (str): dir to save ckpt and other results
        render_images (bool): whether to render test image after training
        save_mesh (bool): whether to save the reconstructed mesh of object, default False
        save_ckpt (bool): whether to save the checkpoints in data_dir, default False
        network_cfg (dict): args of network config
        match_type (str): colmap feature matching type, only for colmap data
        frame_count (str): extract number of frames, only for video input
        use_distortion (bool): whether run colmap undistortion
    Nmodelcfg_filec                    s  |  |}|| _|d u rt|tj}t | tj	
 s"tdi | _|   | D ]	\}}|| j|< q-|d ur@|| jd< |d urI|| jd< |d urR|| jd< |d ur[|| jd< |d urd|| jd< |d urm|| jd< |	d urv|	| jd< |
d ur|
| jd	< |d ur|| jd
< |d ur|| jd< |d ur|| jd< | jd | _| jdkr| jdkrtd| j| jd | _| jd | _| jd | _| jd | _| j| j | _| jd | _| jd | _| jd | _| jd | _| jd | _| jd | _| jd	 | _| jd | _| jd
 | _| jd | _| jd | _ | jd | _!| jd | _"| jd | _#t$%d| j t&j'(| js<t&)| j t*| j| j| j!| j"| j#d| _+| jrb| jdkrbt&j'| jd}t,|| _-| jdkryd| _.d| j d< d| _/d| j d< n5| jdkrd | _.| j| _| j| _d | j d< | jrd| _/d| j d< t$%d! nd"| _/d"| j d< t$%d# t$%| j  t0| j | j| jd$	 | _1tj2j3| j14 d%d&d'| _5tj	j67d(| _8tj2j9j:| j5| jd) | jd* d+ | jd, d- gd.d/| _;t< | _=| >d0 d S )1NzGPU is required	data_typeuse_maskmax_steptrain_num_raysmax_train_num_rayslog_every_n_stepswork_dirrender_images	save_ckptframe_countuse_distortionblendercolmapz%data type {} is not support currentlynum_samples_per_raytest_ray_chunkdynamic_ray_samplingmax_sizen_test_traj_steps	save_meshnetwork_cfg
match_typez	params:{})r   r   r(   r   r   z
matting.pb)   r)   g      ?radiuswhite
backgroundg      ?zrun nerf with mask datarandomzrun nerf without mask data)r!   r"   g{Gz?gV瞯<)lrepsi            	   
   gQ?)
milestonesgamma*   )?get_or_download_model_dir	model_dirospjoinr   CONFIGURATIONsuper__init__torchcudais_available	Exceptionparams_override_params_from_fileitemsr   formatr   r   r   r!   train_num_samplesr   r"   r#   r$   r%   r   r   r&   r   r   r'   r(   r   r   loggerinfoospathexistsmakedirsr   preprocessorr
   	segmenterimg_whr,   r	   r   optimAdam
parameters	optimizeramp
GradScalergrad_scalerlr_schedulerMultiStepLR	schedulerr   
criterionsset_random_seed)selfr   r   r   r   r   r   r   r   r   r   r   r   r   argskwargskeyvaluesegment_path	__class__ a/home/ubuntu/.local/lib/python3.10/site-packages/modelscope/trainers/cv/nerf_recon_acc_trainer.pyr>   =   s   


















zNeRFReconAccTrainer.__init__c                 C   s$   t | tj | t| d S N)r-   seednpr?   manual_seed)r]   rh   re   re   rf   r\      s   
z#NeRFReconAccTrainer.set_random_seedc                 C   s  | j d d | jd< | j d d | jd< | j d d | jd< | j d d | jd< | j d d | jd< | j d d | jd< | j d d | jd< | j d d	 | jd	< | j d d
 | jd
< | j d d | jd< | j d d | jd< | j d d | jd< | j d d | jd< | j d d | jd< | j d d | jd< | j d d | jd< | j d d | jd< | j d d | jd< | j d d | jd< d S )Ntrainr   r   r   r   r   r#   r   r   r   r   r   r!   r"   r$   r%   r&   r'   rN   r(   r   r   )cfgrC   )r]   re   re   rf   rD      s>   




z.NeRFReconAccTrainer._override_params_from_filec                 O   s  t d i }| jdkr d|vrtd|d }||d< d|d< | jdkrod|v r7|d }| j|d< ||d< n8d|v rk|d }tj|d}tj|rgt		d	
|}t|d
kr^td||d< d|d< ntdtd| |}|d }t d
| | jdkr| jrtjtj|drtj|d}	tj|d}
ntj|d}	tj|d}
tj|
dd t		d	
|	t		d
|	 }|D ]}t|}| j|}tj|
tj|}t|| qt d | jdkrt|d| j| j| jd| _t|d| j| jd| _n$| jdkr0t|d| j| j| j| jd| _t|d| j| j| j| jd| _d
}t }|| jk rt t| jD ]}| j!"  | j| }| j!#| |d $ }|d $ }| !|}|d d
kroqCd }| j%r| j&t'|d  }t(| j| }t)t(| jd! |d"  | j*| _| j+| j t,-|d# |d$  ||d$  }||7 }| .|d# |}| j/0  | j12|3  | j/4  | j54  || j6 d
krt | }t d%|d&d'| d(|d)d*| jd+d,|d)d- |d.7 }qC|| jk s<| j7r(tj| jd/}t89| j| j!: | j/: d0| t d1
| | j;r;tj| jd2}| <| j| t d3 d S )4Nz"Begin nerf reconstruction trainingr   data_dirz.Please specify data_dir of nerf_synthetic data video_input_pathr    imagesz{}/*.*gr   zno images found in images dirz images dir not found in data_dirz;Please specify video_path or images path for colmap processz3nerf reconstruction preprocess done, data_dir is {}
preprocesszpreprocess/imageszpreprocess/masksmasksTexist_okz{}/*.*Gzsegment images done!rk   )root_fpsplitrP   num_rayscolor_bkgd_augtest)ru   rv   rP   rw   )ru   rv   rP   r$   rw   rx   )ru   rv   rP   r$   rw   r%   rayspixelsnum_samplesg        g?g?comp_rgb
rays_validzelapsed_time=z.2fz	s | step=z | loss=z.4fz | train/num_rays=dz |PSNR=    z
model.ckpt)global_stepnetwork_state_dictoptimizer_state_dictz"save checkpoints done, saved as {}z
render.mp4zNeRF reconstruction finish)=rH   rI   r   rB   r   rJ   rK   r;   rL   globrF   lenrN   r   rM   cv2imreadrO   run_maskbasenameimwriter   rP   r   r,   train_datasettest_datasetr   r$   r%   timer   ranger   rk   update_stepr@   r#   rG   sumintminr   update_num_raysFsmooth_l1_lossr[   rT   	zero_gradrW   scalebackwardsteprZ   r   r   r?   save
state_dictr   render_video)r]   r^   r_   processor_inputrm   ro   
images_dir
image_listprocessor_output	image_dirsave_mask_dirimg_listimg_pathimgmaskoutpathr   ticidatarz   r{   outlosstempr   loss_rgbpsnrelapsed_timesave_ckpt_namesave_video_pathre   re   rf   rk      s  










		






*zNeRFReconAccTrainer.traincheckpoint_pathreturnc                 O   s   t d)a  evaluate a dataset

        evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path`
        does not exist, read from the config file.

        Args:
            checkpoint_path (Optional[str], optional): the model path. Defaults to None.

        Returns:
            Dict[str, float]: the results about the evaluation
            Example:
            {"accuracy": 0.5091743119266054, "f1": 0.673780487804878}
        z#evaluate is not supported currently)NotImplementedError)r]   r   r^   r_   re   re   rf   evaluate}  s   zNeRFReconAccTrainer.evaluatec                 C   s  | j   t  d}ttt| jD ]S}| j| }|d  }|d  }|d }| j 	|}	|| 
|	d |7 }|\}
}|	d ||
d}tj|d}tj|dd	 tj||d
d}| || q| || td|t| j  td| | jr| j  }tj|d}| ||d |d  td| W d    d S W d    d S 1 sw   Y  d S )Nr   rz   r{   image_whr}   r1   renderTrs   r   z.pngztest psnr: {}z#save render video done. saved as {}z
render.objv_pos	t_pos_idxz"save render mesh done. saved as {})r   evalr?   no_gradtqdmr   r   r   r@   	inferencer[   viewrJ   rK   r;   rM   
save_image
save_videorH   rI   rF   r&   
isosurfacesave_obj)r]   save_dirr   r   r   r   rz   r{   r   r   WHr   save_img_dirsave_img_pathmeshsave_mesh_pathre   re   rf   r     s>   



"z NeRFReconAccTrainer.render_videoc                 C   s\   | dd  }|d tj}t|tj}t	j
|}t	j|dd t|| d S )Nr   r   g     o@Trs   )clipcpunumpyastyperi   uint8r   cvtColorCOLOR_RGB2BGRrJ   rK   dirnamerM   r   )r]   filenamer   r   re   re   rf   r     s   zNeRFReconAccTrainer.save_image   c                 C   sz   t  d|}t|dd d}dd |D }|d j\}}}t|tjd |||fd	}	|D ]}
|	|
 q/|	  d S )
Nz{}/*.pngc                 S   s   t tj| d d S )N)r   rJ   rK   r   )fre   re   rf   <lambda>  s    z0NeRFReconAccTrainer.save_video.<locals>.<lambda>)r`   c                 S   s   g | ]}t |qS re   )r   r   ).0r   re   re   rf   
<listcomp>  s    z2NeRFReconAccTrainer.save_video.<locals>.<listcomp>r   mp4vT)	r   rF   sortedshaper   VideoWriterVideoWriter_fourccwriterelease)r]   r   img_dirfps	img_pathsimgsr   r   _writerr   re   re   rf   r     s   
zNeRFReconAccTrainer.save_videoc           
   
   C   s  t |dy}|D ]}|d|d |d |d  q|d ur=t|t|ks)J |D ]}|d|d d|d   q+tt|D ]2}|d td	D ]!}	|d
t|| |	 d |d u rbdn	t|| |	 d f  qN|d qCW d    d S 1 sw   Y  d S )Nwzv {} {} {} 
r   r   r0   z
vt {} {} 
g      ?zf r1   z %s/%srn   
)openr   rF   r   r   str)
r]   r   r   r   v_tex	t_tex_idxr   vr   jre   re   rf   	write_obj  s(   " 
"zNeRFReconAccTrainer.write_objc                 C   s   |   }|   }tj|}tj|dd |d ur8|d ur8|   }|   }| ||||| d S | ||||| d S )NTrs   )r   r   rJ   rK   r   rM   r   )r]   r   r   r   r   r   r   re   re   rf   r     s   zNeRFReconAccTrainer.save_obj)NNNNNNNNNNNNrg   )r   )NN)__name__
__module____qualname____doc__r   r>   r\   rD   rk   r   r   floatr   r   r   r   r   r   __classcell__re   re   rc   rf   r   !   sF      

!
r   )-r   rJ   os.pathrK   r:   r-   r   r   typingr   r   r   r   ri   r?   torch.nn.functionalnn
functionalr   r   modelscope.metainfor   #modelscope.models.cv.nerf_recon_accr   ;modelscope.models.cv.nerf_recon_acc.dataloader.nerf_datasetr   r   0modelscope.models.cv.nerf_recon_acc.network.nerfr	   5modelscope.models.cv.nerf_recon_acc.network.segmenterr
   1modelscope.models.cv.nerf_recon_acc.network.utilsr   modelscope.trainers.baser   modelscope.trainers.builderr   modelscope.utils.constantr   modelscope.utils.loggerr   rH   register_modulenerf_recon_accr   re   re   re   rf   <module>   s2   