- Downgrade LLamaSharp packages to 0.25.0 to match Vulkan backend availability - Add LLamaSharp.Backend.Vulkan for AMD/Intel/NVIDIA GPU acceleration - Fix _gpuLayers bug: was reading LlamaCppTimeout instead of a dedicated field - Add GpuLayerCount to JournalConfig, sourced from JOURNAL_GPU_LAYERS env var - Document AI/LLM notes in README (version pinning, known vulkaninfo issue) Co-Authored-By: Oz <oz-agent@warp.dev>
330 lines
12 KiB
C#
330 lines
12 KiB
C#
using System.Text;
|
|
using System.Text.RegularExpressions;
|
|
using Journal.Core.Dtos;
|
|
using Journal.Core.Models;
|
|
using Journal.Core.Services.Ai;
|
|
using LLama;
|
|
using LLama.Common;
|
|
using LLama.Sampling;
|
|
|
|
namespace Journal.AI;
|
|
|
|
public sealed partial class LlamaSharpAiService(JournalConfig config) : IAiService, IDisposable
|
|
{
|
|
private const string DefaultModelUrl =
|
|
"https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf";
|
|
private const string DefaultModelFileName = "Phi-3-mini-4k-instruct-q4.gguf";
|
|
private const string ModelSubDirectory = "ai-models";
|
|
|
|
private readonly string _configuredModelPath = config.GgufModelPath;
|
|
private readonly uint _contextSize = (uint)Math.Clamp(config.ModelContextTokens, 512, 4096);
|
|
private readonly int _gpuLayers = config.GpuLayerCount;
|
|
|
|
private readonly Lock _sync = new();
|
|
private string? _resolvedModelPath;
|
|
private LLamaWeights? _weights;
|
|
private bool _disposed;
|
|
|
|
public Task<AiHealthDto> HealthAsync(CancellationToken cancellationToken = default)
|
|
{
|
|
var resolved = _resolvedModelPath ?? _configuredModelPath;
|
|
var modelExists = File.Exists(resolved) || File.Exists(GetDefaultModelPath());
|
|
var loaded = _weights is not null;
|
|
return Task.FromResult(new AiHealthDto(
|
|
Provider: "llamasharp",
|
|
Enabled: true,
|
|
Healthy: modelExists || loaded,
|
|
Message: loaded
|
|
? "Model loaded."
|
|
: modelExists
|
|
? "Model found (will load on first use)."
|
|
: "Model not found locally. It will be downloaded on first use."));
|
|
}
|
|
|
|
private static string BuildChatSystemPrompt()
|
|
{
|
|
var dateStr = DateTime.Now.ToString("MMMM d, yyyy");
|
|
return $"You are a supportive conversational coach inside a private journaling app. " +
|
|
$"Today's date is {dateStr}. " +
|
|
$"Reply in plain natural language only. Never output JSON, code blocks, or structured data. " +
|
|
$"Be warm, practical, and concise. Do not repeat yourself.";
|
|
}
|
|
|
|
public async Task<string> ChatAsync(string prompt, CancellationToken cancellationToken = default)
|
|
{
|
|
if (string.IsNullOrWhiteSpace(prompt))
|
|
throw new ArgumentException("Prompt is required.", nameof(prompt));
|
|
|
|
var raw = await RunSessionAsync(prompt, BuildChatSystemPrompt(),
|
|
maxTokens: 512, cancellationToken: cancellationToken);
|
|
return CleanChatResponse(raw);
|
|
}
|
|
|
|
public async Task<string> ChatWithHistoryAsync(IReadOnlyList<(string Role, string Text)> history,
|
|
string prompt, CancellationToken cancellationToken = default)
|
|
{
|
|
if (string.IsNullOrWhiteSpace(prompt))
|
|
throw new ArgumentException("Prompt is required.", nameof(prompt));
|
|
|
|
var modelPath = await EnsureModelAsync(cancellationToken);
|
|
EnsureWeights(modelPath);
|
|
|
|
using var context = _weights!.CreateContext(new ModelParams(modelPath)
|
|
{
|
|
ContextSize = _contextSize,
|
|
GpuLayerCount = _gpuLayers
|
|
});
|
|
|
|
var executor = new StatelessExecutor(_weights!, context.Params);
|
|
|
|
var sb = new StringBuilder();
|
|
sb.Append($"<|system|>\n{BuildChatSystemPrompt()}<|end|>\n");
|
|
|
|
foreach (var (role, text) in history)
|
|
{
|
|
var tag = string.Equals(role, "user", StringComparison.OrdinalIgnoreCase) ? "user" : "assistant";
|
|
sb.Append($"<|{tag}|>\n{text}<|end|>\n");
|
|
}
|
|
|
|
sb.Append($"<|user|>\n{prompt}<|end|>\n");
|
|
sb.Append("<|assistant|>\n");
|
|
|
|
var inferenceParams = new InferenceParams
|
|
{
|
|
MaxTokens = 512,
|
|
AntiPrompts = ["<|user|>", "<|system|>", "<|end|>", "<|endoftext|>"],
|
|
SamplingPipeline = new DefaultSamplingPipeline { Temperature = 0.7f }
|
|
};
|
|
|
|
var result = new StringBuilder();
|
|
await foreach (var token in executor.InferAsync(sb.ToString(), inferenceParams, cancellationToken))
|
|
{
|
|
result.Append(token);
|
|
}
|
|
|
|
return CleanChatResponse(StripSpecialTokens(result.ToString()));
|
|
}
|
|
|
|
internal async Task<string> ChatJsonAsync(string prompt, CancellationToken cancellationToken = default)
|
|
{
|
|
if (string.IsNullOrWhiteSpace(prompt))
|
|
throw new ArgumentException("Prompt is required.", nameof(prompt));
|
|
|
|
var dateStr = DateTime.Now.ToString("MMMM d, yyyy");
|
|
return await RunSessionAsync(prompt,
|
|
$"You are a coaching assistant inside a private journaling app. " +
|
|
$"Today's date is {dateStr}. " +
|
|
$"You MUST respond with ONLY a single valid JSON object. " +
|
|
$"Do NOT write any text, explanation, or commentary before or after the JSON. " +
|
|
$"Output MUST start with {{ and end with }}.",
|
|
maxTokens: 2048, temperature: 0.2f, cancellationToken: cancellationToken);
|
|
}
|
|
|
|
public async Task<string> SummarizeEntryAsync(string content, string? fileStem = null,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
if (string.IsNullOrWhiteSpace(content))
|
|
throw new ArgumentException("Entry content is required.", nameof(content));
|
|
|
|
var prompt = fileStem is not null
|
|
? $"Summarize this journal entry ({fileStem}) concisely:\n\n{content}"
|
|
: $"Summarize this journal entry concisely:\n\n{content}";
|
|
|
|
return await ChatAsync(prompt, cancellationToken);
|
|
}
|
|
|
|
public async Task<string> SummarizeAllAsync(IReadOnlyList<string> entries,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
if (entries is null || entries.Count == 0)
|
|
return "No entries to summarize.";
|
|
|
|
var combined = string.Join("\n\n---\n\n", entries);
|
|
var prompt = $"Summarize the following {entries.Count} journal entries into a concise overview:\n\n{combined}";
|
|
|
|
return await ChatAsync(prompt, cancellationToken);
|
|
}
|
|
|
|
public async Task<IReadOnlyList<double>> EmbedAsync(string content,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
if (string.IsNullOrWhiteSpace(content))
|
|
throw new ArgumentException("Content is required.", nameof(content));
|
|
|
|
var modelPath = await EnsureModelAsync(cancellationToken);
|
|
|
|
try
|
|
{
|
|
EnsureWeights(modelPath);
|
|
var embedder = new LLamaEmbedder(_weights!, new ModelParams(modelPath)
|
|
{
|
|
Embeddings = true,
|
|
ContextSize = _contextSize,
|
|
GpuLayerCount = _gpuLayers
|
|
});
|
|
|
|
var embeddingArrays = await embedder.GetEmbeddings(content, cancellationToken);
|
|
var result = new List<double>();
|
|
|
|
foreach (var arr in embeddingArrays)
|
|
{
|
|
foreach (var val in arr)
|
|
{
|
|
result.Add(val);
|
|
}
|
|
}
|
|
|
|
return result;
|
|
}
|
|
catch
|
|
{
|
|
return [];
|
|
}
|
|
}
|
|
|
|
// ── Model download (mirrors LocalWhisperS2TService.EnsureModelAsync) ───
|
|
|
|
private static string GetDefaultModelPath()
|
|
{
|
|
var modelDirectory = Path.Combine(
|
|
Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData),
|
|
"ProjectJournal",
|
|
ModelSubDirectory);
|
|
return Path.Combine(modelDirectory, DefaultModelFileName);
|
|
}
|
|
|
|
private async Task<string> EnsureModelAsync(CancellationToken cancellationToken = default)
|
|
{
|
|
if (File.Exists(_configuredModelPath))
|
|
return _configuredModelPath;
|
|
|
|
var defaultPath = GetDefaultModelPath();
|
|
if (File.Exists(defaultPath))
|
|
return defaultPath;
|
|
|
|
var modelDirectory = Path.GetDirectoryName(defaultPath)!;
|
|
Directory.CreateDirectory(modelDirectory);
|
|
|
|
var tempPath = defaultPath + ".download";
|
|
using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
|
|
cts.CancelAfter(TimeSpan.FromMinutes(30));
|
|
|
|
using var httpClient = new HttpClient();
|
|
httpClient.Timeout = TimeSpan.FromMinutes(30);
|
|
|
|
using var response = await httpClient.GetAsync(DefaultModelUrl,
|
|
HttpCompletionOption.ResponseHeadersRead, cts.Token);
|
|
response.EnsureSuccessStatusCode();
|
|
|
|
await using var contentStream = await response.Content.ReadAsStreamAsync(cts.Token);
|
|
await using var fileStream = new FileStream(tempPath, FileMode.Create, FileAccess.Write, FileShare.None);
|
|
await contentStream.CopyToAsync(fileStream, cts.Token);
|
|
await fileStream.FlushAsync(cts.Token);
|
|
fileStream.Close();
|
|
|
|
File.Move(tempPath, defaultPath, overwrite: true);
|
|
return defaultPath;
|
|
}
|
|
|
|
// ── Session / weights lifecycle ────────────────────────────────────────
|
|
|
|
private Task<string> RunSessionAsync(string prompt, string systemPrompt,
|
|
int maxTokens, CancellationToken cancellationToken)
|
|
=> RunSessionAsync(prompt, systemPrompt, maxTokens, temperature: 0.7f, cancellationToken);
|
|
|
|
private async Task<string> RunSessionAsync(string prompt, string systemPrompt,
|
|
int maxTokens, float temperature, CancellationToken cancellationToken)
|
|
{
|
|
var modelPath = await EnsureModelAsync(cancellationToken);
|
|
EnsureWeights(modelPath);
|
|
|
|
using var context = _weights!.CreateContext(new ModelParams(modelPath)
|
|
{
|
|
ContextSize = _contextSize,
|
|
GpuLayerCount = _gpuLayers
|
|
});
|
|
|
|
var executor = new StatelessExecutor(_weights!, context.Params);
|
|
|
|
var fullPrompt = $"<|system|>\n{systemPrompt}<|end|>\n" +
|
|
$"<|user|>\n{prompt}<|end|>\n" +
|
|
$"<|assistant|>\n";
|
|
|
|
var inferenceParams = new InferenceParams
|
|
{
|
|
MaxTokens = maxTokens,
|
|
AntiPrompts = [
|
|
"<|user|>",
|
|
"<|system|>",
|
|
"<|end|>",
|
|
"<|endoftext|>",
|
|
],
|
|
SamplingPipeline = new DefaultSamplingPipeline
|
|
{
|
|
Temperature = temperature
|
|
}
|
|
};
|
|
|
|
var sb = new StringBuilder();
|
|
|
|
await foreach (var token in executor.InferAsync(fullPrompt, inferenceParams, cancellationToken))
|
|
{
|
|
sb.Append(token);
|
|
}
|
|
|
|
return StripSpecialTokens(sb.ToString());
|
|
}
|
|
|
|
private static string StripSpecialTokens(string raw)
|
|
{
|
|
var text = raw;
|
|
foreach (var marker in new[] { "<|assistant|>", "<|user|>", "<|system|>", "<|end|>", "<|endoftext|>" })
|
|
text = text.Replace(marker, "");
|
|
return text.Trim();
|
|
}
|
|
|
|
private static readonly Regex RoleMarkerRegex = MyRegex();
|
|
|
|
private static string CleanChatResponse(string raw)
|
|
{
|
|
var text = StripSpecialTokens(raw);
|
|
|
|
text = RoleMarkerRegex.Replace(text, "");
|
|
text = text.Replace("**", "");
|
|
text = MyRegex2().Replace(text, "\n\n");
|
|
|
|
return text.Trim();
|
|
}
|
|
|
|
private void EnsureWeights(string modelPath)
|
|
{
|
|
if (_weights is not null) return;
|
|
|
|
lock (_sync)
|
|
{
|
|
if (_weights is not null) return;
|
|
|
|
_resolvedModelPath = modelPath;
|
|
_weights = LLamaWeights.LoadFromFile(new ModelParams(modelPath)
|
|
{
|
|
ContextSize = _contextSize,
|
|
GpuLayerCount = _gpuLayers
|
|
});
|
|
}
|
|
}
|
|
|
|
public void Dispose()
|
|
{
|
|
if (_disposed) return;
|
|
_disposed = true;
|
|
_weights?.Dispose();
|
|
}
|
|
|
|
[GeneratedRegex(@"\*{0,2}(System|Assistant|User):\*{0,2}", RegexOptions.IgnoreCase | RegexOptions.Compiled, "en-US")]
|
|
private static partial Regex MyRegex();
|
|
[GeneratedRegex(@"\n{3,}")]
|
|
private static partial Regex MyRegex1();
|
|
[GeneratedRegex(@"\n{3,}")]
|
|
private static partial Regex MyRegex2();
|
|
}
|