Source code for concepts.language.gpt_parsing.caption_sng
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : caption_sng.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 04/23/2023
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
from openai import OpenAI
client = OpenAI()
from dataclasses import dataclass
from typing import Any, Tuple, Dict
from concepts.language.gpt_parsing.utils import TagNotUniqueError, load_prompt, extract_tag
[docs]
@dataclass
class CaptionSceneGraph(object):
objects: Tuple[str]
relations: Tuple[Tuple[str, str, str]]
[docs]
class CaptionSNGParser(object):
[docs]
def __init__(self, max_tokens: int = 1024):
self.prompt = load_prompt('gpt-35-turbo-chat-caption')
self.max_tokens = max_tokens
[docs]
def parse(self, sentence: str) -> Dict[str, Any]:
response = client.chat.completions.create(model='gpt-3.5-turbo',
messages=[
self.prompt[0], # the system prmopt
{'role': 'user', 'content': sentence}
],
max_tokens=self.max_tokens)
parsing = None
exception = None
try:
parsing = self.extract(response.choices[0].message.content)
except TagNotUniqueError as e:
exception = e
return {
'sentence': sentence,
'raw_response': response,
'parsing': parsing,
'exception': exception,
}
[docs]
def parse_batch(self, sentences):
# TODO(Jiayuan Mao @ 2023/04/23): support batchified parsing.
raise NotImplementedError()
default_caption_sng_parser = CaptionSNGParser()
[docs]
def parse_caption(sentence: str) -> Dict[str, Any]:
return default_caption_sng_parser.parse(sentence)
if __name__ == '__main__':
print(parse_caption('A little girl and a woman are having their picture taken in front of a desert.'))