mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-07-31 12:56:19 +00:00
71 lines
1.5 KiB
Python
71 lines
1.5 KiB
Python
from abc import ABC, abstractmethod
|
|
|
|
|
|
class NNExecutor(ABC):
|
|
"""
|
|
Entry point of model execution in a certain pod.
|
|
"""
|
|
|
|
@classmethod
|
|
def from_config(cls, nn_config, **kwargs):
|
|
pass
|
|
|
|
def __init__(self,
|
|
backend_model,
|
|
backend_tokenizer,
|
|
init_args,
|
|
**kwargs):
|
|
self.backend_model = backend_model
|
|
self.backend_tokenizer = backend_tokenizer
|
|
self.init_args = init_args
|
|
self.kwargs = kwargs
|
|
|
|
|
|
class LLMExecutor(NNExecutor):
|
|
|
|
@classmethod
|
|
def from_config(cls, nn_config: dict, **kwargs):
|
|
"""
|
|
Args:
|
|
nn_config
|
|
"""
|
|
|
|
# TODO
|
|
pass
|
|
|
|
@abstractmethod
|
|
def sft(self, args=None, callbacks=None, **kwargs):
|
|
"""
|
|
The entry point of SFT execution in a certain pod.
|
|
"""
|
|
raise NotImplementedError(f"{self.__class__.__name__} does not support SFT.")
|
|
|
|
@abstractmethod
|
|
def rl_tuning(self, args=None, callbacks=None, **kwargs):
|
|
"""
|
|
The entry point of SFT execution in a certain pod.
|
|
"""
|
|
raise NotImplementedError(f"{self.__class__.__name__} does not support RL-Tuning.")
|
|
|
|
def batch_inference(self, args, **kwargs):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def inference(self, data, **kwargs):
|
|
"""
|
|
The entry point of inference. Usually for local invokers or model services.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
|
|
class HfLLMExecutor(NNExecutor):
|
|
|
|
pass
|
|
|
|
|
|
class DeepKeExecutor(NNExecutor):
|
|
|
|
pass
|
|
|
|
|