This commit is contained in:
Татьяна Фарбер 2026-02-04 17:52:10 +04:00
parent 0b9d5e4cfa
commit 54679d5e6d
8 changed files with 444 additions and 2 deletions

View File

@ -1,3 +1,22 @@
# gigchat
# GigaChat API Lua SDK
GigaChat API Lua SDK
Документация API: https://developers.sber.ru/docs/ru/gigachat/api/reference/rest/
## Использование
Необходимо переименовать gigachat.test_conf в gigachat.conf и заполнить в нем креды
```
local api = require('gigachat.api')
-- Получение списка моделей
api.models()
-- Запрос к чату (строка или массив сообщений)
api.chat.completions('тест')
```

52
config/config.lua Normal file
View File

@ -0,0 +1,52 @@
local json = require('cjson.safe')
local _M = {}
_M.data = {}
_M.comments = {}
_M.file = '' -- файл конфигурации
local key
function _M.read()
for line in io.lines(_M.file) do
key = string.match(line, '([%w_]+)::')
if (key) then
_M.data[key] = string.match(line, '::(.*) #')
_M.comments[key] = string.match(line, '#(.*)')
if string.find(_M.data[key],'%{%"') or string.find(_M.data[key],'%[%"') then
_M.data[key] = json.decode(_M.data[key])
end
end
end
end
function reprint(k,v)
if type(v) == 'table' then
for i,j in pairs(v) do
reprint(i,j)
end
else
print(k..': '..v)
end
end
function _M.data:write()
local config_file = io.open(_M.file, 'w')
for k,v in pairs(self) do
if type(v) ~= 'function' then
if type(v) == 'table' then v = json.encode(v) end
config_file:write(k..'::'..v..' #'.._M.comments[key]..'\n')
end
end
config_file:close()
end
function _M.data:print()
for k,v in pairs(self) do
if type(v) ~= 'function' then
if _M.comments[k] then print('\n'.._M.comments[k]:gsub("^%s*(.-)%s*$", "%1")..': \n') end
reprint(k,v)
end
end
end
return _M

5
config/gigachat.lua Normal file
View File

@ -0,0 +1,5 @@
local config = require('config.config')
local _M = config
_M.file = 'gigachat.conf' -- файл конфигурации
_M.read()
return _M.data

8
gigachat.test_conf Normal file
View File

@ -0,0 +1,8 @@
auth_key:: #Ключ авторизации
client_id:: #ClientID
scope::GIGACHAT_API_PERS #Scope
token_url::https://ngw.devices.sberbank.ru:9443/api/v2/oauth #Token endpoint
base_url::https://gigachat.devices.sberbank.ru/api/v1/ #Базовый url для информационных ответов
logs_path::/var/www/gigachat/logs/ #Путь к логам
db_path::/var/www/gigachat #Путь к бд хранения токенов

172
gigachat/api.lua Normal file
View File

@ -0,0 +1,172 @@
local json = require('cjson')
local cURL = require('cURL')
local log = require('utils.log')
local config = require('config.gigachat')
local uuid = require('utils.uuid')
local flatdb = require('utils.flatdb')
local db = flatdb(config.db_path)
if not db.token then
db.token = {}
db:save()
end
local _M = {}
_M.chat = {}
log.outfile = config.logs_path..'gigachat_'..os.date('%Y-%m-%d')..'.log'
log.level = 'trace'
local function poster(data)
local result = {}
for i,k in pairs(data) do table.insert(result, i..'='..k) end
return table.concat(result,'&')
end
local function get_result(str)
local result, err = pcall(json.decode,str)
if result then
result = json.decode(str)
else
log.error(err)
return nil, err
end
return result,err
end
-- Получение актуального токена доступа
function _M.token()
local token = table.remove(db.token, 1)
if token and tonumber(token.expires_at/1000) > os.time() then return token.access_token
else
local str = ''
local headers = {
'Content-type: application/x-www-form-urlencoded',
'Accept: application/json',
'RqUID: '..uuid,
'Authorization: Basic '..config.auth_key
}
local c = cURL.easy{
url = config.token_url,
post = true,
postfields = poster({scope=config.scope}),
httpheader = headers,
writefunction = function(st)
str = str..st
collectgarbage("collect")
return #st
end
}
local ok, err = c:perform()
local code = c:getinfo_response_code()
c:close()
if not ok then return nil, err end
if code ~= 200 then
log.error(str)
return nil,str
end
local res,err = get_result(str)
if res then
db.token = {res}
db:save()
return res.access_token
end
end
end
local function get(endpoint)
local str = ''
local url = config.base_url..endpoint
local headers = {
'Content-type: application/x-www-form-urlencoded',
'Accept: application/json',
'Authorization: Bearer '.._M.token()
}
local c = cURL.easy{
url = url,
httpheader = headers,
writefunction = function(st)
str = str..st
collectgarbage("collect")
return #st
end
}
local ok, err = c:perform()
local code = c:getinfo_response_code()
c:close()
if not ok then return nil, err end
if code ~= 200 then
log.error(str)
return nil,str
end
local res,err = get_result(str)
if res then return res end
return res,err
end
local function post(endpoint,data)
local str = ''
local url = config.base_url..endpoint
local headers = {
'Content-type: application/x-www-form-urlencoded',
'Accept: application/json',
'Authorization: Bearer '.._M.token()
}
local c = cURL.easy{
url = url,
httpheader = headers,
post = true,
postfields = json.encode(data),
writefunction = function(st)
str = str..st
collectgarbage("collect")
return #st
end
}
local ok, err = c:perform()
local code = c:getinfo_response_code()
c:close()
if not ok then return nil, err end
if code ~= 200 then
log.error(str)
return nil,str
end
local res,err = get_result(str)
if res then return res end
return res,err
end
-- Получение списка доступных моделей
function _M.models()
local models,err = get('models')
if models and models.data then return models.data end
return models,err
end
-- Генерация ответа модели
function _M.chat.completions(messages,model,max_tokens,repetition_penalty,stream,update_interval,temperature,top_p,function_call,functions)
if type(messages) == 'string' then messages = {{role='user',content=messages}} end
if temperature and top_p then return nil, 'Одновременно указаны top_p и температура выборки' end
return post('chat/completions',{
model = model or 'GigaChat-2',
stream = stream or false,
max_tokens = max_tokens or 512,
repetition_penalty = repetition_penalty or 1,
update_interval = update_interval or 0,
messages = messages
})
end
-- Создание эмбеддингов
function _M.embeddings(model)
return post('embeddings',{
model = model or 'embeddings'
})
end
return _M

