base api
This commit is contained in:
parent
0b9d5e4cfa
commit
54679d5e6d
23
README.md
23
README.md
@ -1,3 +1,22 @@
|
|||||||
# gigchat
|
# GigaChat API Lua SDK
|
||||||
|
|
||||||
GigaChat API Lua SDK
|
Документация API: https://developers.sber.ru/docs/ru/gigachat/api/reference/rest/
|
||||||
|
|
||||||
|
## Использование
|
||||||
|
|
||||||
|
Необходимо переименовать gigachat.test_conf в gigachat.conf и заполнить в нем креды
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
local api = require('gigachat.api')
|
||||||
|
|
||||||
|
-- Получение списка моделей
|
||||||
|
|
||||||
|
api.models()
|
||||||
|
|
||||||
|
-- Запрос к чату (строка или массив сообщений)
|
||||||
|
|
||||||
|
api.chat.completions('тест')
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|||||||
52
config/config.lua
Normal file
52
config/config.lua
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
local json = require('cjson.safe')
|
||||||
|
local _M = {}
|
||||||
|
_M.data = {}
|
||||||
|
_M.comments = {}
|
||||||
|
_M.file = '' -- файл конфигурации
|
||||||
|
local key
|
||||||
|
|
||||||
|
function _M.read()
|
||||||
|
for line in io.lines(_M.file) do
|
||||||
|
key = string.match(line, '([%w_]+)::')
|
||||||
|
if (key) then
|
||||||
|
_M.data[key] = string.match(line, '::(.*) #')
|
||||||
|
_M.comments[key] = string.match(line, '#(.*)')
|
||||||
|
if string.find(_M.data[key],'%{%"') or string.find(_M.data[key],'%[%"') then
|
||||||
|
_M.data[key] = json.decode(_M.data[key])
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function reprint(k,v)
|
||||||
|
if type(v) == 'table' then
|
||||||
|
for i,j in pairs(v) do
|
||||||
|
reprint(i,j)
|
||||||
|
end
|
||||||
|
else
|
||||||
|
print(k..': '..v)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function _M.data:write()
|
||||||
|
local config_file = io.open(_M.file, 'w')
|
||||||
|
for k,v in pairs(self) do
|
||||||
|
if type(v) ~= 'function' then
|
||||||
|
if type(v) == 'table' then v = json.encode(v) end
|
||||||
|
config_file:write(k..'::'..v..' #'.._M.comments[key]..'\n')
|
||||||
|
end
|
||||||
|
end
|
||||||
|
config_file:close()
|
||||||
|
end
|
||||||
|
|
||||||
|
function _M.data:print()
|
||||||
|
for k,v in pairs(self) do
|
||||||
|
if type(v) ~= 'function' then
|
||||||
|
if _M.comments[k] then print('\n'.._M.comments[k]:gsub("^%s*(.-)%s*$", "%1")..': \n') end
|
||||||
|
reprint(k,v)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return _M
|
||||||
|
|
||||||
5
config/gigachat.lua
Normal file
5
config/gigachat.lua
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
local config = require('config.config')
|
||||||
|
local _M = config
|
||||||
|
_M.file = 'gigachat.conf' -- файл конфигурации
|
||||||
|
_M.read()
|
||||||
|
return _M.data
|
||||||
8
gigachat.test_conf
Normal file
8
gigachat.test_conf
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
auth_key:: #Ключ авторизации
|
||||||
|
client_id:: #ClientID
|
||||||
|
scope::GIGACHAT_API_PERS #Scope
|
||||||
|
token_url::https://ngw.devices.sberbank.ru:9443/api/v2/oauth #Token endpoint
|
||||||
|
base_url::https://gigachat.devices.sberbank.ru/api/v1/ #Базовый url для информационных ответов
|
||||||
|
logs_path::/var/www/gigachat/logs/ #Путь к логам
|
||||||
|
db_path::/var/www/gigachat #Путь к бд хранения токенов
|
||||||
|
|
||||||
172
gigachat/api.lua
Normal file
172
gigachat/api.lua
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
local json = require('cjson')
|
||||||
|
local cURL = require('cURL')
|
||||||
|
local log = require('utils.log')
|
||||||
|
local config = require('config.gigachat')
|
||||||
|
local uuid = require('utils.uuid')
|
||||||
|
local flatdb = require('utils.flatdb')
|
||||||
|
local db = flatdb(config.db_path)
|
||||||
|
|
||||||
|
if not db.token then
|
||||||
|
db.token = {}
|
||||||
|
db:save()
|
||||||
|
end
|
||||||
|
|
||||||
|
local _M = {}
|
||||||
|
|
||||||
|
_M.chat = {}
|
||||||
|
|
||||||
|
log.outfile = config.logs_path..'gigachat_'..os.date('%Y-%m-%d')..'.log'
|
||||||
|
log.level = 'trace'
|
||||||
|
|
||||||
|
local function poster(data)
|
||||||
|
local result = {}
|
||||||
|
for i,k in pairs(data) do table.insert(result, i..'='..k) end
|
||||||
|
return table.concat(result,'&')
|
||||||
|
end
|
||||||
|
|
||||||
|
local function get_result(str)
|
||||||
|
local result, err = pcall(json.decode,str)
|
||||||
|
if result then
|
||||||
|
result = json.decode(str)
|
||||||
|
else
|
||||||
|
log.error(err)
|
||||||
|
return nil, err
|
||||||
|
end
|
||||||
|
return result,err
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Получение актуального токена доступа
|
||||||
|
|
||||||
|
function _M.token()
|
||||||
|
local token = table.remove(db.token, 1)
|
||||||
|
if token and tonumber(token.expires_at/1000) > os.time() then return token.access_token
|
||||||
|
else
|
||||||
|
local str = ''
|
||||||
|
local headers = {
|
||||||
|
'Content-type: application/x-www-form-urlencoded',
|
||||||
|
'Accept: application/json',
|
||||||
|
'RqUID: '..uuid,
|
||||||
|
'Authorization: Basic '..config.auth_key
|
||||||
|
}
|
||||||
|
local c = cURL.easy{
|
||||||
|
url = config.token_url,
|
||||||
|
post = true,
|
||||||
|
postfields = poster({scope=config.scope}),
|
||||||
|
httpheader = headers,
|
||||||
|
writefunction = function(st)
|
||||||
|
str = str..st
|
||||||
|
collectgarbage("collect")
|
||||||
|
return #st
|
||||||
|
end
|
||||||
|
}
|
||||||
|
local ok, err = c:perform()
|
||||||
|
local code = c:getinfo_response_code()
|
||||||
|
c:close()
|
||||||
|
if not ok then return nil, err end
|
||||||
|
if code ~= 200 then
|
||||||
|
log.error(str)
|
||||||
|
return nil,str
|
||||||
|
end
|
||||||
|
local res,err = get_result(str)
|
||||||
|
if res then
|
||||||
|
db.token = {res}
|
||||||
|
db:save()
|
||||||
|
return res.access_token
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
local function get(endpoint)
|
||||||
|
local str = ''
|
||||||
|
local url = config.base_url..endpoint
|
||||||
|
local headers = {
|
||||||
|
'Content-type: application/x-www-form-urlencoded',
|
||||||
|
'Accept: application/json',
|
||||||
|
'Authorization: Bearer '.._M.token()
|
||||||
|
}
|
||||||
|
local c = cURL.easy{
|
||||||
|
url = url,
|
||||||
|
httpheader = headers,
|
||||||
|
writefunction = function(st)
|
||||||
|
str = str..st
|
||||||
|
collectgarbage("collect")
|
||||||
|
return #st
|
||||||
|
end
|
||||||
|
}
|
||||||
|
local ok, err = c:perform()
|
||||||
|
local code = c:getinfo_response_code()
|
||||||
|
c:close()
|
||||||
|
if not ok then return nil, err end
|
||||||
|
if code ~= 200 then
|
||||||
|
log.error(str)
|
||||||
|
return nil,str
|
||||||
|
end
|
||||||
|
local res,err = get_result(str)
|
||||||
|
if res then return res end
|
||||||
|
return res,err
|
||||||
|
end
|
||||||
|
|
||||||
|
local function post(endpoint,data)
|
||||||
|
local str = ''
|
||||||
|
local url = config.base_url..endpoint
|
||||||
|
local headers = {
|
||||||
|
'Content-type: application/x-www-form-urlencoded',
|
||||||
|
'Accept: application/json',
|
||||||
|
'Authorization: Bearer '.._M.token()
|
||||||
|
}
|
||||||
|
local c = cURL.easy{
|
||||||
|
url = url,
|
||||||
|
httpheader = headers,
|
||||||
|
post = true,
|
||||||
|
postfields = json.encode(data),
|
||||||
|
writefunction = function(st)
|
||||||
|
str = str..st
|
||||||
|
collectgarbage("collect")
|
||||||
|
return #st
|
||||||
|
end
|
||||||
|
}
|
||||||
|
local ok, err = c:perform()
|
||||||
|
local code = c:getinfo_response_code()
|
||||||
|
c:close()
|
||||||
|
if not ok then return nil, err end
|
||||||
|
if code ~= 200 then
|
||||||
|
log.error(str)
|
||||||
|
return nil,str
|
||||||
|
end
|
||||||
|
local res,err = get_result(str)
|
||||||
|
if res then return res end
|
||||||
|
return res,err
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Получение списка доступных моделей
|
||||||
|
|
||||||
|
function _M.models()
|
||||||
|
local models,err = get('models')
|
||||||
|
if models and models.data then return models.data end
|
||||||
|
return models,err
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Генерация ответа модели
|
||||||
|
|
||||||
|
function _M.chat.completions(messages,model,max_tokens,repetition_penalty,stream,update_interval,temperature,top_p,function_call,functions)
|
||||||
|
if type(messages) == 'string' then messages = {{role='user',content=messages}} end
|
||||||
|
if temperature and top_p then return nil, 'Одновременно указаны top_p и температура выборки' end
|
||||||
|
return post('chat/completions',{
|
||||||
|
model = model or 'GigaChat-2',
|
||||||
|
stream = stream or false,
|
||||||
|
max_tokens = max_tokens or 512,
|
||||||
|
repetition_penalty = repetition_penalty or 1,
|
||||||
|
update_interval = update_interval or 0,
|
||||||
|
messages = messages
|
||||||
|
})
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Создание эмбеддингов
|
||||||
|
|
||||||
|
function _M.embeddings(model)
|
||||||
|
return post('embeddings',{
|
||||||
|
model = model or 'embeddings'
|
||||||
|
})
|
||||||
|
end
|
||||||
|
|
||||||
|
return _M
|
||||||
86
utils/flatdb.lua
Normal file
86
utils/flatdb.lua
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
local mp = require("MessagePack")
|
||||||
|
|
||||||
|
local function isFile(path)
|
||||||
|
local f = io.open(path, "r")
|
||||||
|
if f then
|
||||||
|
f:close()
|
||||||
|
return true
|
||||||
|
end
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
|
||||||
|
local function isDir(path)
|
||||||
|
path = string.gsub(path.."/", "//", "/")
|
||||||
|
local ok, err, code = os.rename(path, path)
|
||||||
|
if ok or code == 13 then
|
||||||
|
return true
|
||||||
|
end
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
|
||||||
|
local function load_page(path)
|
||||||
|
local ret
|
||||||
|
local f = io.open(path, "rb")
|
||||||
|
if f then
|
||||||
|
ret = mp.unpack(f:read("*a"))
|
||||||
|
f:close()
|
||||||
|
end
|
||||||
|
return ret
|
||||||
|
end
|
||||||
|
|
||||||
|
local function store_page(path, page)
|
||||||
|
if type(page) == "table" then
|
||||||
|
local f = io.open(path, "wb")
|
||||||
|
if f then
|
||||||
|
f:write(mp.pack(page))
|
||||||
|
f:close()
|
||||||
|
return true
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
|
||||||
|
local pool = {}
|
||||||
|
|
||||||
|
local db_funcs = {
|
||||||
|
save = function(db, p)
|
||||||
|
if p then
|
||||||
|
if type(p) == "string" and type(db[p]) == "table" then
|
||||||
|
return store_page(pool[db].."/"..p, db[p])
|
||||||
|
else
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
end
|
||||||
|
for p, page in pairs(db) do
|
||||||
|
if not store_page(pool[db].."/"..p, page) then
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return true
|
||||||
|
end
|
||||||
|
}
|
||||||
|
|
||||||
|
local mt = {
|
||||||
|
__index = function(db, k)
|
||||||
|
if db_funcs[k] then return db_funcs[k] end
|
||||||
|
if isFile(pool[db].."/"..k) then
|
||||||
|
db[k] = load_page(pool[db].."/"..k)
|
||||||
|
end
|
||||||
|
return rawget(db, k)
|
||||||
|
end
|
||||||
|
}
|
||||||
|
|
||||||
|
pool.hack = db_funcs
|
||||||
|
|
||||||
|
return setmetatable(pool, {
|
||||||
|
__mode = "kv",
|
||||||
|
__call = function(pool, path)
|
||||||
|
assert(isDir(path), path.." is not a directory.")
|
||||||
|
if pool[path] then return pool[path] end
|
||||||
|
local db = {}
|
||||||
|
setmetatable(db, mt)
|
||||||
|
pool[path] = db
|
||||||
|
pool[db] = path
|
||||||
|
return db
|
||||||
|
end
|
||||||
|
})
|
||||||
90
utils/log.lua
Normal file
90
utils/log.lua
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
--
|
||||||
|
-- log.lua
|
||||||
|
--
|
||||||
|
-- Copyright (c) 2016 rxi
|
||||||
|
--
|
||||||
|
-- This library is free software; you can redistribute it and/or modify it
|
||||||
|
-- under the terms of the MIT license. See LICENSE for details.
|
||||||
|
--
|
||||||
|
|
||||||
|
local log = { _version = "0.1.0" }
|
||||||
|
|
||||||
|
log.usecolor = true
|
||||||
|
log.outfile = nil
|
||||||
|
log.level = "trace"
|
||||||
|
|
||||||
|
|
||||||
|
local modes = {
|
||||||
|
{ name = "trace", color = "\27[34m", },
|
||||||
|
{ name = "debug", color = "\27[36m", },
|
||||||
|
{ name = "info", color = "\27[32m", },
|
||||||
|
{ name = "warn", color = "\27[33m", },
|
||||||
|
{ name = "error", color = "\27[31m", },
|
||||||
|
{ name = "fatal", color = "\27[35m", },
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
local levels = {}
|
||||||
|
for i, v in ipairs(modes) do
|
||||||
|
levels[v.name] = i
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
local round = function(x, increment)
|
||||||
|
increment = increment or 1
|
||||||
|
x = x / increment
|
||||||
|
return (x > 0 and math.floor(x + .5) or math.ceil(x - .5)) * increment
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
local _tostring = tostring
|
||||||
|
|
||||||
|
local tostring = function(...)
|
||||||
|
local t = {}
|
||||||
|
for i = 1, select('#', ...) do
|
||||||
|
local x = select(i, ...)
|
||||||
|
if type(x) == "number" then
|
||||||
|
x = round(x, .01)
|
||||||
|
end
|
||||||
|
t[#t + 1] = _tostring(x)
|
||||||
|
end
|
||||||
|
return table.concat(t, " ")
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
for i, x in ipairs(modes) do
|
||||||
|
local nameupper = x.name:upper()
|
||||||
|
log[x.name] = function(...)
|
||||||
|
|
||||||
|
-- Return early if we're below the log level
|
||||||
|
if i < levels[log.level] then
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
|
local msg = tostring(...)
|
||||||
|
local info = debug.getinfo(2, "Sl")
|
||||||
|
local lineinfo = info.short_src .. ":" .. info.currentline
|
||||||
|
|
||||||
|
-- Output to console
|
||||||
|
print(string.format("%s[%-6s%s]%s %s: %s",
|
||||||
|
log.usecolor and x.color or "",
|
||||||
|
nameupper,
|
||||||
|
os.date("%H:%M:%S"),
|
||||||
|
log.usecolor and "\27[0m" or "",
|
||||||
|
lineinfo,
|
||||||
|
msg))
|
||||||
|
|
||||||
|
-- Output to log file
|
||||||
|
if log.outfile then
|
||||||
|
local fp = io.open(log.outfile, "a")
|
||||||
|
local str = string.format("[%-6s%s] %s: %s\n",
|
||||||
|
nameupper, os.date(), lineinfo, msg)
|
||||||
|
fp:write(str)
|
||||||
|
fp:close()
|
||||||
|
end
|
||||||
|
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
return log
|
||||||
10
utils/uuid.lua
Normal file
10
utils/uuid.lua
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
local random = math.random
|
||||||
|
local function uuid()
|
||||||
|
local template ='xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'
|
||||||
|
return string.gsub(template, '[xy]', function (c)
|
||||||
|
local v = (c == 'x') and random(0, 0xf) or random(8, 0xb)
|
||||||
|
return string.format('%x', v)
|
||||||
|
end)
|
||||||
|
end
|
||||||
|
|
||||||
|
return uuid()
|
||||||
Loading…
x
Reference in New Issue
Block a user