|
| 1 | +import axios from 'axios'; |
| 2 | +import { Mistral } from '@mistralai/mistralai'; |
| 3 | +import { OpenAI } from 'openai'; |
| 4 | +import { GenerateCommitMessageErrorEnum } from '../generateCommitMessageFromGitDiff'; |
| 5 | +import { tokenCount } from '../utils/tokenCount'; |
| 6 | +import { AiEngine, AiEngineConfig } from './Engine'; |
| 7 | +import { |
| 8 | + AssistantMessage as MistralAssistantMessage, |
| 9 | + SystemMessage as MistralSystemMessage, |
| 10 | + ToolMessage as MistralToolMessage, |
| 11 | + UserMessage as MistralUserMessage |
| 12 | +} from '@mistralai/mistralai/models/components'; |
| 13 | + |
| 14 | +export interface MistralAiConfig extends AiEngineConfig {} |
| 15 | +export type MistralCompletionMessageParam = Array< |
| 16 | +| (MistralSystemMessage & { role: "system" }) |
| 17 | +| (MistralUserMessage & { role: "user" }) |
| 18 | +| (MistralAssistantMessage & { role: "assistant" }) |
| 19 | +| (MistralToolMessage & { role: "tool" }) |
| 20 | +> |
| 21 | + |
| 22 | +export class MistralAiEngine implements AiEngine { |
| 23 | + config: MistralAiConfig; |
| 24 | + client: Mistral; |
| 25 | + |
| 26 | + constructor(config: MistralAiConfig) { |
| 27 | + this.config = config; |
| 28 | + |
| 29 | + if (!config.baseURL) { |
| 30 | + this.client = new Mistral({ apiKey: config.apiKey }); |
| 31 | + } else { |
| 32 | + this.client = new Mistral({ apiKey: config.apiKey, serverURL: config.baseURL }); |
| 33 | + } |
| 34 | + } |
| 35 | + |
| 36 | + public generateCommitMessage = async ( |
| 37 | + messages: Array<OpenAI.Chat.Completions.ChatCompletionMessageParam> |
| 38 | + ): Promise<string | null> => { |
| 39 | + const params = { |
| 40 | + model: this.config.model, |
| 41 | + messages: messages as MistralCompletionMessageParam, |
| 42 | + topP: 0.1, |
| 43 | + maxTokens: this.config.maxTokensOutput |
| 44 | + }; |
| 45 | + |
| 46 | + try { |
| 47 | + const REQUEST_TOKENS = messages |
| 48 | + .map((msg) => tokenCount(msg.content as string) + 4) |
| 49 | + .reduce((a, b) => a + b, 0); |
| 50 | + |
| 51 | + if ( |
| 52 | + REQUEST_TOKENS > |
| 53 | + this.config.maxTokensInput - this.config.maxTokensOutput |
| 54 | + ) |
| 55 | + throw new Error(GenerateCommitMessageErrorEnum.tooMuchTokens); |
| 56 | + |
| 57 | + const completion = await this.client.chat.complete(params); |
| 58 | + |
| 59 | + if (!completion.choices) |
| 60 | + throw Error('No completion choice available.') |
| 61 | + |
| 62 | + const message = completion.choices[0].message; |
| 63 | + |
| 64 | + if (!message || !message.content) |
| 65 | + throw Error('No completion choice available.') |
| 66 | + |
| 67 | + return message.content as string; |
| 68 | + } catch (error) { |
| 69 | + const err = error as Error; |
| 70 | + if ( |
| 71 | + axios.isAxiosError<{ error?: { message: string } }>(error) && |
| 72 | + error.response?.status === 401 |
| 73 | + ) { |
| 74 | + const mistralError = error.response.data.error; |
| 75 | + |
| 76 | + if (mistralError) throw new Error(mistralError.message); |
| 77 | + } |
| 78 | + |
| 79 | + throw err; |
| 80 | + } |
| 81 | + }; |
| 82 | +} |
0 commit comments