From da79e986b6b8461007b5f73ba8fb36b83586cbc8 Mon Sep 17 00:00:00 2001
From: thinhlpg <thinhlpg@gmail.com>
Date: Tue, 1 Apr 2025 04:51:24 +0700
Subject: [PATCH] feat: add new script and functionality in train script to
save model in 16 bit format
---
scripts/save_merged_16bit.py | 73 ++++++++++++++++++++++++++++++++++++
train_grpo.py | 10 +++++
2 files changed, 83 insertions(+)
create mode 100644 scripts/save_merged_16bit.py
diff --git a/scripts/save_merged_16bit.py b/scripts/save_merged_16bit.py
new file mode 100644
index 0000000..0449809
--- /dev/null
+++ b/scripts/save_merged_16bit.py
@@ -0,0 +1,73 @@
+"""
+Simple script to load unsloth checkpoint and save to FP16 format.
+"""
+
+import os
+
+from unsloth import FastLanguageModel
+
+
+def load_model(
+ model_name: str,
+ max_seq_length: int = 8192,
+ load_in_4bit: bool = True,
+ fast_inference: bool = True,
+ max_lora_rank: int = 64,
+ gpu_memory_utilization: float = 0.6,
+):
+ """Load model and tokenizer with unsloth."""
+ model, tokenizer = FastLanguageModel.from_pretrained(
+ model_name=model_name,
+ max_seq_length=max_seq_length,
+ load_in_4bit=load_in_4bit,
+ fast_inference=fast_inference,
+ max_lora_rank=max_lora_rank,
+ gpu_memory_utilization=gpu_memory_utilization,
+ )
+ return model, tokenizer
+
+
+def save_to_fp16(checkpoint_dir: str, output_dir: str | None = None):
+ """
+ Load unsloth checkpoint and save to FP16 format.
+
+ Args:
+ checkpoint_dir: Directory containing the checkpoint
+ output_dir: Directory to save the FP16 model (default: model_merged_16bit in parent of checkpoint_dir)
+ """
+ if output_dir is None:
+ # Get parent directory of checkpoint and create model_merged_16bit there
+ parent_dir = os.path.dirname(checkpoint_dir)
+ output_dir = os.path.join(parent_dir, "model_merged_16bit")
+
+ # Load model and tokenizer
+ print(f"Loading model from {checkpoint_dir}")
+ model, tokenizer = load_model(checkpoint_dir)
+
+ # Save to FP16
+ print(f"Saving model to FP16 in {output_dir}")
+ model.save_pretrained_merged(
+ output_dir,
+ tokenizer,
+ save_method="merged_16bit",
+ )
+ print("Done!")
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(description="Save unsloth checkpoint to FP16")
+ parser.add_argument(
+ "checkpoint_dir",
+ nargs="?",
+ default="trainer_output_example/checkpoint-101",
+ help="Directory containing the checkpoint (default: trainer_output_example/checkpoint-101)",
+ )
+ parser.add_argument(
+ "--output_dir",
+ help="Directory to save the FP16 model (default: model_merged_16bit in parent of checkpoint_dir)",
+ )
+ args = parser.parse_args()
+
+ save_to_fp16(args.checkpoint_dir, args.output_dir)
diff --git a/train_grpo.py b/train_grpo.py
index bd9ee63..e29fdae 100644
--- a/train_grpo.py
+++ b/train_grpo.py
@@ -120,3 +120,13 @@ if __name__ == "__main__":
trainer.train()
logger.info("Training completed")
logger.info(f"Model saved to {OUTPUT_DIR}")
+
+ # Save model to FP16 format
+ logger.info("Saving model to FP16 format")
+ model_merged_dir = os.path.join(OUTPUT_DIR, "model_merged_16bit")
+ model.save_pretrained_merged(
+ model_merged_dir,
+ tokenizer,
+ save_method="merged_16bit",
+ )
+ logger.info(f"FP16 model saved to {model_merged_dir}")