hyfiml.models.text_classifier

hyfiml.models.text_classifier#

Module: text_classifier.py

This module provides a TextClassifier class for performing text classification tasks using transformer models from the Hugging Face library. The TextClassifier class allows loading datasets, preprocessing data, training models, making predictions, and identifying potential label errors.

Classes:
  • TrainingConfig: A Pydantic BaseModel class for configuring training arguments.

  • DatasetConfig: A Pydantic BaseModel class for configuring dataset arguments.

  • CrossValidateConfig: A Pydantic BaseModel class for configuring cross-validation arguments.

  • TextClassifier: The main class for text classification tasks.

The TextClassifier class includes the following methods:
  • load_dataset: Loads a dataset using the configuration specified in dataset_config.

  • preprocess_dataset: Preprocesses the dataset by tokenizing the text and converting labels.

  • split_dataset: Splits the dataset into train, test, and optionally dev sets based on the configuration specified in dataset_config.

  • compute_metrics: Computes evaluation metrics during training.

  • train: Trains the model on the provided dataset using the specified training configuration.

  • predict: Makes predictions on a new dataset using the trained model.

  • save_model: Saves the trained model to a specified directory.

  • load_model: Loads a trained model from a specified directory.

  • plot_confusion_matrix: Plots the confusion matrix for a given dataset.

  • cross_validate_and_predict: Performs cross-validation and prediction using the trained model.

  • find_potential_label_errors: Finds potential label errors using cleanlab’s find_label_issues function.

The TextClassifier class takes the following arguments during initialization:
  • model_name: The name of the transformer model to use.

  • num_labels: The number of labels in the classification task.

  • dataset_config: An instance of the DatasetConfig class specifying the dataset configuration.

  • training_config: An instance of the TrainingConfig class specifying the training configuration.

  • cross_validate_config: An instance of the CrossValidateConfig class specifying the cross-validation configuration.

Example usage:

# Create dataset, training, and cross-validation configurations dataset_config = DatasetConfig(

dataset_name=”imdb”, text_column_name=”text”, label_column_name=”label”, num_labels=2,

)

training_config = TrainingConfig(

output_dir=”output”, num_train_epochs=3, per_device_train_batch_size=16, per_device_eval_batch_size=64,

)

cross_validate_config = CrossValidateConfig(

n_splits=5, validation_size=0.1, random_state=42, shuffle=True,

)

# Create a TextClassifier instance classifier = TextClassifier(

model_name=”bert-base-uncased”, dataset_config=dataset_config, training_config=training_config, cross_validate_config=cross_validate_config,

)

# Load the dataset dataset = classifier.load_dataset()

# Train the model classifier.train(dataset)

# Make predictions on a new dataset new_dataset = load_dataset(“imdb”, split=”test”) predictions = classifier.predict(new_dataset)

# Perform cross-validation and find potential label errors predictions = classifier.cross_validate_and_predict(dataset) label_issues_info = classifier.find_potential_label_errors(predictions, dataset[“label”])

Classes

CrossValidateConfig(*[, n_splits, ...])

Configuration for cross-validation arguments.

DatasetConfig(*, dataset_name[, ...])

Configuration for dataset arguments.

TextClassifier(*, model_name, ...)

Text classifier based on transformer models from Hugging Face.

TrainingConfig(*, output_dir, num_train_epochs)

Configuration for training arguments.