118 lines
4.2 KiB
Python
118 lines
4.2 KiB
Python
"""A wrapper around gemini, customized for our config"""
|
|
|
|
from google.api_core.exceptions import ResourceExhausted
|
|
from google.generativeai import (
|
|
configure as genai_config, GenerationConfig, GenerativeModel)
|
|
from google.generativeai.types.safety_types import HarmCategory
|
|
from logging import getLogger
|
|
from time import sleep
|
|
from typedefs import Llm, ChatResult
|
|
from typing import Optional
|
|
|
|
|
|
logger = getLogger()
|
|
|
|
|
|
class Gemini(Llm):
|
|
"""Represents an interaction with Gemini"""
|
|
|
|
__GENAI_INITIALIZED = False
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.__model = None
|
|
|
|
def __create_payload(self,
|
|
message: str,
|
|
system_message: Optional[str] = None)\
|
|
-> list[dict[str, str]]:
|
|
"""Creates a list of messages we can send to an LLM."""
|
|
assert message, 'Message must be non-empty'
|
|
messages = []
|
|
if system_message:
|
|
messages.append(
|
|
# TODO: in other LLMs, the role would be 'system'.
|
|
# What's the best fit in Gemini?
|
|
{'role': 'user', 'parts': [{'text': message}]})
|
|
messages.append({'role': 'user', 'parts': [{'text': message}]})
|
|
logger.debug('Message object to be sent:\n%s' % messages)
|
|
return messages
|
|
|
|
def __get_model(self) -> GenerativeModel:
|
|
if self.__model:
|
|
return self.__model
|
|
if not Gemini.__GENAI_INITIALIZED:
|
|
logger.debug('Calling genai_config()')
|
|
genai_config()
|
|
Gemini.__GENAI_INITIALIZED = True
|
|
gen_config = GenerationConfig(temperature=1.0)
|
|
self.__model = GenerativeModel('gemini-1.5-flash-latest',
|
|
generation_config=gen_config,
|
|
)
|
|
return self.__model
|
|
|
|
def __count_words(self, payload: list[dict[str, any]]) -> int:
|
|
count = 0
|
|
for i in payload:
|
|
if 'parts' not in i:
|
|
continue
|
|
for p in i['parts']:
|
|
if 'text' not in p:
|
|
continue
|
|
count += len(p['text'].split())
|
|
return count
|
|
|
|
def chat(
|
|
self,
|
|
message: str,
|
|
system_message: Optional[str])\
|
|
-> ChatResult:
|
|
payload = self.__create_payload(message, system_message)
|
|
# see https://www.googlecloudcommunity.com/\
|
|
# gc/AI-ML/Gemini-Pro-Quota-Exceeded/m-p/693185
|
|
sleep_count = 0
|
|
sleep_time = 2
|
|
words_sent = 0
|
|
words_received = 0
|
|
payload_wordcount = self.__count_words(payload)
|
|
while True:
|
|
try:
|
|
words_sent += payload_wordcount
|
|
response = self.__get_model().generate_content(payload)
|
|
if response and response.text:
|
|
words_received += len(response.text.split())
|
|
except ValueError as ve:
|
|
logger.warn("ValueError occurred, skipping topic: %s" % ve)
|
|
return ChatResult(
|
|
success=False,
|
|
text_response=None,
|
|
words_sent=words_sent,
|
|
words_received=words_received
|
|
)
|
|
except ResourceExhausted as re:
|
|
logger.warn(
|
|
'ResourceExhausted exception occurred '
|
|
'while talking to gemini: %s' % re)
|
|
sleep_count += 1
|
|
if sleep_count > 5:
|
|
logger.warn(
|
|
'ResourceExhausted exception occurred '
|
|
'5 times in a row. Exiting.')
|
|
return ChatResult(
|
|
success=False,
|
|
text_response=None,
|
|
words_sent=words_sent,
|
|
words_received=words_received
|
|
)
|
|
logger.info(
|
|
'Too many requests, backing off for %s seconds'
|
|
% sleep_time)
|
|
sleep(sleep_time)
|
|
sleep_time *= 2
|
|
else:
|
|
return ChatResult(
|
|
success=True,
|
|
text_response=response.text.rstrip(),
|
|
words_sent=words_sent,
|
|
words_received=words_received
|
|
)
|