viam.app.ml_training_client
Attributes
Classes
gRPC client for working with ML training jobs. |
Module Contents
- viam.app.ml_training_client.LOGGER
- class viam.app.ml_training_client.MLTrainingClient(channel: grpclib.client.Channel, metadata: Mapping[str, str])[source]
gRPC client for working with ML training jobs.
Constructor is used by ViamClient to instantiate relevant service stubs. Calls to MLTrainingClient methods should be made through ViamClient.
Establish a Connection:
import asyncio from viam.rpc.dial import DialOptions, Credentials from viam.app.viam_client import ViamClient async def connect() -> ViamClient: # Replace "<API-KEY>" (including brackets) with your API key and "<API-KEY-ID>" with your API key ID dial_options = DialOptions.with_api_key("<API-KEY>", "<API-KEY-ID>") return await ViamClient.create_from_dial_options(dial_options) async def main(): # Make a ViamClient viam_client = await connect() # Instantiate an MLTrainingClient to run ML training client API methods on ml_training_client = viam_client.ml_training_client viam_client.close() if __name__ == '__main__': asyncio.run(main())
For more information, see ML Training Client API.
- async submit_training_job(org_id: str, dataset_id: str, model_name: str, model_version: str, model_type: viam.proto.app.mltraining.ModelType.ValueType, tags: List[str]) str [source]
Submit a training job.
from viam.proto.app.mltraining import ModelType job_id = await ml_training_client.submit_training_job( org_id="<organization-id>", dataset_id="<dataset-id>", model_name="<your-model-name>", model_version="1", model_type=ModelType.MODEL_TYPE_SINGLE_LABEL_CLASSIFICATION, tags=["tag1", "tag2"] )
- Parameters:
org_id (str) – the ID of the org to submit the training job to.
dataset_id (str) – the ID of the dataset to train the model on.
model_name (str) – the model name.
model_version (str) – the model version.
model_type (ModelType.ValueType) – the model type.
tags (List[str]) – the labels to train the model on.
- Returns:
the ID of the training job.
- Return type:
str
For more information, see ML Training Client API.
- async submit_custom_training_job(org_id: str, dataset_id: str, registry_item_id: str, registry_item_version: str, model_name: str, model_version: str) str [source]
Submit a custom training job.
job_id = await ml_training_client.submit_custom_training_job( org_id="<organization-id>", dataset_id="<dataset-id>", registry_item_id="viam:classification-tflite", registry_item_version="2024-08-13T12-11-54", model_name="<your-model-name>", model_version="1" )
- Parameters:
org_id (str) – the ID of the org to submit the training job to.
dataset_id (str) – the ID of the dataset to train the model on.
registry_item_id (str) – the ID of the training script from the registry.
registry_item_version (str) – the version of the training script from the registry.
model_name (str) – the model name.
model_version (str) – the model version.
- Returns:
the ID of the training job.
- Return type:
str
For more information, see ML Training Client API.
- async get_training_job(id: str) viam.proto.app.mltraining.TrainingJobMetadata [source]
Gets training job data.
job_metadata = await ml_training_client.get_training_job( id="<job-id>")
- Parameters:
id (str) – the ID of the requested training job.
- Returns:
the training job data.
- Return type:
For more information, see ML Training Client API.
- async list_training_jobs(org_id: str, training_status: viam.proto.app.mltraining.TrainingStatus.ValueType | None = None) List[viam.proto.app.mltraining.TrainingJobMetadata] [source]
Returns training job data for all jobs within an org.
jobs_metadata = await ml_training_client.list_training_jobs( org_id="<org-id>") first_job_id = jobs_metadata[1].id
- Parameters:
org_id (str) – the ID of the org to request training job data from.
training_status (Optional[TrainingStatus]) – the status to filter the training jobs list by. If unspecified, all training jobs will be returned.
- Returns:
the list of training job data.
- Return type:
For more information, see ML Training Client API.
- async cancel_training_job(id: str) None [source]
Cancels the specified training job.
await ml_training_client.cancel_training_job( id="<job-id>")
- Parameters:
id (str) – the ID of the job to cancel.
- Raises:
GRPCError – if no training job exists with the given ID.
For more information, see ML Training Client API.
- async delete_completed_training_job(id: str) None [source]
Delete a completed training job from the database, whether the job succeeded or failed.
await ml_training_client.delete_completed_training_job( id="<job-id>")
- Parameters:
id (str) – the ID of the training job to delete.
For more information, see ML Training Client API.