From 3c9d1865ca9e25554d93cb5fd2a2f6c6b3700c2a Mon Sep 17 00:00:00 2001 From: nick Date: Sat, 20 Jul 2024 00:15:41 -0700 Subject: [PATCH] video node --- comfy-nodes/external_video.py | 420 +++++++++++++++++++++++++++------- 1 file changed, 340 insertions(+), 80 deletions(-) diff --git a/comfy-nodes/external_video.py b/comfy-nodes/external_video.py index bbdf444..a8e2d2d 100644 --- a/comfy-nodes/external_video.py +++ b/comfy-nodes/external_video.py @@ -1,10 +1,15 @@ -# credit goes to https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite and is meant to work with +# credit goes to https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite +# Intended to work with https://github.com/NicholasKao1029/ComfyUI-VideoHelperSuite/tree/main import os import itertools import numpy as np import torch +from typing import Union +from torch import Tensor import cv2 +import psutil +from collections.abc import Mapping import folder_paths from comfy.utils import common_upscale @@ -90,13 +95,25 @@ if gifski_path is None: gifski_path = shutil.which("gifski") +def is_safe_path(path): + if "VHS_STRICT_PATHS" not in os.environ: + return True + basedir = os.path.abspath(".") + try: + common_path = os.path.commonpath([basedir, path]) + except: + # Different drive on windows + return False + return common_path == basedir + + def get_sorted_dir_files_from_directory( directory: str, skip_first_images: int = 0, select_every_nth: int = 1, extensions: Iterable = None, ): - directory = directory.strip() + directory = strip_path(directory) dir_files = os.listdir(directory) dir_files = sorted(dir_files) dir_files = [os.path.join(directory, x) for x in dir_files] @@ -177,18 +194,59 @@ def requeue_workflow(requeue_required=(-1, True)): def get_audio(file, start_time=0, duration=0): - args = [ffmpeg_path, "-v", "error", "-i", file] + args = [ffmpeg_path, "-i", file] if start_time > 0: args += ["-ss", str(start_time)] if duration > 0: args += ["-t", str(duration)] try: + # TODO: scan for sample rate and maintain res = subprocess.run( - args + ["-f", "wav", "-"], stdout=subprocess.PIPE, check=True - ).stdout + args + ["-f", "f32le", "-"], capture_output=True, check=True + ) + audio = torch.frombuffer(bytearray(res.stdout), dtype=torch.float32) + match = re.search(", (\\d+) Hz, (\\w+), ", res.stderr.decode("utf-8")) except subprocess.CalledProcessError as e: - return False - return res + raise Exception( + f"VHS failed to extract audio from {file}:\n" + e.stderr.decode("utf-8") + ) + if match: + ar = int(match.group(1)) + # NOTE: Just throwing an error for other channel types right now + # Will deal with issues if they come + ac = {"mono": 1, "stereo": 2}[match.group(2)] + else: + ar = 44100 + ac = 2 + audio = audio.reshape((-1, ac)).transpose(0, 1).unsqueeze(0) + return {"waveform": audio, "sample_rate": ar} + + +class LazyAudioMap(Mapping): + def __init__(self, file, start_time, duration): + self.file = file + self.start_time = start_time + self.duration = duration + self._dict = None + + def __getitem__(self, key): + if self._dict is None: + self._dict = get_audio(self.file, self.start_time, self.duration) + return self._dict[key] + + def __iter__(self): + if self._dict is None: + self._dict = get_audio(self.file, self.start_time, self.duration) + return iter(self._dict) + + def __len__(self): + if self._dict is None: + self._dict = get_audio(self.file, self.start_time, self.duration) + return len(self._dict) + + +def lazy_get_audio(file, start_time=0, duration=0): + return LazyAudioMap(file, start_time, duration) def lazy_eval(func): @@ -230,6 +288,19 @@ def validate_sequence(path): return False +def strip_path(path): + # This leaves whitespace inside quotes and only a single " + # thus ' ""test"' -> '"test' + # consider path.strip(string.whitespace+"\"") + # or weightier re.fullmatch("[\\s\"]*(.+?)[\\s\"]*", path).group(1) + path = path.strip() + if path.startswith('"'): + path = path[1:] + if path.endswith('"'): + path = path[:-1] + return path + + def hash_path(path): if path is None: return "input" @@ -286,6 +357,145 @@ def target_size( return (width, height) +def validate_index( + index: int, + length: int = 0, + is_range: bool = False, + allow_negative=False, + allow_missing=False, +) -> int: + # if part of range, do nothing + if is_range: + return index + # otherwise, validate index + # validate not out of range - only when latent_count is passed in + if length > 0 and index > length - 1 and not allow_missing: + raise IndexError(f"Index '{index}' out of range for {length} item(s).") + # if negative, validate not out of range + if index < 0: + if not allow_negative: + raise IndexError(f"Negative indeces not allowed, but was '{index}'.") + conv_index = length + index + if conv_index < 0 and not allow_missing: + raise IndexError( + f"Index '{index}', converted to '{conv_index}' out of range for {length} item(s)." + ) + index = conv_index + return index + + +def convert_to_index_int( + raw_index: str, + length: int = 0, + is_range: bool = False, + allow_negative=False, + allow_missing=False, +) -> int: + try: + return validate_index( + int(raw_index), + length=length, + is_range=is_range, + allow_negative=allow_negative, + allow_missing=allow_missing, + ) + except ValueError as e: + raise ValueError(f"Index '{raw_index}' must be an integer.", e) + + +def convert_str_to_indexes( + indexes_str: str, length: int = 0, allow_missing=False +) -> list[int]: + if not indexes_str: + return [] + int_indexes = list(range(0, length)) + allow_negative = length > 0 + chosen_indexes = [] + # parse string - allow positive ints, negative ints, and ranges separated by ':' + groups = indexes_str.split(",") + groups = [g.strip() for g in groups] + for g in groups: + # parse range of indeces (e.g. 2:16) + if ":" in g: + index_range = g.split(":", 2) + index_range = [r.strip() for r in index_range] + + start_index = index_range[0] + if len(start_index) > 0: + start_index = convert_to_index_int( + start_index, + length=length, + is_range=True, + allow_negative=allow_negative, + allow_missing=allow_missing, + ) + else: + start_index = 0 + end_index = index_range[1] + if len(end_index) > 0: + end_index = convert_to_index_int( + end_index, + length=length, + is_range=True, + allow_negative=allow_negative, + allow_missing=allow_missing, + ) + else: + end_index = length + # support step as well, to allow things like reversing, every-other, etc. + step = 1 + if len(index_range) > 2: + step = index_range[2] + if len(step) > 0: + step = convert_to_index_int( + step, + length=length, + is_range=True, + allow_negative=True, + allow_missing=True, + ) + else: + step = 1 + # if latents were passed in, base indeces on known latent count + if len(int_indexes) > 0: + chosen_indexes.extend(int_indexes[start_index:end_index][::step]) + # otherwise, assume indeces are valid + else: + chosen_indexes.extend(list(range(start_index, end_index, step))) + # parse individual indeces + else: + chosen_indexes.append( + convert_to_index_int( + g, + length=length, + allow_negative=allow_negative, + allow_missing=allow_missing, + ) + ) + return chosen_indexes + + +def select_indexes(input_obj: Union[Tensor, list], idxs: list): + if type(input_obj) == Tensor: + return input_obj[idxs] + else: + return [input_obj[i] for i in idxs] + + +def select_indexes_from_str( + input_obj: Union[Tensor, list], indexes: str, err_if_missing=True, err_if_empty=True +): + real_idxs = convert_str_to_indexes( + indexes, len(input_obj), allow_missing=not err_if_missing + ) + if err_if_empty and len(real_idxs) == 0: + raise Exception(f"Nothing was selected based on indexes found in '{indexes}'.") + return select_indexes(input_obj, real_idxs) + + +### + + def cv_frame_generator( video, force_rate, @@ -295,9 +505,10 @@ def cv_frame_generator( meta_batch=None, unique_id=None, ): - video_cap = cv2.VideoCapture(video) + video_cap = cv2.VideoCapture(strip_path(video)) if not video_cap.isOpened(): raise ValueError(f"{video} could not be loaded with cv.") + pbar = None # extract video metadata fps = video_cap.get(cv2.CAP_PROP_FPS) @@ -319,6 +530,8 @@ def cv_frame_generator( target_frame_time = 1 / force_rate yield (width, height, fps, duration, total_frames, target_frame_time) + if meta_batch is not None: + yield min(frame_load_cap, total_frames) time_offset = target_frame_time - base_frame_time while video_cap.isOpened(): @@ -349,7 +562,8 @@ def cv_frame_generator( frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # convert frame to comfyui's expected format # TODO: frame contains no exif information. Check if opencv2 has already applied - frame = np.array(frame, dtype=np.float32) / 255.0 + frame = np.array(frame, dtype=np.float32) + torch.from_numpy(frame).div_(255) if prev_frame is not None: inp = yield prev_frame if inp is not None: @@ -357,6 +571,8 @@ def cv_frame_generator( return prev_frame = frame frames_added += 1 + if pbar is not None: + pbar.update_absolute(frames_added, frame_load_cap) # if cap exists and we've reached it, stop processing frames if frame_load_cap > 0 and frames_added >= frame_load_cap: break @@ -367,6 +583,17 @@ def cv_frame_generator( yield prev_frame +def batched(it, n): + while batch := tuple(itertools.islice(it, n)): + yield batch + + +def batched_vae_encode(images, vae, frames_per_batch): + for batch in batched(images, frames_per_batch): + image_batch = torch.from_numpy(np.array(batch)) + yield from vae.encode(image_batch).numpy() + + def load_video_cv( video: str, force_rate: int, @@ -378,6 +605,8 @@ def load_video_cv( select_every_nth: int, meta_batch=None, unique_id=None, + memory_limit_mb=None, + vae=None, ): if meta_batch is None or unique_id not in meta_batch.inputs: gen = cv_frame_generator( @@ -401,30 +630,89 @@ def load_video_cv( total_frames, target_frame_time, ) + meta_batch.total_frames = min(meta_batch.total_frames, next(gen)) else: (gen, width, height, fps, duration, total_frames, target_frame_time) = ( meta_batch.inputs[unique_id] ) - if meta_batch is not None: - gen = itertools.islice(gen, meta_batch.frames_per_batch) + memory_limit = None + if memory_limit_mb is not None: + memory_limit *= 2**20 + else: + # TODO: verify if garbage collection should be performed here. + # leaves ~128 MB unreserved for safety + try: + memory_limit = ( + psutil.virtual_memory().available + psutil.swap_memory().free + ) - 2**27 + except: + print( + "Failed to calculate available memory. Memory load limit has been disabled" + ) + if memory_limit is not None: + if vae is not None: + # space required to load as f32, exist as latent with wiggle room, decode to f32 + max_loadable_frames = int( + memory_limit // (width * height * 3 * (4 + 4 + 1 / 10)) + ) + else: + # TODO: use better estimate for when vae is not None + # Consider completely ignoring for load_latent case? + max_loadable_frames = int(memory_limit // (width * height * 3 * (0.1))) + if meta_batch is not None: + if meta_batch.frames_per_batch > max_loadable_frames: + raise RuntimeError( + f"Meta Batch set to {meta_batch.frames_per_batch} frames but only {max_loadable_frames} can fit in memory" + ) + gen = itertools.islice(gen, meta_batch.frames_per_batch) + else: + original_gen = gen + gen = itertools.islice(gen, max_loadable_frames) + downscale_ratio = getattr(vae, "downscale_ratio", 8) + frames_per_batch = (1920 * 1080 * 16) // (width * height) or 1 + if force_size != "Disabled" or vae is not None: + new_size = target_size( + width, height, force_size, custom_width, custom_height, downscale_ratio + ) + if new_size[0] != width or new_size[1] != height: - # Some minor wizardry to eliminate a copy and reduce max memory by a factor of ~2 - images = torch.from_numpy( - np.fromiter(gen, np.dtype((np.float32, (height, width, 3)))) - ) + def rescale(frame): + s = torch.from_numpy( + np.fromiter(frame, np.dtype((np.float32, (height, width, 3)))) + ) + s = s.movedim(-1, 1) + s = common_upscale(s, new_size[0], new_size[1], "lanczos", "center") + return s.movedim(1, -1).numpy() + + gen = itertools.chain.from_iterable( + map(rescale, batched(gen, frames_per_batch)) + ) + else: + new_size = width, height + if vae is not None: + gen = batched_vae_encode(gen, vae, frames_per_batch) + vw, vh = new_size[0] // downscale_ratio, new_size[1] // downscale_ratio + images = torch.from_numpy(np.fromiter(gen, np.dtype((np.float32, (4, vh, vw))))) + else: + # Some minor wizardry to eliminate a copy and reduce max memory by a factor of ~2 + images = torch.from_numpy( + np.fromiter(gen, np.dtype((np.float32, (new_size[1], new_size[0], 3)))) + ) + if meta_batch is None and memory_limit is not None: + try: + next(original_gen) + raise RuntimeError( + f"Memory limit hit after loading {len(images)} frames. Stopping execution." + ) + except StopIteration: + pass if len(images) == 0: raise RuntimeError("No frames generated") - if force_size != "Disabled": - new_size = target_size(width, height, force_size, custom_width, custom_height) - if new_size[0] != width or new_size[1] != height: - s = images.movedim(-1, 1) - s = common_upscale(s, new_size[0], new_size[1], "lanczos", "center") - images = s.movedim(1, -1) # Setup lambda for lazy audio capture - audio = lambda: get_audio( + audio = lazy_get_audio( video, skip_first_frames * target_frame_time, frame_load_cap * target_frame_time * select_every_nth, @@ -440,11 +728,13 @@ def load_video_cv( "loaded_fps": 1 / target_frame_time, "loaded_frame_count": len(images), "loaded_duration": len(images) * target_frame_time, - "loaded_width": images.shape[2], - "loaded_height": images.shape[1], + "loaded_width": new_size[0], + "loaded_height": new_size[1], } - - return (images, len(images), lazy_eval(audio), video_info) + if vae is None: + return (images, len(images), audio, video_info, None) + else: + return (None, len(images), audio, video_info, {"samples": images}) class ComfyUIDeployExternalVideo: @@ -457,68 +747,38 @@ class ComfyUIDeployExternalVideo: file_parts = f.split(".") if len(file_parts) > 1 and (file_parts[-1] in video_extensions): files.append(f) - return { - "required": { - "input_id": ( - "STRING", - {"multiline": False, "default": "input_video"}, - ), - "force_rate": ("INT", {"default": 0, "min": 0, "max": 60, "step": 1}), - "force_size": ( - [ - "Disabled", - "Custom Height", - "Custom Width", - "Custom", - "256x?", - "?x256", - "256x256", - "512x?", - "?x512", - "512x512", - ], - ), - "custom_width": ( - "INT", - {"default": 512, "min": 0, "max": DIMMAX, "step": 8}, - ), - "custom_height": ( - "INT", - {"default": 512, "min": 0, "max": DIMMAX, "step": 8}, - ), - "frame_load_cap": ( - "INT", - {"default": 0, "min": 0, "max": BIGMAX, "step": 1}, - ), - "skip_first_frames": ( - "INT", - {"default": 0, "min": 0, "max": BIGMAX, "step": 1}, - ), - "select_every_nth": ( - "INT", - {"default": 1, "min": 1, "max": BIGMAX, "step": 1}, - ), - }, - "optional": { - "meta_batch": ("VHS_BatchManager",), - "default_value": (sorted(files),), - }, - "hidden": {"unique_id": "UNIQUE_ID"}, - } + return {"required": { + "input_id": ( + "STRING", + {"multiline": False, "default": "input_video"}, + ), + "force_rate": ("INT", {"default": 0, "min": 0, "max": 60, "step": 1}), + "force_size": (["Disabled", "Custom Height", "Custom Width", "Custom", "256x?", "?x256", "256x256", "512x?", "?x512", "512x512"],), + "custom_width": ("INT", {"default": 512, "min": 0, "max": DIMMAX, "step": 8}), + "custom_height": ("INT", {"default": 512, "min": 0, "max": DIMMAX, "step": 8}), + "frame_load_cap": ("INT", {"default": 0, "min": 0, "max": BIGMAX, "step": 1}), + "skip_first_frames": ("INT", {"default": 0, "min": 0, "max": BIGMAX, "step": 1}), + "select_every_nth": ("INT", {"default": 1, "min": 1, "max": BIGMAX, "step": 1}), + }, + "optional": { + "meta_batch": ("VHS_BatchManager",), + "vae": ("VAE",), + "default_value": (sorted(files),), + }, + "hidden": { + "unique_id": "UNIQUE_ID" + }, + } CATEGORY = "Video Helper Suite 🎥🅥🅗🅢" - RETURN_TYPES = ( - "IMAGE", - "INT", - "AUDIO", - "VHS_VIDEOINFO", - ) + RETURN_TYPES = ("IMAGE", "INT", "AUDIO", "VHS_VIDEOINFO", "LATENT") RETURN_NAMES = ( "IMAGE", "frame_count", "audio", "video_info", + "LATENT", ) FUNCTION = "load_video"