"""A wrapper around gemini, customized for our config""" from google.api_core.exceptions import ResourceExhausted from google.generativeai import ( configure as genai_config, GenerationConfig, GenerativeModel) from google.generativeai.types.safety_types import HarmCategory from logging import getLogger from time import sleep from typedefs import Llm, ChatResult from typing import Optional logger = getLogger() class Gemini(Llm): """Represents an interaction with Gemini""" __GENAI_INITIALIZED = False def __init__(self) -> None: super().__init__() self.__model = None def __create_payload(self, message: str, system_message: Optional[str] = None)\ -> list[dict[str, str]]: """Creates a list of messages we can send to an LLM.""" assert message, 'Message must be non-empty' messages = [] if system_message: messages.append( # TODO: in other LLMs, the role would be 'system'. # What's the best fit in Gemini? {'role': 'user', 'parts': [{'text': message}]}) messages.append({'role': 'user', 'parts': [{'text': message}]}) logger.debug('Message object to be sent:\n%s' % messages) return messages def __get_model(self) -> GenerativeModel: if self.__model: return self.__model if not Gemini.__GENAI_INITIALIZED: logger.debug('Calling genai_config()') genai_config() Gemini.__GENAI_INITIALIZED = True gen_config = GenerationConfig(temperature=1.0) self.__model = GenerativeModel('gemini-1.5-flash-latest', generation_config=gen_config, ) return self.__model def __count_words(self, payload: list[dict[str, any]]) -> int: count = 0 for i in payload: if 'parts' not in i: continue for p in i['parts']: if 'text' not in p: continue count += len(p['text'].split()) return count def chat( self, message: str, system_message: Optional[str])\ -> ChatResult: payload = self.__create_payload(message, system_message) # see https://www.googlecloudcommunity.com/\ # gc/AI-ML/Gemini-Pro-Quota-Exceeded/m-p/693185 sleep_count = 0 sleep_time = 2 words_sent = 0 words_received = 0 payload_wordcount = self.__count_words(payload) while True: try: words_sent += payload_wordcount response = self.__get_model().generate_content(payload) if response and response.text: words_received += len(response.text.split()) except ValueError as ve: logger.warn("ValueError occurred, skipping topic: %s" % ve) return ChatResult( success=False, text_response=None, words_sent=words_sent, words_received=words_received ) except ResourceExhausted as re: logger.warn( 'ResourceExhausted exception occurred ' 'while talking to gemini: %s' % re) sleep_count += 1 if sleep_count > 5: logger.warn( 'ResourceExhausted exception occurred ' '5 times in a row. Exiting.') return ChatResult( success=False, text_response=None, words_sent=words_sent, words_received=words_received ) logger.info( 'Too many requests, backing off for %s seconds' % sleep_time) sleep(sleep_time) sleep_time *= 2 else: return ChatResult( success=True, text_response=response.text.rstrip(), words_sent=words_sent, words_received=words_received )