From 98fb684d18c17f5849b512ad525a4307c565d534 Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 8 Jul 2023 15:57:43 -0400 Subject: [PATCH] unit test for llmn class --- tests/LLM.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tests/LLM.py diff --git a/tests/LLM.py b/tests/LLM.py new file mode 100644 index 00000000..578723ce --- /dev/null +++ b/tests/LLM.py @@ -0,0 +1,48 @@ +import unittest +import os +from unittest.mock import patch, MagicMock +from langchain import PromptTemplate, HuggingFaceHub, ChatOpenAI, LLMChain + +from your_module import LLM # import the module where you defined the LLM class + +class TestLLM(unittest.TestCase): + @patch.object(HuggingFaceHub, '__init__', return_value=None) + @patch.object(ChatOpenAI, '__init__', return_value=None) + def setUp(self, mock_hf_init, mock_openai_init): + self.llm_openai = LLM(openai_api_key='mock_openai_key') + self.llm_hf = LLM(hf_repo_id='mock_repo_id', hf_api_token='mock_hf_token') + self.prompt = "Who won the FIFA World Cup in 1998?" + + def test_init(self): + self.assertEqual(self.llm_openai.openai_api_key, 'mock_openai_key') + self.assertEqual(self.llm_hf.hf_repo_id, 'mock_repo_id') + self.assertEqual(self.llm_hf.hf_api_token, 'mock_hf_token') + + @patch.object(HuggingFaceHub, 'run', return_value="France") + @patch.object(ChatOpenAI, 'run', return_value="France") + def test_run(self, mock_hf_run, mock_openai_run): + result_openai = self.llm_openai.run(self.prompt) + mock_openai_run.assert_called_once() + self.assertEqual(result_openai, "France") + + result_hf = self.llm_hf.run(self.prompt) + mock_hf_run.assert_called_once() + self.assertEqual(result_hf, "France") + + def test_error_on_no_keys(self): + with self.assertRaises(ValueError): + LLM() + + @patch.object(os, 'environ', {}) + def test_error_on_missing_hf_token(self): + with self.assertRaises(ValueError): + LLM(hf_repo_id='mock_repo_id') + + @patch.dict(os.environ, {"HUGGINGFACEHUB_API_TOKEN": "mock_hf_token"}) + def test_hf_token_from_env(self): + llm = LLM(hf_repo_id='mock_repo_id') + self.assertEqual(llm.hf_api_token, "mock_hf_token") + + +if __name__ == '__main__': + unittest.main()