You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
52 lines
1.8 KiB
52 lines
1.8 KiB
2 years ago
|
import unittest
|
||
|
import os
|
||
2 years ago
|
from unittest.mock import patch
|
||
|
from langchain import HuggingFaceHub, ChatOpenAI
|
||
2 years ago
|
|
||
1 year ago
|
from swarms.models.llm import LLM
|
||
2 years ago
|
|
||
1 year ago
|
|
||
2 years ago
|
class TestLLM(unittest.TestCase):
|
||
1 year ago
|
@patch.object(HuggingFaceHub, "__init__", return_value=None)
|
||
|
@patch.object(ChatOpenAI, "__init__", return_value=None)
|
||
2 years ago
|
def setUp(self, mock_hf_init, mock_openai_init):
|
||
1 year ago
|
self.llm_openai = LLM(openai_api_key="mock_openai_key")
|
||
1 year ago
|
self.llm_hf = LLM(
|
||
|
hf_repo_id="mock_repo_id", hf_api_token="mock_hf_token"
|
||
|
)
|
||
2 years ago
|
self.prompt = "Who won the FIFA World Cup in 1998?"
|
||
|
|
||
|
def test_init(self):
|
||
1 year ago
|
self.assertEqual(self.llm_openai.openai_api_key, "mock_openai_key")
|
||
|
self.assertEqual(self.llm_hf.hf_repo_id, "mock_repo_id")
|
||
|
self.assertEqual(self.llm_hf.hf_api_token, "mock_hf_token")
|
||
2 years ago
|
|
||
1 year ago
|
@patch.object(HuggingFaceHub, "run", return_value="France")
|
||
|
@patch.object(ChatOpenAI, "run", return_value="France")
|
||
2 years ago
|
def test_run(self, mock_hf_run, mock_openai_run):
|
||
|
result_openai = self.llm_openai.run(self.prompt)
|
||
|
mock_openai_run.assert_called_once()
|
||
|
self.assertEqual(result_openai, "France")
|
||
|
|
||
|
result_hf = self.llm_hf.run(self.prompt)
|
||
|
mock_hf_run.assert_called_once()
|
||
|
self.assertEqual(result_hf, "France")
|
||
|
|
||
|
def test_error_on_no_keys(self):
|
||
|
with self.assertRaises(ValueError):
|
||
|
LLM()
|
||
|
|
||
1 year ago
|
@patch.object(os, "environ", {})
|
||
2 years ago
|
def test_error_on_missing_hf_token(self):
|
||
|
with self.assertRaises(ValueError):
|
||
1 year ago
|
LLM(hf_repo_id="mock_repo_id")
|
||
2 years ago
|
|
||
|
@patch.dict(os.environ, {"HUGGINGFACEHUB_API_TOKEN": "mock_hf_token"})
|
||
|
def test_hf_token_from_env(self):
|
||
1 year ago
|
llm = LLM(hf_repo_id="mock_repo_id")
|
||
2 years ago
|
self.assertEqual(llm.hf_api_token, "mock_hf_token")
|
||
|
|
||
|
|
||
1 year ago
|
if __name__ == "__main__":
|
||
2 years ago
|
unittest.main()
|