parent
5c06f7a9e7
commit
55449c8934
@ -0,0 +1,157 @@
|
|||||||
|
# `BioGPT` Documentation
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
1. [Introduction](#introduction)
|
||||||
|
2. [Overview](#overview)
|
||||||
|
3. [Installation](#installation)
|
||||||
|
4. [Usage](#usage)
|
||||||
|
1. [BioGPT Class](#biogpt-class)
|
||||||
|
2. [Examples](#examples)
|
||||||
|
5. [Additional Information](#additional-information)
|
||||||
|
6. [Conclusion](#conclusion)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Introduction <a name="introduction"></a>
|
||||||
|
|
||||||
|
The `BioGPT` module is a domain-specific generative language model designed for the biomedical domain. It is built upon the powerful Transformer architecture and pretrained on a large corpus of biomedical literature. This documentation provides an extensive guide on using the `BioGPT` module, explaining its purpose, parameters, and usage.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Overview <a name="overview"></a>
|
||||||
|
|
||||||
|
The `BioGPT` module addresses the need for a language model specialized in the biomedical domain. Unlike general-purpose language models, `BioGPT` excels in generating coherent and contextually relevant text specific to biomedical terms and concepts. It has been evaluated on various biomedical natural language processing tasks and has demonstrated superior performance.
|
||||||
|
|
||||||
|
Key features and parameters of the `BioGPT` module include:
|
||||||
|
- `model_name`: Name of the pretrained model.
|
||||||
|
- `max_length`: Maximum length of generated text.
|
||||||
|
- `num_return_sequences`: Number of sequences to return.
|
||||||
|
- `do_sample`: Whether to use sampling in generation.
|
||||||
|
- `min_length`: Minimum length of generated text.
|
||||||
|
|
||||||
|
The `BioGPT` module is equipped with features for generating text, extracting features, and more.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Installation <a name="installation"></a>
|
||||||
|
|
||||||
|
Before using the `BioGPT` module, ensure you have the required dependencies installed, including the Transformers library and Torch. You can install these dependencies using pip:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install transformers
|
||||||
|
pip install torch
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. Usage <a name="usage"></a>
|
||||||
|
|
||||||
|
In this section, we'll cover how to use the `BioGPT` module effectively. It consists of the `BioGPT` class and provides examples to demonstrate its usage.
|
||||||
|
|
||||||
|
### 4.1. `BioGPT` Class <a name="biogpt-class"></a>
|
||||||
|
|
||||||
|
The `BioGPT` class is the core component of the `BioGPT` module. It is used to create a `BioGPT` instance, which can generate text, extract features, and more.
|
||||||
|
|
||||||
|
#### Parameters:
|
||||||
|
- `model_name` (str): Name of the pretrained model.
|
||||||
|
- `max_length` (int): Maximum length of generated text.
|
||||||
|
- `num_return_sequences` (int): Number of sequences to return.
|
||||||
|
- `do_sample` (bool): Whether or not to use sampling in generation.
|
||||||
|
- `min_length` (int): Minimum length of generated text.
|
||||||
|
|
||||||
|
### 4.2. Examples <a name="examples"></a>
|
||||||
|
|
||||||
|
Let's explore how to use the `BioGPT` class with different scenarios and applications.
|
||||||
|
|
||||||
|
#### Example 1: Generating Biomedical Text
|
||||||
|
|
||||||
|
```python
|
||||||
|
from biogpt import BioGPT
|
||||||
|
|
||||||
|
# Initialize the BioGPT model
|
||||||
|
biogpt = BioGPT()
|
||||||
|
|
||||||
|
# Generate biomedical text
|
||||||
|
input_text = "The patient has a fever"
|
||||||
|
generated_text = biogpt(input_text)
|
||||||
|
|
||||||
|
print(generated_text)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Example 2: Extracting Features
|
||||||
|
|
||||||
|
```python
|
||||||
|
from biogpt import BioGPT
|
||||||
|
|
||||||
|
# Initialize the BioGPT model
|
||||||
|
biogpt = BioGPT()
|
||||||
|
|
||||||
|
# Extract features from a biomedical text
|
||||||
|
input_text = "The patient has a fever"
|
||||||
|
features = biogpt.get_features(input_text)
|
||||||
|
|
||||||
|
print(features)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Example 3: Using Beam Search Decoding
|
||||||
|
|
||||||
|
```python
|
||||||
|
from biogpt import BioGPT
|
||||||
|
|
||||||
|
# Initialize the BioGPT model
|
||||||
|
biogpt = BioGPT()
|
||||||
|
|
||||||
|
# Generate biomedical text using beam search decoding
|
||||||
|
input_text = "The patient has a fever"
|
||||||
|
generated_text = biogpt.beam_search_decoding(input_text)
|
||||||
|
|
||||||
|
print(generated_text)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.3. Additional Features
|
||||||
|
|
||||||
|
The `BioGPT` class also provides additional features:
|
||||||
|
|
||||||
|
#### Set a New Pretrained Model
|
||||||
|
```python
|
||||||
|
biogpt.set_pretrained_model("new_pretrained_model")
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Get the Model's Configuration
|
||||||
|
```python
|
||||||
|
config = biogpt.get_config()
|
||||||
|
print(config)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Save and Load the Model
|
||||||
|
```python
|
||||||
|
# Save the model and tokenizer to a directory
|
||||||
|
biogpt.save_model("saved_model")
|
||||||
|
|
||||||
|
# Load a model and tokenizer from a directory
|
||||||
|
biogpt.load_from_path("saved_model")
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Print the Model's Architecture
|
||||||
|
```python
|
||||||
|
biogpt.print_model()
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Additional Information <a name="additional-information"></a>
|
||||||
|
|
||||||
|
- **Biomedical Text Generation**: The `BioGPT` module is designed specifically for generating biomedical text, making it a valuable tool for various biomedical natural language processing tasks.
|
||||||
|
- **Feature Extraction**: It also provides the capability to extract features from biomedical text.
|
||||||
|
- **Beam Search Decoding**: Beam search decoding is available for generating text with improved quality.
|
||||||
|
- **Customization**: You can set a new pretrained model and save/load models for customization.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Conclusion <a name="conclusion"></a>
|
||||||
|
|
||||||
|
The `BioGPT` module is a powerful and specialized tool for generating and working with biomedical text. This documentation has provided a comprehensive guide on its usage, parameters, and examples, enabling you to effectively leverage it for various biomedical natural language processing tasks.
|
||||||
|
|
||||||
|
By using `BioGPT`, you can enhance your biomedical text generation and analysis tasks with contextually relevant and coherent text.
|
||||||
|
|
||||||
|
*Please check the official `BioGPT` repository and documentation for any updates beyond the knowledge cutoff date.*
|
@ -0,0 +1,208 @@
|
|||||||
|
"""
|
||||||
|
BioGPT
|
||||||
|
Pre-trained language models have attracted increasing attention in the biomedical domain,
|
||||||
|
inspired by their great success in the general natural language domain.
|
||||||
|
Among the two main branches of pre-trained language models in the general language domain, i.e. BERT (and its variants) and GPT (and its variants),
|
||||||
|
the first one has been extensively studied in the biomedical domain, such as BioBERT and PubMedBERT.
|
||||||
|
While they have achieved great success on a variety of discriminative downstream biomedical tasks,
|
||||||
|
the lack of generation ability constrains their application scope.
|
||||||
|
In this paper, we propose BioGPT, a domain-specific generative Transformer language model
|
||||||
|
pre-trained on large-scale biomedical literature.
|
||||||
|
We evaluate BioGPT on six biomedical natural language processing tasks
|
||||||
|
and demonstrate that our model outperforms previous models on most tasks.
|
||||||
|
Especially, we get 44.98%, 38.42% and 40.76% F1 score on BC5CDR, KD-DTI and DDI
|
||||||
|
end-to-end relation extraction tasks, respectively, and 78.2% accuracy on PubMedQA,
|
||||||
|
creating a new record. Our case study on text generation further demonstrates the
|
||||||
|
advantage of BioGPT on biomedical literature to generate fluent descriptions for biomedical terms.
|
||||||
|
|
||||||
|
|
||||||
|
@article{10.1093/bib/bbac409,
|
||||||
|
author = {Luo, Renqian and Sun, Liai and Xia, Yingce and Qin, Tao and Zhang, Sheng and Poon, Hoifung and Liu, Tie-Yan},
|
||||||
|
title = "{BioGPT: generative pre-trained transformer for biomedical text generation and mining}",
|
||||||
|
journal = {Briefings in Bioinformatics},
|
||||||
|
volume = {23},
|
||||||
|
number = {6},
|
||||||
|
year = {2022},
|
||||||
|
month = {09},
|
||||||
|
abstract = "{Pre-trained language models have attracted increasing attention in the biomedical domain, inspired by their great success in the general natural language domain. Among the two main branches of pre-trained language models in the general language domain, i.e. BERT (and its variants) and GPT (and its variants), the first one has been extensively studied in the biomedical domain, such as BioBERT and PubMedBERT. While they have achieved great success on a variety of discriminative downstream biomedical tasks, the lack of generation ability constrains their application scope. In this paper, we propose BioGPT, a domain-specific generative Transformer language model pre-trained on large-scale biomedical literature. We evaluate BioGPT on six biomedical natural language processing tasks and demonstrate that our model outperforms previous models on most tasks. Especially, we get 44.98\%, 38.42\% and 40.76\% F1 score on BC5CDR, KD-DTI and DDI end-to-end relation extraction tasks, respectively, and 78.2\% accuracy on PubMedQA, creating a new record. Our case study on text generation further demonstrates the advantage of BioGPT on biomedical literature to generate fluent descriptions for biomedical terms.}",
|
||||||
|
issn = {1477-4054},
|
||||||
|
doi = {10.1093/bib/bbac409},
|
||||||
|
url = {https://doi.org/10.1093/bib/bbac409},
|
||||||
|
note = {bbac409},
|
||||||
|
eprint = {https://academic.oup.com/bib/article-pdf/23/6/bbac409/47144271/bbac409.pdf},
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import pipeline, set_seed, BioGptTokenizer, BioGptForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
class BioGPT:
|
||||||
|
"""
|
||||||
|
A wrapper class for the BioGptForCausalLM model from the transformers library.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model_name (str): Name of the pretrained model.
|
||||||
|
model (BioGptForCausalLM): The pretrained BioGptForCausalLM model.
|
||||||
|
tokenizer (BioGptTokenizer): The tokenizer for the BioGptForCausalLM model.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
__call__: Generate text based on the given input.
|
||||||
|
get_features: Get the features of a given text.
|
||||||
|
beam_search_decoding: Generate text using beam search decoding.
|
||||||
|
set_pretrained_model: Set a new tokenizer and model.
|
||||||
|
get_config: Get the model's configuration.
|
||||||
|
save_model: Save the model and tokenizer to a directory.
|
||||||
|
load_from_path: Load a model and tokenizer from a directory.
|
||||||
|
print_model: Print the model's architecture.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
>>> from swarms.models.biogpt import BioGPTWrapper
|
||||||
|
>>> model = BioGPTWrapper()
|
||||||
|
>>> out = model("The patient has a fever")
|
||||||
|
>>> print(out)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "microsoft/biogpt",
|
||||||
|
max_length: int = 500,
|
||||||
|
num_return_sequences: int = 5,
|
||||||
|
do_sample: bool = True,
|
||||||
|
min_length: int = 100,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the wrapper class with a model name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): Name of the pretrained model. Default is "microsoft/biogpt".
|
||||||
|
"""
|
||||||
|
self.model_name = model_name
|
||||||
|
self.max_length = max_length
|
||||||
|
self.num_return_sequences = num_return_sequences
|
||||||
|
self.do_sample = do_sample
|
||||||
|
self.min_length = min_length
|
||||||
|
|
||||||
|
self.model = BioGptForCausalLM.from_pretrained(self.model_name)
|
||||||
|
self.tokenizer = BioGptTokenizer.from_pretrained(self.model_name)
|
||||||
|
|
||||||
|
def __call__(self, text: str):
|
||||||
|
"""
|
||||||
|
Generate text based on the given input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The input text to generate from.
|
||||||
|
max_length (int): Maximum length of the generated text.
|
||||||
|
num_return_sequences (int): Number of sequences to return.
|
||||||
|
do_sample (bool): Whether or not to use sampling in generation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[dict]: A list of generated texts.
|
||||||
|
"""
|
||||||
|
set_seed(42)
|
||||||
|
generator = pipeline(
|
||||||
|
"text-generation", model=self.model, tokenizer=self.tokenizer
|
||||||
|
)
|
||||||
|
return generator(
|
||||||
|
text,
|
||||||
|
max_length=self.max_length,
|
||||||
|
num_return_sequences=self.num_return_sequences,
|
||||||
|
do_sample=self.do_sample,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_features(self, text):
|
||||||
|
"""
|
||||||
|
Get the features of a given text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): Input text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseModelOutputWithPastAndCrossAttentions: Model output.
|
||||||
|
"""
|
||||||
|
encoded_input = self.tokenizer(text, return_tensors="pt")
|
||||||
|
return self.model(**encoded_input)
|
||||||
|
|
||||||
|
def beam_search_decoding(
|
||||||
|
self,
|
||||||
|
sentence,
|
||||||
|
num_beams=5,
|
||||||
|
early_stopping=True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate text using beam search decoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sentence (str): The input sentence to generate from.
|
||||||
|
min_length (int): Minimum length of the generated text.
|
||||||
|
max_length (int): Maximum length of the generated text.
|
||||||
|
num_beams (int): Number of beams for beam search.
|
||||||
|
early_stopping (bool): Whether to stop early during beam search.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated text.
|
||||||
|
"""
|
||||||
|
inputs = self.tokenizer(sentence, return_tensors="pt")
|
||||||
|
set_seed(42)
|
||||||
|
with torch.no_grad():
|
||||||
|
beam_output = self.model.generate(
|
||||||
|
**inputs,
|
||||||
|
min_length=self.min_length,
|
||||||
|
max_length=self.max_length,
|
||||||
|
num_beams=num_beams,
|
||||||
|
early_stopping=early_stopping
|
||||||
|
)
|
||||||
|
return self.tokenizer.decode(beam_output[0], skip_special_tokens=True)
|
||||||
|
|
||||||
|
# Feature 1: Set a new tokenizer and model
|
||||||
|
def set_pretrained_model(self, model_name):
|
||||||
|
"""
|
||||||
|
Set a new tokenizer and model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): Name of the pretrained model.
|
||||||
|
"""
|
||||||
|
self.model_name = model_name
|
||||||
|
self.model = BioGptForCausalLM.from_pretrained(self.model_name)
|
||||||
|
self.tokenizer = BioGptTokenizer.from_pretrained(self.model_name)
|
||||||
|
|
||||||
|
# Feature 2: Get the model's config details
|
||||||
|
def get_config(self):
|
||||||
|
"""
|
||||||
|
Get the model's configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PretrainedConfig: The configuration of the model.
|
||||||
|
"""
|
||||||
|
return self.model.config
|
||||||
|
|
||||||
|
# Feature 3: Save the model and tokenizer to disk
|
||||||
|
def save_model(self, path):
|
||||||
|
"""
|
||||||
|
Save the model and tokenizer to a directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): Path to the directory.
|
||||||
|
"""
|
||||||
|
self.model.save_pretrained(path)
|
||||||
|
self.tokenizer.save_pretrained(path)
|
||||||
|
|
||||||
|
# Feature 4: Load a model from a custom path
|
||||||
|
def load_from_path(self, path):
|
||||||
|
"""
|
||||||
|
Load a model and tokenizer from a directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): Path to the directory.
|
||||||
|
"""
|
||||||
|
self.model = BioGptForCausalLM.from_pretrained(path)
|
||||||
|
self.tokenizer = BioGptTokenizer.from_pretrained(path)
|
||||||
|
|
||||||
|
# Feature 5: Print the model's architecture
|
||||||
|
def print_model(self):
|
||||||
|
"""
|
||||||
|
Print the model's architecture.
|
||||||
|
"""
|
||||||
|
print(self.model)
|
@ -0,0 +1,159 @@
|
|||||||
|
"""
|
||||||
|
SpeechT5 (TTS task)
|
||||||
|
SpeechT5 model fine-tuned for speech synthesis (text-to-speech) on LibriTTS.
|
||||||
|
|
||||||
|
This model was introduced in SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei.
|
||||||
|
|
||||||
|
SpeechT5 was first released in this repository, original weights. The license used is MIT.
|
||||||
|
|
||||||
|
Model Description
|
||||||
|
Motivated by the success of T5 (Text-To-Text Transfer Transformer) in pre-trained natural language processing models, we propose a unified-modal SpeechT5 framework that explores the encoder-decoder pre-training for self-supervised speech/text representation learning. The SpeechT5 framework consists of a shared encoder-decoder network and six modal-specific (speech/text) pre/post-nets. After preprocessing the input speech/text through the pre-nets, the shared encoder-decoder network models the sequence-to-sequence transformation, and then the post-nets generate the output in the speech/text modality based on the output of the decoder.
|
||||||
|
|
||||||
|
Leveraging large-scale unlabeled speech and text data, we pre-train SpeechT5 to learn a unified-modal representation, hoping to improve the modeling capability for both speech and text. To align the textual and speech information into this unified semantic space, we propose a cross-modal vector quantization approach that randomly mixes up speech/text states with latent units as the interface between encoder and decoder.
|
||||||
|
|
||||||
|
Extensive evaluations show the superiority of the proposed SpeechT5 framework on a wide variety of spoken language processing tasks, including automatic speech recognition, speech synthesis, speech translation, voice conversion, speech enhancement, and speaker identification.
|
||||||
|
|
||||||
|
Developed by: Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei.
|
||||||
|
Shared by [optional]: Matthijs Hollemans
|
||||||
|
Model type: text-to-speech
|
||||||
|
Language(s) (NLP): [More Information Needed]
|
||||||
|
License: MIT
|
||||||
|
Finetuned from model [optional]: [More Information Needed]
|
||||||
|
Model Sources [optional]
|
||||||
|
Repository: [https://github.com/microsoft/SpeechT5/]
|
||||||
|
Paper: [https://arxiv.org/pdf/2110.07205.pdf]
|
||||||
|
Blog Post: [https://huggingface.co/blog/speecht5]
|
||||||
|
Demo: [https://huggingface.co/spaces/Matthijs/speecht5-tts-demo]
|
||||||
|
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import soundfile as sf
|
||||||
|
from transformers import (
|
||||||
|
pipeline,
|
||||||
|
SpeechT5Processor,
|
||||||
|
SpeechT5ForTextToSpeech,
|
||||||
|
SpeechT5HifiGan,
|
||||||
|
)
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechT5:
|
||||||
|
"""
|
||||||
|
SpeechT5Wrapper
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str, optional): Model name or path. Defaults to "microsoft/speecht5_tts".
|
||||||
|
vocoder_name (str, optional): Vocoder name or path. Defaults to "microsoft/speecht5_hifigan".
|
||||||
|
dataset_name (str, optional): Dataset name or path. Defaults to "Matthijs/cmu-arctic-xvectors".
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model_name (str): Model name or path.
|
||||||
|
vocoder_name (str): Vocoder name or path.
|
||||||
|
dataset_name (str): Dataset name or path.
|
||||||
|
processor (SpeechT5Processor): Processor for the SpeechT5 model.
|
||||||
|
model (SpeechT5ForTextToSpeech): SpeechT5 model.
|
||||||
|
vocoder (SpeechT5HifiGan): SpeechT5 vocoder.
|
||||||
|
embeddings_dataset (datasets.Dataset): Dataset containing speaker embeddings.
|
||||||
|
|
||||||
|
Methods
|
||||||
|
__call__: Synthesize speech from text.
|
||||||
|
save_speech: Save speech to a file.
|
||||||
|
set_model: Change the model.
|
||||||
|
set_vocoder: Change the vocoder.
|
||||||
|
set_embeddings_dataset: Change the embeddings dataset.
|
||||||
|
get_sampling_rate: Get the sampling rate of the model.
|
||||||
|
print_model_details: Print details of the model.
|
||||||
|
quick_synthesize: Customize pipeline method for quick synthesis.
|
||||||
|
change_dataset_split: Change dataset split (train, validation, test).
|
||||||
|
load_custom_embedding: Load a custom speaker embedding (xvector) for the text.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
>>> speechT5 = SpeechT5Wrapper()
|
||||||
|
>>> result = speechT5("Hello, how are you?")
|
||||||
|
>>> speechT5.save_speech(result)
|
||||||
|
>>> print("Speech saved successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name="microsoft/speecht5_tts",
|
||||||
|
vocoder_name="microsoft/speecht5_hifigan",
|
||||||
|
dataset_name="Matthijs/cmu-arctic-xvectors",
|
||||||
|
):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.vocoder_name = vocoder_name
|
||||||
|
self.dataset_name = dataset_name
|
||||||
|
self.processor = SpeechT5Processor.from_pretrained(self.model_name)
|
||||||
|
self.model = SpeechT5ForTextToSpeech.from_pretrained(self.model_name)
|
||||||
|
self.vocoder = SpeechT5HifiGan.from_pretrained(self.vocoder_name)
|
||||||
|
self.embeddings_dataset = load_dataset(self.dataset_name, split="validation")
|
||||||
|
|
||||||
|
def __call__(self, text: str, speaker_id: float = 7306):
|
||||||
|
"""Call the model on some text and return the speech."""
|
||||||
|
speaker_embedding = torch.tensor(
|
||||||
|
self.embeddings_dataset[speaker_id]["xvector"]
|
||||||
|
).unsqueeze(0)
|
||||||
|
inputs = self.processor(text=text, return_tensors="pt")
|
||||||
|
speech = self.model.generate_speech(
|
||||||
|
inputs["input_ids"], speaker_embedding, vocoder=self.vocoder
|
||||||
|
)
|
||||||
|
return speech
|
||||||
|
|
||||||
|
def save_speech(self, speech, filename="speech.wav"):
|
||||||
|
"""Save Speech to a file."""
|
||||||
|
sf.write(filename, speech.numpy(), samplerate=16000)
|
||||||
|
|
||||||
|
def set_model(self, model_name: str):
|
||||||
|
"""Set the model to a new model."""
|
||||||
|
self.model_name = model_name
|
||||||
|
self.processor = SpeechT5Processor.from_pretrained(self.model_name)
|
||||||
|
self.model = SpeechT5ForTextToSpeech.from_pretrained(self.model_name)
|
||||||
|
|
||||||
|
def set_vocoder(self, vocoder_name):
|
||||||
|
"""Set the vocoder to a new vocoder."""
|
||||||
|
self.vocoder_name = vocoder_name
|
||||||
|
self.vocoder = SpeechT5HifiGan.from_pretrained(self.vocoder_name)
|
||||||
|
|
||||||
|
def set_embeddings_dataset(self, dataset_name):
|
||||||
|
"""Set the embeddings dataset to a new dataset."""
|
||||||
|
self.dataset_name = dataset_name
|
||||||
|
self.embeddings_dataset = load_dataset(self.dataset_name, split="validation")
|
||||||
|
|
||||||
|
# Feature 1: Get sampling rate
|
||||||
|
def get_sampling_rate(self):
|
||||||
|
"""Get sampling rate of the model."""
|
||||||
|
return 16000
|
||||||
|
|
||||||
|
# Feature 2: Print details of the model
|
||||||
|
def print_model_details(self):
|
||||||
|
"""Print details of the model."""
|
||||||
|
print(f"Model Name: {self.model_name}")
|
||||||
|
print(f"Vocoder Name: {self.vocoder_name}")
|
||||||
|
|
||||||
|
# Feature 3: Customize pipeline method for quick synthesis
|
||||||
|
def quick_synthesize(self, text):
|
||||||
|
"""Customize pipeline method for quick synthesis."""
|
||||||
|
synthesiser = pipeline("text-to-speech", self.model_name)
|
||||||
|
speech = synthesiser(text)
|
||||||
|
return speech
|
||||||
|
|
||||||
|
# Feature 4: Change dataset split (train, validation, test)
|
||||||
|
def change_dataset_split(self, split="train"):
|
||||||
|
"""Change dataset split (train, validation, test)."""
|
||||||
|
self.embeddings_dataset = load_dataset(self.dataset_name, split=split)
|
||||||
|
|
||||||
|
# Feature 5: Load a custom speaker embedding (xvector) for the text
|
||||||
|
def load_custom_embedding(self, xvector):
|
||||||
|
"""Load a custom speaker embedding (xvector) for the text."""
|
||||||
|
return torch.tensor(xvector).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# speechT5 = SpeechT5Wrapper()
|
||||||
|
# result = speechT5("Hello, how are you?")
|
||||||
|
# speechT5.save_speech(result)
|
||||||
|
# print("Speech saved successfully!")
|
Loading…
Reference in new issue