mirror of
https://github.com/OpenSPG/KAG.git
synced 2025-07-10 18:46:23 +00:00
215 lines
6.8 KiB
Python
215 lines
6.8 KiB
Python
![]() |
# -*- coding: utf-8 -*-
|
||
|
# Copyright 2023 OpenSPG Authors
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||
|
# in compliance with the License. You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||
|
# or implied.
|
||
|
import os
|
||
|
import json
|
||
|
import argparse
|
||
|
import tempfile
|
||
|
import requests
|
||
|
from git import Repo
|
||
|
from kag.bin.base import Command
|
||
|
from kag.common.registry import Registrable
|
||
|
from kag.common.conf import KAG_PROJECT_CONF
|
||
|
from kag.common.utils import bold, green, reset
|
||
|
from openai import NotFoundError
|
||
|
|
||
|
|
||
|
@Command.register("submit_builder_job")
|
||
|
class BuilderJobSubmit(Command):
|
||
|
def add_to_parser(self, subparsers: argparse._SubParsersAction):
|
||
|
|
||
|
parser = subparsers.add_parser(
|
||
|
"builder", help="Submit distributed builder jobs to cluster"
|
||
|
)
|
||
|
|
||
|
parser.add_argument(
|
||
|
"--user_number",
|
||
|
type=str,
|
||
|
help="User number",
|
||
|
)
|
||
|
|
||
|
parser.add_argument(
|
||
|
"--host_addr",
|
||
|
default=None,
|
||
|
help="Host address of SPG server.",
|
||
|
)
|
||
|
|
||
|
parser.add_argument(
|
||
|
"--project_id",
|
||
|
default=None,
|
||
|
help="Project ID in SPG server.",
|
||
|
)
|
||
|
|
||
|
parser.add_argument(
|
||
|
"--git_url",
|
||
|
required=True,
|
||
|
type=str,
|
||
|
help="Git repository URL containing project source code (supports SSH/HTTP)",
|
||
|
)
|
||
|
|
||
|
parser.add_argument(
|
||
|
"--commit_id",
|
||
|
required=True,
|
||
|
type=str,
|
||
|
help="Git commit id containing project source code (supports SSH/HTTP)",
|
||
|
)
|
||
|
|
||
|
parser.add_argument(
|
||
|
"--init_script",
|
||
|
default=None,
|
||
|
help="Bash script path for worker container initialization.",
|
||
|
)
|
||
|
|
||
|
parser.add_argument(
|
||
|
"--entry_script",
|
||
|
type=str,
|
||
|
default=None,
|
||
|
help="Python entry script path. \n"
|
||
|
"Will be executed as: python <entry_script>",
|
||
|
)
|
||
|
|
||
|
parser.add_argument("--image", type=str, help="Worker image.")
|
||
|
parser.add_argument("--pool", type=str, help="Worker resource pool.")
|
||
|
|
||
|
parser.add_argument(
|
||
|
"--num_workers",
|
||
|
type=int,
|
||
|
default=1,
|
||
|
help="Number of parallel worker instances. \n",
|
||
|
)
|
||
|
|
||
|
parser.add_argument(
|
||
|
"--num_gpus",
|
||
|
type=int,
|
||
|
default=0,
|
||
|
help="GPUs per worker. Requires NVIDIA CUDA-enabled cluster. \n",
|
||
|
)
|
||
|
|
||
|
parser.add_argument(
|
||
|
"--num_cpus", type=int, default=8, help="CPU cores per worker."
|
||
|
)
|
||
|
|
||
|
# 存储资源配置
|
||
|
parser.add_argument(
|
||
|
"--memory",
|
||
|
type=int,
|
||
|
default=8,
|
||
|
help="Memory allocation per worker (GB).",
|
||
|
)
|
||
|
|
||
|
parser.add_argument(
|
||
|
"--storage",
|
||
|
type=int,
|
||
|
default=50,
|
||
|
help="Ephemeral disk space per worker (GB).",
|
||
|
)
|
||
|
|
||
|
parser.add_argument(
|
||
|
"--env",
|
||
|
type=str,
|
||
|
default="",
|
||
|
help="Environment variables, with each variable formatted as key=value and separated by commas: k1=v1, k2=v2",
|
||
|
)
|
||
|
|
||
|
parser.add_argument(
|
||
|
"--validity_check",
|
||
|
action="store_true",
|
||
|
help="Perform validity check.",
|
||
|
)
|
||
|
|
||
|
parser.set_defaults(func=self.get_handler())
|
||
|
|
||
|
@staticmethod
|
||
|
def get_cls(cls_name):
|
||
|
interface_classes = Registrable.list_all_registered(with_leaf_classes=False)
|
||
|
for item in interface_classes:
|
||
|
if item.__name__ == cls_name:
|
||
|
return item
|
||
|
raise ValueError(f"class {cls_name} is not a valid kag configurable class")
|
||
|
|
||
|
@staticmethod
|
||
|
def validity_check(args: argparse.Namespace):
|
||
|
with tempfile.TemporaryDirectory() as local_dir:
|
||
|
repo = Repo.clone_from(args.git_url, local_dir)
|
||
|
# parsed_url = parse(args.git_url)
|
||
|
repo.git.checkout(args.commit_id)
|
||
|
if args.init_script is not None:
|
||
|
if not os.path.exists(os.path.join(local_dir, args.init_script)):
|
||
|
raise NotFoundError(
|
||
|
f"init script {args.init_script} not found in git repo"
|
||
|
)
|
||
|
if not os.path.exists(os.path.join(local_dir, args.entry_script)):
|
||
|
raise ValueError(
|
||
|
f"entry script {args.entry_script} not found in git repo"
|
||
|
)
|
||
|
|
||
|
@staticmethod
|
||
|
def handler(args: argparse.Namespace):
|
||
|
work_dir = "src"
|
||
|
cmds = [
|
||
|
f"git clone {args.git_url} {work_dir}",
|
||
|
f"cd {work_dir}",
|
||
|
f"git checkout {args.commit_id}",
|
||
|
"/openspg_venv/bin/pip3.8 install -e . -i https://artifacts.antgroup-inc.cn/artifact/repositories/simple-dev/",
|
||
|
]
|
||
|
if args.validity_check:
|
||
|
BuilderJobSubmit.validity_check(args)
|
||
|
if args.init_script is not None:
|
||
|
cmds.append(f"sh {args.init_script}")
|
||
|
|
||
|
if args.entry_script is not None:
|
||
|
entry_script_dir = os.path.dirname(args.entry_script)
|
||
|
entry_script_name = os.path.basename(args.entry_script)
|
||
|
entry_cmd = f"cd {entry_script_dir} && python {entry_script_name}"
|
||
|
cmds.append(entry_cmd)
|
||
|
|
||
|
command = " && ".join(cmds)
|
||
|
|
||
|
envs = {}
|
||
|
if args.env:
|
||
|
|
||
|
kvs = args.env.split(",")
|
||
|
for kv in kvs:
|
||
|
key, value = kv.split("=")
|
||
|
envs[key.strip()] = value.strip()
|
||
|
|
||
|
if args.project_id is not None:
|
||
|
project_id = int(args.project_id)
|
||
|
else:
|
||
|
project_id = int(KAG_PROJECT_CONF.project_id)
|
||
|
req = {
|
||
|
"projectId": project_id,
|
||
|
"command": command,
|
||
|
"workerNum": args.num_workers,
|
||
|
"workerCpu": args.num_cpus,
|
||
|
"workerGpu": args.num_gpus,
|
||
|
"workerMemory": args.memory * 1024,
|
||
|
"workerStorage": args.storage * 1024,
|
||
|
"envs": envs,
|
||
|
}
|
||
|
if args.image:
|
||
|
req["image"] = args.image
|
||
|
if args.pool:
|
||
|
req["workerPool"] = args.pool
|
||
|
|
||
|
if args.user_number:
|
||
|
req["userNumber"] = args.user_number
|
||
|
|
||
|
if args.host_addr is not None:
|
||
|
host_addr = args.host_addr.rstrip("/")
|
||
|
else:
|
||
|
host_addr = KAG_PROJECT_CONF.host_addr.rstrip("/")
|
||
|
url = host_addr + "/public/v1/builder/kag/submit"
|
||
|
rsp = requests.post(url, json=req)
|
||
|
rsp.raise_for_status()
|
||
|
print(f"{bold}{green}Success submit job to server, info:{reset}")
|
||
|
print(json.dumps(rsp.json(), indent=4))
|