video node

This commit is contained in:
nick 2024-07-20 00:15:41 -07:00
parent 6e4532078f
commit 3c9d1865ca

View File

@ -1,10 +1,15 @@
# credit goes to https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite and is meant to work with # credit goes to https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite
# Intended to work with https://github.com/NicholasKao1029/ComfyUI-VideoHelperSuite/tree/main
import os import os
import itertools import itertools
import numpy as np import numpy as np
import torch import torch
from typing import Union
from torch import Tensor
import cv2 import cv2
import psutil
from collections.abc import Mapping
import folder_paths import folder_paths
from comfy.utils import common_upscale from comfy.utils import common_upscale
@ -90,13 +95,25 @@ if gifski_path is None:
gifski_path = shutil.which("gifski") gifski_path = shutil.which("gifski")
def is_safe_path(path):
if "VHS_STRICT_PATHS" not in os.environ:
return True
basedir = os.path.abspath(".")
try:
common_path = os.path.commonpath([basedir, path])
except:
# Different drive on windows
return False
return common_path == basedir
def get_sorted_dir_files_from_directory( def get_sorted_dir_files_from_directory(
directory: str, directory: str,
skip_first_images: int = 0, skip_first_images: int = 0,
select_every_nth: int = 1, select_every_nth: int = 1,
extensions: Iterable = None, extensions: Iterable = None,
): ):
directory = directory.strip() directory = strip_path(directory)
dir_files = os.listdir(directory) dir_files = os.listdir(directory)
dir_files = sorted(dir_files) dir_files = sorted(dir_files)
dir_files = [os.path.join(directory, x) for x in dir_files] dir_files = [os.path.join(directory, x) for x in dir_files]
@ -177,18 +194,59 @@ def requeue_workflow(requeue_required=(-1, True)):
def get_audio(file, start_time=0, duration=0): def get_audio(file, start_time=0, duration=0):
args = [ffmpeg_path, "-v", "error", "-i", file] args = [ffmpeg_path, "-i", file]
if start_time > 0: if start_time > 0:
args += ["-ss", str(start_time)] args += ["-ss", str(start_time)]
if duration > 0: if duration > 0:
args += ["-t", str(duration)] args += ["-t", str(duration)]
try: try:
# TODO: scan for sample rate and maintain
res = subprocess.run( res = subprocess.run(
args + ["-f", "wav", "-"], stdout=subprocess.PIPE, check=True args + ["-f", "f32le", "-"], capture_output=True, check=True
).stdout )
audio = torch.frombuffer(bytearray(res.stdout), dtype=torch.float32)
match = re.search(", (\\d+) Hz, (\\w+), ", res.stderr.decode("utf-8"))
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
return False raise Exception(
return res f"VHS failed to extract audio from {file}:\n" + e.stderr.decode("utf-8")
)
if match:
ar = int(match.group(1))
# NOTE: Just throwing an error for other channel types right now
# Will deal with issues if they come
ac = {"mono": 1, "stereo": 2}[match.group(2)]
else:
ar = 44100
ac = 2
audio = audio.reshape((-1, ac)).transpose(0, 1).unsqueeze(0)
return {"waveform": audio, "sample_rate": ar}
class LazyAudioMap(Mapping):
def __init__(self, file, start_time, duration):
self.file = file
self.start_time = start_time
self.duration = duration
self._dict = None
def __getitem__(self, key):
if self._dict is None:
self._dict = get_audio(self.file, self.start_time, self.duration)
return self._dict[key]
def __iter__(self):
if self._dict is None:
self._dict = get_audio(self.file, self.start_time, self.duration)
return iter(self._dict)
def __len__(self):
if self._dict is None:
self._dict = get_audio(self.file, self.start_time, self.duration)
return len(self._dict)
def lazy_get_audio(file, start_time=0, duration=0):
return LazyAudioMap(file, start_time, duration)
def lazy_eval(func): def lazy_eval(func):
@ -230,6 +288,19 @@ def validate_sequence(path):
return False return False
def strip_path(path):
# This leaves whitespace inside quotes and only a single "
# thus ' ""test"' -> '"test'
# consider path.strip(string.whitespace+"\"")
# or weightier re.fullmatch("[\\s\"]*(.+?)[\\s\"]*", path).group(1)
path = path.strip()
if path.startswith('"'):
path = path[1:]
if path.endswith('"'):
path = path[:-1]
return path
def hash_path(path): def hash_path(path):
if path is None: if path is None:
return "input" return "input"
@ -286,6 +357,145 @@ def target_size(
return (width, height) return (width, height)
def validate_index(
index: int,
length: int = 0,
is_range: bool = False,
allow_negative=False,
allow_missing=False,
) -> int:
# if part of range, do nothing
if is_range:
return index
# otherwise, validate index
# validate not out of range - only when latent_count is passed in
if length > 0 and index > length - 1 and not allow_missing:
raise IndexError(f"Index '{index}' out of range for {length} item(s).")
# if negative, validate not out of range
if index < 0:
if not allow_negative:
raise IndexError(f"Negative indeces not allowed, but was '{index}'.")
conv_index = length + index
if conv_index < 0 and not allow_missing:
raise IndexError(
f"Index '{index}', converted to '{conv_index}' out of range for {length} item(s)."
)
index = conv_index
return index
def convert_to_index_int(
raw_index: str,
length: int = 0,
is_range: bool = False,
allow_negative=False,
allow_missing=False,
) -> int:
try:
return validate_index(
int(raw_index),
length=length,
is_range=is_range,
allow_negative=allow_negative,
allow_missing=allow_missing,
)
except ValueError as e:
raise ValueError(f"Index '{raw_index}' must be an integer.", e)
def convert_str_to_indexes(
indexes_str: str, length: int = 0, allow_missing=False
) -> list[int]:
if not indexes_str:
return []
int_indexes = list(range(0, length))
allow_negative = length > 0
chosen_indexes = []
# parse string - allow positive ints, negative ints, and ranges separated by ':'
groups = indexes_str.split(",")
groups = [g.strip() for g in groups]
for g in groups:
# parse range of indeces (e.g. 2:16)
if ":" in g:
index_range = g.split(":", 2)
index_range = [r.strip() for r in index_range]
start_index = index_range[0]
if len(start_index) > 0:
start_index = convert_to_index_int(
start_index,
length=length,
is_range=True,
allow_negative=allow_negative,
allow_missing=allow_missing,
)
else:
start_index = 0
end_index = index_range[1]
if len(end_index) > 0:
end_index = convert_to_index_int(
end_index,
length=length,
is_range=True,
allow_negative=allow_negative,
allow_missing=allow_missing,
)
else:
end_index = length
# support step as well, to allow things like reversing, every-other, etc.
step = 1
if len(index_range) > 2:
step = index_range[2]
if len(step) > 0:
step = convert_to_index_int(
step,
length=length,
is_range=True,
allow_negative=True,
allow_missing=True,
)
else:
step = 1
# if latents were passed in, base indeces on known latent count
if len(int_indexes) > 0:
chosen_indexes.extend(int_indexes[start_index:end_index][::step])
# otherwise, assume indeces are valid
else:
chosen_indexes.extend(list(range(start_index, end_index, step)))
# parse individual indeces
else:
chosen_indexes.append(
convert_to_index_int(
g,
length=length,
allow_negative=allow_negative,
allow_missing=allow_missing,
)
)
return chosen_indexes
def select_indexes(input_obj: Union[Tensor, list], idxs: list):
if type(input_obj) == Tensor:
return input_obj[idxs]
else:
return [input_obj[i] for i in idxs]
def select_indexes_from_str(
input_obj: Union[Tensor, list], indexes: str, err_if_missing=True, err_if_empty=True
):
real_idxs = convert_str_to_indexes(
indexes, len(input_obj), allow_missing=not err_if_missing
)
if err_if_empty and len(real_idxs) == 0:
raise Exception(f"Nothing was selected based on indexes found in '{indexes}'.")
return select_indexes(input_obj, real_idxs)
###
def cv_frame_generator( def cv_frame_generator(
video, video,
force_rate, force_rate,
@ -295,9 +505,10 @@ def cv_frame_generator(
meta_batch=None, meta_batch=None,
unique_id=None, unique_id=None,
): ):
video_cap = cv2.VideoCapture(video) video_cap = cv2.VideoCapture(strip_path(video))
if not video_cap.isOpened(): if not video_cap.isOpened():
raise ValueError(f"{video} could not be loaded with cv.") raise ValueError(f"{video} could not be loaded with cv.")
pbar = None
# extract video metadata # extract video metadata
fps = video_cap.get(cv2.CAP_PROP_FPS) fps = video_cap.get(cv2.CAP_PROP_FPS)
@ -319,6 +530,8 @@ def cv_frame_generator(
target_frame_time = 1 / force_rate target_frame_time = 1 / force_rate
yield (width, height, fps, duration, total_frames, target_frame_time) yield (width, height, fps, duration, total_frames, target_frame_time)
if meta_batch is not None:
yield min(frame_load_cap, total_frames)
time_offset = target_frame_time - base_frame_time time_offset = target_frame_time - base_frame_time
while video_cap.isOpened(): while video_cap.isOpened():
@ -349,7 +562,8 @@ def cv_frame_generator(
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# convert frame to comfyui's expected format # convert frame to comfyui's expected format
# TODO: frame contains no exif information. Check if opencv2 has already applied # TODO: frame contains no exif information. Check if opencv2 has already applied
frame = np.array(frame, dtype=np.float32) / 255.0 frame = np.array(frame, dtype=np.float32)
torch.from_numpy(frame).div_(255)
if prev_frame is not None: if prev_frame is not None:
inp = yield prev_frame inp = yield prev_frame
if inp is not None: if inp is not None:
@ -357,6 +571,8 @@ def cv_frame_generator(
return return
prev_frame = frame prev_frame = frame
frames_added += 1 frames_added += 1
if pbar is not None:
pbar.update_absolute(frames_added, frame_load_cap)
# if cap exists and we've reached it, stop processing frames # if cap exists and we've reached it, stop processing frames
if frame_load_cap > 0 and frames_added >= frame_load_cap: if frame_load_cap > 0 and frames_added >= frame_load_cap:
break break
@ -367,6 +583,17 @@ def cv_frame_generator(
yield prev_frame yield prev_frame
def batched(it, n):
while batch := tuple(itertools.islice(it, n)):
yield batch
def batched_vae_encode(images, vae, frames_per_batch):
for batch in batched(images, frames_per_batch):
image_batch = torch.from_numpy(np.array(batch))
yield from vae.encode(image_batch).numpy()
def load_video_cv( def load_video_cv(
video: str, video: str,
force_rate: int, force_rate: int,
@ -378,6 +605,8 @@ def load_video_cv(
select_every_nth: int, select_every_nth: int,
meta_batch=None, meta_batch=None,
unique_id=None, unique_id=None,
memory_limit_mb=None,
vae=None,
): ):
if meta_batch is None or unique_id not in meta_batch.inputs: if meta_batch is None or unique_id not in meta_batch.inputs:
gen = cv_frame_generator( gen = cv_frame_generator(
@ -401,30 +630,89 @@ def load_video_cv(
total_frames, total_frames,
target_frame_time, target_frame_time,
) )
meta_batch.total_frames = min(meta_batch.total_frames, next(gen))
else: else:
(gen, width, height, fps, duration, total_frames, target_frame_time) = ( (gen, width, height, fps, duration, total_frames, target_frame_time) = (
meta_batch.inputs[unique_id] meta_batch.inputs[unique_id]
) )
if meta_batch is not None: memory_limit = None
gen = itertools.islice(gen, meta_batch.frames_per_batch) if memory_limit_mb is not None:
memory_limit *= 2**20
else:
# TODO: verify if garbage collection should be performed here.
# leaves ~128 MB unreserved for safety
try:
memory_limit = (
psutil.virtual_memory().available + psutil.swap_memory().free
) - 2**27
except:
print(
"Failed to calculate available memory. Memory load limit has been disabled"
)
if memory_limit is not None:
if vae is not None:
# space required to load as f32, exist as latent with wiggle room, decode to f32
max_loadable_frames = int(
memory_limit // (width * height * 3 * (4 + 4 + 1 / 10))
)
else:
# TODO: use better estimate for when vae is not None
# Consider completely ignoring for load_latent case?
max_loadable_frames = int(memory_limit // (width * height * 3 * (0.1)))
if meta_batch is not None:
if meta_batch.frames_per_batch > max_loadable_frames:
raise RuntimeError(
f"Meta Batch set to {meta_batch.frames_per_batch} frames but only {max_loadable_frames} can fit in memory"
)
gen = itertools.islice(gen, meta_batch.frames_per_batch)
else:
original_gen = gen
gen = itertools.islice(gen, max_loadable_frames)
downscale_ratio = getattr(vae, "downscale_ratio", 8)
frames_per_batch = (1920 * 1080 * 16) // (width * height) or 1
if force_size != "Disabled" or vae is not None:
new_size = target_size(
width, height, force_size, custom_width, custom_height, downscale_ratio
)
if new_size[0] != width or new_size[1] != height:
# Some minor wizardry to eliminate a copy and reduce max memory by a factor of ~2 def rescale(frame):
images = torch.from_numpy( s = torch.from_numpy(
np.fromiter(gen, np.dtype((np.float32, (height, width, 3)))) np.fromiter(frame, np.dtype((np.float32, (height, width, 3))))
) )
s = s.movedim(-1, 1)
s = common_upscale(s, new_size[0], new_size[1], "lanczos", "center")
return s.movedim(1, -1).numpy()
gen = itertools.chain.from_iterable(
map(rescale, batched(gen, frames_per_batch))
)
else:
new_size = width, height
if vae is not None:
gen = batched_vae_encode(gen, vae, frames_per_batch)
vw, vh = new_size[0] // downscale_ratio, new_size[1] // downscale_ratio
images = torch.from_numpy(np.fromiter(gen, np.dtype((np.float32, (4, vh, vw)))))
else:
# Some minor wizardry to eliminate a copy and reduce max memory by a factor of ~2
images = torch.from_numpy(
np.fromiter(gen, np.dtype((np.float32, (new_size[1], new_size[0], 3))))
)
if meta_batch is None and memory_limit is not None:
try:
next(original_gen)
raise RuntimeError(
f"Memory limit hit after loading {len(images)} frames. Stopping execution."
)
except StopIteration:
pass
if len(images) == 0: if len(images) == 0:
raise RuntimeError("No frames generated") raise RuntimeError("No frames generated")
if force_size != "Disabled":
new_size = target_size(width, height, force_size, custom_width, custom_height)
if new_size[0] != width or new_size[1] != height:
s = images.movedim(-1, 1)
s = common_upscale(s, new_size[0], new_size[1], "lanczos", "center")
images = s.movedim(1, -1)
# Setup lambda for lazy audio capture # Setup lambda for lazy audio capture
audio = lambda: get_audio( audio = lazy_get_audio(
video, video,
skip_first_frames * target_frame_time, skip_first_frames * target_frame_time,
frame_load_cap * target_frame_time * select_every_nth, frame_load_cap * target_frame_time * select_every_nth,
@ -440,11 +728,13 @@ def load_video_cv(
"loaded_fps": 1 / target_frame_time, "loaded_fps": 1 / target_frame_time,
"loaded_frame_count": len(images), "loaded_frame_count": len(images),
"loaded_duration": len(images) * target_frame_time, "loaded_duration": len(images) * target_frame_time,
"loaded_width": images.shape[2], "loaded_width": new_size[0],
"loaded_height": images.shape[1], "loaded_height": new_size[1],
} }
if vae is None:
return (images, len(images), lazy_eval(audio), video_info) return (images, len(images), audio, video_info, None)
else:
return (None, len(images), audio, video_info, {"samples": images})
class ComfyUIDeployExternalVideo: class ComfyUIDeployExternalVideo:
@ -457,68 +747,38 @@ class ComfyUIDeployExternalVideo:
file_parts = f.split(".") file_parts = f.split(".")
if len(file_parts) > 1 and (file_parts[-1] in video_extensions): if len(file_parts) > 1 and (file_parts[-1] in video_extensions):
files.append(f) files.append(f)
return { return {"required": {
"required": { "input_id": (
"input_id": ( "STRING",
"STRING", {"multiline": False, "default": "input_video"},
{"multiline": False, "default": "input_video"}, ),
), "force_rate": ("INT", {"default": 0, "min": 0, "max": 60, "step": 1}),
"force_rate": ("INT", {"default": 0, "min": 0, "max": 60, "step": 1}), "force_size": (["Disabled", "Custom Height", "Custom Width", "Custom", "256x?", "?x256", "256x256", "512x?", "?x512", "512x512"],),
"force_size": ( "custom_width": ("INT", {"default": 512, "min": 0, "max": DIMMAX, "step": 8}),
[ "custom_height": ("INT", {"default": 512, "min": 0, "max": DIMMAX, "step": 8}),
"Disabled", "frame_load_cap": ("INT", {"default": 0, "min": 0, "max": BIGMAX, "step": 1}),
"Custom Height", "skip_first_frames": ("INT", {"default": 0, "min": 0, "max": BIGMAX, "step": 1}),
"Custom Width", "select_every_nth": ("INT", {"default": 1, "min": 1, "max": BIGMAX, "step": 1}),
"Custom", },
"256x?", "optional": {
"?x256", "meta_batch": ("VHS_BatchManager",),
"256x256", "vae": ("VAE",),
"512x?", "default_value": (sorted(files),),
"?x512", },
"512x512", "hidden": {
], "unique_id": "UNIQUE_ID"
), },
"custom_width": ( }
"INT",
{"default": 512, "min": 0, "max": DIMMAX, "step": 8},
),
"custom_height": (
"INT",
{"default": 512, "min": 0, "max": DIMMAX, "step": 8},
),
"frame_load_cap": (
"INT",
{"default": 0, "min": 0, "max": BIGMAX, "step": 1},
),
"skip_first_frames": (
"INT",
{"default": 0, "min": 0, "max": BIGMAX, "step": 1},
),
"select_every_nth": (
"INT",
{"default": 1, "min": 1, "max": BIGMAX, "step": 1},
),
},
"optional": {
"meta_batch": ("VHS_BatchManager",),
"default_value": (sorted(files),),
},
"hidden": {"unique_id": "UNIQUE_ID"},
}
CATEGORY = "Video Helper Suite 🎥🅥🅗🅢" CATEGORY = "Video Helper Suite 🎥🅥🅗🅢"
RETURN_TYPES = ( RETURN_TYPES = ("IMAGE", "INT", "AUDIO", "VHS_VIDEOINFO", "LATENT")
"IMAGE",
"INT",
"AUDIO",
"VHS_VIDEOINFO",
)
RETURN_NAMES = ( RETURN_NAMES = (
"IMAGE", "IMAGE",
"frame_count", "frame_count",
"audio", "audio",
"video_info", "video_info",
"LATENT",
) )
FUNCTION = "load_video" FUNCTION = "load_video"