viam.app.ml_training_client

Attributes

LOGGER

Classes

MLTrainingClient

gRPC client for working with ML training jobs.

Module Contents

viam.app.ml_training_client.LOGGER
class viam.app.ml_training_client.MLTrainingClient(channel: grpclib.client.Channel, metadata: Mapping[str, str])[source]

gRPC client for working with ML training jobs.

Constructor is used by ViamClient to instantiate relevant service stubs. Calls to MLTrainingClient methods should be made through ViamClient.

Establish a Connection:

import asyncio

from viam.rpc.dial import DialOptions, Credentials
from viam.app.viam_client import ViamClient


async def connect() -> ViamClient:
    # Replace "<API-KEY>" (including brackets) with your API key and "<API-KEY-ID>" with your API key ID
    dial_options = DialOptions.with_api_key("<API-KEY>", "<API-KEY-ID>")
    return await ViamClient.create_from_dial_options(dial_options)


async def main():

    # Make a ViamClient
    viam_client = await connect()
    # Instantiate an MLTrainingClient to run ML training client API methods on
    ml_training_client = viam_client.ml_training_client

    viam_client.close()

if __name__ == '__main__':
    asyncio.run(main())

For more information, see ML Training Client API.

async submit_training_job(org_id: str, dataset_id: str, model_name: str, model_version: str, model_type: viam.proto.app.mltraining.ModelType.ValueType, tags: List[str]) str[source]

Submit a training job.

from viam.proto.app.mltraining import ModelType

job_id = await ml_training_client.submit_training_job(
    org_id="<organization-id>",
    dataset_id="<dataset-id>",
    model_name="<your-model-name>",
    model_version="1",
    model_type=ModelType.MODEL_TYPE_SINGLE_LABEL_CLASSIFICATION,
    tags=["tag1", "tag2"]
)
Parameters:
  • org_id (str) – the ID of the org to submit the training job to.

  • dataset_id (str) – the ID of the dataset to train the model on.

  • model_name (str) – the model name.

  • model_version (str) – the model version.

  • model_type (ModelType.ValueType) – the model type.

  • tags (List[str]) – the labels to train the model on.

Returns:

the ID of the training job.

Return type:

str

For more information, see ML Training Client API.

async submit_custom_training_job(org_id: str, dataset_id: str, registry_item_id: str, registry_item_version: str, model_name: str, model_version: str) str[source]

Submit a custom training job.

job_id = await ml_training_client.submit_custom_training_job(
    org_id="<organization-id>",
    dataset_id="<dataset-id>",
    registry_item_id="viam:classification-tflite",
    registry_item_version="2024-08-13T12-11-54",
    model_name="<your-model-name>",
    model_version="1"
)
Parameters:
  • org_id (str) – the ID of the org to submit the training job to.

  • dataset_id (str) – the ID of the dataset to train the model on.

  • registry_item_id (str) – the ID of the training script from the registry.

  • registry_item_version (str) – the version of the training script from the registry.

  • model_name (str) – the model name.

  • model_version (str) – the model version.

Returns:

the ID of the training job.

Return type:

str

For more information, see ML Training Client API.

async get_training_job(id: str) viam.proto.app.mltraining.TrainingJobMetadata[source]

Gets training job data.

job_metadata = await ml_training_client.get_training_job(
    id="<job-id>")
Parameters:

id (str) – the ID of the requested training job.

Returns:

the training job data.

Return type:

viam.proto.app.mltraining.TrainingJobMetadata

For more information, see ML Training Client API.

async list_training_jobs(org_id: str, training_status: viam.proto.app.mltraining.TrainingStatus.ValueType | None = None) List[viam.proto.app.mltraining.TrainingJobMetadata][source]

Returns training job data for all jobs within an org.

jobs_metadata = await ml_training_client.list_training_jobs(
    org_id="<org-id>")

first_job_id = jobs_metadata[1].id
Parameters:
  • org_id (str) – the ID of the org to request training job data from.

  • training_status (Optional[TrainingStatus]) – the status to filter the training jobs list by. If unspecified, all training jobs will be returned.

Returns:

the list of training job data.

Return type:

List[viam.proto.app.mltraining.TrainingJobMetadata]

For more information, see ML Training Client API.

async cancel_training_job(id: str) None[source]

Cancels the specified training job.

await ml_training_client.cancel_training_job(
    id="<job-id>")
Parameters:

id (str) – the ID of the job to cancel.

Raises:

GRPCError – if no training job exists with the given ID.

For more information, see ML Training Client API.

async delete_completed_training_job(id: str) None[source]

Delete a completed training job from the database, whether the job succeeded or failed.

await ml_training_client.delete_completed_training_job(
    id="<job-id>")
Parameters:

id (str) – the ID of the training job to delete.

For more information, see ML Training Client API.