feat: add new script and functionality in train script to save model in 16 bit format

main
thinhlpg 1 month ago
parent f6b6cca2ce
commit da79e986b6

@ -0,0 +1,73 @@
"""
Simple script to load unsloth checkpoint and save to FP16 format.
"""
import os
from unsloth import FastLanguageModel
def load_model(
model_name: str,
max_seq_length: int = 8192,
load_in_4bit: bool = True,
fast_inference: bool = True,
max_lora_rank: int = 64,
gpu_memory_utilization: float = 0.6,
):
"""Load model and tokenizer with unsloth."""
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
load_in_4bit=load_in_4bit,
fast_inference=fast_inference,
max_lora_rank=max_lora_rank,
gpu_memory_utilization=gpu_memory_utilization,
)
return model, tokenizer
def save_to_fp16(checkpoint_dir: str, output_dir: str | None = None):
"""
Load unsloth checkpoint and save to FP16 format.
Args:
checkpoint_dir: Directory containing the checkpoint
output_dir: Directory to save the FP16 model (default: model_merged_16bit in parent of checkpoint_dir)
"""
if output_dir is None:
# Get parent directory of checkpoint and create model_merged_16bit there
parent_dir = os.path.dirname(checkpoint_dir)
output_dir = os.path.join(parent_dir, "model_merged_16bit")
# Load model and tokenizer
print(f"Loading model from {checkpoint_dir}")
model, tokenizer = load_model(checkpoint_dir)
# Save to FP16
print(f"Saving model to FP16 in {output_dir}")
model.save_pretrained_merged(
output_dir,
tokenizer,
save_method="merged_16bit",
)
print("Done!")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Save unsloth checkpoint to FP16")
parser.add_argument(
"checkpoint_dir",
nargs="?",
default="trainer_output_example/checkpoint-101",
help="Directory containing the checkpoint (default: trainer_output_example/checkpoint-101)",
)
parser.add_argument(
"--output_dir",
help="Directory to save the FP16 model (default: model_merged_16bit in parent of checkpoint_dir)",
)
args = parser.parse_args()
save_to_fp16(args.checkpoint_dir, args.output_dir)

@ -120,3 +120,13 @@ if __name__ == "__main__":
trainer.train()
logger.info("Training completed")
logger.info(f"Model saved to {OUTPUT_DIR}")
# Save model to FP16 format
logger.info("Saving model to FP16 format")
model_merged_dir = os.path.join(OUTPUT_DIR, "model_merged_16bit")
model.save_pretrained_merged(
model_merged_dir,
tokenizer,
save_method="merged_16bit",
)
logger.info(f"FP16 model saved to {model_merged_dir}")

Loading…
Cancel
Save