Source code for racket.operations.load

import logging

from tensorflow_serving.apis import model_management_pb2
from tensorflow_serving.config import model_server_config_pb2

from racket.utils import Printer as p
from racket.models.channel import Channel
from racket.managers.server import ServerManager

log = logging.getLogger('root')


[docs]class ModelLoader: """ This class provides the interface to load new models into TensorFlow Serving. This is implemented through a gRPC call to the TFS api which triggers it to look for directories matching the name of the model specified """ channel = Channel.service_channel() request = model_management_pb2.ReloadConfigRequest() model_server_config = model_server_config_pb2.ModelServerConfig() conf = model_server_config_pb2.ModelConfigList() @classmethod def set_config(cls, model_name: str) -> None: config = cls.conf.config.add() config.name = model_name config.base_path = '/models/' + model_name config.model_platform = 'tensorflow' cls.model_server_config.model_config_list.CopyFrom(cls.conf) cls.request.config.CopyFrom(cls.model_server_config)
[docs] @classmethod def load(cls, model_name: str) -> None: """Load model This will send the gRPC request. In particular, it will open a gRPC channel and communicate with the ReloadConfigRequest api to inform TFS of a change in configuration Parameters ---------- model_name : str Name of the model, as specified in the instantiated Learner class Returns ------- None """ cls.set_config(model_name) log.info(cls.request.IsInitialized()) log.info(cls.request.ListFields()) response = cls.channel.HandleReloadConfigRequest(cls.request, ServerManager.PREDICTION_TIMEOUT) if response.status.error_code == 0: p.print_success(f'Loaded model {model_name} successfully') else: p.print_error(f'Loading failed, {response.status.error_code}: {response.status.error_message}')
if __name__ == '__main__': ModelLoader.load('keras-simple-regression')