working
This commit is contained in:
		
							parent
							
								
									911cc8d16b
								
							
						
					
					
						commit
						db04d02d34
					
				@ -47,6 +47,7 @@ machine_id_websocket_dict = {}
 | 
			
		||||
machine_id_status = {}
 | 
			
		||||
 | 
			
		||||
fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
 | 
			
		||||
civitai_api_key = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FlyReplayMiddleware(BaseHTTPMiddleware):
 | 
			
		||||
@ -272,7 +273,8 @@ async def upload_logic(body: UploadBody):
 | 
			
		||||
            "checkpoint_id": body.checkpoint_id,
 | 
			
		||||
            "volume_id": body.volume_id,
 | 
			
		||||
            "folder_path": upload_path,
 | 
			
		||||
        }
 | 
			
		||||
        },
 | 
			
		||||
        "civitai_api_key": os.environ.get('CIVITAI_API_KEY')
 | 
			
		||||
    }
 | 
			
		||||
    with open(f"{folder_path}/config.py", "w") as f:
 | 
			
		||||
        f.write("config = " + json.dumps(config))
 | 
			
		||||
 | 
			
		||||
@ -25,16 +25,13 @@ vol_name_to_links = config["volume_names"]
 | 
			
		||||
vol_name_to_path = config["volume_paths"]
 | 
			
		||||
callback_url = config["callback_url"]
 | 
			
		||||
callback_body = config["callback_body"]
 | 
			
		||||
civitai_key = config["civitai_api_key"]
 | 
			
		||||
 | 
			
		||||
volumes = create_volumes(vol_name_to_links, vol_name_to_path)
 | 
			
		||||
