base api
This commit is contained in:
parent
0b9d5e4cfa
commit
54679d5e6d
23
README.md
23
README.md
@ -1,3 +1,22 @@
|
||||
# gigchat
|
||||
# GigaChat API Lua SDK
|
||||
|
||||
GigaChat API Lua SDK
|
||||
Документация API: https://developers.sber.ru/docs/ru/gigachat/api/reference/rest/
|
||||
|
||||
## Использование
|
||||
|
||||
Необходимо переименовать gigachat.test_conf в gigachat.conf и заполнить в нем креды
|
||||
|
||||
```
|
||||
|
||||
local api = require('gigachat.api')
|
||||
|
||||
-- Получение списка моделей
|
||||
|
||||
api.models()
|
||||
|
||||
-- Запрос к чату (строка или массив сообщений)
|
||||
|
||||
api.chat.completions('тест')
|
||||
|
||||
|
||||
```
|
||||
|
||||
52
config/config.lua
Normal file
52
config/config.lua
Normal file
@ -0,0 +1,52 @@
|
||||
local json = require('cjson.safe')
|
||||
local _M = {}
|
||||
_M.data = {}
|
||||
_M.comments = {}
|
||||
_M.file = '' -- файл конфигурации
|
||||
local key
|
||||
|
||||
function _M.read()
|
||||
for line in io.lines(_M.file) do
|
||||
key = string.match(line, '([%w_]+)::')
|
||||
if (key) then
|
||||
_M.data[key] = string.match(line, '::(.*) #')
|
||||
_M.comments[key] = string.match(line, '#(.*)')
|
||||
if string.find(_M.data[key],'%{%"') or string.find(_M.data[key],'%[%"') then
|
||||
_M.data[key] = json.decode(_M.data[key])
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function reprint(k,v)
|
||||
if type(v) == 'table' then
|
||||
for i,j in pairs(v) do
|
||||
reprint(i,j)
|
||||
end
|
||||
else
|
||||
print(k..': '..v)
|
||||
end
|
||||
end
|
||||
|
||||
function _M.data:write()
|
||||
local config_file = io.open(_M.file, 'w')
|
||||
for k,v in pairs(self) do
|
||||
if type(v) ~= 'function' then
|
||||
if type(v) == 'table' then v = json.encode(v) end
|
||||
config_file:write(k..'::'..v..' #'.._M.comments[key]..'\n')
|
||||
end
|
||||
end
|
||||
config_file:close()
|
||||
end
|
||||
|
||||
function _M.data:print()
|
||||
for k,v in pairs(self) do
|
||||
if type(v) ~= 'function' then
|
||||
if _M.comments[k] then print('\n'.._M.comments[k]:gsub("^%s*(.-)%s*$", "%1")..': \n') end
|
||||
reprint(k,v)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return _M
|
||||
|
||||
5
config/gigachat.lua
Normal file
5
config/gigachat.lua
Normal file
@ -0,0 +1,5 @@
|
||||
local config = require('config.config')
|
||||
local _M = config
|
||||
_M.file = 'gigachat.conf' -- файл конфигурации
|
||||
_M.read()
|
||||
return _M.data
|
||||
8
gigachat.test_conf
Normal file
8
gigachat.test_conf
Normal file
@ -0,0 +1,8 @@
|
||||
auth_key:: #Ключ авторизации
|
||||
client_id:: #ClientID
|
||||
scope::GIGACHAT_API_PERS #Scope
|
||||
token_url::https://ngw.devices.sberbank.ru:9443/api/v2/oauth #Token endpoint
|
||||
base_url::https://gigachat.devices.sberbank.ru/api/v1/ #Базовый url для информационных ответов
|
||||
logs_path::/var/www/gigachat/logs/ #Путь к логам
|
||||
db_path::/var/www/gigachat #Путь к бд хранения токенов
|
||||
|
||||
172
gigachat/api.lua
Normal file
172
gigachat/api.lua
Normal file
@ -0,0 +1,172 @@
|
||||
local json = require('cjson')
|
||||
local cURL = require('cURL')
|
||||
local log = require('utils.log')
|
||||
local config = require('config.gigachat')
|
||||
local uuid = require('utils.uuid')
|
||||
local flatdb = require('utils.flatdb')
|
||||
local db = flatdb(config.db_path)
|
||||
|
||||
if not db.token then
|
||||
db.token = {}
|
||||
db:save()
|
||||
end
|
||||
|
||||
local _M = {}
|
||||
|
||||
_M.chat = {}
|
||||
|
||||
log.outfile = config.logs_path..'gigachat_'..os.date('%Y-%m-%d')..'.log'
|
||||
log.level = 'trace'
|
||||
|
||||
local function poster(data)
|
||||
local result = {}
|
||||
for i,k in pairs(data) do table.insert(result, i..'='..k) end
|
||||
return table.concat(result,'&')
|
||||
end
|
||||
|
||||
local function get_result(str)
|
||||
local result, err = pcall(json.decode,str)
|
||||
if result then
|
||||
result = json.decode(str)
|
||||
else
|
||||
log.error(err)
|
||||
return nil, err
|
||||
end
|
||||
return result,err
|
||||
end
|
||||
|
||||
-- Получение актуального токена доступа
|
||||
|
||||
function _M.token()
|
||||
local token = table.remove(db.token, 1)
|
||||
if token and tonumber(token.expires_at/1000) > os.time() then return token.access_token
|
||||
else
|
||||
local str = ''
|
||||
local headers = {
|
||||
'Content-type: application/x-www-form-urlencoded',
|
||||
'Accept: application/json',
|
||||
'RqUID: '..uuid,
|
||||
'Authorization: Basic '..config.auth_key
|
||||
}
|
||||
local c = cURL.easy{
|
||||
url = config.token_url,
|
||||
post = true,
|
||||
postfields = poster({scope=config.scope}),
|
||||
httpheader = headers,
|
||||
writefunction = function(st)
|
||||
str = str..st
|
||||
collectgarbage("collect")
|
||||
return #st
|
||||
end
|
||||
}
|
||||
local ok, err = c:perform()
|
||||
local code = c:getinfo_response_code()
|
||||
c:close()
|
||||
if not ok then return nil, err end
|
||||
if code ~= 200 then
|
||||
log.error(str)
|
||||
return nil,str
|
||||
end
|
||||
local res,err = get_result(str)
|
||||
if res then
|
||||
db.token = {res}
|
||||
db:save()
|
||||
return res.access_token
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function get(endpoint)
|
||||
local str = ''
|
||||
local url = config.base_url..endpoint
|
||||
local headers = {
|
||||
'Content-type: application/x-www-form-urlencoded',
|
||||
'Accept: application/json',
|
||||
'Authorization: Bearer '.._M.token()
|
||||
}
|
||||
local c = cURL.easy{
|
||||
url = url,
|
||||
httpheader = headers,
|
||||
writefunction = function(st)
|
||||
str = str..st
|
||||
collectgarbage("collect")
|
||||
return #st
|
||||
end
|
||||
}
|
||||
local ok, err = c:perform()
|
||||
local code = c:getinfo_response_code()
|
||||
c:close()
|
||||
if not ok then return nil, err end
|
||||
if code ~= 200 then
|
||||
log.error(str)
|
||||
return nil,str
|
||||
end
|
||||
local res,err = get_result(str)
|
||||
if res then return res end
|
||||
return res,err
|
||||
end
|
||||
|
||||
local function post(endpoint,data)
|
||||
local str = ''
|
||||
local url = config.base_url..endpoint
|
||||
local headers = {
|
||||
'Content-type: application/x-www-form-urlencoded',
|
||||
'Accept: application/json',
|
||||
'Authorization: Bearer '.._M.token()
|
||||
}
|
||||
local c = cURL.easy{
|
||||
url = url,
|
||||
httpheader = headers,
|
||||
post = true,
|
||||
postfields = json.encode(data),
|
||||
writefunction = function(st)
|
||||
str = str..st
|
||||
collectgarbage("collect")
|
||||
return #st
|
||||
end
|
||||
}
|
||||
local ok, err = c:perform()
|
||||
local code = c:getinfo_response_code()
|
||||
c:close()
|
||||
if not ok then return nil, err end
|
||||
if code ~= 200 then
|
||||
log.error(str)
|
||||
return nil,str
|
||||
end
|
||||
local res,err = get_result(str)
|
||||
if res then return res end
|
||||
return res,err
|
||||
end
|
||||
|
||||
-- Получение списка доступных моделей
|
||||
|
||||
function _M.models()
|
||||
local models,err = get('models')
|
||||
if models and models.data then return models.data end
|
||||
return models,err
|
||||
end
|
||||
|
||||
-- Генерация ответа модели
|
||||
|
||||
function _M.chat.completions(messages,model,max_tokens,repetition_penalty,stream,update_interval,temperature,top_p,function_call,functions)
|
||||
if type(messages) == 'string' then messages = {{role='user',content=messages}} end
|
||||
if temperature and top_p then return nil, 'Одновременно указаны top_p и температура выборки' end
|
||||
return post('chat/completions',{
|
||||
model = model or 'GigaChat-2',
|
||||
stream = stream or false,
|
||||
max_tokens = max_tokens or 512,
|
||||
repetition_penalty = repetition_penalty or 1,
|
||||
update_interval = update_interval or 0,
|
||||
messages = messages
|
||||
})
|
||||
end
|
||||
|
||||
-- Создание эмбеддингов
|
||||
|
||||
function _M.embeddings(model)
|
||||
return post('embeddings',{
|
||||
model = model or 'embeddings'
|
||||
})
|
||||
end
|
||||
|
||||
return _M
|
||||
86
utils/flatdb.lua
Normal file
86
utils/flatdb.lua
Normal file
@ -0,0 +1,86 @@
|
||||
local mp = require("MessagePack")
|
||||
|
||||
local function isFile(path)
|
||||
local f = io.open(path, "r")
|
||||
if f then
|
||||
f:close()
|
||||
return true
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
local function isDir(path)
|
||||
path = string.gsub(path.."/", "//", "/")
|
||||
local ok, err, code = os.rename(path, path)
|
||||
if ok or code == 13 then
|
||||
return true
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
local function load_page(path)
|
||||
local ret
|
||||
local f = io.open(path, "rb")
|
||||
if f then
|
||||
ret = mp.unpack(f:read("*a"))
|
||||
f:close()
|
||||
end
|
||||
return ret
|
||||
end
|
||||
|
||||
local function store_page(path, page)
|
||||
if type(page) == "table" then
|
||||
local f = io.open(path, "wb")
|
||||
if f then
|
||||
f:write(mp.pack(page))
|
||||
f:close()
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
local pool = {}
|
||||
|
||||
local db_funcs = {
|
||||
save = function(db, p)
|
||||
if p then
|
||||
if type(p) == "string" and type(db[p]) == "table" then
|
||||
return store_page(pool[db].."/"..p, db[p])
|
||||
else
|
||||
return false
|
||||
end
|
||||
end
|
||||
for p, page in pairs(db) do
|
||||
if not store_page(pool[db].."/"..p, page) then
|
||||
return false
|
||||
end
|
||||
end
|
||||
return true
|
||||
end
|
||||
}
|
||||
|
||||
local mt = {
|
||||
__index = function(db, k)
|
||||
if db_funcs[k] then return db_funcs[k] end
|
||||
if isFile(pool[db].."/"..k) then
|
||||
db[k] = load_page(pool[db].."/"..k)
|
||||
end
|
||||
return rawget(db, k)
|
||||
end
|
||||
}
|
||||
|
||||
pool.hack = db_funcs
|
||||
|
||||
return setmetatable(pool, {
|
||||
__mode = "kv",
|
||||
__call = function(pool, path)
|
||||
assert(isDir(path), path.." is not a directory.")
|
||||
if pool[path] then return pool[path] end
|
||||
local db = {}
|
||||
setmetatable(db, mt)
|
||||
pool[path] = db
|
||||
pool[db] = path
|
||||
return db
|
||||
end
|
||||
})
|
||||
90
utils/log.lua
Normal file
90
utils/log.lua
Normal file
@ -0,0 +1,90 @@
|
||||
--
|
||||
-- log.lua
|
||||
--
|
||||
-- Copyright (c) 2016 rxi
|
||||
--
|
||||
-- This library is free software; you can redistribute it and/or modify it
|
||||
-- under the terms of the MIT license. See LICENSE for details.
|
||||
--
|
||||
|
||||
local log = { _version = "0.1.0" }
|
||||
|
||||
log.usecolor = true
|
||||
log.outfile = nil
|
||||
log.level = "trace"
|
||||
|
||||
|
||||
local modes = {
|
||||
{ name = "trace", color = "\27[34m", },
|
||||
{ name = "debug", color = "\27[36m", },
|
||||
{ name = "info", color = "\27[32m", },
|
||||
{ name = "warn", color = "\27[33m", },
|
||||
{ name = "error", color = "\27[31m", },
|
||||
{ name = "fatal", color = "\27[35m", },
|
||||
}
|
||||
|
||||
|
||||
local levels = {}
|
||||
for i, v in ipairs(modes) do
|
||||
levels[v.name] = i
|
||||
end
|
||||
|
||||
|
||||
local round = function(x, increment)
|
||||
increment = increment or 1
|
||||
x = x / increment
|
||||
return (x > 0 and math.floor(x + .5) or math.ceil(x - .5)) * increment
|
||||
end
|
||||
|
||||
|
||||
local _tostring = tostring
|
||||
|
||||
local tostring = function(...)
|
||||
local t = {}
|
||||
for i = 1, select('#', ...) do
|
||||
local x = select(i, ...)
|
||||
if type(x) == "number" then
|
||||
x = round(x, .01)
|
||||
end
|
||||
t[#t + 1] = _tostring(x)
|
||||
end
|
||||
return table.concat(t, " ")
|
||||
end
|
||||
|
||||
|
||||
for i, x in ipairs(modes) do
|
||||
local nameupper = x.name:upper()
|
||||
log[x.name] = function(...)
|
||||
|
||||
-- Return early if we're below the log level
|
||||
if i < levels[log.level] then
|
||||
return
|
||||
end
|
||||
|
||||
local msg = tostring(...)
|
||||
local info = debug.getinfo(2, "Sl")
|
||||
local lineinfo = info.short_src .. ":" .. info.currentline
|
||||
|
||||
-- Output to console
|
||||
print(string.format("%s[%-6s%s]%s %s: %s",
|
||||
log.usecolor and x.color or "",
|
||||
nameupper,
|
||||
os.date("%H:%M:%S"),
|
||||
log.usecolor and "\27[0m" or "",
|
||||
lineinfo,
|
||||
msg))
|
||||
|
||||
-- Output to log file
|
||||
if log.outfile then
|
||||
local fp = io.open(log.outfile, "a")
|
||||
local str = string.format("[%-6s%s] %s: %s\n",
|
||||
nameupper, os.date(), lineinfo, msg)
|
||||
fp:write(str)
|
||||
fp:close()
|
||||
end
|
||||
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
return log
|
||||
10
utils/uuid.lua
Normal file
10
utils/uuid.lua
Normal file
@ -0,0 +1,10 @@
|
||||
local random = math.random
|
||||
local function uuid()
|
||||
local template ='xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'
|
||||
return string.gsub(template, '[xy]', function (c)
|
||||
local v = (c == 'x') and random(0, 0xf) or random(8, 0xb)
|
||||
return string.format('%x', v)
|
||||
end)
|
||||
end
|
||||
|
||||
return uuid()
|
||||
Loading…
x
Reference in New Issue
Block a user