86
utils/flatdb.lua Normal file
View File

@ -0,0 +1,86 @@
local mp = require("MessagePack")
local function isFile(path)
local f = io.open(path, "r")
if f then
f:close()
return true
end
return false
end
local function isDir(path)
path = string.gsub(path.."/", "//", "/")
local ok, err, code = os.rename(path, path)
if ok or code == 13 then
return true
end
return false
end
local function load_page(path)
local ret
local f = io.open(path, "rb")
if f then
ret = mp.unpack(f:read("*a"))
f:close()
end
return ret
end
local function store_page(path, page)
if type(page) == "table" then
local f = io.open(path, "wb")
if f then
f:write(mp.pack(page))
f:close()
return true
end
end
return false
end
local pool = {}
local db_funcs = {
save = function(db, p)
if p then
if type(p) == "string" and type(db[p]) == "table" then
return store_page(pool[db].."/"..p, db[p])
else
return false
end
end
for p, page in pairs(db) do
if not store_page(pool[db].."/"..p, page) then
return false
end
end
return true
end
}
local mt = {
__index = function(db, k)
if db_funcs[k] then return db_funcs[k] end
if isFile(pool[db].."/"..k) then
db[k] = load_page(pool[db].."/"..k)
end
return rawget(db, k)
end
}
pool.hack = db_funcs
return setmetatable(pool, {
__mode = "kv",
__call = function(pool, path)
assert(isDir(path), path.." is not a directory.")
if pool[path] then return pool[path] end
local db = {}
setmetatable(db, mt)
pool[path] = db
pool[db] = path
return db
end
})

90
utils/log.lua Normal file
View File

@ -0,0 +1,90 @@
--
-- log.lua
--
-- Copyright (c) 2016 rxi
--
-- This library is free software; you can redistribute it and/or modify it
-- under the terms of the MIT license. See LICENSE for details.
--
local log = { _version = "0.1.0" }
log.usecolor = true
log.outfile = nil
log.level = "trace"
local modes = {
{ name = "trace", color = "\27[34m", },
{ name = "debug", color = "\27[36m", },
{ name = "info", color = "\27[32m", },
{ name = "warn", color = "\27[33m", },
{ name = "error", color = "\27[31m", },
{ name = "fatal", color = "\27[35m", },
}
local levels = {}
for i, v in ipairs(modes) do
levels[v.name] = i
end
local round = function(x, increment)
increment = increment or 1
x = x / increment
return (x > 0 and math.floor(x + .5) or math.ceil(x - .5)) * increment
end
local _tostring = tostring
local tostring = function(...)
local t = {}
for i = 1, select('#', ...) do
local x = select(i, ...)
if type(x) == "number" then
x = round(x, .01)
end
t[#t + 1] = _tostring(x)
end
return table.concat(t, " ")
end
for i, x in ipairs(modes) do
local nameupper = x.name:upper()
log[x.name] = function(...)
-- Return early if we're below the log level
if i < levels[log.level] then
return
end
local msg = tostring(...)
local info = debug.getinfo(2, "Sl")
local lineinfo = info.short_src .. ":" .. info.currentline
-- Output to console
print(string.format("%s[%-6s%s]%s %s: %s",
log.usecolor and x.color or "",
nameupper,
os.date("%H:%M:%S"),
log.usecolor and "\27[0m" or "",
lineinfo,
msg))
-- Output to log file
if log.outfile then
local fp = io.open(log.outfile, "a")
local str = string.format("[%-6s%s] %s: %s\n",
nameupper, os.date(), lineinfo, msg)
fp:write(str)
fp:close()
end
end
end
return log

10
utils/uuid.lua Normal file
View File

@ -0,0 +1,10 @@
local random = math.random
local function uuid()
local template ='xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'
return string.gsub(template, '[xy]', function (c)
local v = (c == 'x') and random(0, 0xf) or random(8, 0xb)
return string.format('%x', v)
end)
end
return uuid()