image = ( 
 | 
			
		||||
   modal.Image.debian_slim().apt_install("wget").pip_install("requests")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
print(vol_name_to_links)
 | 
			
		||||
print(vol_name_to_path)
 | 
			
		||||
print(volumes)
 | 
			
		||||
 | 
			
		||||
# download config { "download_url": "", "folder_path": ""}
 | 
			
		||||
timeout=5000
 | 
			
		||||
@stub.function(volumes=volumes, image=image, timeout=timeout, gpu=None)
 | 
			
		||||
@ -42,20 +39,23 @@ def download_model(volume_name, download_config):
 | 
			
		||||
    import requests
 | 
			
		||||
    download_url = download_config["download_url"]
 | 
			
		||||
    folder_path = download_config["folder_path"]
 | 
			
		||||
 | 
			
		||||
    volume_base_path = vol_name_to_path[volume_name]
 | 
			
		||||
    model_store_path = os.path.join(volume_base_path, folder_path)
 | 
			
		||||
    modified_download_url = download_url + ("&" if "?" in download_url else "?") + "token=" + civitai_key
 | 
			
		||||
    print('downlodaing', modified_download_url)
 | 
			
		||||
 | 
			
		||||
    subprocess.run(["wget", download_url, "--content-disposition", "-P", model_store_path])
 | 
			
		||||
    subprocess.run(["wget", modified_download_url , "--content-disposition", "-P", model_store_path])
 | 
			
		||||
    subprocess.run(["ls", "-la", volume_base_path])
 | 
			
		||||
    subprocess.run(["ls", "-la", model_store_path])
 | 
			
		||||
    volumes[volume_base_path].commit()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    status =  {"status": "success"}
 | 
			
		||||
    requests.post(callback_url, json={**status, **callback_body})
 | 
			
		||||
    print(f"finished! sending to {callback_url}")
 | 
			
		||||
    pprint({**status, **callback_body})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@stub.local_entrypoint()
 | 
			
		||||
def simple_download():
 | 
			
		||||
    import requests
 | 
			
		||||
 | 
			
		||||
@ -13,5 +13,6 @@ config = {
 | 
			
		||||
        "checkpoint_id": "",
 | 
			
		||||
        "volume_id": "",
 | 
			
		||||
        "folder_path": "images",
 | 
			
		||||
    }
 | 
			
		||||
    }, 
 | 
			
		||||
    "civitai_api_key": "",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -26,7 +26,8 @@ export async function POST(request: Request) {
 | 
			
		||||
      .update(checkpointTable)
 | 
			
		||||
      .set({
 | 
			
		||||
        status: "success",
 | 
			
		||||
        folder_path 
 | 
			
		||||
        folder_path,
 | 
			
		||||
        updated_at: new Date(),
 | 
			
		||||
        // build_log: build_log,
 | 
			
		||||
      })
 | 
			
		||||
      .where(eq(checkpointTable.id, checkpoint_id));
 | 
			
		||||
@ -37,6 +38,7 @@ export async function POST(request: Request) {
 | 
			
		||||
      .set({
 | 
			
		||||
        status: "failed",
 | 
			
		||||
        error_log, 
 | 
			
		||||
        updated_at: new Date(),
 | 
			
		||||
        // status: "error",
 | 
			
		||||
        // build_log: build_log,
 | 
			
		||||
      })
 | 
			
		||||
 | 
			
		||||
@ -116,6 +116,38 @@ export const columns: ColumnDef<CheckpointItemList>[] = [
 | 
			
		||||
          {row.original.status}
 | 
			
		||||
        </Badge>
 | 
			
		||||
      );
 | 
			
		||||
      // const oneHourAgo = new Date(new Date().getTime() - (60 * 60 * 1000));
 | 
			
		||||
      // const lastUpdated = new Date(row.original.updated_at);
 | 
			
		||||
      // const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo;
 | 
			
		||||
      // const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo;
 | 
			
		||||
      // cell: ({ row }) => {
 | 
			
		||||
      //   // const oneHourAgo = new Date(new Date().getTime() - (60 * 60 * 1000));
 | 
			
		||||
      //   // const lastUpdated = new Date(row.original.updated_at);
 | 
			
		||||
      //   // const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo;
 | 
			
		||||
      //   const canReDownload = true;
 | 
			
		||||
      //
 | 
			
		||||
      //   return (
 | 
			
		||||
      //     <div className="flex items-center space-x-2">
 | 
			
		||||
      //       <Badge
 | 
			
		||||
      //         variant={row.original.status === "failed"
 | 
			
		||||
      //           ? "red"
 | 
			
		||||
      //           : row.original.status === "started"
 | 
			
		||||
      //           ? "yellow"
 | 
			
		||||
      //           : "green"}
 | 
			
		||||
      //       >
 | 
			
		||||
      //         {row.original.status}
 | 
			
		||||
      //       </Badge>
 | 
			
		||||
      //       {canReDownload && (
 | 
			
		||||
      //         <RefreshCcw
 | 
			
		||||
      //           onClick={() => {
 | 
			
		||||
      //             redownloadCheckpoint(row.original);
 | 
			
		||||
      //           }}
 | 
			
		||||
      //           className="h-4 w-4 cursor-pointer" // Adjust the size with h-x and w-x classes
 | 
			
		||||
      //         />
 | 
			
		||||
      //       )}
 | 
			
		||||
      //     </div>
 | 
			
		||||
      //   );
 | 
			
		||||
      // },
 | 
			
		||||
    },
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
 | 
			
		||||
@ -406,13 +406,47 @@ export const checkpointTable = dbSchema.table("checkpoints", {
 | 
			
		||||
 | 
			
		||||
  is_public: boolean("is_public").notNull().default(false),
 | 
			
		||||
  status: resourceUpload("status").notNull().default("started"),
 | 
			
		||||
  upload_machine_id: text("upload_machine_id"),
 | 
			
		||||
  upload_machine_id: text("upload_machine_id"), // TODO: review if actually needed
 | 
			
		||||
  upload_type: modelUploadType("upload_type").notNull(),
 | 
			
		||||
  error_log: text("error_log"),
 | 
			
		||||
  created_at: timestamp("created_at").defaultNow().notNull(),
 | 
			
		||||
  updated_at: timestamp("updated_at").defaultNow().notNull(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export const checkpointVolumeTable = dbSchema.table("checkpoint_volume", {
 | 
			
		||||
  id: uuid("id").primaryKey().defaultRandom().notNull(),
 | 
			
		||||
  user_id: text("user_id").references(() => usersTable.id, {
 | 
			
		||||
    // onDelete: "cascade",
 | 
			
		||||
  }),
 | 
			
		||||
  org_id: text("org_id"),
 | 
			
		||||
  volume_name: text("volume_name").notNull(),
 | 
			
		||||
  created_at: timestamp("created_at").defaultNow().notNull(),
 | 
			
		||||
  updated_at: timestamp("updated_at").defaultNow().notNull(),
 | 
			
		||||
  disabled: boolean("disabled").default(false).notNull(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export const checkpointRelations = relations(checkpointTable, ({ one }) => ({
 | 
			
		||||
  user: one(usersTable, {
 | 
			
		||||
    fields: [checkpointTable.user_id],
 | 
			
		||||
    references: [usersTable.id],
 | 
			
		||||
  }),
 | 
			
		||||
  volume: one(checkpointVolumeTable, {
 | 
			
		||||
    fields: [checkpointTable.checkpoint_volume_id],
 | 
			
		||||
    references: [checkpointVolumeTable.id],
 | 
			
		||||
  }),
 | 
			
		||||
}));
 | 
			
		||||
 | 
			
		||||
export const checkpointVolumeRelations = relations(
 | 
			
		||||
  checkpointVolumeTable,
 | 
			
		||||
  ({ many, one }) => ({
 | 
			
		||||
    checkpoint: many(checkpointTable),
 | 
			
		||||
    user: one(usersTable, {
 | 
			
		||||
      fields: [checkpointVolumeTable.user_id],
 | 
			
		||||
      references: [usersTable.id],
 | 
			
		||||
    }),
 | 
			
		||||
  })
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
export const subscriptionPlan = pgEnum("subscription_plan", [
 | 
			
		||||
  "basic",
 | 
			
		||||
  "pro",
 | 
			
		||||
@ -452,40 +486,6 @@ export const insertCivitaiCheckpointSchema = createInsertSchema(
 | 
			
		||||
  }
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
export const checkpointVolumeTable = dbSchema.table("checkpoint_volume", {
 | 
			
		||||
  id: uuid("id").primaryKey().defaultRandom().notNull(),
 | 
			
		||||
  user_id: text("user_id").references(() => usersTable.id, {
 | 
			
		||||
    // onDelete: "cascade",
 | 
			
		||||
  }),
 | 
			
		||||
  org_id: text("org_id"),
 | 
			
		||||
  volume_name: text("volume_name").notNull(),
 | 
			
		||||
  created_at: timestamp("created_at").defaultNow().notNull(),
 | 
			
		||||
  updated_at: timestamp("updated_at").defaultNow().notNull(),
 | 
			
		||||
  disabled: boolean("disabled").default(false).notNull(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export const checkpointRelations = relations(checkpointTable, ({ one }) => ({
 | 
			
		||||
  user: one(usersTable, {
 | 
			
		||||
    fields: [checkpointTable.user_id],
 | 
			
		||||
    references: [usersTable.id],
 | 
			
		||||
  }),
 | 
			
		||||
  volume: one(checkpointVolumeTable, {
 | 
			
		||||
    fields: [checkpointTable.checkpoint_volume_id],
 | 
			
		||||
    references: [checkpointVolumeTable.id],
 | 
			
		||||
  }),
 | 
			
		||||
}));
 | 
			
		||||
 | 
			
		||||
export const checkpointVolumeRelations = relations(
 | 
			
		||||
  checkpointVolumeTable,
 | 
			
		||||
  ({ many, one }) => ({
 | 
			
		||||
    checkpoint: many(checkpointTable),
 | 
			
		||||
    user: one(usersTable, {
 | 
			
		||||
      fields: [checkpointVolumeTable.user_id],
 | 
			
		||||
      references: [usersTable.id],
 | 
			
		||||
    }),
 | 
			
		||||
  })
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
export type UserType = InferSelectModel<typeof usersTable>;
 | 
			
		||||
export type WorkflowType = InferSelectModel<typeof workflowTable>;
 | 
			
		||||
export type MachineType = InferSelectModel<typeof machinesTable>;
 | 
			
		||||
 | 
			
		||||
@ -15,6 +15,7 @@ import { headers } from "next/headers";
 | 
			
		||||
import { addCivitaiCheckpointSchema } from "./addCheckpointSchema";
 | 
			
		||||
import { and, eq, isNull } from "drizzle-orm";
 | 
			
		||||
import { CivitaiModelResponse } from "@/types/civitai";
 | 
			
		||||
import { CheckpointItemList } from "@/components/CheckpointList";
 | 
			
		||||
 | 
			
		||||
export async function getCheckpoints() {
 | 
			
		||||
  const { userId, orgId } = auth();
 | 
			
		||||
@ -74,12 +75,12 @@ export async function getCheckpointVolumes() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export async function retrieveCheckpointVolumes() {
 | 
			
		||||
  let volumes = await getCheckpointVolumes()
 | 
			
		||||
  let volumes = await getCheckpointVolumes();
 | 
			
		||||
  if (volumes.length === 0) {
 | 
			
		||||
    // create volume if not already created
 | 
			
		||||
    volumes = await addCheckpointVolume()
 | 
			
		||||
  } 
 | 
			
		||||
  return volumes
 | 
			
		||||
    volumes = await addCheckpointVolume();
 | 
			
		||||
  }
 | 
			
		||||
  return volumes;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export async function addCheckpointVolume() {
 | 
			
		||||
@ -97,7 +98,7 @@ export async function addCheckpointVolume() {
 | 
			
		||||
      disabled: false, // Default value
 | 
			
		||||
    })
 | 
			
		||||
    .returning(); // Returns the inserted row
 | 
			
		||||
return insertedVolume;
 | 
			
		||||
  return insertedVolume;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
function getUrl(civitai_url: string) {
 | 
			
		||||
@ -115,30 +116,24 @@ function getUrl(civitai_url: string) {
 | 
			
		||||
export const addCivitaiCheckpoint = withServerPromise(
 | 
			
		||||
  async (data: z.infer<typeof addCivitaiCheckpointSchema>) => {
 | 
			
		||||
    const { userId, orgId } = auth();
 | 
			
		||||
    console.log("START")
 | 
			
		||||
    console.log("1");
 | 
			
		||||
 | 
			
		||||
    if (!data.civitai_url) return { error: "no civitai_url" };
 | 
			
		||||
    console.log("2");
 | 
			
		||||
    if (!userId) return { error: "No user id" };
 | 
			
		||||
    console.log("3");
 | 
			
		||||
 | 
			
		||||
    const { url, modelVersionId } = getUrl(data?.civitai_url);
 | 
			
		||||
    console.log("4", url, modelVersionId);
 | 
			
		||||
    const civitaiModelRes = await fetch(url)
 | 
			
		||||
      .then((x) => x.json())
 | 
			
		||||
      .then((a) => {
 | 
			
		||||
        console.log(a)
 | 
			
		||||
        console.log(a);
 | 
			
		||||
        return CivitaiModelResponse.parse(a);
 | 
			
		||||
      });
 | 
			
		||||
    console.log("5");
 | 
			
		||||
 | 
			
		||||
    if (civitaiModelRes?.modelVersions?.length === 0) {
 | 
			
		||||
      console.log("6");
 | 
			
		||||
      return; // no versions to download
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    console.log("7");
 | 
			
		||||
    let selectedModelVersion;
 | 
			
		||||
    let selectedModelVersionId: string | null = modelVersionId;
 | 
			
		||||
    if (!selectedModelVersionId) {
 | 
			
		||||
@ -153,23 +148,16 @@ export const addCivitaiCheckpoint = withServerPromise(
 | 
			
		||||
      }
 | 
			
		||||
      selectedModelVersionId = selectedModelVersion?.id.toString();
 | 
			
		||||
    }
 | 
			
		||||
    console.log("8");
 | 
			
		||||
 | 
			
		||||
    const checkpointVolumes = await getCheckpointVolumes();
 | 
			
		||||
    console.log("9");
 | 
			
		||||
    let cVolume;
 | 
			
		||||
    if (checkpointVolumes.length === 0) {
 | 
			
		||||
      console.log("10");
 | 
			
		||||
      const volume = await addCheckpointVolume();
 | 
			
		||||
      console.log("11");
 | 
			
		||||
      cVolume = volume[0];
 | 
			
		||||
    } else {
 | 
			
		||||
      console.log("12");
 | 
			
		||||
      cVolume = checkpointVolumes[0];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    console.log("13");
 | 
			
		||||
 | 
			
		||||
    const a = await db
 | 
			
		||||
      .insert(checkpointTable)
 | 
			
		||||
      .values({
 | 
			
		||||
@ -183,18 +171,52 @@ export const addCivitaiCheckpoint = withServerPromise(
 | 
			
		||||
        civitai_download_url: selectedModelVersion.files[0].downloadUrl,
 | 
			
		||||
        civitai_model_response: civitaiModelRes,
 | 
			
		||||
        checkpoint_volume_id: cVolume.id,
 | 
			
		||||
        updated_at: new Date(),
 | 
			
		||||
      })
 | 
			
		||||
      .returning();
 | 
			
		||||
    console.log("14");
 | 
			
		||||
 | 
			
		||||
    const b = a[0];
 | 
			
		||||
 | 
			
		||||
    await uploadCheckpoint(data, b, cVolume);
 | 
			
		||||
    console.log("15");
 | 
			
		||||
    // redirect(`/checkpoints/${b.id}`);
 | 
			
		||||
  },
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
// export const redownloadCheckpoint = withServerPromise(
 | 
			
		||||
//   async (data: CheckpointItemList) => {
 | 
			
		||||
//     const { userId } = auth();
 | 
			
		||||
//     if (!userId) return { error: "No user id" };
 | 
			
		||||
//
 | 
			
		||||
//     const checkpointVolumes = await getCheckpointVolumes();
 | 
			
		||||
//     let cVolume;
 | 
			
		||||
//     if (checkpointVolumes.length === 0) {
 | 
			
		||||
//       const volume = await addCheckpointVolume();
 | 
			
		||||
//       cVolume = volume[0];
 | 
			
		||||
//     } else {
 | 
			
		||||
//       cVolume = checkpointVolumes[0];
 | 
			
		||||
//     }
 | 
			
		||||
//
 | 
			
		||||
//     console.log("data");
 | 
			
		||||
//     console.log(data);
 | 
			
		||||
//
 | 
			
		||||
//     const a = await db
 | 
			
		||||
//       .update(checkpointTable)
 | 
			
		||||
//       .set({
 | 
			
		||||
//         // status: "started",
 | 
			
		||||
//         // updated_at: new Date(),
 | 
			
		||||
//       })
 | 
			
		||||
//       .returning();
 | 
			
		||||
//
 | 
			
		||||
//     const b = a[0];
 | 
			
		||||
//
 | 
			
		||||
//     console.log("b");
 | 
			
		||||
//     console.log(b);
 | 
			
		||||
//
 | 
			
		||||
//     await uploadCheckpoint(data, b, cVolume);
 | 
			
		||||
//     // redirect(`/checkpoints/${b.id}`);
 | 
			
		||||
//   },
 | 
			
		||||
// );
 | 
			
		||||
 | 
			
		||||
async function uploadCheckpoint(
 | 
			
		||||
  data: z.infer<typeof addCivitaiCheckpointSchema>,
 | 
			
		||||
  c: CheckpointType,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user