Compare commits
2 Commits
a53f1f1242
...
2041dd9412
| Author | SHA1 | Date | |
|---|---|---|---|
| 2041dd9412 | |||
| f000f13672 |
@@ -28,10 +28,10 @@ COPY --from=builder /app/wheels /wheels
|
|||||||
RUN pip install --no-cache-dir /wheels/*
|
RUN pip install --no-cache-dir /wheels/*
|
||||||
|
|
||||||
# Copy application code
|
# Copy application code
|
||||||
COPY companion/ ./companion/
|
COPY src/companion/ ./companion/
|
||||||
COPY companion/forge/ ./companion/forge/
|
COPY src/companion/forge/ ./companion/forge/
|
||||||
COPY companion/indexer_daemon/ ./companion/indexer_daemon/
|
COPY src/companion/indexer_daemon/ ./companion/indexer_daemon/
|
||||||
COPY companion/rag/ ./companion/rag/
|
COPY src/companion/rag/ ./companion/rag/
|
||||||
|
|
||||||
# Create directories for data
|
# Create directories for data
|
||||||
RUN mkdir -p /data/vectors /data/memory /models
|
RUN mkdir -p /data/vectors /data/memory /models
|
||||||
|
|||||||
@@ -13,9 +13,9 @@ RUN pip install --no-cache-dir \
|
|||||||
pydantic lancedb pyarrow requests watchdog typer rich numpy httpx
|
pydantic lancedb pyarrow requests watchdog typer rich numpy httpx
|
||||||
|
|
||||||
# Copy application code
|
# Copy application code
|
||||||
COPY companion/ ./companion/
|
COPY src/companion/ ./companion/
|
||||||
COPY companion/indexer_daemon/ ./companion/indexer_daemon/
|
COPY src/companion/indexer_daemon/ ./companion/indexer_daemon/
|
||||||
COPY companion/rag/ ./companion/rag/
|
COPY src/companion/rag/ ./companion/rag/
|
||||||
|
|
||||||
# Create directories for data
|
# Create directories for data
|
||||||
RUN mkdir -p /data/vectors
|
RUN mkdir -p /data/vectors
|
||||||
|
|||||||
@@ -55,7 +55,7 @@
|
|||||||
"rag": {
|
"rag": {
|
||||||
"embedding": {
|
"embedding": {
|
||||||
"provider": "ollama",
|
"provider": "ollama",
|
||||||
"model": "mxbai-embed-large",
|
"model": "mxbai-embed-large:335m",
|
||||||
"base_url": "http://localhost:11434",
|
"base_url": "http://localhost:11434",
|
||||||
"dimensions": 1024,
|
"dimensions": 1024,
|
||||||
"batch_size": 32
|
"batch_size": 32
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
version: '3.8'
|
version: "3.8"
|
||||||
|
|
||||||
services:
|
services:
|
||||||
companion-api:
|
companion-api:
|
||||||
@@ -20,7 +20,13 @@ services:
|
|||||||
- companion-network
|
- companion-network
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "python", "-c", "import requests; requests.get('http://localhost:7373/health')"]
|
test:
|
||||||
|
[
|
||||||
|
"CMD",
|
||||||
|
"python",
|
||||||
|
"-c",
|
||||||
|
"import requests; requests.get('http://localhost:7373/health')",
|
||||||
|
]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 3
|
retries: 3
|
||||||
@@ -34,7 +40,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- ./config.json:/app/config.json:ro
|
- ./config.json:/app/config.json:ro
|
||||||
- companion-data:/data
|
- companion-data:/data
|
||||||
- /home/san/KnowledgeVault:/vault:ro # Mount Obsidian vault as read-only
|
- ./sample-data/Default:/app/sample-data/Default:ro # Mount Obsidian vault as read-only
|
||||||
environment:
|
environment:
|
||||||
- COMPANION_CONFIG=/app/config.json
|
- COMPANION_CONFIG=/app/config.json
|
||||||
- COMPANION_DATA_DIR=/data
|
- COMPANION_DATA_DIR=/data
|
||||||
|
|||||||
97
docs/gpu-compatibility.md
Normal file
97
docs/gpu-compatibility.md
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
# GPU Compatibility Guide
|
||||||
|
|
||||||
|
## RTX 50-Series (Blackwell) Compatibility Notice
|
||||||
|
|
||||||
|
### Issue
|
||||||
|
NVIDIA RTX 50-series GPUs (RTX 5070, 5080, 5090) use CUDA capability `sm_120` (Blackwell architecture). PyTorch stable releases (up to 2.5.1) only officially support up to `sm_90` (Hopper/Ada).
|
||||||
|
|
||||||
|
**Warning you'll see:**
|
||||||
|
```
|
||||||
|
NVIDIA GeForce RTX 5070 with CUDA capability sm_120 is not compatible with the current PyTorch installation.
|
||||||
|
The current PyTorch install supports CUDA capabilities sm_50 sm_60 sm_61 sm_70 sm_75 sm_80 sm_86 sm_90.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Current Status
|
||||||
|
- ✅ PyTorch detects the GPU
|
||||||
|
- ✅ CUDA operations generally work
|
||||||
|
- ⚠️ Some operations may fail or fall back to CPU
|
||||||
|
- ⚠️ Performance may not be optimal
|
||||||
|
|
||||||
|
### Workarounds
|
||||||
|
|
||||||
|
#### Option 1: Use PyTorch Nightly (Recommended for RTX 50-series)
|
||||||
|
```bash
|
||||||
|
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Option 2: Use Current Stable with Known Limitations
|
||||||
|
Many workloads work fine despite the warning. Test your specific use case.
|
||||||
|
|
||||||
|
#### Option 3: Wait for PyTorch 2.7
|
||||||
|
Full sm_120 support is expected in the next stable release.
|
||||||
|
|
||||||
|
### Installation Steps for KV-RAG with GPU
|
||||||
|
|
||||||
|
1. **Install CUDA-enabled PyTorch:**
|
||||||
|
```bash
|
||||||
|
pip install torch==2.5.1+cu121 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Install unsloth without dependencies:**
|
||||||
|
```bash
|
||||||
|
pip install unsloth --no-deps
|
||||||
|
pip install unsloth_zoo
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Install remaining training dependencies:**
|
||||||
|
```bash
|
||||||
|
pip install bitsandbytes accelerate peft transformers datasets trl
|
||||||
|
```
|
||||||
|
Note: Skip `xformers` as it may overwrite torch. Unsloth works without it.
|
||||||
|
|
||||||
|
### Verify GPU is Working
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||||
|
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
||||||
|
print(f"CUDA version: {torch.version.cuda}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Ollama GPU Status
|
||||||
|
|
||||||
|
Ollama runs **natively on Windows** and uses GPU automatically when available:
|
||||||
|
- Check with: `nvidia-smi` (look for `ollama.exe` processes)
|
||||||
|
- Embedding model (`mxbai-embed-large:335m`) runs on GPU
|
||||||
|
- Chat models also use GPU when loaded
|
||||||
|
|
||||||
|
### Forge Training GPU Status
|
||||||
|
|
||||||
|
The training script uses `unsloth` + `trl` for QLoRA fine-tuning:
|
||||||
|
- Requires CUDA-enabled PyTorch
|
||||||
|
- Optimized for 12GB VRAM (RTX 5070)
|
||||||
|
- Uses 4-bit quantization + LoRA adapters
|
||||||
|
- See `src/companion/forge/train.py` for implementation
|
||||||
|
|
||||||
|
### Troubleshooting
|
||||||
|
|
||||||
|
**Issue:** `CUDA available: False` after installation
|
||||||
|
**Fix:** PyTorch was overwritten by a package dependency. Reinstall:
|
||||||
|
```bash
|
||||||
|
pip install torch==2.5.1+cu121 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 --force-reinstall
|
||||||
|
```
|
||||||
|
|
||||||
|
**Issue:** `xformers` overwrites torch
|
||||||
|
**Fix:** Skip xformers or install matching wheel:
|
||||||
|
```bash
|
||||||
|
# Skip for now - unsloth works without it
|
||||||
|
# Or install specific version matching your torch
|
||||||
|
pip install xformers==0.0.28.post3 --index-url https://download.pytorch.org/whl/cu121
|
||||||
|
```
|
||||||
|
|
||||||
|
### References
|
||||||
|
|
||||||
|
- [PyTorch CUDA Compatibility](https://pytorch.org/get-started/locally/)
|
||||||
|
- [NVIDIA CUDA Capability Matrix](https://developer.nvidia.com/cuda-gpus)
|
||||||
|
- [Unsloth Documentation](https://github.com/unsloth/unsloth)
|
||||||
|
- [RTX 50-Series Architecture](https://www.nvidia.com/en-us/geforce/graphics-cards/50-series/)
|
||||||
@@ -37,7 +37,7 @@ train = [
|
|||||||
"trl>=0.7.0",
|
"trl>=0.7.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.hatchling]
|
[tool.hatch.build.targets.wheel]
|
||||||
packages = ["src/companion"]
|
packages = ["src/companion"]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from contextlib import asynccontextmanager
|
|||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException, APIRouter
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
@@ -37,6 +37,12 @@ class ChatResponse(BaseModel):
|
|||||||
sources: list[dict] | None = None
|
sources: list[dict] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ReloadModelRequest(BaseModel):
|
||||||
|
"""Model reload request."""
|
||||||
|
|
||||||
|
model_path: str
|
||||||
|
|
||||||
|
|
||||||
# Global instances
|
# Global instances
|
||||||
config: Config
|
config: Config
|
||||||
vector_store: VectorStore
|
vector_store: VectorStore
|
||||||
@@ -70,8 +76,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
orchestrator = ChatOrchestrator(
|
orchestrator = ChatOrchestrator(
|
||||||
config=config,
|
config=config,
|
||||||
search_engine=search_engine,
|
search_engine=search_engine,
|
||||||
memory=memory,
|
session_memory=memory,
|
||||||
http_client=http_client,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
@@ -99,8 +104,11 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create API router with /api prefix
|
||||||
|
api_router = APIRouter(prefix="/api")
|
||||||
|
|
||||||
@app.get("/health")
|
|
||||||
|
@api_router.get("/health")
|
||||||
async def health_check() -> dict:
|
async def health_check() -> dict:
|
||||||
"""Health check endpoint."""
|
"""Health check endpoint."""
|
||||||
return {
|
return {
|
||||||
@@ -110,7 +118,7 @@ async def health_check() -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/chat")
|
@api_router.post("/chat")
|
||||||
async def chat(request: ChatRequest) -> EventSourceResponse:
|
async def chat(request: ChatRequest) -> EventSourceResponse:
|
||||||
"""Chat endpoint with SSE streaming."""
|
"""Chat endpoint with SSE streaming."""
|
||||||
if not request.message.strip():
|
if not request.message.strip():
|
||||||
@@ -168,7 +176,7 @@ async def chat(request: ChatRequest) -> EventSourceResponse:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/sessions/{session_id}/history")
|
@api_router.get("/sessions/{session_id}/history")
|
||||||
async def get_session_history(session_id: str) -> dict:
|
async def get_session_history(session_id: str) -> dict:
|
||||||
"""Get conversation history for a session."""
|
"""Get conversation history for a session."""
|
||||||
history = memory.get_history(session_id)
|
history = memory.get_history(session_id)
|
||||||
@@ -185,13 +193,7 @@ async def get_session_history(session_id: str) -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ReloadModelRequest(BaseModel):
|
@api_router.post("/admin/reload-model")
|
||||||
"""Model reload request."""
|
|
||||||
|
|
||||||
model_path: str
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/admin/reload-model")
|
|
||||||
async def reload_model_endpoint(request: ReloadModelRequest) -> dict:
|
async def reload_model_endpoint(request: ReloadModelRequest) -> dict:
|
||||||
"""Reload the model with a new fine-tuned version (admin only)."""
|
"""Reload the model with a new fine-tuned version (admin only)."""
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -214,6 +216,10 @@ async def reload_model_endpoint(request: ReloadModelRequest) -> dict:
|
|||||||
raise HTTPException(status_code=500, detail=f"Failed to reload model: {e}")
|
raise HTTPException(status_code=500, detail=f"Failed to reload model: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# Include the API router
|
||||||
|
app.include_router(api_router)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
|||||||
@@ -28,16 +28,27 @@ class ChunkingRule:
|
|||||||
|
|
||||||
|
|
||||||
def sliding_window_chunks(text: str, chunk_size: int, chunk_overlap: int) -> List[str]:
|
def sliding_window_chunks(text: str, chunk_size: int, chunk_overlap: int) -> List[str]:
|
||||||
words = text.split()
|
"""Split text into chunks based on character count (not word count)."""
|
||||||
if not words:
|
if not text:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
step = chunk_size - chunk_overlap
|
step = chunk_size - chunk_overlap
|
||||||
for i in range(0, len(words), step):
|
start = 0
|
||||||
chunk_words = words[i : i + chunk_size]
|
text_len = len(text)
|
||||||
chunks.append(" ".join(chunk_words))
|
|
||||||
if i + chunk_size >= len(words):
|
while start < text_len:
|
||||||
|
end = min(start + chunk_size, text_len)
|
||||||
|
# Try to break at word boundary
|
||||||
|
if end < text_len:
|
||||||
|
# Look for whitespace to break at
|
||||||
|
while end > start and not text[end].isspace():
|
||||||
|
end -= 1
|
||||||
|
if end == start: # No good break found, force cut
|
||||||
|
end = min(start + chunk_size, text_len)
|
||||||
|
chunks.append(text[start:end].strip())
|
||||||
|
start += step
|
||||||
|
if end >= text_len:
|
||||||
break
|
break
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|||||||
Reference in New Issue
Block a user