didicout de35b970b7 model interfaces
(cherry picked from commit 4c556c39f6fdefd8755d9afcec247086f5d47380)
2023-12-11 23:16:12 +08:00

71 lines
1.5 KiB
Python

from abc import ABC, abstractmethod
class NNExecutor(ABC):
"""
Entry point of model execution in a certain pod.
"""
@classmethod
def from_config(cls, nn_config, **kwargs):
pass
def __init__(self,
backend_model,
backend_tokenizer,
init_args,
**kwargs):
self.backend_model = backend_model
self.backend_tokenizer = backend_tokenizer
self.init_args = init_args
self.kwargs = kwargs
class LLMExecutor(NNExecutor):
@classmethod
def from_config(cls, nn_config: dict, **kwargs):
"""
Args:
nn_config
"""
# TODO
pass
@abstractmethod
def sft(self, args=None, callbacks=None, **kwargs):
"""
The entry point of SFT execution in a certain pod.
"""
raise NotImplementedError(f"{self.__class__.__name__} does not support SFT.")
@abstractmethod
def rl_tuning(self, args=None, callbacks=None, **kwargs):
"""
The entry point of SFT execution in a certain pod.
"""
raise NotImplementedError(f"{self.__class__.__name__} does not support RL-Tuning.")
def batch_inference(self, args, **kwargs):
pass
@abstractmethod
def inference(self, data, **kwargs):
"""
The entry point of inference. Usually for local invokers or model services.
"""
raise NotImplementedError()
class HfLLMExecutor(NNExecutor):
pass
class DeepKeExecutor(NNExecutor):
pass