51 lines
1.4 KiB
Python
51 lines
1.4 KiB
Python
import inspect
|
|
import shutil
|
|
import socket
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import List
|
|
|
|
MODAL_GPU = "A10G"
|
|
START_CMD = "bash /webui.sh -f --lora-dir '/models/loras' --embeddings-dir '/models/embeddings' --ckpt-dir '/models/checkpoints/' --vae-dir '/models/vaes/' --xformers"
|
|
ALWAYS_GET_LATEST_A1111 = True
|
|
|
|
|
|
def wait_for_port(port: int):
|
|
while True:
|
|
try:
|
|
with socket.create_connection(("127.0.0.1", port), timeout=5.0):
|
|
break
|
|
except OSError:
|
|
time.sleep(0.1)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class UserModels:
|
|
checkpoints_urls: List[str] = field(default_factory=list)
|
|
vae_urls: List[str] = field(default_factory=list)
|
|
loras_urls: List[str] = field(default_factory=list)
|
|
embeddings_urls: List[str] = field(default_factory=list)
|
|
|
|
@classmethod
|
|
def from_dict(cls, env):
|
|
return cls(**{
|
|
k: v for k, v in env.items()
|
|
if k in inspect.signature(cls).parameters
|
|
})
|
|
|
|
|
|
def get_urls():
|
|
""" should only be ran locally """
|
|
import yaml
|
|
|
|
try:
|
|
with open("./config.yaml") as f:
|
|
try:
|
|
config = yaml.safe_load(f)
|
|
return UserModels.from_dict(config)
|
|
except yaml.YAMLError as exc:
|
|
print(exc)
|
|
except FileNotFoundError:
|
|
shutil.copyfile("./config.example.yaml", "./config.yaml")
|
|
return UserModels()
|