from typing import List, Mapping, Optional
from grpclib.client import Channel
from viam import logging
from viam.proto.app.mltraining import (
CancelTrainingJobRequest,
DeleteCompletedTrainingJobRequest,
GetTrainingJobRequest,
GetTrainingJobResponse,
ListTrainingJobsRequest,
ListTrainingJobsResponse,
MLTrainingServiceStub,
ModelType,
SubmitCustomTrainingJobRequest,
SubmitCustomTrainingJobResponse,
SubmitTrainingJobRequest,
SubmitTrainingJobResponse,
TrainingJobMetadata,
TrainingStatus,
)
LOGGER = logging.getLogger(__name__)
[docs]class MLTrainingClient:
"""gRPC client for working with ML training jobs.
Constructor is used by `ViamClient` to instantiate relevant service stubs.
Calls to `MLTrainingClient` methods should be made through `ViamClient`.
Establish a Connection::
import asyncio
from viam.rpc.dial import DialOptions, Credentials
from viam.app.viam_client import ViamClient
async def connect() -> ViamClient:
# Replace "<API-KEY>" (including brackets) with your API key and "<API-KEY-ID>" with your API key ID
dial_options = DialOptions.with_api_key("<API-KEY>", "<API-KEY-ID>")
return await ViamClient.create_from_dial_options(dial_options)
async def main():
# Make a ViamClient
viam_client = await connect()
# Instantiate an MLTrainingClient to run ML training client API methods on
ml_training_client = viam_client.ml_training_client
viam_client.close()
if __name__ == '__main__':
asyncio.run(main())
For more information, see `ML Training Client API <https://docs.viam.com/appendix/apis/ml-training-client/>`_.
"""
def __init__(self, channel: Channel, metadata: Mapping[str, str]):
"""Create a `MLTrainingClient` that maintains a connection to app.
Args:
channel (grpclib.client.Channel): Connection to app.
metadata (Mapping[str, str]): Required authorization token to send requests to app.
For more information, see `ML Training Client API <https://docs.viam.com/appendix/apis/ml-training-client/>`_.
"""
self._metadata = metadata
self._ml_training_client = MLTrainingServiceStub(channel)
self._channel = channel
[docs] async def submit_training_job(
self,
org_id: str,
dataset_id: str,
model_name: str,
model_version: str,
model_type: ModelType.ValueType,
tags: List[str],
) -> str:
"""Submit a training job.
::
from viam.proto.app.mltraining import ModelType
job_id = await ml_training_client.submit_training_job(
org_id="<organization-id>",
dataset_id="<dataset-id>",
model_name="<your-model-name>",
model_version="1",
model_type=ModelType.MODEL_TYPE_SINGLE_LABEL_CLASSIFICATION,
tags=["tag1", "tag2"]
)
Args:
org_id (str): the ID of the org to submit the training job to.
dataset_id (str): the ID of the dataset to train the model on.
model_name (str): the model name.
model_version (str): the model version.
model_type (ModelType.ValueType): the model type.
tags (List[str]): the labels to train the model on.
Returns:
str: the ID of the training job.
For more information, see `ML Training Client API <https://docs.viam.com/appendix/apis/ml-training-client/>`_.
"""
request = SubmitTrainingJobRequest(
dataset_id=dataset_id,
organization_id=org_id,
model_name=model_name,
model_version=model_version,
model_type=model_type,
tags=tags,
)
response: SubmitTrainingJobResponse = await self._ml_training_client.SubmitTrainingJob(request, metadata=self._metadata)
return response.id
[docs] async def submit_custom_training_job(
self, org_id: str, dataset_id: str, registry_item_id: str, registry_item_version: str, model_name: str, model_version: str
) -> str:
"""Submit a custom training job.
::
job_id = await ml_training_client.submit_custom_training_job(
org_id="<organization-id>",
dataset_id="<dataset-id>",
registry_item_id="viam:classification-tflite",
registry_item_version="2024-08-13T12-11-54",
model_name="<your-model-name>",
model_version="1"
)
Args:
org_id (str): the ID of the org to submit the training job to.
dataset_id (str): the ID of the dataset to train the model on.
registry_item_id (str): the ID of the training script from the registry.
registry_item_version (str): the version of the training script from the registry.
model_name (str): the model name.
model_version (str): the model version.
Returns:
str: the ID of the training job.
For more information, see `ML Training Client API <https://docs.viam.com/appendix/apis/ml-training-client/>`_.
"""
request = SubmitCustomTrainingJobRequest(
dataset_id=dataset_id,
registry_item_id=registry_item_id,
registry_item_version=registry_item_version,
organization_id=org_id,
model_name=model_name,
model_version=model_version,
)
response: SubmitCustomTrainingJobResponse = await self._ml_training_client.SubmitCustomTrainingJob(request, metadata=self._metadata)
return response.id
[docs] async def get_training_job(self, id: str) -> TrainingJobMetadata:
"""Gets training job data.
::
job_metadata = await ml_training_client.get_training_job(
id="<job-id>")
Args:
id (str): the ID of the requested training job.
Returns:
viam.proto.app.mltraining.TrainingJobMetadata: the training job data.
For more information, see `ML Training Client API <https://docs.viam.com/appendix/apis/ml-training-client/>`_.
"""
request = GetTrainingJobRequest(id=id)
response: GetTrainingJobResponse = await self._ml_training_client.GetTrainingJob(request, metadata=self._metadata)
return response.metadata
[docs] async def list_training_jobs(
self,
org_id: str,
training_status: Optional[TrainingStatus.ValueType] = None,
) -> List[TrainingJobMetadata]:
"""Returns training job data for all jobs within an org.
::
jobs_metadata = await ml_training_client.list_training_jobs(
org_id="<org-id>")
first_job_id = jobs_metadata[1].id
Args:
org_id (str): the ID of the org to request training job data from.
training_status (Optional[TrainingStatus]): the status to filter the training jobs list by.
If unspecified, all training jobs will be returned.
Returns:
List[viam.proto.app.mltraining.TrainingJobMetadata]: the list of training job data.
For more information, see `ML Training Client API <https://docs.viam.com/appendix/apis/ml-training-client/>`_.
"""
training_status = training_status if training_status else TrainingStatus.TRAINING_STATUS_UNSPECIFIED
request = ListTrainingJobsRequest(organization_id=org_id, status=training_status)
response: ListTrainingJobsResponse = await self._ml_training_client.ListTrainingJobs(request, metadata=self._metadata)
return list(response.jobs)
[docs] async def cancel_training_job(self, id: str) -> None:
"""Cancels the specified training job.
::
await ml_training_client.cancel_training_job(
id="<job-id>")
Args:
id (str): the ID of the job to cancel.
Raises:
GRPCError: if no training job exists with the given ID.
For more information, see `ML Training Client API <https://docs.viam.com/appendix/apis/ml-training-client/>`_.
"""
request = CancelTrainingJobRequest(id=id)
await self._ml_training_client.CancelTrainingJob(request, metadata=self._metadata)
[docs] async def delete_completed_training_job(self, id: str) -> None:
"""Delete a completed training job from the database, whether the job succeeded or failed.
::
await ml_training_client.delete_completed_training_job(
id="<job-id>")
Args:
id (str): the ID of the training job to delete.
For more information, see `ML Training Client API <https://docs.viam.com/appendix/apis/ml-training-client/>`_.
"""
request = DeleteCompletedTrainingJobRequest(id=id)
await self._ml_training_client.DeleteCompletedTrainingJob(request, metadata=self._metadata)