From da79e986b6b8461007b5f73ba8fb36b83586cbc8 Mon Sep 17 00:00:00 2001
From: thinhlpg <thinhlpg@gmail.com>
Date: Tue, 1 Apr 2025 04:51:24 +0700
Subject: [PATCH] feat: add new script and functionality in train script to
 save model in 16 bit format

---
 scripts/save_merged_16bit.py | 73 ++++++++++++++++++++++++++++++++++++
 train_grpo.py                | 10 +++++
 2 files changed, 83 insertions(+)
 create mode 100644 scripts/save_merged_16bit.py

diff --git a/scripts/save_merged_16bit.py b/scripts/save_merged_16bit.py
new file mode 100644
index 0000000..0449809
--- /dev/null
+++ b/scripts/save_merged_16bit.py
@@ -0,0 +1,73 @@
+"""
+Simple script to load unsloth checkpoint and save to FP16 format.
+"""
+
+import os
+
+from unsloth import FastLanguageModel
+
+
+def load_model(
+    model_name: str,
+    max_seq_length: int = 8192,
+    load_in_4bit: bool = True,
+    fast_inference: bool = True,
+    max_lora_rank: int = 64,
+    gpu_memory_utilization: float = 0.6,
+):
+    """Load model and tokenizer with unsloth."""
+    model, tokenizer = FastLanguageModel.from_pretrained(
+        model_name=model_name,
+        max_seq_length=max_seq_length,
+        load_in_4bit=load_in_4bit,
+        fast_inference=fast_inference,
+        max_lora_rank=max_lora_rank,
+        gpu_memory_utilization=gpu_memory_utilization,
+    )
+    return model, tokenizer
+
+
+def save_to_fp16(checkpoint_dir: str, output_dir: str | None = None):
+    """
+    Load unsloth checkpoint and save to FP16 format.
+
+    Args:
+        checkpoint_dir: Directory containing the checkpoint
+        output_dir: Directory to save the FP16 model (default: model_merged_16bit in parent of checkpoint_dir)
+    """
+    if output_dir is None:
+        # Get parent directory of checkpoint and create model_merged_16bit there
+        parent_dir = os.path.dirname(checkpoint_dir)
+        output_dir = os.path.join(parent_dir, "model_merged_16bit")
+
+    # Load model and tokenizer
+    print(f"Loading model from {checkpoint_dir}")
+    model, tokenizer = load_model(checkpoint_dir)
+
+    # Save to FP16
+    print(f"Saving model to FP16 in {output_dir}")
+    model.save_pretrained_merged(
+        output_dir,
+        tokenizer,
+        save_method="merged_16bit",
+    )
+    print("Done!")
+
+
+if __name__ == "__main__":
+    import argparse
+
+    parser = argparse.ArgumentParser(description="Save unsloth checkpoint to FP16")
+    parser.add_argument(
+        "checkpoint_dir",
+        nargs="?",
+        default="trainer_output_example/checkpoint-101",
+        help="Directory containing the checkpoint (default: trainer_output_example/checkpoint-101)",
+    )
+    parser.add_argument(
+        "--output_dir",
+        help="Directory to save the FP16 model (default: model_merged_16bit in parent of checkpoint_dir)",
+    )
+    args = parser.parse_args()
+
+    save_to_fp16(args.checkpoint_dir, args.output_dir)
diff --git a/train_grpo.py b/train_grpo.py
index bd9ee63..e29fdae 100644
--- a/train_grpo.py
+++ b/train_grpo.py
@@ -120,3 +120,13 @@ if __name__ == "__main__":
     trainer.train()
     logger.info("Training completed")
     logger.info(f"Model saved to {OUTPUT_DIR}")
+
+    # Save model to FP16 format
+    logger.info("Saving model to FP16 format")
+    model_merged_dir = os.path.join(OUTPUT_DIR, "model_merged_16bit")
+    model.save_pretrained_merged(
+        model_merged_dir,
+        tokenizer,
+        save_method="merged_16bit",
+    )
+    logger.info(f"FP16 model saved to {model_merged_dir}")